diff --git a/pyproject.toml b/pyproject.toml index 1d8d93f..f1428f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "djctools" -version = "0.1.1" +version = "0.1.2" description = "A package for logging, loss management, and multi-GPU training. Follows the core ideas of DeepJetCore but in torch." authors = [ { name = "Jan Kieseler", email = "jan.kieseler@cern.ch" } diff --git a/src/djctools/dataparallel.py b/src/djctools/dataparallel.py new file mode 100644 index 0000000..eb6a181 --- /dev/null +++ b/src/djctools/dataparallel.py @@ -0,0 +1,277 @@ +""" +Prototype single-host multi-GPU trainer skeleton for local reasoning. + +Goals: +- one master model + optimizer +- one replica per GPU +- uneven batches allowed +- semantics intended to match concatenated global batch, assuming local losses + are represented as sample-summed contributions before final normalization +- per-forward thread-local loss context +- loss modules disappear cleanly when truth is None + +This is intentionally a compact prototype, not production code. +""" + +import torch +from djctools.parallel import make_replicas, check_replicas_equal, train_step_threaded +import torch.nn as nn +from djctools.module_extensions import LossModule +from typing import Sequence + +from djctools.threading_context import forward_context # just for testing here + + +# --------------------------------------------------------------------------- +# Example toy loss module +# --------------------------------------------------------------------------- + +class MSELossModule(LossModule): + def compute_loss(self, pred, truth=None): + ''' + to be put into the documentation: loss *must* be mean over batch size + ''' + if truth is None: + return None + # sample-summed contribution + return ((pred - truth) ** 2).mean() + + +# --------------------------------------------------------------------------- +# Example toy model +# --------------------------------------------------------------------------- + +class ToyModel(nn.Module): + def __init__(self, in_features: int, out_features: int): + super().__init__() + self.linear = nn.Linear(in_features, out_features) + self.aux_loss = MSELossModule("mse") + + def forward(self, batch): + x, truth = batch + pred = self.linear(x) + self.aux_loss(pred, truth=truth) + return pred + + +# --------------------------------------------------------------------------- +# Example usage +# --------------------------------------------------------------------------- + +def example_setup(num_gpus: int = 2): + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required for this prototype") + + devices = [torch.device(f"cuda:{i}") for i in range(num_gpus)] + master_model = ToyModel(2, 2).to(devices[0]) + replicas = make_replicas(master_model, devices) + optimizer = torch.optim.Adam(replicas[0].parameters(), lr=1e-3) + + # check replica sync + print("Checking replica synchronization... before training") + check_replicas_equal(replicas) + print("Replicas are initially synchronized.") + + return replicas, optimizer, devices + + +def example_batches(devices: Sequence[torch.device]): + batches = [] + for i, _dev in enumerate(devices): + bs = 32 + i # intentionally uneven + x = torch.randn(bs, 2) + y = torch.randn(bs, 2) + batches.append((x, y)) + return batches + + + +def run_equivalence_tests(num_steps: int = 4, atol: float = 1e-6, rtol: float = 1e-6) -> None: + if not torch.cuda.is_available(): + raise RuntimeError("CUDA required for equivalence tests") + + if torch.cuda.device_count() < 2: + raise RuntimeError("At least 2 CUDA devices required for test 2") + + def set_seed(seed: int) -> None: + import random + random.seed(seed) + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + + @torch.no_grad() + def assert_models_close(model_a: nn.Module, model_b: nn.Module, msg: str) -> None: + params_a = dict(model_a.named_parameters()) + params_b = dict(model_b.named_parameters()) + bufs_a = dict(model_a.named_buffers()) + bufs_b = dict(model_b.named_buffers()) + + for name, p in params_a.items(): + q = params_b[name] + if not torch.allclose(p.detach().cpu(), q.detach().cpu(), atol=atol, rtol=rtol): + diff = (p.detach().cpu() - q.detach().cpu()).abs().max().item() + raise RuntimeError(f"{msg}: parameter '{name}' differs, max abs diff = {diff}") + + for name, b in bufs_a.items(): + q = bufs_b[name] + if not torch.allclose(b.detach().cpu(), q.detach().cpu(), atol=atol, rtol=rtol): + diff = (b.detach().cpu() - q.detach().cpu()).abs().max().item() + raise RuntimeError(f"{msg}: buffer '{name}' differs, max abs diff = {diff}") + + def clone_batches_to_device(batches, device): + out = [] + for x, y in batches: + x2 = x.clone().to(device) + y2 = None if y is None else y.clone().to(device) + out.append((x2, y2)) + return out + + def make_split_batches(step_idx: int): + # deterministic but not identical across steps + bs0 = 8 + (step_idx % 3) + bs1 = 11 + ((2 * step_idx) % 4) + x0 = torch.randn(bs0, 2) + y0 = torch.randn(bs0, 2) + x1 = torch.randn(bs1, 2) + y1 = torch.randn(bs1, 2) + return [(x0, y0), (x1, y1)] + + def make_concat_batch(split_batches): + xs = [b[0] for b in split_batches] + ys = [b[1] for b in split_batches] + return (torch.cat(xs, dim=0), torch.cat(ys, dim=0)) + + def single_gpu_reference_step(model, optimizer, batch, device): + x, y = batch + x = x.to(device) + y = None if y is None else y.to(device) + + optimizer.zero_grad(set_to_none=True) + with forward_context(sample_count=len(x)) as ctx: + _ = model((x,y)) + loss = ctx.total_loss() + if loss is not None: + loss = loss + loss.backward() + + ## print all parameters (all of them, not just max or min) before step + #for name, p in model.named_parameters(): + # print(f"Before step: {name} param value: {p.detach().cpu()}") + ## print all gradients before step + #for name, p in model.named_parameters(): + # print(f"Before step: {name} grad value: {p.grad.detach().cpu() if p.grad is not None else None}") + optimizer.step() + + # ------------------------------------------------------------------ + # Test 1: plain single-GPU reference vs threaded path with 1 GPU + # ------------------------------------------------------------------ + print("Running test 1: single-GPU reference vs threaded 1-GPU") + + seed = 12345 + set_seed(seed) + ref_model_1 = ToyModel(2, 2).to("cuda:0") + ref_opt_1 = torch.optim.Adam(ref_model_1.parameters(), lr=1e-2) + + set_seed(seed) + threaded_model_1 = ToyModel(2, 2).to("cuda:0") + + + + threaded_replicas_1 = make_replicas(threaded_model_1, [torch.device("cuda:0")]) + threaded_opt_1 = torch.optim.Adam(threaded_model_1.parameters(), lr=1e-2) + + #sanity check if models are same at this stage + check_replicas_equal([ref_model_1]+ threaded_replicas_1) + print("models are the same to begin with") + + #check optimisers + + set_seed(seed + 1) + shared_batches_test1 = [] + for step in range(num_steps): + bs = 10 # + (step % 4) + x = torch.randn(bs, 2) + y = torch.randn(bs, 2) + shared_batches_test1.append((x, y)) + + for batch in shared_batches_test1: + single_gpu_reference_step(ref_model_1, ref_opt_1, batch, torch.device("cuda:0")) + + for batch in shared_batches_test1: + local_batches = clone_batches_to_device([batch], torch.device("cuda:0")) + train_step_threaded( + replicas=threaded_replicas_1, + optimizer=threaded_opt_1, + batches=local_batches, + devices=[torch.device("cuda:0")], + check_sync=True, + ) + + assert_models_close(ref_model_1, threaded_replicas_1[0], "Test 1 failed, models weights: "+str(list(ref_model_1.parameters())[0].detach().cpu()) + " vs " + str(list(threaded_replicas_1[0].parameters())[0].detach().cpu())) + print("Test 1 passed") + + # ------------------------------------------------------------------ + # Test 2: single-GPU concatenated reference vs multi-GPU split batches + # ------------------------------------------------------------------ + print("Running test 2: single-GPU concatenated reference vs 2-GPU split") + + seed = 67890 + set_seed(seed) + ref_model_2 = ToyModel(2, 2).to("cuda:0") + ref_opt_2 = torch.optim.Adam(ref_model_2.parameters(), lr=1e-2) + + set_seed(seed) + threaded_model_2 = ToyModel(2, 2).to("cuda:0") + devices_2 = [torch.device("cuda:0"), torch.device("cuda:1")] + threaded_replicas_2 = make_replicas(threaded_model_2, devices_2) + threaded_opt_2 = torch.optim.Adam(threaded_replicas_2[0].parameters(), lr=1e-2) + + set_seed(seed + 1) + split_batches_per_step = [make_split_batches(step) for step in range(num_steps)] + concat_batches_per_step = [make_concat_batch(split_batches) for split_batches in split_batches_per_step] + + print(concat_batches_per_step) + + for batch in concat_batches_per_step: + single_gpu_reference_step(ref_model_2, ref_opt_2, batch, torch.device("cuda:0")) + + for split_batches in split_batches_per_step: + local_batches = [ + (split_batches[0][0].clone().to("cuda:0"), split_batches[0][1].clone().to("cuda:0")), + (split_batches[1][0].clone().to("cuda:1"), split_batches[1][1].clone().to("cuda:1")), + ] + train_step_threaded( + replicas=threaded_replicas_2, + optimizer=threaded_opt_2, + batches=local_batches, + devices=devices_2, + check_sync=True, + ) + + assert_models_close(ref_model_2, threaded_replicas_2[0], "Test 2 failed") + print("Test 2 passed") + + print("All equivalence tests passed") + + + + +if __name__ == "__main__": + run_equivalence_tests() + exit() + replicas, optimizer, devices = example_setup(num_gpus=2) + + batches = example_batches(devices) + + print("Stepping") + + infos = train_step_threaded( + replicas=replicas, + optimizer=optimizer, + batches=batches, + devices=devices, + check_sync=True, + ) + + for i, info in enumerate(infos): + print(f"Replica {i}: batch_size={info.batch_size}, loss={info.loss_value}") diff --git a/src/djctools/module_extensions.py b/src/djctools/module_extensions.py index 25286a3..6652d87 100644 --- a/src/djctools/module_extensions.py +++ b/src/djctools/module_extensions.py @@ -1,12 +1,14 @@ # import as djctools.module_extensions from .wandb_tools import wandb_wrapper +from .threading_context import get_current_context import torch import threading import logging + # Configure logger logger = logging.getLogger(__name__) logging.basicConfig(level=logging.WARNING) # Set the default level to WARNING @@ -152,7 +154,6 @@ def __init__(self, name=None, logging_active=False, loss_active=True): instances within a given module. """ super(LossModule, self).__init__(name=name, logging_active=logging_active) - self._losses = [] # Instance-level list to store losses for this LossModule self.switch_loss_calculation(loss_active) @property @@ -163,7 +164,22 @@ def loss_active(self): def _compute_loss_and_record(self, *args, **kwargs): """Compute the loss and append to the instance's loss list.""" loss = self.compute_loss(*args, **kwargs) - self._losses.append(loss) + + if loss is None: + return None + + ctx = get_current_context() + if ctx is None: + raise RuntimeError( + f"LossModule '{self.__class__.__name__}' returned a loss but no ForwardContext is active" + ) + + ctx.add_loss(loss) + return loss + + def _no_op(self, *args, **kwargs): + """A no-op function used when loss calculation is disabled.""" + return None def compute_loss(self, *args, **kwargs): """ @@ -171,6 +187,7 @@ def compute_loss(self, *args, **kwargs): This function will be called by `forward` when the loss calculation is enabled. Must return a single scalar tensor representing the loss. + This tensor must be a mean over the batch size, not a sum, to ensure correct scaling when aggregating losses across replicas! Raises: NotImplementedError: If the subclass does not override this method. @@ -209,8 +226,9 @@ def switch_loss_calculation(self, loss_active): child.switch_loss_calculation(loss_active) def clear_losses(self): - """Clears the accumulated losses in this module's instance-level loss list.""" - self._losses.clear() + raise DeprecationWarning("This method should not be used anymore and is not needed.") + + @@ -355,42 +373,6 @@ def switch_all_losses(module : torch.nn.Module, loss_active : bool): if isinstance(child, LossModule): child.switch_loss_calculation(loss_active) -def sum_all_losses(module : torch.nn.Module): - """ - Recursively collects and sums all losses from LossModule instances within a given module. - - Args: - module (torch.nn.Module): The module to search through. - - Returns: - torch.Tensor: A single scalar tensor representing the sum of all accumulated losses. - - Note: - This method operates recursively across all levels of nested LossModule instances. - """ - if hasattr(module, 'parameters') and next(module.parameters(), None) is not None: - device = next(module.parameters()).device - else: - device = torch.device('cpu') - total_loss = torch.tensor(0.0, requires_grad=True).to(device) - - for child in module.modules(): - if isinstance(child, LossModule): - if child._losses: - total_loss = total_loss + sum([l.to(device) for l in child._losses]) - return total_loss - -def clear_all_losses(module : torch.nn.Module): - """ - Recursively clears all accumulated losses from LossModule instances within a given module. - - Args: - module (torch.nn.Module): The module to search through. - """ - for child in module.modules(): - if isinstance(child, LossModule): - child.clear_losses() - def switch_all_plotting(module: torch.nn.Module, plotting_active: bool): """ Searches through a given torch.nn.Module and applies switch_plotting to any diff --git a/src/djctools/parallel.py b/src/djctools/parallel.py new file mode 100644 index 0000000..3ad42ca --- /dev/null +++ b/src/djctools/parallel.py @@ -0,0 +1,310 @@ +import copy +from concurrent.futures import ThreadPoolExecutor +import torch +from dataclasses import dataclass +from typing import Any, List, Optional, Sequence, Tuple +import torch.nn as nn +from .threading_context import ForwardContext, forward_context + +# --------------------------------------------------------------------------- +# Replica management +# --------------------------------------------------------------------------- + +def make_replicas(master_model: nn.Module, devices: Sequence[torch.device]) -> List[nn.Module]: + """ + Create one replica per device. + Assumes master_model is already in the desired initial state. + """ + master_model.to('cpu') # ensure master starts on CPU for clean copying + replicas: List[nn.Module] = [] + + #check if devices exist and raise if not already here with proper error + for d in devices: + if d.type == 'cuda' and d.index >= torch.cuda.device_count(): + raise RuntimeError(f"Device {d} is not available. Only {torch.cuda.device_count()} CUDA devices detected.") + + for i, dev in enumerate(devices): + if i: + replica = copy.deepcopy(master_model) + else: + replica = master_model + replica.train(master_model.training) + if i > 0: + pass + #switch_all_logging(replica, False) + replicas.append(replica) + + #move all to device + for replica, dev in zip(replicas, devices): + replica.to(dev) + + sync_from_master(replicas) + + return replicas + + +# FIXME: set to non blocking at some point, and sync at the end of the function +@torch.no_grad() +def sync_from_master_no_working(replicas: Sequence[nn.Module], blocking: bool = True) -> None: + """ + Copy parameters and buffers from master to all other replicas. + replicas[0] may be master itself; it is skipped. + """ + master_params = dict(replicas[0].named_parameters()) + master_buffers = dict(replicas[0].named_buffers()) + + for replica in replicas[1:]: + replica_params = dict(replica.named_parameters()) + replica_buffers = dict(replica.named_buffers()) + + for name, p in master_params.items(): + replica_params[name].copy_(p, non_blocking = not blocking) + #replica_params[name].data.copy_(p.data, non_blocking=not blocking) + # check if they are actually the same + if not torch.equal(replica_params[name].cpu(), p.cpu()): #needs to be on same device + raise RuntimeError(f"sync_from_master: Parameter '{name}' differs between master and replica after sync_from_master: values: {p} vs {replica_params[name]}, shapes: {p.shape} vs {replica_params[name].shape}") + + for name, b in master_buffers.items(): + replica_buffers[name].copy_(b, non_blocking = not blocking) + # check if they are actually the same + if not torch.equal(replica_buffers[name].cpu(), b.cpu()): #needs to be on same device + raise RuntimeError(f"sync_from_master: Buffer '{name}' differs between master and replica after sync_from_master: values: {b} vs {replica_buffers[name]}, shapes: {b.shape} vs {replica_buffers[name].shape}") + +@torch.no_grad() +def sync_from_master(replicas: Sequence[nn.Module]) -> None: + """ + Copy parameters and buffers from master to all other replicas via CPU staging. + replicas[0] is the master. + """ + if len(replicas) <= 1: + return # nothing to do + master_params = dict(replicas[0].named_parameters()) + master_buffers = dict(replicas[0].named_buffers()) + + for replica in replicas[1:]: + replica_params = dict(replica.named_parameters()) + replica_buffers = dict(replica.named_buffers()) + + for name, p in master_params.items(): + tmp = p.detach().cpu() + replica_params[name].data.copy_(tmp.to(replica_params[name].device)) + if not torch.equal(replica_params[name].detach().cpu(), tmp): + raise RuntimeError( + f"Parameter '{name}' differs between master and replica after sync_from_master: " + f"values: {p} vs {replica_params[name]}, shapes: {p.shape} vs {replica_params[name].shape}" + ) + + for name, b in master_buffers.items(): + tmp = b.detach().cpu() + replica_buffers[name].copy_(tmp.to(replica_buffers[name].device)) + if not torch.equal(replica_buffers[name].detach().cpu(), tmp): + raise RuntimeError( + f"Buffer '{name}' differs between master and replica after sync_from_master: " + f"values: {b} vs {replica_buffers[name]}, shapes: {b.shape} vs {replica_buffers[name].shape}" + ) + +@torch.no_grad() +def check_replicas_equal(replicas: Sequence[nn.Module]) -> None: + master_params = dict(replicas[0].named_parameters()) + master_buffers = dict(replicas[0].named_buffers()) + + for i, replica in enumerate(replicas[1:], start=1): + replica_params = dict(replica.named_parameters()) + replica_buffers = dict(replica.named_buffers()) + + for name, p in master_params.items(): + other = replica_params[name] + if not torch.equal(p.detach().cpu(), other.detach().cpu()): + raise RuntimeError( + f"Replica {i} differs from master parameter '{name}': " + f"values: {p} vs {other}, shapes: {p.shape} vs {other.shape}" + ) + + for name, b in master_buffers.items(): + other = replica_buffers[name] + if not torch.equal(b.detach().cpu(), other.detach().cpu()): + raise RuntimeError( + f"Replica {i} differs from master buffer '{name}': " + f"values: {b} vs {other}, shapes: {b.shape} vs {other.shape}" + ) + +# --------------------------------------------------------------------------- +# Local worker step +# --------------------------------------------------------------------------- + +@dataclass +class LocalStepInfo: + batch_size: int + loss_value: Optional[float] + ctx: ForwardContext + + +def _infer_batch_size(x: Any) -> int: + if hasattr(x, "__len__"): + return len(x) + raise TypeError("Cannot infer batch size from input batch") + + +def local_worker(model: nn.Module, batch: Tuple[Any, Any], device: torch.device) -> LocalStepInfo: + x, y = batch + + bs = _infer_batch_size(x) + + model.zero_grad(set_to_none=True) + + with forward_context(sample_count=bs) as ctx: + _ = model(batch) + loss = ctx.total_loss() + + if loss is not None: + loss.backward() + + #print gradients for all parameters in this local worker + #for name, p in model.named_parameters(): + # print(f"Local worker on device {device}: parameter '{name}' grad value: {p.grad.detach().cpu() if p.grad is not None else None}") + + return LocalStepInfo( + batch_size=bs, + loss_value=None if loss is None else float(loss.detach().cpu()), + ctx=ctx, + ) + + +def threaded_local_steps( + replicas: Sequence[nn.Module], + batches: Sequence[Tuple[Any, Any]], + devices: Sequence[torch.device], + max_workers: Optional[int] = None, +) -> List[LocalStepInfo]: + assert len(replicas) == len(batches) == len(devices) + + with ThreadPoolExecutor(max_workers=max_workers or len(devices)) as pool: + futures = [ + pool.submit(local_worker, model, batch, dev) + for model, batch, dev in zip(replicas, batches, devices) + ] + infos = [f.result() for f in futures] + + return infos + + +# --------------------------------------------------------------------------- +# Gradient aggregation and optimizer step +# --------------------------------------------------------------------------- + + +def create_splitbatch_scalers(batch_sizes: Sequence[int]) -> List[float]: + total_bs = sum(batch_sizes) + if total_bs <= 0: + raise RuntimeError("Total batch size must be > 0") + scalers = [bs / total_bs for bs in batch_sizes] + return scalers + +@torch.no_grad() +def aggregate_grads_to_master( + replicas: Sequence[nn.Module], + batch_sizes: Sequence[int], + scalers: Optional[Sequence[float]] = None, +) -> None: + """ + Aggregate gradients from all replicas onto replicas[0]. + + Assumes: + - local losses correspond to sample-summed contributions + - final desired semantics are global mean over all samples + """ + if len(replicas) != len(batch_sizes): + raise ValueError("Length of replicas and batch_sizes must match") + if len(replicas) == 1: + return [1.0] # nothing to do + + master = replicas[0] + total_bs = sum(batch_sizes) + if total_bs <= 0: + raise RuntimeError("Total batch size must be > 0") + + if scalers is None: + scalers = create_splitbatch_scalers(batch_sizes) + + master_params = list(master.parameters()) + replica_params = [list(m.parameters()) for m in replicas] + + for p_idx, master_p in enumerate(master_params): + acc = None + + for r_idx in range(len(replicas)): + g = replica_params[r_idx][p_idx].grad + if g is None: + continue + + g_on_master = g.detach().cpu().to(master_p.device) #room for optimisation here if peer access is available + g_on_master.mul_(scalers[r_idx]) + + if acc is None: + acc = g_on_master.clone() + else: + acc.add_(g_on_master) + + if acc is None: + continue + + master_p.grad = acc + + + +def master_step( optimizer: torch.optim.Optimizer) -> None: + optimizer.step() + optimizer.zero_grad(set_to_none=True) + + +# --------------------------------------------------------------------------- +# One full multi-GPU train step; also works for single GPU or CPU +# --------------------------------------------------------------------------- + +def train_step_threaded( + replicas: Sequence[nn.Module], + optimizer: torch.optim.Optimizer, + batches: Sequence[Tuple[Any, Any]], + devices: Sequence[torch.device], + check_sync: bool = False, +) -> List[LocalStepInfo]: + infos = threaded_local_steps(replicas, batches, devices) + batch_sizes = [info.batch_size for info in infos] + + scalers = create_splitbatch_scalers(batch_sizes) + #torch.cuda.synchronize() + aggregate_grads_to_master(replicas, batch_sizes, scalers) + #print master parameters + #for name, p in replicas[0].named_parameters(): + #print(f"Before master step: {name} param value: {p.detach().cpu()}") + #print(f"Before master step: {name} grad value: {p.grad.detach().cpu() if p.grad is not None else None}") + master_step(optimizer) + #torch.cuda.synchronize() + sync_from_master(replicas) + torch.cuda.synchronize() #check if needed + + if check_sync: + check_replicas_equal(replicas) + + #scale the info loss values accordingly + for info, scaler in zip(infos, scalers): + if info.loss_value is not None: + info.loss_value *= scaler + + return infos + + +def val_step_threaded( + replicas: Sequence[nn.Module], + batches: Sequence[Tuple[Any, Any]], + devices: Sequence[torch.device], +) -> List[LocalStepInfo]: + infos = threaded_local_steps(replicas, batches, devices) + batch_sizes = [info.batch_size for info in infos] + scalers = scalers = create_splitbatch_scalers(batch_sizes) + for info, scaler in zip(infos, scalers): + if info.loss_value is not None: + info.loss_value *= scaler + + return infos + diff --git a/src/djctools/threading_context.py b/src/djctools/threading_context.py new file mode 100644 index 0000000..b96bf0c --- /dev/null +++ b/src/djctools/threading_context.py @@ -0,0 +1,53 @@ + +import torch +import threading +from contextlib import contextmanager +# --------------------------------------------------------------------------- +# Forward context +# --------------------------------------------------------------------------- + +class ForwardContext: + def __init__(self, sample_count=0): + self.sample_count = sample_count + self.losses = [] + self._closed = False + + def add_loss(self, value): + if self._closed: + raise RuntimeError("Cannot add loss to closed ForwardContext") + if not isinstance(value, torch.Tensor): + raise TypeError("Loss must be a torch.Tensor") + if value.dim() != 0: + raise ValueError("Loss must be a scalar tensor") + self.losses.append(value) + + def total_loss(self): + if not self.losses: + return None + out = self.losses[0] + for l in self.losses[1:]: + out = out + l + return out + + def close(self): + self._closed = True + + +_tls = threading.local() + +def get_current_context(): + return getattr(_tls, "ctx", None) + + +@contextmanager +def forward_context(sample_count=0): + if getattr(_tls, "ctx", None) is not None: + raise RuntimeError("Nested ForwardContext is not allowed") + + ctx = ForwardContext(sample_count=sample_count) + _tls.ctx = ctx + try: + yield ctx + finally: + _tls.ctx = None + ctx.close() diff --git a/src/djctools/training.py b/src/djctools/training.py index 790ce1f..07b0c92 100644 --- a/src/djctools/training.py +++ b/src/djctools/training.py @@ -1,69 +1,28 @@ # import as from djctools.training import torch -from torch.nn.parallel import DistributedDataParallel as DDP -from .module_extensions import sum_all_losses, clear_all_losses, flush_all_plotting +from .module_extensions import flush_all_plotting from .wandb_tools import wandb_wrapper import numpy as np import os from torch.nn import DataParallel from typing import Any, Dict, Optional, Sequence, Tuple, Union - -class _CustomDataParallel(DataParallel): - def scatter( - self, - inputs: list, - kwargs: Optional[Dict[str, Any]], - device_ids: Sequence[Union[int, torch.device]], - ) -> Any: - """ - Custom scatter method that handles the input structures provided by the Trainer class, - such as lists of dictionaries or lists of lists of tensors. - - Args: - inputs (Tuple[Any, ...]): The input to be scattered - here this is a tuple with one entry. - The latter entry is the list mentioned above. - kwargs (Optional[Dict[str, Any]]): Keyword arguments. - device_ids (Sequence[Union[int, torch.device]]): Target devices. - - Returns: - Tuple of scattered inputs and kwargs for each device. - """ - # Implement custom logic to split `inputs` and `kwargs` based on your structure. - # Example: if inputs is a list of dicts, scatter each dict entry to the devices. - - # Example pseudo-code: - scattered_inputs = [] - scattered_kwargs = [] - #print('inputs len',len(inputs)) - #print('inputs types',[type(i) for i in inputs]) - ##nested - #print('inputs[0]',[type(i) for i in inputs[0]]) - - for i, device_id in enumerate(device_ids): - # Create device-specific slices of the input. - device_input = (inputs[0][i],) # ensure each replica receives a tuple of inputs - device_kwargs = kwargs if kwargs is not None else {} - - scattered_inputs.append(device_input) - scattered_kwargs.append(device_kwargs) - - return tuple(scattered_inputs), tuple(scattered_kwargs) +from .parallel import make_replicas, check_replicas_equal, train_step_threaded, val_step_threaded class Trainer: """ - Trainer class for multi-GPU training using PyTorch Distributed Data Parallel (DDP). + Trainer class for multi-GPU training using custom parallel utilities. - This Trainer class handles the initialization of DDP, manual batch distribution, + This Trainer class handles the initialization, manual batch distribution, and model synchronization across multiple GPUs, allowing flexibility for complex data structures and control over data loading. Compatible with both single and multi-GPU configurations, and can fall back to CPU if no GPU is available or `num_gpus=0` is specified. Attributes: - model (torch.nn.Module): The main model for training, wrapped in DDP if using multiple GPUs. + model (torch.nn.Module): The main model for training. optimizer (torch.optim.Optimizer): The optimizer for updating model parameters. num_gpus (int): Number of GPUs to use for training. Set to 0 for CPU training. device_ids (list of int): List of GPU device IDs to use for training. Defaults to `[0, 1, ..., num_gpus-1]`. @@ -85,14 +44,10 @@ class Trainer: Executes the validation loop, computing and logging validation losses. Runs without gradient updates. save_model(filepath): - Saves the model weights to a file. For DDP-wrapped models, uses `model.module.state_dict()`. + Saves the model to a file. load_model(filepath): - Loads model weights from a file. For DDP-wrapped models, loads weights into `model.module`. - - cleanup(): - Cleans up the DDP process group after training. Recommended when using multiple training sessions - in a single script to release GPU resources properly. + Loads model from a file. train_batch_callback(model, batch_number, batch_data): Callback function that is called after each batch is processed during training. @@ -120,19 +75,14 @@ class Trainer: >>> trainer.val_loop(val_loader) >>> trainer.save_model("model_weights.pth") - >>> trainer.cleanup() # Call when using multi-GPU to release resources Notes: ------ - The `Trainer` class assumes single-process execution. Each batch is moved manually to the correct device, allowing full control over batch distribution. - - For DDP, the model is wrapped with `DistributedDataParallel`, which handles gradient synchronization and - weight updates across GPUs. Manual gradient averaging is not required. - This class is optimized for cases where each batch may consist of complex nested structures (e.g., lists of dictionaries or tuples). It can be used with both standard PyTorch data loaders and custom data iterators. - - `DistributedDataParallel` uses `nccl` backend by default for multi-GPU setups. If running on a single GPU - or CPU, DDP is bypassed, and the model is trained in a standard non-parallel setup. """ def __init__(self, model, optimizer, num_gpus=1, device_ids=None, verbose_level=0): @@ -151,12 +101,54 @@ def __init__(self, model, optimizer, num_gpus=1, device_ids=None, verbose_level= self.num_gpus = 0 print("Warning: CUDA not available or num_gpus=0. Using CPU.") - # Move model to device and wrap with DDP - model.to(self.device) - self.model = _CustomDataParallel(model, device_ids=self.device_ids) if len(self.device_ids) > 1 else model + #make devices a Sequence of torch.device objects + self.devices = [torch.device(d) for d in self.devices] + + # if 'model' is a valid file path and not a torch.nn.Module, load the model from the file + if isinstance(model, str) and os.path.isfile(model): + print(f"Loading model from file: {model}") + model = torch.load(model) + elif isinstance(model, torch.nn.Module): + pass + else: + raise ValueError("Model must be either a torch.nn.Module or a valid file path to a saved model.") + + print(f"Creating replicas using devices: {self.devices}") + + self.model_replicas = make_replicas(model, self.devices) self.optimizer = optimizer self.verbose_level = verbose_level + def save_model(self, filepath): + """ + Saves the model (not just weights) to a file. + + Args: + filepath (str): The path to the file where the model weights will be saved. + """ + torch.save(self.model_replicas[0], filepath) #the first replica is always the master + + def load_model(self, filepath): + """ + Loads model from a file. slim wrapper + + Args: + filepath (str): The path to the file from which the model weights will be loaded. + """ + loaded_model = torch.load(filepath, map_location=self.device) + self.model_replicas = make_replicas(loaded_model, self.devices) + + @property + def model(self): + """ + Returns the master model (the first replica). + + Returns: + torch.nn.Module: The master model. + """ + return self.model_replicas[0] + + def _data_to_device(self, data, device): """ Moves data to the specified device. @@ -206,32 +198,27 @@ def train_loop(self, train_loader): Args: train_loader (DataLoader): The data loader for training data. """ - self.model.train() + for r in self.model_replicas: + r.train() # Set all replicas to training mode + + self.optimizer.zero_grad() # Zero gradients before starting the epoch data_iterator = iter(train_loader) batch_idx = 0 while True: - self.optimizer.zero_grad() batches = self.create_batches(data_iterator) if not batches: break # End of epoch - if len(batches) == 1: - batches = batches[0] - - outputs = self.model(batches) # DataParallel handles passing data to each GPU - loss = sum_all_losses(self.model) - - loss.backward() - self.optimizer.step() - clear_all_losses(self.model) - flush_all_plotting(self.model) - self.train_batch_callback(self.model, batch_idx, batches) + info = train_step_threaded(self.model_replicas, self.optimizer, batches, self.devices, check_sync=False) + flush_all_plotting(self.model_replicas[0]) + self.train_batch_callback(self.model_replicas[0], batch_idx, batches) + loss = sum([i.loss_value for i in info if i.loss_value is not None])#ok, they have been scaled for this to work before # Logging and printing - wandb_wrapper.log("total_loss", loss.item()) + wandb_wrapper.log("total_loss", loss) if self.verbose_level > 0 and batch_idx % 10 == 0: - print(f'Batch {batch_idx}: Loss {loss.item()}') + print(f'Batch {batch_idx}: Loss {loss}') batch_idx += 1 wandb_wrapper.flush() @@ -242,7 +229,8 @@ def val_loop(self, val_loader): Args: val_loader (DataLoader): The data loader for validation data. """ - self.model.eval() + for r in self.model_replicas: + r.eval() # Set all replicas to evaluation mode data_iterator = iter(val_loader) batch_idx = 0 @@ -251,37 +239,36 @@ def val_loop(self, val_loader): batches = self.create_batches(data_iterator) if not batches: break # End of epoch - if len(batches) == 1: #no multi gpu - batches = batches[0] - - outputs = self.model(batches) # DataParallel handles passing data to each GPU - loss = sum_all_losses(self.model) - clear_all_losses(self.model) - flush_all_plotting(self.model) - self.val_batch_callback(self.model, batch_idx, batches) + + info = val_step_threaded(self.model_replicas, batches, self.devices) + flush_all_plotting(self.model_replicas[0]) + self.val_batch_callback(self.model_replicas[0], batch_idx, batches) + loss = sum([i.loss_value for i in info if i.loss_value is not None]) #ok, they have been scaled for this to work before # Logging and printing - wandb_wrapper.log("total_loss", loss.item()) + wandb_wrapper.log("total_loss", loss) if self.verbose_level > 0 and batch_idx % 10 == 0: - print(f'Validation Batch {batch_idx}: Loss {loss.item()}') + print(f'Validation Batch {batch_idx}: Loss {loss}') batch_idx += 1 wandb_wrapper.flush(prefix="val_") - def save_model(self, filepath): - """Saves the model weights to a file.""" - torch.save(self.model.module.state_dict() if hasattr(self.model, "module") else self.model.state_dict(), filepath) - - def load_model(self, filepath): - """Loads model weights from a file.""" - state_dict = torch.load(filepath) - self.model.module.load_state_dict(state_dict) if hasattr(self.model, "module") else self.model.load_state_dict(state_dict) - - def cleanup(self): - pass - def train_batch_callback(self, model, batch_number, batch_data): + """ + Callback function that is called after each batch is processed during training. + The function should take the model, the batch number, and the batch data as arguments. + This function can be used to perform custom operations on the model or the data after each batch + and should be implemented by the user through inheritance. Please do not use for logging purposes, + use the wandb_wrapper.log() function instead. + """ pass def val_batch_callback(self, model, batch_number, batch_data): - pass + """ + Callback function that is called after each batch is processed during validation. + The function should take the model, the batch number, and the batch data as arguments. + This function can be used to perform custom operations on the model or the data after each batch + and should be implemented by the user through inheritance. Please do not use for logging purposes, + use the wandb_wrapper.log() function instead. + """ + pass \ No newline at end of file diff --git a/tests/test_loss_module.py b/tests/test_loss_module.py index 9c95459..4bab6b6 100644 --- a/tests/test_loss_module.py +++ b/tests/test_loss_module.py @@ -1,6 +1,6 @@ import unittest import torch -from djctools.module_extensions import LossModule, sum_all_losses, clear_all_losses, switch_all_losses, switch_all_logging +from djctools.module_extensions import LossModule, switch_all_losses, switch_all_logging # Define a simple custom loss class for testing purposes class TestLossModule(LossModule): @@ -72,42 +72,7 @@ def test_switch_loss_calculation(self): except: self.fail("Model should work without targets") - def test_sum_all_losses(self): - """Test that all accumulated losses are correctly summed.""" - predictions = torch.randn(10, 5) - targets = torch.randn(10, 5) - - # Compute losses for both modules - self.model.loss1(predictions, targets) - self.model.loss2(predictions, targets) - - # Sum all losses - total_loss = sum_all_losses(self.model) - clear_all_losses(self.model) - - # Check that total_loss is a single scalar tensor - self.assertTrue(isinstance(total_loss, torch.Tensor)) - self.assertEqual(total_loss.shape, torch.Size([])) - self.assertGreater(total_loss.item(), 0) - - def test_clear_all_losses(self): - """Test that all accumulated losses are cleared correctly.""" - predictions = torch.randn(10, 5) - targets = torch.randn(10, 5) - - # Compute losses for both modules - self.model(predictions, targets) - - # Ensure there are losses - self.assertEqual(len(self.model.loss1._losses), 1) - self.assertEqual(len(self.model.loss2._losses), 1) - - # Clear all losses - clear_all_losses(self.model) - - # Verify losses are cleared - self.assertEqual(len(self.model.loss1._losses), 0) - self.assertEqual(len(self.model.loss2._losses), 0) + def test_loss_active_property(self): """Test that the loss_active property correctly reflects the module's active state.""" diff --git a/tests/test_trainer.py b/tests/test_trainer.py index e17f578..923c069 100644 --- a/tests/test_trainer.py +++ b/tests/test_trainer.py @@ -103,7 +103,7 @@ def test_single_gpu_training(self): self.trainer.train_loop(self.train_loader) def test_multi_gpu_training(self): - self._setUp(num_gpus=3) + self._setUp(num_gpus=2) """Test training loop on a single GPU or CPU.""" self.trainer.train_loop(self.train_loader) self.trainer.train_loop(self.train_loader) @@ -115,7 +115,7 @@ def test_validation_loop(self): self.trainer.val_loop(self.val_loader) def test_multi_gpu_validation(self): - self._setUp(num_gpus=3) + self._setUp(num_gpus=2) """Test validation loop execution and logging of validation losses.""" self.trainer.val_loop(self.val_loader) self.trainer.val_loop(self.val_loader) @@ -129,8 +129,9 @@ def test_save_and_load_model(self): # Create a new instance and load weights model2 = SimpleModel() optimizer2 = optim.SGD(model2.parameters(), lr=0.01) - trainer2 = Trainer(model=model2, optimizer=optimizer2, num_gpus=0, verbose_level=1) - trainer2.load_model(filepath) + trainer2 = Trainer(model=filepath, optimizer=optimizer2, num_gpus=0, verbose_level=1) + + model2 = trainer2.model # Get the loaded model from the trainer # Check if weights are loaded correctly for param1, param2 in zip(self.model.parameters(), model2.parameters()): @@ -138,6 +139,9 @@ def test_save_and_load_model(self): def do_test_with_djcdata(self, num_gpus): + #check if num gpus exists otherwise skip + if num_gpus > torch.cuda.device_count(): + self.skipTest(f"Not enough GPUs available for this test: required {num_gpus}, available {torch.cuda.device_count()}") self._setUp(num_gpus=num_gpus) #overwrite the data loaders from djcdata import TrainDataGenerator