diff --git a/.github/workflows/linters.yaml b/.github/workflows/linters.yaml new file mode 100644 index 00000000..98b727ed --- /dev/null +++ b/.github/workflows/linters.yaml @@ -0,0 +1,24 @@ +name: Lint Python + +on: [push, pull_request] + +jobs: + build: + runs-on: ubuntu-latest + strategy: + matrix: + python-version: [3.8, 3.9] + + steps: + - uses: actions/checkout@v2 + - name: Set up Python ${{ matrix.python-version }} + uses: actions/setup-python@v2 + with: + python-version: ${{ matrix.python-version }} + - name: Install dependencies + run: | + python -m pip install --upgrade pip + pip install pydocstyle + - name: Docstyle linting + run: | + pydocstyle --convention=google --add-ignore=D200,D210,D212,D415 \ No newline at end of file diff --git a/satflow/__init__.py b/satflow/__init__.py index 58f3ace6..b653c901 100644 --- a/satflow/__init__.py +++ b/satflow/__init__.py @@ -1 +1,2 @@ +"""satflow package""" from .version import __version__ diff --git a/satflow/baseline/__init__.py b/satflow/baseline/__init__.py index e69de29b..d3739e07 100644 --- a/satflow/baseline/__init__.py +++ b/satflow/baseline/__init__.py @@ -0,0 +1 @@ +"""Evaluation of baseline models""" \ No newline at end of file diff --git a/satflow/baseline/optical_flow.py b/satflow/baseline/optical_flow.py index bd78dd44..ab5e3874 100644 --- a/satflow/baseline/optical_flow.py +++ b/satflow/baseline/optical_flow.py @@ -1,3 +1,4 @@ +"""Evaluation of baseline models""" import cv2 from satflow.data.datasets import OpticalFlowDataset, SatFlowDataset import webdataset as wds @@ -7,6 +8,7 @@ def load_config(config_file): + """Load a config file from disk""" with open(config_file, "r") as cfg: return yaml.load(cfg, Loader=yaml.FullLoader)["config"] @@ -21,6 +23,15 @@ def load_config(config_file): def warp_flow(img, flow): + """ + Get the previous image by inverting the optical flow and applying it to the current image + + Args: + img: the current image + flow: the optical flow + + Returns: the resulting image + """ h, w = flow.shape[:2] flow = -flow flow[:, :, 0] += np.arange(w) diff --git a/satflow/core/__init__.py b/satflow/core/__init__.py index e69de29b..cb3122cd 100644 --- a/satflow/core/__init__.py +++ b/satflow/core/__init__.py @@ -0,0 +1 @@ +"""Core utility functions""" \ No newline at end of file diff --git a/satflow/core/utils.py b/satflow/core/utils.py index aeed815b..b726298b 100644 --- a/satflow/core/utils.py +++ b/satflow/core/utils.py @@ -1,3 +1,4 @@ +"""Core utility functions""" import logging import typing as Dict from nowcasting_dataset.config.load import load_yaml_configuration @@ -6,12 +7,23 @@ def load_config(file_path: str) -> Dict: + """Load yaml config file from file path""" with open(file_path, "r") as f: config = yaml.load(f) return config def make_logger(name: str, level=logging.DEBUG) -> logging.Logger: + """ + Get a named logger at a specified level + + Args: + name: name of the logger + level: level of the logger. Default is logging.DEBUG + + Returns: + The logger + """ logger = logging.getLogger(name) logger.setLevel(level=level) return logger @@ -30,7 +42,6 @@ def make_logger(name: str, level=logging.DEBUG) -> logging.Logger: def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: """Initializes multi-GPU-friendly python logger.""" - logger = logging.getLogger(name) logger.setLevel(level) @@ -43,19 +54,21 @@ def get_logger(name=__name__, level=logging.INFO) -> logging.Logger: def extras(config: DictConfig) -> None: - """A couple of optional utilities, controlled by main config file: - - disabling warnings - - easier access to debug mode - - forcing debug friendly configuration - - forcing multi-gpu friendly configuration - - Ensure correct number of timesteps/etc for all of them + """ + A couple of optional utilities, controlled by main config file + + Utilities include + - disabling warnings + - easier access to debug mode + - forcing debug friendly configuration + - forcing multi-gpu friendly configuration + - Ensure correct number of timesteps/etc for all of them Modifies DictConfig in place. Args: config (DictConfig): Configuration composed by Hydra. """ - log = get_logger() # enable adding new keys to config @@ -150,10 +163,9 @@ def print_config( Args: config (DictConfig): Configuration composed by Hydra. fields (Sequence[str], optional): Determines which main fields from config will - be printed and in what order. + be printed and in what order. resolve (bool, optional): Whether to resolve reference fields of DictConfig. """ - style = "dim" tree = rich.tree.Tree(":gear: CONFIG", style=style, guide_style=style) @@ -180,12 +192,17 @@ def log_hyperparameters( model: pl.LightningModule, trainer: pl.Trainer, ) -> None: - """This method controls which parameters from Hydra config are saved by Lightning loggers. + """ + This method controls which parameters from Hydra config are saved by Lightning loggers. Additionaly saves: - number of trainable model parameters - """ + Args: + config (DictConfig): Configuration composed by Hydra. + model (pl.LightningModule): the model with parameters to save + trainer (pl.Trainer): the trainer with hyperparams + """ hparams = {} # choose which parts of hydra config will be saved to loggers diff --git a/satflow/data/__init__.py b/satflow/data/__init__.py index e69de29b..bddc332e 100644 --- a/satflow/data/__init__.py +++ b/satflow/data/__init__.py @@ -0,0 +1 @@ +"""satflow data""" \ No newline at end of file diff --git a/satflow/data/datamodules.py b/satflow/data/datamodules.py index 46a8e2f3..5adc067e 100644 --- a/satflow/data/datamodules.py +++ b/satflow/data/datamodules.py @@ -1,3 +1,4 @@ +"""A DataModule that encapsulates the steps to process the data""" import os from nowcasting_dataset.dataset.datasets import worker_init_fn from nowcasting_dataset.config.load import load_yaml_configuration @@ -27,6 +28,8 @@ class SatFlowDataModule(LightningDataModule): """ + A SatFlow DataModule + Example of LightningDataModule for NETCDF dataset. A DataModule implements 5 key methods: - prepare_data (things to do on 1 GPU/TPU, not on every GPU/TPU in distributed mode) @@ -67,7 +70,21 @@ def __init__( forecast_minutes: Optional[int] = None, ): """ - fake_data: random data is created and used instead. This is useful for testing + Initialize a satflow DataModule + + Args: + temp_path: a file path to store temporary data + n_train_data: default is 24900 + n_val_data: default is 1000 + cloud: name of cloud provider. Default is "aws". + num_workers: default is 8 + pin_memory: default is true + configuration_filename: a file path + fake_data: random data is created and used instead. This is useful for testing. + Default is false. + required_keys: tuple or list of keys required in the example for it to be considered usable + history_minutes: how many past minutes of data to use, if subsetting the batch. Default is None. + forecast_minutes: how many future minutes of data to use, if reducing the amount of forecast time. Default is None. """ super().__init__() @@ -95,6 +112,7 @@ def __init__( ) def train_dataloader(self): + """A data loader for the training data""" if self.fake_data: train_dataset = FakeDataset( history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes @@ -114,6 +132,7 @@ def train_dataloader(self): return torch.utils.data.DataLoader(train_dataset, **self.dataloader_config) def val_dataloader(self): + """A data loader for the validation data""" if self.fake_data: val_dataset = FakeDataset( history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes @@ -133,6 +152,7 @@ def val_dataloader(self): return torch.utils.data.DataLoader(val_dataset, **self.dataloader_config) def test_dataloader(self): + """A data loader for the testing data""" if self.fake_data: test_dataset = FakeDataset( history_minutes=self.history_minutes, forecast_minutes=self.forecast_minutes @@ -154,7 +174,7 @@ def test_dataloader(self): class FakeDataset(torch.utils.data.Dataset): - """Fake dataset.""" + """Fake dataset with random data, useful for testing.""" def __init__( self, @@ -166,6 +186,7 @@ def __init__( history_minutes=30, forecast_minutes=30, ): + """Initialize a fake dataset""" self.batch_size = batch_size if history_minutes is None or forecast_minutes is None: history_minutes = 30 # Half an hour @@ -179,13 +200,15 @@ def __init__( self.length = length def __len__(self): + """Length of dataset""" return self.length def per_worker_init(self, worker_id: int): + """Not implemented""" pass def __getitem__(self, idx): - + """Get data at the index""" x = { SATELLITE_DATA: torch.randn( self.batch_size, self.seq_length, self.width, self.height, self.number_sat_channels diff --git a/satflow/data/datasets.py b/satflow/data/datasets.py index 7717b608..61fe1600 100644 --- a/satflow/data/datasets.py +++ b/satflow/data/datasets.py @@ -1,3 +1,4 @@ +"""SatFlowDataset""" from typing import Tuple, Union, List, Optional import numpy as np @@ -18,9 +19,7 @@ class SatFlowDataset(NetCDFDataset): - """Loads data saved by the `prepare_ml_training_data.py` script. - Adapted from predict_pv_yield - """ + """Loads data saved by the `prepare_ml_training_data.py` script. Adapted from predict_pv_yield""" def __init__( self, @@ -45,13 +44,18 @@ def __init__( combine_inputs: bool = False, ): """ + Initialize SatFlowDataSet + Args: - n_batches: Number of batches available on disk. - src_path: The full path (including 'gs://') to the data on - Google Cloud storage. - tmp_path: The full path to the local temporary directory - (on a local filesystem). - batch_size: Batch size, if requested, will subset data along batch dimension + n_batches: Number of batches available on disk. + src_path: The full path (including 'gs://') to the data on Google Cloud storage. + tmp_path: The full path to the local temporary directory (on a local filesystem). + configuration: configuration values + cloud: name of cloud provider. Default is "gcp". + required_keys: Tuple or list of keys required in the example for it to be considered usable + history_minutes: How many past minutes of data to use, if subsetting the batch. Default is 30. + forecast_minutes: How many future minutes of data to use, if reducing the amount of forecast time. Default is 60. + combine_inputs: Default is False. """ super().__init__( n_batches, @@ -69,6 +73,7 @@ def __init__( self.current_timestep_index = (history_minutes // 5) + 1 def __getitem__(self, batch_idx: int): + """Get data at the index""" batch = super().__getitem__(batch_idx) # Need to partition out past and future sat images here, along with the rest of the data diff --git a/satflow/data/utils/__init__.py b/satflow/data/utils/__init__.py index e69de29b..7c84b20a 100644 --- a/satflow/data/utils/__init__.py +++ b/satflow/data/utils/__init__.py @@ -0,0 +1 @@ +"""Data utility functions""" \ No newline at end of file diff --git a/satflow/data/utils/utils.py b/satflow/data/utils/utils.py index bf896524..61de5d49 100644 --- a/satflow/data/utils/utils.py +++ b/satflow/data/utils/utils.py @@ -1,3 +1,4 @@ +"""Utility functions for loading and processing data""" import datetime import io import re @@ -17,9 +18,9 @@ def eumetsat_filename_to_datetime(inner_tar_name): - """Takes a file from the EUMETSAT API and returns - the date and time part of the filename""" - + """ + Takes a file from the EUMETSAT API and returns the date and time part of the filename + """ p = re.compile("^MSG[23]-SEVI-MSG15-0100-NA-(\d*)\.") title_match = p.match(inner_tar_name) date_str = title_match.group(1) @@ -27,13 +28,17 @@ def eumetsat_filename_to_datetime(inner_tar_name): def eumetsat_name_to_datetime(filename: str): + """Parse eumetsat name as a datetime""" date_str = filename.split("0100-0100-")[-1].split(".")[0] return datetime.datetime.strptime(date_str, "%Y%m%d%H%M%S") def retrieve_pixel_value(geo_coord, data_source): - """Return floating-point value that corresponds to given point. - Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal""" + """ + Return floating-point value that corresponds to given point. + + Taken from https://gis.stackexchange.com/questions/221292/retrieve-pixel-value-with-geographic-coordinate-as-input-with-gdal + """ x, y = geo_coord[0], geo_coord[1] forward_transform = affine.Affine.from_gdal(*data_source.GetGeoTransform()) reverse_transform = ~forward_transform @@ -68,13 +73,14 @@ def map_satellite_to_mercator( ): """ Opens, transforms to Transverse Mercator over Europe, and optionally saves it to files on disk. - :param native_satellite: - :param grib_files: - :param bufr_files: - :param bands: - :param save_scene: - :param save_loc: Save location - :return: + + Args: + native_satellite: file path. Default is None. + grib_files: file path. Default is None. + bufr_files: file path. Default is None. + bands: list of bands to load + save_scene: name of the writer to use when writing data to disk. Default is "geotiff". + save_loc: Save location """ if not _SAT_LIBS: raise EnvironmentError("Pyresample or Satpy are not installed, please install them first") @@ -91,7 +97,7 @@ def map_satellite_to_mercator( # By default resamples to 3km, as thats the native resolution of all bands other than HRV scene = scene.resample(areas[0]) if save_loc is not None: - # Now the relvant data is all together, just need to save it somehow, or return it to the calling process + # Now the relevant data is all together, just need to save it somehow, or return it to the calling process scene.save_datasets(writer=save_scene, base_dir=save_loc, enhance=False) return scene @@ -106,6 +112,7 @@ def create_time_layer(dt: datetime.datetime, shape): def load_np(data): + """Load data from binary stream into numpy""" import numpy.lib.format stream = io.BytesIO(data) @@ -123,10 +130,13 @@ def create_pixel_coord_layers(x_dim: int, y_dim: int, with_r: bool = False) -> n """ Creates Coord layer for CoordConv model - :param x_dim: size of x dimension for output - :param y_dim: size of y dimension for output - :param with_r: Whether to include polar coordinates from center - :return: (2, x_dim, y_dim) or (3, x_dim, y_dim) array of the pixel coordinates + Args: + x_dim: size of x dimension for output + y_dim: size of y dimension for output + with_r: Whether to include polar coordinates from center + + Returns: + (2, x_dim, y_dim) or (3, x_dim, y_dim) array of the pixel coordinates """ xx_ones = np.ones([1, x_dim], dtype=np.int32) xx_ones = np.expand_dims(xx_ones, -1) @@ -162,14 +172,17 @@ def create_pixel_coord_layers(x_dim: int, y_dim: int, with_r: bool = False) -> n def check_channels(config: dict) -> int: """ + Determine the number of channels needed + Checks the number of channels needed per timestep, to use for preallocating the numpy array Is not the same as the one for training, as that includes the number of channels after the array is partly flattened + Args: - config: + config: configuration values Returns: - + The number of channels """ channels = len(config.get("bands", [])) channels = channels + 1 if config.get("use_mask", False) else channels @@ -197,5 +210,6 @@ def crop_center(img: np.ndarray, cropx: int, cropy: int) -> np.ndarray: def load_config(config_file): + """Load a config file from a file path""" with open(config_file, "r") as cfg: return yaml.load(cfg, Loader=yaml.FullLoader)["config"] diff --git a/satflow/examples/metnet_example.py b/satflow/examples/metnet_example.py index 69647745..f0b0fc44 100644 --- a/satflow/examples/metnet_example.py +++ b/satflow/examples/metnet_example.py @@ -1,9 +1,16 @@ +"""Example of metnet model""" from satflow.models import LitMetNet import torch import urllib.request def get_input_target(number: int): + """ + Load a single input + + Args: + number: input number + """ url = f"https://github.com/openclimatefix/satflow/releases/download/v0.0.3/input_{number}.pth" filename, headers = urllib.request.urlretrieve(url, filename=f"input_{number}.pth") input_data = torch.load(filename) diff --git a/satflow/experiments/__init__.py b/satflow/experiments/__init__.py index e69de29b..bce8636d 100644 --- a/satflow/experiments/__init__.py +++ b/satflow/experiments/__init__.py @@ -0,0 +1 @@ +"""Training pipeline for satflow models""" \ No newline at end of file diff --git a/satflow/experiments/train.py b/satflow/experiments/train.py index fa41b13a..c3ae821e 100644 --- a/satflow/experiments/train.py +++ b/satflow/experiments/train.py @@ -1,3 +1,4 @@ +"""The training pipeline for a satflow model""" from typing import List, Optional import hydra @@ -19,8 +20,7 @@ def train(config: DictConfig) -> Optional[float]: - """Contains training pipeline. - Instantiates all PyTorch Lightning objects from config. + """Contains training pipeline. Instantiates all PyTorch Lightning objects from config. Args: config (DictConfig): Configuration composed by Hydra. @@ -28,7 +28,6 @@ def train(config: DictConfig) -> Optional[float]: Returns: Optional[float]: Metric score for hyperparameter optimization. """ - # Set seed for random number generators in pytorch, numpy and python.random if "seed" in config: seed_everything(config.seed, workers=True) diff --git a/satflow/models/__init__.py b/satflow/models/__init__.py index 63a609a0..cd748feb 100644 --- a/satflow/models/__init__.py +++ b/satflow/models/__init__.py @@ -1,3 +1,4 @@ +"""Different model architectures""" from nowcasting_utils.models.base import get_model, create_model from .conv_lstm import EncoderDecoderConvLSTM, ConvLSTM from .pl_metnet import LitMetNet diff --git a/satflow/models/attention_unet.py b/satflow/models/attention_unet.py index 47486b6e..c42d6dc9 100644 --- a/satflow/models/attention_unet.py +++ b/satflow/models/attention_unet.py @@ -1,3 +1,4 @@ +"""Attention gates integrated into U-Net model architectures""" from typing import Union from satflow.models.layers.RUnetLayers import * import pytorch_lightning as pl @@ -10,6 +11,7 @@ @register_model class AttentionUnet(pl.LightningModule): + """A model with attention gates integrated into a U-Net architecture: https://arxiv.org/abs/1804.03999""" def __init__( self, input_channels: int = 12, @@ -20,6 +22,18 @@ def __init__( conv_type: str = "standard", pretrained: bool = False, ): + """ + Initialize the model + + Args: + input_channels: default is 12 + forecast_steps: number of timesteps to forecast. default is 12. + loss: name of the loss function or torch.nn.Module. Default is "mse" + lr: learning rate. default is 0.001. + visualize: add a visualization step. default is False + conv_type: one of "standard", "coord", "antialiased", or "3d" + pretrained: not implemented. default is False. + """ super().__init__() self.lr = lr self.visualize = visualize @@ -32,14 +46,26 @@ def __init__( self.criterion = get_loss(loss) def forward(self, x): + """A forward step of the model""" return self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: used to visualize the results of the training step + + Returns: + The loss for the training step + """ x, y = batch x = x.float() y_hat = self(x) @@ -59,6 +85,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch x = x.float() y_hat = self(x) @@ -73,6 +109,16 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch x = x.float() y_hat = self(x) @@ -80,6 +126,16 @@ def test_step(self, batch, batch_idx): return loss def visualize_step(self, x, y, y_hat, batch_idx, step): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + step: name of the step type. Default is "train" + """ # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time @@ -99,6 +155,7 @@ def visualize_step(self, x, y, y_hat, batch_idx, step): @register_model class AttentionRUnet(pl.LightningModule): + """A model with attention gates integrated into a RU-Net architecture""" def __init__( self, input_channels: int = 12, @@ -109,6 +166,18 @@ def __init__( lr: float = 0.001, pretrained: bool = False, ): + """ + Initialize the model + + Args: + input_channels: default is 12 + forecast_steps: number of timesteps to forecast. default is 12. + recurrent_blocks: default is 2 + visualize: add a visualization step. default is False + loss: name of the loss function or torch.nn.Module. Default is "mse" + lr: learning rate. default is 0.001. + pretrained: not implemented. default is False. + """ super().__init__() self.lr = lr self.input_channels = input_channels @@ -121,14 +190,26 @@ def __init__( self.criterion = get_loss(loss) def forward(self, x): + """A forward step of the model""" return self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: used to visualize the results of the training step + + Returns: + The loss for the training step + """ x, y = batch x = x.float() y_hat = self(x) @@ -148,6 +229,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch x = x.float() y_hat = self(x) @@ -162,6 +253,16 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch x = x.float() y_hat = self(x) @@ -169,6 +270,16 @@ def test_step(self, batch, batch_idx): return loss def visualize_step(self, x, y, y_hat, batch_idx, step): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + step: name of the step type. Default is "train" + """ # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time @@ -187,7 +298,16 @@ def visualize_step(self, x, y, y_hat, batch_idx, step): class AttU_Net(nn.Module): + """An network of attention gates and convolutional blocks""" def __init__(self, input_channels=3, output_channels=1, conv_type: str = "standard"): + """ + Initialize the module + + Args: + input_channels: default is 3 + output_channels: default is 1 + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) @@ -217,6 +337,7 @@ def __init__(self, input_channels=3, output_channels=1, conv_type: str = "standa self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): + """Perform the encoding, decoding, and concatenation""" # encoding path x1 = self.Conv1(x) @@ -259,7 +380,17 @@ def forward(self, x): class R2AttU_Net(nn.Module): + """A recurrent residual network of attention gates and convolutional blocks""" def __init__(self, input_channels=3, output_channels=1, t=2, conv_type: str = "standard"): + """ + Initialize the module + + Args: + input_channels: default is 3 + output_channels: default is 1 + t: number of recurrent blocks. default is 2 + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(R2AttU_Net, self).__init__() self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=2) @@ -294,6 +425,7 @@ def __init__(self, input_channels=3, output_channels=1, t=2, conv_type: str = "s self.Conv_1x1 = nn.Conv2d(64, output_channels, kernel_size=1, stride=1, padding=0) def forward(self, x): + """Perform the encoding, decoding, and concatenation""" # encoding path x1 = self.RRCNN1(x) diff --git a/satflow/models/cloudgan.py b/satflow/models/cloudgan.py index 46cab944..b157095c 100644 --- a/satflow/models/cloudgan.py +++ b/satflow/models/cloudgan.py @@ -1,3 +1,4 @@ +"""Creates CloudGAN, based off of https://www.climatechange.ai/papers/icml2021/54""" import pytorch_lightning as pl import torch from torch.optim import lr_scheduler @@ -12,6 +13,7 @@ class CloudGAN(pl.LightningModule): + """Creates CloudGAN, based off of https://www.climatechange.ai/papers/icml2021/54""" def __init__( self, forecast_steps: int = 48, @@ -36,6 +38,7 @@ def __init__( ): """ Creates CloudGAN, based off of https://www.climatechange.ai/papers/icml2021/54 + Changes include allowing outputs for all timesteps, optionally conditioning on time for single timestep output @@ -58,6 +61,7 @@ def __init__( l1_loss: Loss to use for the L1 in the slides, default is L1, also SSIM is available channels_per_timestep: Channels per input timestep condition_time: Whether to condition on a future timestep, similar to MetNet + pretrained: not implemented. default is False. """ super().__init__() self.lr = lr @@ -120,18 +124,20 @@ def train_per_timestep( self, images: torch.Tensor, future_images: torch.Tensor, optimizer_idx: int, batch_idx: int ): """ - For training with conditioning on time, so when the model is giving a single output + For training with conditioning on time + + When the model is giving a single output, this goes through every timestep + in forecast_steps and runs the training - This goes through every timestep in forecast_steps and runs the training Args: images: (Batch, Timestep, Channels, Width, Height) future_images: (Batch, Timestep, Channels, Width, Height) - optimizer_idx: int, the optiimizer to use + optimizer_idx: int, the optimizer to use + batch_idx: int Returns: - + A dictionary with loss information """ - if optimizer_idx == 0: # generate images total_loss = 0 @@ -191,14 +197,15 @@ def train_all_timestep( ): """ Train on all timesteps, instead of single timestep at a time. No conditioning on future timestep + Args: - images: - future_images: - optimizer_idx: - batch_idx: + images: (Batch, Timestep, Channels, Width, Height) + future_images: (Batch, Timestep, Channels, Width, Height) + optimizer_idx: int, the optimizer to use + batch_idx: int Returns: - + A dictionary with loss information """ if optimizer_idx == 0: # generate images @@ -240,6 +247,16 @@ def train_all_timestep( return output def training_step(self, batch, batch_idx, optimizer_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: int + optimizer_idx: int, the optimizer to use + Returns: + A dictionary with loss information + """ images, future_images = batch if self.condition_time: return self.train_per_timestep(images, future_images, optimizer_idx, batch_idx) @@ -247,6 +264,17 @@ def training_step(self, batch, batch_idx, optimizer_idx): return self.train_all_timestep(images, future_images, optimizer_idx, batch_idx) def val_all_timestep(self, images, future_images, batch_idx): + """ + Validate on all timesteps, instead of single timestep at a time. No conditioning on future timestep + + Args: + images: (Batch, Timestep, Channels, Width, Height) + future_images: (Batch, Timestep, Channels, Width, Height) + batch_idx: int + + Returns: + A dictionary with loss information + """ # generate images generated_images = self(images) fake = torch.cat((images, generated_images), 1) @@ -279,6 +307,20 @@ def val_all_timestep(self, images, future_images, batch_idx): return output def val_per_timestep(self, images, future_images, batch_idx): + """ + For validation with conditioning on time + + When the model is giving a single output, this goes through every timestep + in forecast_steps and runs the validation + + Args: + images: (Batch, Timestep, Channels, Width, Height) + future_images: (Batch, Timestep, Channels, Width, Height) + batch_idx: int + + Returns: + A dictionary with loss information + """ total_g_loss = 0 total_d_loss = 0 vis_step = True if np.random.random() < 0.01 else False @@ -323,6 +365,16 @@ def val_per_timestep(self, images, future_images, batch_idx): return output def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: int + optimizer_idx: int, the optimizer to use + Returns: + A dictionary with loss information + """ images, future_images = batch if self.condition_time: return self.val_per_timestep(images, future_images, batch_idx) @@ -330,9 +382,16 @@ def validation_step(self, batch, batch_idx): return self.val_all_timestep(images, future_images, batch_idx) def forward(self, x, **kwargs): + """A forward step of the generator""" return self.generator.forward(x, **kwargs) def configure_optimizers(self): + """ + Get the optimizers and the learning rate schedulers for the generator and discriminator + + Returns: + A tuple of [g_optimizer, d_optimizer], [g_scheduler, d_scheduler] + """ lr = self.lr b1 = self.b1 b2 = self.b2 @@ -364,6 +423,16 @@ def configure_optimizers(self): def visualize_step( self, x: torch.Tensor, y: torch.Tensor, y_hat: torch.Tensor, batch_idx: int, step: str ): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + step: name of the step type. + """ # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment[0] # Image input is either (B, C, H, W) or (B, T, C, H, W) diff --git a/satflow/models/conv_lstm.py b/satflow/models/conv_lstm.py index 1c30629c..72fb273f 100644 --- a/satflow/models/conv_lstm.py +++ b/satflow/models/conv_lstm.py @@ -1,3 +1,4 @@ +"""An Autoencoder model with Convolutional LSTM cells""" from typing import Any, Dict, Union import pytorch_lightning as pl @@ -14,6 +15,7 @@ @register_model class EncoderDecoderConvLSTM(pl.LightningModule): + """An Autoencoder model with Convolutional LSTM cells""" def __init__( self, hidden_dim: int = 64, @@ -26,6 +28,20 @@ def __init__( pretrained: bool = False, conv_type: str = "standard", ): + """ + Initialize the model + + Args: + hidden_dim: number of channels of hidden state. default is 64 + input_channels: number of channels of input tensor. default is 12 + out_channels: number of channels in output. default is 1 + forecast_steps: number of timesteps to forecast. default is 48. + lr: learning rate. default is 0.001 + visualize: not implemented. default is False. + loss: name of the loss function or torch.nn.Module. Default is "mse" + pretrained: not implemented. default is False + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(EncoderDecoderConvLSTM, self).__init__() self.forecast_steps = forecast_steps self.criterion = get_loss(loss) @@ -36,6 +52,7 @@ def __init__( @classmethod def from_config(cls, config): + """Initialize EncoderDecoderConvLSTM model from configuration values""" return EncoderDecoderConvLSTM( hidden_dim=config.get("num_hidden", 64), input_channels=config.get("in_channels", 12), @@ -45,14 +62,33 @@ def from_config(cls, config): ) def forward(self, x, future_seq=0, hidden_state=None): + """ + A forward step of the model + + Args: + x: 5-D Tensor of shape (batch, time, channel, height, width) + future_seq: number of timesteps to forecast. default is 0. + hidden_state: not implemented. default is None + """ return self.model.forward(x, future_seq, hidden_state) def configure_optimizers(self): + """Get the optimizer with the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the training step + """ x, y = batch y_hat = self(x, self.forecast_steps) y_hat = torch.permute(y_hat, dims=(0, 2, 1, 3, 4)) @@ -72,6 +108,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch y_hat = self(x, self.forecast_steps) y_hat = torch.permute(y_hat, dims=(0, 2, 1, 3, 4)) @@ -87,12 +133,32 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch y_hat = self(x, self.forecast_steps) loss = self.criterion(y_hat, y) return loss def visualize_step(self, x, y, y_hat, batch_idx, step="train"): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: the global step to record for this batch + step: name of the step type. Default is "train" + """ tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time if len(x.shape) == 5: @@ -123,13 +189,20 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class ConvLSTM(torch.nn.Module): def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "standard"): super().__init__() - """ ARCHITECTURE - - # Encoder (ConvLSTM) - # Encoder Vector (final hidden state of encoder) - # Decoder (ConvLSTM) - takes Encoder Vector as input - # Decoder (3D CNN) - produces regression predictions for our model - + """ + Initialize the ConvSTM module + + ARCHITECTURE + - Encoder (ConvLSTM) + - Encoder Vector (final hidden state of encoder) + - Decoder (ConvLSTM) - takes Encoder Vector as input + - Decoder (3D CNN) - produces regression predictions for our model + + Args: + input_channels: number of channels of input tensor + hidden_dim: number of channels of hidden state + out_channels: number of channels in output + conv_type: one of "standard", "coord", "antialiased", or "3d" """ self.encoder_1_convlstm = ConvLSTMCell( input_dim=input_channels, @@ -171,7 +244,22 @@ def __init__(self, input_channels, hidden_dim, out_channels, conv_type: str = "s ) def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, h_t4, c_t4): - + """ + Compute the forward pass of an autoencoder on the hidden and cell states of lstm cells + + Args: + x: 5-D Tensor of shape (batch, time, channel, height, width) + seq_len: size of time dimension in x + future_step: number of timesteps to forecast + h_t: hidden state for first encoder + c_t: cell state for first encoder + h_t2: hidden state for second encoder + c_t2: cell state for second encoder + h_t3: hidden state for first decoder + c_t3: cell state for first decoder + h_t4: hidden state for second decoder + c_t4: cell state for second decoder + """ outputs = [] # encoder @@ -205,14 +293,14 @@ def autoencoder(self, x, seq_len, future_step, h_t, c_t, h_t2, c_t2, h_t3, c_t3, return outputs def forward(self, x, forecast_steps=0, hidden_state=None): - - """ - Parameters - ---------- - input_tensor: - 5-D Tensor of shape (b, t, c, h, w) # batch, time, channel, height, width """ + Compute the forward pass + Args: + x: 5-D Tensor of shape (batch, time, channel, height, width) + forecast_steps: number of timesteps to forecast. default is 0. + hidden_state: not implemented + """ # find size of different input dimensions b, seq_len, _, h, w = x.size() diff --git a/satflow/models/deeplabv3.py b/satflow/models/deeplabv3.py index adefe2be..c5edd521 100644 --- a/satflow/models/deeplabv3.py +++ b/satflow/models/deeplabv3.py @@ -1,3 +1,4 @@ +"""A semantic segmentation architecture""" import torch import torch.nn.functional as F import pytorch_lightning as pl @@ -10,6 +11,7 @@ @register_model class DeeplabV3(pl.LightningModule): + """A semantic segmentation architecture""" def __init__( self, forecast_steps: int = 48, @@ -21,6 +23,19 @@ def __init__( pretrained: bool = False, aux_loss: bool = False, ): + """ + Initialize the model + + Args: + forecast_steps: number of timesteps to forecast. default is 48. + input_channels: default is 12 + lr: learning rate. default is 0.001 + make_vis: whether to add a visualization step. default is False. + loss: name of the loss function or torch.nn.Module. Default is "mse" + backbone: the name of the backbone model. default is "resnet50". + pretrained: Whether to use a model pre-trained on other data, Default is False + aux_loss: Whether to use an auxilary loss. Default is False + """ super(DeeplabV3, self).__init__() self.lr = lr assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] @@ -50,6 +65,7 @@ def __init__( @classmethod def from_config(cls, config): + """Initialize model from configuration values""" return DeeplabV3( forecast_steps=config.get("forecast_steps", 12), input_channels=config.get("in_channels", 12), @@ -60,14 +76,26 @@ def from_config(cls, config): ) def forward(self, x): + """A forward step of the model""" return self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: used to visualize the results of the training step + + Returns: + The loss for the training step + """ x, y = batch y_hat = self(x) @@ -81,6 +109,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch y_hat = self(x) val_loss = self.criterion(y_hat, y) @@ -88,12 +126,31 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch y_hat = self(x, self.forecast_steps) loss = self.criterion(y_hat, y) return loss def visualize(self, x, y, y_hat, batch_idx): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + """ # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment # Add all the different timesteps for a single prediction, 0.1% of the time diff --git a/satflow/models/fcn.py b/satflow/models/fcn.py index d0ea4030..fa072e71 100644 --- a/satflow/models/fcn.py +++ b/satflow/models/fcn.py @@ -1,3 +1,4 @@ +"""A fully convolutional network""" import torch import torch.nn.functional as F import pytorch_lightning as pl @@ -10,6 +11,7 @@ @register_model class FCN(pl.LightningModule): + """A fully convolutional network""" def __init__( self, forecast_steps: int = 48, @@ -20,6 +22,18 @@ def __init__( backbone: str = "resnet50", pretrained: bool = False, ): + """ + Initialize the model + + Args: + forecast_steps: number of timesteps to forecast. default is 48. + input_channels: default is 12 + lr: learning rate. default is 0.001 + make_vis: whether to add a visualization step. default is False. + loss: name of the loss function or torch.nn.Module. Default is "mse" + backbone: the name of the backbone model. default is "resnet50". + pretrained: Whether to use a model pre-trained on other data, Default is False + """ super(FCN, self).__init__() self.lr = lr assert loss in ["mse", "bce", "binary_crossentropy", "crossentropy", "focal"] @@ -45,6 +59,7 @@ def __init__( @classmethod def from_config(cls, config): + """Initialize model from configuration values""" return DeeplabV3( forecast_steps=config.get("forecast_steps", 12), input_channels=config.get("in_channels", 12), @@ -55,14 +70,26 @@ def from_config(cls, config): ) def forward(self, x): + """A forward step of the model""" return self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: used to visualize the results of the training step + + Returns: + The loss for the training step + """ x, y = batch y_hat = self(x) @@ -76,6 +103,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch y_hat = self(x) val_loss = self.criterion(y_hat, y) @@ -83,12 +120,31 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch y_hat = self(x, self.forecast_steps) loss = self.criterion(y_hat, y) return loss def visualize(self, x, y, y_hat, batch_idx): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + """ # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment # Add all the different timesteps for a single prediction, 0.1% of the time diff --git a/satflow/models/gan/__init__.py b/satflow/models/gan/__init__.py index 4b8307af..d47adeec 100644 --- a/satflow/models/gan/__init__.py +++ b/satflow/models/gan/__init__.py @@ -1,2 +1,3 @@ +"""discriminators and generators for GANs""" from .discriminators import GANLoss, PixelDiscriminator, NLayerDiscriminator, define_discriminator from .generators import define_generator diff --git a/satflow/models/gan/common.py b/satflow/models/gan/common.py index 394392e0..b7c07967 100644 --- a/satflow/models/gan/common.py +++ b/satflow/models/gan/common.py @@ -1,3 +1,4 @@ +"""Common functions used for GANs""" import functools import torch from torch.nn import init @@ -6,11 +7,11 @@ def get_norm_layer(norm_type="instance"): """Return a normalization layer - Parameters: - norm_type (str) -- the name of the normalization layer: batch | instance | none - For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev). For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics. + + Args: + norm_type (str): the name of the normalization layer: batch | instance | none """ if norm_type == "batch": norm_layer = functools.partial(torch.nn.BatchNorm2d, affine=True, track_running_stats=True) @@ -31,15 +32,14 @@ def norm_layer(x): def init_weights(net, init_type="normal", init_gain=0.02): """Initialize network weights. - Parameters: - net (network) -- network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - We use 'normal' in the original pix2pix and CycleGAN paper. But xavier and kaiming might work better for some applications. Feel free to try yourself. - """ + Args: + net (network): network to be initialized + init_type (str): the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float) : scaling factor for normal, xavier and orthogonal. + """ def init_func(m): # define the initialization function classname = m.__class__.__name__ if hasattr(m, "weight") and ( @@ -70,14 +70,15 @@ def init_func(m): # define the initialization function def init_net(net, init_type="normal", init_gain=0.02): - """Initialize a network: 1. register CPU/GPU device (with multi-GPU support); 2. initialize the network weights - Parameters: - net (network) -- the network to be initialized - init_type (str) -- the name of an initialization method: normal | xavier | kaiming | orthogonal - gain (float) -- scaling factor for normal, xavier and orthogonal. - gpu_ids (int list) -- which GPUs the network runs on: e.g., 0,1,2 - - Return an initialized network. + """Initialize a network + + Args: + net (network): the network to be initialized + init_type (str): the name of an initialization method: normal | xavier | kaiming | orthogonal + init_gain (float): scaling factor for normal, xavier and orthogonal. + + Returns: + the initialized network """ init_weights(net, init_type, init_gain=init_gain) return net @@ -88,16 +89,17 @@ def cal_gradient_penalty( ): """Calculate the gradient penalty loss, used in WGAN-GP paper https://arxiv.org/abs/1704.00028 - Arguments: - netD (network) -- discriminator network - real_data (tensor array) -- real images - fake_data (tensor array) -- generated images from the generator - device (str) -- GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') - type (str) -- if we mix real and fake data or not [real | fake | mixed]. - constant (float) -- the constant used in formula ( ||gradient||_2 - constant)^2 - lambda_gp (float) -- weight for this loss - - Returns the gradient penalty loss + Args: + netD (network): discriminator network + real_data (tensor array): real images + fake_data (tensor array): generated images from the generator + device (str): GPU / CPU: from torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') + type (str): if we mix real and fake data or not [real | fake | mixed]. + constant (float): the constant used in formula ( ||gradient||_2 - constant)^2 + lambda_gp (float): weight for this loss + + Returns: + a tuple (the gradient penalty loss, the gradients) """ if lambda_gp > 0.0: if type == "real": # either use real images, fake images, or a linear interpolation of two. diff --git a/satflow/models/gan/discriminators.py b/satflow/models/gan/discriminators.py index 07b5e16a..005ce088 100644 --- a/satflow/models/gan/discriminators.py +++ b/satflow/models/gan/discriminators.py @@ -1,3 +1,4 @@ +"""Implement various discriminators for GANs""" import functools import torch from torch import nn as nn @@ -16,18 +17,8 @@ def define_discriminator( init_gain=0.02, conv_type: str = "standard", ): - """Create a discriminator - - Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the first conv layer - netD (str) -- the architecture's name: basic | n_layers | pixel - n_layers_D (int) -- the number of conv layers in the discriminator; effective when netD=='n_layers' - norm (str) -- the type of normalization layers used in the network. - init_type (str) -- the name of the initialization method. - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - - Returns a discriminator + """ + Create a discriminator Our current implementation provides three types of discriminators: [basic]: 'PatchGAN' classifier described in the original pix2pix paper. @@ -43,6 +34,19 @@ def define_discriminator( It encourages greater color diversity but has no effect on spatial statistics. The discriminator has been initialized by . It uses Leakly RELU for non-linearity. + + Args: + input_nc (int): the number of channels in input images + ndf (int): the number of filters in the first conv layer + netD (str): the architecture's name: basic | n_layers | pixel + n_layers_D (int): the number of conv layers in the discriminator; effective when netD=='n_layers' + norm (str): the type of normalization layers used in the network. + init_type (str): the name of the initialization method. + init_gain (float): scaling factor for normal, xavier and orthogonal. + conv_type (str): one of "standard", "coord", "antialiased", or "3d" + + Returns: + a discriminator """ net = None norm_layer = get_norm_layer(norm_type=norm) @@ -75,13 +79,13 @@ class GANLoss(nn.Module): def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0): """Initialize the GANLoss class. - Parameters: - gan_mode (str) - - the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. - target_real_label (bool) - - label for a real image - target_fake_label (bool) - - label of a fake image - Note: Do not use sigmoid as the last layer of Discriminator. LSGAN needs no sigmoid. vanilla GANs will handle it with BCEWithLogitsLoss. + + Args: + gan_mode (str): the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + target_real_label (bool): label for a real image + target_fake_label (bool): label of a fake image """ super(GANLoss, self).__init__() self.register_buffer("real_label", torch.tensor(target_real_label)) @@ -100,8 +104,8 @@ def get_target_tensor(self, prediction, target_is_real): """Create label tensors with the same size as the input. Parameters: - prediction (tensor) - - tpyically the prediction from a discriminator - target_is_real (bool) - - if the ground truth label is for real images or fake images + prediction (tensor): typically the prediction from a discriminator + target_is_real (bool): if the ground truth label is for real images or fake images Returns: A label tensor filled with ground truth label, and with the size of the input @@ -117,8 +121,8 @@ def __call__(self, prediction, target_is_real): """Calculate loss given Discriminator's output and grount truth labels. Parameters: - prediction (tensor) - - tpyically the prediction output from a discriminator - target_is_real (bool) - - if the ground truth label is for real images or fake images + prediction (tensor): typically the prediction output from a discriminator + target_is_real (bool): if the ground truth label is for real images or fake images Returns: the calculated loss. @@ -143,10 +147,11 @@ def __init__( """Construct a PatchGAN discriminator Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - n_layers (int) -- the number of conv layers in the discriminator - norm_layer -- normalization layer + input_nc (int): the number of channels in input images + ndf (int): the number of filters in the last conv layer + n_layers (int): the number of conv layers in the discriminator + norm_layer: normalization layer + conv_type (str): one of "standard", "coord", "antialiased", or "3d" """ super(NLayerDiscriminator, self).__init__() if ( @@ -230,9 +235,10 @@ def __init__(self, input_nc, ndf=64, norm_layer=nn.BatchNorm2d, conv_type: str = """Construct a 1x1 PatchGAN discriminator Parameters: - input_nc (int) -- the number of channels in input images - ndf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer + input_nc (int): the number of channels in input images + ndf (int): the number of filters in the last conv layer + norm_layer: normalization layer + conv_type (str): one of "standard", "coord", "antialiased", or "3d" """ super(PixelDiscriminator, self).__init__() if ( @@ -261,7 +267,14 @@ def forward(self, input): class CloudGANBlock(nn.Module): + """Implement a block for the CloudGANDiscriminator""" def __init__(self, input_channels, conv_type: str = "standard"): + """Initialize the block + + Args: + input_channels: the number of channels in the input + conv_type (str): one of "standard", "coord", "antialiased", or "3d" + """ super().__init__() conv2d = get_conv_layer(conv_type) self.conv = conv2d(input_channels, input_channels * 2, kernel_size=(3, 3)) @@ -274,6 +287,10 @@ def __init__(self, input_channels, conv_type: str = "standard"): self.blurpool = torch.nn.Identity() def forward(self, x): + """Compute the foward pass + + A convolutional layer, RelU, max pooling, and blur pool (if conv_type == "antialiased") + """ x = self.conv(x) x = self.relu(x) x = self.pool(x) @@ -291,6 +308,15 @@ def __init__( num_stages: int = 3, conv_type: str = "standard", ): + """ + Initialize the module + + Args: + input_channels: default is 12 + num_filters: default is 64 + num_stages: the number of blocks to use. default is 3 + conv_type (str): one of "standard", "coord", "antialiased", or "3d" + """ super().__init__() conv2d = get_conv_layer(conv_type) self.conv_1 = conv2d(input_channels, num_filters, kernel_size=1, stride=1, padding=0) @@ -303,6 +329,7 @@ def __init__( self.fc = torch.nn.LazyLinear(1) # Real/Fake def forward(self, x): + """Compute the forward step""" x = self.conv_1(x) x = self.stages(x) x = self.flatten(x) diff --git a/satflow/models/gan/generators.py b/satflow/models/gan/generators.py index 794e13a4..3bb166dc 100644 --- a/satflow/models/gan/generators.py +++ b/satflow/models/gan/generators.py @@ -1,3 +1,4 @@ +"""Implement various generators for GANs""" import functools import torch from torch import nn as nn @@ -19,18 +20,6 @@ def define_generator( ): """Create a generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - ngf (int) -- the number of filters in the last conv layer - netG (str) -- the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 - norm (str) -- the name of normalization layers used in the network: batch | instance | none - use_dropout (bool) -- if use dropout layers. - init_type (str) -- the name of our initialization method. - init_gain (float) -- scaling factor for normal, xavier and orthogonal. - - Returns a generator - Our current implementation provides two types of generators: U-Net: [unet_128] (for 128x128 input images) and [unet_256] (for 256x256 input images) The original U-Net paper: https://arxiv.org/abs/1505.04597 @@ -41,6 +30,19 @@ def define_generator( The generator has been initialized by . It uses RELU for non-linearity. + + Args: + input_nc (int): the number of channels in input images + output_nc (int): the number of channels in output images + ngf (int): the number of filters in the last conv layer + netG (str): the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + norm (str): the name of normalization layers used in the network: batch | instance | none + use_dropout (bool): if use dropout layers. + init_type (str): the name of our initialization method. + init_gain (float): scaling factor for normal, xavier and orthogonal. + + Returns: + a generator """ net = None norm_layer = get_norm_layer(norm_type=norm) @@ -70,7 +72,8 @@ def define_generator( class ResnetGenerator(nn.Module): """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations. - We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style) + We adapt Torch code and idea from Justin Johnson's neural style transfer project + (https://github.com/jcjohnson/fast-neural-style) """ def __init__( @@ -86,14 +89,15 @@ def __init__( ): """Construct a Resnet-based generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers - n_blocks (int) -- the number of ResNet blocks - padding_type (str) -- the name of padding layer in conv layers: reflect | replicate | zero + Args: + input_nc (int): the number of channels in input images + output_nc (int): the number of channels in output images + ngf (int): the number of filters in the last conv layer + norm_layer: normalization layer + use_dropout (bool): if use dropout layers + n_blocks (int): the number of ResNet blocks + padding_type (str): the name of padding layer in conv layers: reflect | replicate | zero + conv_type (str): conv_type: one of "standard", "coord", "antialiased", or "3d" """ assert n_blocks >= 0 super(ResnetGenerator, self).__init__() @@ -189,10 +193,17 @@ def __init__( ): """Initialize the Resnet block - A resnet block is a conv block with skip connections - We construct a conv block with build_conv_block function, - and implement skip connections in function. + A resnet block is a conv block with skip connections. We construct a conv block with + build_conv_block function, and implement skip connections in function. Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf + + Args: + dim: the number of channels in the conv layer + padding_type: the name of padding layer: reflect | replicate | zero + norm_layer: normalization layer + use_dropout: if use dropout layers. + use_bias: if the conv layer uses bias or not + conv_type (str): conv_type: one of "standard", "coord", "antialiased", or "3d" """ super(ResnetBlock, self).__init__() conv2d = get_conv_layer(conv_type) @@ -205,14 +216,16 @@ def build_conv_block( ): """Construct a convolutional block. - Parameters: - dim (int) -- the number of channels in the conv layer. - padding_type (str) -- the name of padding layer: reflect | replicate | zero - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. - use_bias (bool) -- if the conv layer uses bias or not + Args: + dim (int): the number of channels in the conv layer. + padding_type (str): the name of padding layer: reflect | replicate | zero + norm_layer: normalization layer + use_dropout (bool): if use dropout layers. + use_bias (bool): if the conv layer uses bias or not + conv2d: the convolutional layer to use - Returns a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) + Returns: + a conv block (with a conv layer, a normalization layer, and a non-linearity layer (ReLU)) """ conv_block = [] p = 0 @@ -269,16 +282,19 @@ def __init__( conv_type: str = "standard", ): """Construct a Unet generator - Parameters: - input_nc (int) -- the number of channels in input images - output_nc (int) -- the number of channels in output images - num_downs (int) -- the number of downsamplings in UNet. For example, # if |num_downs| == 7, - image of size 128x128 will become of size 1x1 # at the bottleneck - ngf (int) -- the number of filters in the last conv layer - norm_layer -- normalization layer We construct the U-Net from the innermost layer to the outermost layer. It is a recursive process. + + Args: + input_nc (int): the number of channels in input images + output_nc (int): the number of channels in output images + num_downs (int): the number of downsamplings in UNet. For example, # if |num_downs| == 7, + image of size 128x128 will become of size 1x1 # at the bottleneck + ngf (int): the number of filters in the last conv layer + norm_layer: normalization layer + use_dropout (bool): if use dropout layers + conv_type (str): conv_type: one of "standard", "coord", "antialiased", or "3d" """ super(UnetGenerator, self).__init__() # construct unet structure @@ -361,15 +377,16 @@ def __init__( ): """Construct a Unet submodule with skip connections. - Parameters: - outer_nc (int) -- the number of filters in the outer conv layer - inner_nc (int) -- the number of filters in the inner conv layer - input_nc (int) -- the number of channels in input images/features - submodule (UnetSkipConnectionBlock) -- previously defined submodules - outermost (bool) -- if this module is the outermost module - innermost (bool) -- if this module is the innermost module - norm_layer -- normalization layer - use_dropout (bool) -- if use dropout layers. + Args: + outer_nc (int): the number of filters in the outer conv layer + inner_nc (int): the number of filters in the inner conv layer + input_nc (int): the number of channels in input images/features + submodule (UnetSkipConnectionBlock): previously defined submodules + outermost (bool): if this module is the outermost module + innermost (bool): if this module is the innermost module + norm_layer: normalization layer + use_dropout (bool): if use dropout layers. + conv_type (str): conv_type: one of "standard", "coord", "antialiased", or "3d" """ super(UnetSkipConnectionBlock, self).__init__() self.outermost = outermost @@ -423,6 +440,7 @@ def __init__( self.model = nn.Sequential(*model) def forward(self, x): + """Compute the forward pass""" if self.outermost: return self.model(x) else: # add skip connections diff --git a/satflow/models/layers/Attention.py b/satflow/models/layers/Attention.py index 49b10aa8..6d089f4a 100644 --- a/satflow/models/layers/Attention.py +++ b/satflow/models/layers/Attention.py @@ -1,3 +1,4 @@ +"""Attention Layers""" import torch import torch.nn as nn from torch.nn import functional as F @@ -5,9 +6,20 @@ class SeparableAttn(nn.Module): + """A sequence of separable attention cells""" def __init__( self, in_dim, activation=F.relu, pooling_factor=2, padding_mode="constant", padding_value=0 ): + """ + Initialize the module + + Args: + in_dim: Number of input channels + activation: activation function to use. Default is torch.nn.functional.relu + pooling_factor: The stride of the window to use on the first dimension + padding_mode: not implemented. Default is "constant" + padding_value: not implemented. Default is 0 + """ super().__init__() self.model = nn.Sequential( SeparableAttnCell(in_dim, "T", activation, pooling_factor, padding_mode, padding_value), @@ -16,11 +28,12 @@ def __init__( ) def forward(self, x): - + """Compute the forward pass of the layer""" return self.model(x) class SeparableAttnCell(nn.Module): + """A separable attention cell""" def __init__( self, in_dim, @@ -30,6 +43,17 @@ def __init__( padding_mode="constant", padding_value=0, ): + """ + Initialize the module + + Args: + in_dim: Number of input channels + attn_id: The dimension to use. "T": timestep, "W": image width, "H": image height + activation: activation function to use. Default is torch.nn.functional.relu + pooling_factor: The stride of the window to use on the first dimension + padding_mode: not implemented. Default is "constant" + padding_value: not implemented. Default is 0 + """ super().__init__() self.attn_id = attn_id self.activation = activation @@ -51,11 +75,19 @@ def __init__( self.softmax = nn.Softmax(dim=-1) def init_conv(self, conv, glu=True): + """Initialize network weights""" init.xavier_uniform_(conv.weight) if conv.bias is not None: conv.bias.data.zero_() def forward(self, x): + """ + Compute the output of the separable attention cell + + Args: + x: a tensor with dimensions (batch_size, num_channels, timesteps, width, height). + timesteps, width, and height dimensions must all have even sizexw + """ batch_size, C, T, W, H = x.size() @@ -110,8 +142,16 @@ def forward(self, x): class SelfAttention(nn.Module): + """A self attention layer""" def __init__(self, in_dim, activation=F.relu, pooling_factor=2): # TODO for better compability + """ + Initialize the module + Args: + in_dim: Number of input channels + activation: activation function to use. Default is torch.nn.functional.relu + pooling_factor: the stride for all dimensions of the pooling layer + """ super(SelfAttention, self).__init__() self.activation = activation @@ -129,11 +169,19 @@ def __init__(self, in_dim, activation=F.relu, pooling_factor=2): # TODO for bet self.softmax = nn.Softmax(dim=-1) def init_conv(self, conv, glu=True): + """Initialize network weights""" init.xavier_uniform_(conv.weight) if conv.bias is not None: conv.bias.data.zero_() def forward(self, x): + """ + Compute the output of the layer + + Args: + x: a tensor with shape (batch_size, num_channels, width, height) or + (batch_size, num_channels, timesteps, width, height) + """ if len(x.size()) == 4: batch_size, C, W, H = x.size() @@ -171,8 +219,9 @@ def forward(self, x): class SelfAttention2d(nn.Module): - r"""Self Attention Module as proposed in the paper `"Self-Attention Generative Adversarial - Networks by Han Zhang et. al." `_ + """ + Self Attention Module as proposed in the paper `"Self-Attention Generative Adversarial Networks by Han Zhang et. al." ` + .. math:: attention = softmax((query(x))^T * key(x)) .. math:: output = \gamma * value(x) * attention + x where @@ -180,16 +229,20 @@ class SelfAttention2d(nn.Module): - :math:`key` : 2D Convolution Operation - :math:`value` : 2D Convolution Operation - :math:`x` : Input - Args: - input_dims (int): The input channel dimension in the input ``x``. - output_dims (int, optional): The output channel dimension. If ``None`` the output - channel value is computed as ``input_dims // 8``. So if the ``input_dims`` is **less - than 8** then the layer will give an error. - return_attn (bool, optional): Set it to ``True`` if you want the attention values to be - returned. """ def __init__(self, input_dims, output_dims=None, return_attn=False): + """ + Initialize the module + + Args: + input_dims (int): The input channel dimension in the input ``x``. + output_dims (int, optional): The output channel dimension. If ``None`` the output + channel value is computed as ``input_dims // 8``. So if the ``input_dims`` is **less + than 8** then the layer will give an error. + return_attn (bool, optional): Set it to ``True`` if you want the attention values to be + returned. Default is False + """ output_dims = input_dims // 8 if output_dims is None else output_dims if output_dims == 0: raise Exception( @@ -204,9 +257,11 @@ def __init__(self, input_dims, output_dims=None, return_attn=False): self.return_attn = return_attn def forward(self, x): - r"""Computes the output of the Self Attention Layer + """Computes the output of the Self Attention Layer + Args: x (torch.Tensor): A 4D Tensor with the channel dimension same as ``input_dims``. + Returns: A tuple of the ``output`` and the ``attention`` if ``return_attn`` is set to ``True`` else just the ``output`` tensor. diff --git a/satflow/models/layers/ConditionTime.py b/satflow/models/layers/ConditionTime.py index 97695ae6..5b48cfb0 100644 --- a/satflow/models/layers/ConditionTime.py +++ b/satflow/models/layers/ConditionTime.py @@ -1,9 +1,21 @@ +"""A layer to add time information to the input""" import torch from torch import nn as nn def condition_time(x, i=0, size=(12, 16), seq_len=15): - "create one hot encoded time image-layers, i in [1, seq_len]" + """ + create one hot encoded time image-layers, i in [1, seq_len] + + Args: + x: input tensor + i: index of the future observation to condition on + size: tuple of (height, width) of the image + seq_len: number of timesteps in the future that is output + + Returns: + A one-hot tensor of shape (seq_len, height, width), which is activated at index (i, *, *) + """ assert i < seq_len times = (torch.eye(seq_len, dtype=x.dtype, device=x.device)[i]).unsqueeze(-1).unsqueeze(-1) ones = torch.ones(1, *size, dtype=x.dtype, device=x.device) @@ -11,16 +23,33 @@ def condition_time(x, i=0, size=(12, 16), seq_len=15): class ConditionTime(nn.Module): - "Condition Time on a stack of images, adds `horizon` channels to image" - + """Condition Time on a stack of images, adds `horizon` channels to image""" def __init__(self, horizon, ch_dim=2, num_dims=5): + """ + Initialize module + + Args: + horizon: number of timesteps in the future to output + ch_dim: the dimension in the input tensor that represents the channels. Default is 2. + num_dims: number of dimensions in input tensor (4 or 5). Default is 5. + """ super().__init__() self.horizon = horizon self.ch_dim = ch_dim self.num_dims = num_dims def forward(self, x, fstep=0): - "x stack of images, fsteps" + """ + x stack of images, fsteps + + Args: + x: input tensor. Either (batch_size, timestep, channels, height, width) + or (batch size, height, width, channels) + fstep: the index of the future timestep to condition on + + Returns: + concatenation of x and one hot tensor + """ if self.num_dims == 5: bs, seq_len, ch, h, w = x.shape ct = condition_time(x, fstep, (h, w), seq_len=self.horizon).repeat(bs, seq_len, 1, 1, 1) diff --git a/satflow/models/layers/ConvLSTM.py b/satflow/models/layers/ConvLSTM.py index 39e204bf..0eaf2a66 100644 --- a/satflow/models/layers/ConvLSTM.py +++ b/satflow/models/layers/ConvLSTM.py @@ -1,25 +1,22 @@ +"""Layers for convolutional LSTM model""" import torch import torch.nn as nn from satflow.models.utils import get_conv_layer class ConvLSTMCell(nn.Module): + """Convolutional LSTM""" def __init__(self, input_dim, hidden_dim, kernel_size, bias, conv_type: str = "standard"): """ Initialize ConvLSTM cell. - Parameters - ---------- - input_dim: int - Number of channels of input tensor. - hidden_dim: int - Number of channels of hidden state. - kernel_size: (int, int) - Size of the convolutional kernel. - bias: bool - Whether or not to add the bias. + Args: + input_dim (int): Number of channels of input tensor. + hidden_dim (int): Number of channels of hidden state. + kernel_size (int, int): Size of the convolutional kernel. + bias (bool): Whether or not to add the bias. + conv_type: one of "standard", "coord", "antialiased", or "3d" """ - super(ConvLSTMCell, self).__init__() self.input_dim = input_dim @@ -39,6 +36,16 @@ def __init__(self, input_dim, hidden_dim, kernel_size, bias, conv_type: str = "s ) def forward(self, input_tensor, cur_state): + """ + Compute the forward pass + + Args: + input_tensor: 5-D Tensor of shape (batch, time, channel, height, width) + cur_state: a tuple of (current_hidden_state, current_cell_state) + + Returns: + a tuple of (next_hidden_state, next_cell_state) + """ h_cur, c_cur = cur_state combined = torch.cat([input_tensor, h_cur], dim=1) # concatenate along channel axis @@ -56,6 +63,16 @@ def forward(self, input_tensor, cur_state): return h_next, c_next def init_hidden(self, batch_size, image_size): + """ + Initialize states + + Args: + batch_size: dimension of each batch + image_size: tuple of (height, width) + + Returns: + A tuple of two tensors filled with zeros + """ height, width = image_size return ( torch.zeros(batch_size, self.hidden_dim, height, width, device=self.conv.weight.device), diff --git a/satflow/models/layers/CoordConv.py b/satflow/models/layers/CoordConv.py index 973327cf..55ba25fd 100644 --- a/satflow/models/layers/CoordConv.py +++ b/satflow/models/layers/CoordConv.py @@ -1,14 +1,25 @@ +"""A 2D convolution over an input tensor and coordinates on a grid""" import torch import torch.nn as nn class AddCoords(nn.Module): + """"Add input tensors for x and y dimensions that are evenly spaced coordinates from -1 to 1""" def __init__(self, with_r=False): + """ + Initialize module + + Args: + with_r: also add an input that is the distance from the center (polar coordinates) + Default if False + """ super().__init__() self.with_r = with_r def forward(self, input_tensor): """ + Compute the forward pass + Args: input_tensor: shape(batch, channel, x_dim, y_dim) """ @@ -42,7 +53,17 @@ def forward(self, input_tensor): class CoordConv(nn.Module): + """A 2D convolution over an input tensor and coordinates on a grid""" def __init__(self, in_channels, out_channels, with_r=False, **kwargs): + """ + A 2D convolution over an input tensor and coordinates on a grid + + Args: + in_chanels: number of input channels + out_channels: number of output channels + with_r: also add an input that is the distance from the center (polar coordinates) + Default if False + """ super().__init__() self.addcoords = AddCoords(with_r=with_r) in_size = in_channels + 2 @@ -51,6 +72,12 @@ def __init__(self, in_channels, out_channels, with_r=False, **kwargs): self.conv = nn.Conv2d(in_size, out_channels, **kwargs) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ ret = self.addcoords(x) ret = self.conv(ret) return ret diff --git a/satflow/models/layers/Discriminator.py b/satflow/models/layers/Discriminator.py index b4aa7503..ddad6661 100644 --- a/satflow/models/layers/Discriminator.py +++ b/satflow/models/layers/Discriminator.py @@ -13,10 +13,13 @@ ######################################################################################## class Spectral_Norm: + """Compute the spectral norm of a module""" def __init__(self, name): + """Initialize the module""" self.name = name def compute_weight(self, module): + """Compute the weight""" weight = getattr(module, self.name + "_orig") u = getattr(module, self.name + "_u") size = weight.size() @@ -34,6 +37,7 @@ def compute_weight(self, module): @staticmethod def apply(module, name): + """Apply the spectral norm to the module""" fn = Spectral_Norm(name) weight = getattr(module, name) @@ -49,18 +53,21 @@ def apply(module, name): return fn def __call__(self, module, input): + """Call this module""" weight_sn, u = self.compute_weight(module) setattr(module, self.name, weight_sn) setattr(module, self.name + "_u", u) def spectral_norm(module, name="weight"): + """Apply the Spectral_Norm module to the input""" Spectral_Norm.apply(module, name) return module def spectral_init(module, gain=1): + """Initialize network and apply a spectral norm""" init.xavier_uniform_(module.weight, gain) if module.bias is not None: module.bias.data.zero_() @@ -69,24 +76,33 @@ def spectral_init(module, gain=1): def init_linear(linear): + """Initialize a linear layer""" init.xavier_uniform_(linear.weight) linear.bias.data.zero_() def init_conv(conv, glu=True): + """Initialize a convolutional layer""" init.xavier_uniform_(conv.weight) if conv.bias is not None: conv.bias.data.zero_() def leaky_relu(input): + """A leaky relu layer with negative_slove = 0.2""" return F.leaky_relu(input, negative_slope=0.2) class SelfAttention(nn.Module): - """ Self attention Layer""" - + """Self attention Layer""" def __init__(self, in_dim, activation=F.relu): + """ + Initialize the module + + Args: + in_dim: number of input channels + activation: activation function to use. Default is torch.nn.functional.relu + """ super(SelfAttention, self).__init__() self.chanel_in = in_dim self.activation = activation @@ -104,11 +120,13 @@ def __init__(self, in_dim, activation=F.relu): def forward(self, x): """ - inputs : + Compute the output of the self attention layer + + Args: x : input feature maps( B X C X W X H) - returns : + + Returns: out : self attention value + input feature - attention: B X N X N (N is Width*Height) """ m_batchsize, C, width, height = x.size() proj_query = ( @@ -127,7 +145,15 @@ def forward(self, x): class ConditionalNorm(nn.Module): + """Conditional Norm""" def __init__(self, in_channel, n_condition=148): + """ + Initialize the module + + Args: + in_channel: number of input channels + n_condition: size of second dimension of class_id + """ super().__init__() self.bn = nn.BatchNorm2d(in_channel, affine=False) @@ -137,6 +163,7 @@ def __init__(self, in_channel, n_condition=148): self.embed.weight.data[:, in_channel:] = 0 def forward(self, input, class_id): + """Compute the conditional norm""" out = self.bn(input) # print(class_id.dtype) # print('class_id', class_id.size()) # torch.Size([4, 148]) @@ -168,6 +195,19 @@ def __init__( upsample=True, downsample=False, ): + """ + Args: + in_channel: Number of input channels + out_channel: Number of output channels + kernel_size: kernel size in convolutional blocks. Default [3, 3] + padding: padding in convolutional blocks. default is 1 + stride: stride in convolutional blocks. default is 1 + n_class: not implemented. default is 1 + bn: whether to add conditional norm to the output. Default is True + activation: activation function to use. Default is torch.nn.functional.relu + upsample: whether to add upsampling. Default is True + downsample: whether to add downsampling. Default is False. + """ super().__init__() gain = 2 ** 0.5 @@ -197,6 +237,7 @@ def __init__( self.HyperBN_1 = ConditionalNorm(out_channel, 148) def forward(self, input, condition=None): + """Compute the forward pass""" out = input if self.bn: @@ -231,7 +272,15 @@ def forward(self, input, condition=None): class SpatialDiscriminator(nn.Module): + """A spatial discriminator""" def __init__(self, chn=128, n_class=4): + """ + Initialize the module + + Args: + chn: number of channels. default is 128 + n_class number of classes. default is 4 + """ super().__init__() self.pre_conv = nn.Sequential( @@ -261,6 +310,7 @@ def __init__(self, chn=128, n_class=4): self.embed = SpectralNorm(self.embed) def forward(self, x, class_id): + """Compute the foward pass""" # reshape input tensor from BxTxCxHxW to BTxCxHxW batch_size, T, C, W, H = x.size() @@ -311,11 +361,19 @@ def forward(self, x, class_id): def conv3x3x3(in_planes, out_planes, stride=1): - # 3x3x3 convolution with padding + """ + 3x3x3 convolution with padding + + Args: + in_planes: Number of channels in the input image + out_planes: Number of channels produced by the convolution + stride: default is 1. + """ return nn.Conv3d(in_planes, out_planes, kernel_size=3, stride=stride, padding=1, bias=False) class Res3dBlock(nn.Module): + """A Res3dBlock""" def __init__( self, in_channel, @@ -329,6 +387,21 @@ def __init__( upsample=True, downsample=False, ): + """ + Initialize the module + + Args: + in_channel: number of input channels + out_channel: number of output channels + kernel_size: default is [3, 3, 3] + padding: default is 1 + stride: default is 1 + n_class: not implemented. Default is None + bn: whether to add ConditionalNorm to the module. Default is True. + activation: activation function to use. Default is torch.nn.functional.relu + upsample: whether to add upsampling. Default is True + downsample: whether to add downsampling. Default is False. + """ super().__init__() gain = 2 ** 0.5 @@ -358,6 +431,7 @@ def __init__( self.HyperBN_1 = ConditionalNorm(out_channel, 148) def forward(self, input, condition=None): + """Compute the Res3dBlock""" out = input if self.bn: @@ -392,7 +466,15 @@ def forward(self, input, condition=None): class TemporalDiscriminator(nn.Module): + """A temporal discriminator""" def __init__(self, chn=128, n_class=4): + """ + Initialize the module + + Args: + chn: number of channels. default is 128 + n_class: number of classes. default is 4 + """ super().__init__() gain = 2 ** 0.5 @@ -422,6 +504,7 @@ def __init__(self, chn=128, n_class=4): self.embed = SpectralNorm(self.embed) def forward(self, x, class_id): + """Compute the forward pass""" # pre-process with avg_pool2d to reduce tensor size # B, T, C, H, W = x.size() # x = F.avg_pool2d(x.view(B * T, C, H, W), kernel_size=2) diff --git a/satflow/models/layers/Generator.py b/satflow/models/layers/Generator.py index 510ce2a6..c438ba98 100644 --- a/satflow/models/layers/Generator.py +++ b/satflow/models/layers/Generator.py @@ -1,3 +1,4 @@ +"""Layers for generators""" import torch import torch.nn as nn from torch.nn import functional as F @@ -13,7 +14,19 @@ class Generator(nn.Module): + """A generator layer""" def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hierar_flag=False): + """ + Initialize the module + + Args: + in_dim: input dimension. default is 120 + latent_dim: latent dimension. default is 4 + n_class: number of classes. default is 4 + ch: number of channels. default is 32 + n_frames: number of frames + hierar_flag: default is False + """ super().__init__() self.in_dim = in_dim @@ -71,7 +84,7 @@ def __init__(self, in_dim=120, latent_dim=4, n_class=4, ch=32, n_frames=48, hier self.colorize = SpectralNorm(nn.Conv2d(2 * ch, 3, kernel_size=(3, 3), padding=1)) def forward(self, x, class_id): - + """Compute the foward pass""" if self.hierar_flag is True: noise_emb = torch.split(x, self.in_dim, dim=1) else: diff --git a/satflow/models/layers/RUnetLayers.py b/satflow/models/layers/RUnetLayers.py index 7ebbf02c..a8f9fa8a 100644 --- a/satflow/models/layers/RUnetLayers.py +++ b/satflow/models/layers/RUnetLayers.py @@ -1,3 +1,4 @@ +"""Layers for RUNet""" import torch import torch.nn as nn from satflow.models.utils import get_conv_layer @@ -5,6 +6,14 @@ def init_weights(net, init_type="normal", gain=0.02): + """ + Initialize network weights + + Args: + net: network to be initialized + init_type: options are "normal", "xavier", "kaiming", "orthogonal" + gain: scaling factor. Default is 0.02 + """ def init_func(m): classname = m.__class__.__name__ if hasattr(m, "weight") and ( @@ -33,7 +42,20 @@ def init_func(m): class conv_block(nn.Module): + """ + Convolutional block + + A twice-repeated chain of a convolutional layer, batch normalization, and ReLU + """ def __init__(self, ch_in, ch_out, conv_type: str = "standard"): + """ + Initialize the block + + Args: + ch_in: number of input channels + ch_out: number of output channels + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(conv_block, self).__init__() conv2d = get_conv_layer(conv_type) self.conv = nn.Sequential( @@ -46,12 +68,32 @@ def __init__(self, ch_in, ch_out, conv_type: str = "standard"): ) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ x = self.conv(x) return x class up_conv(nn.Module): + """ + Convolutional block with upsampling + + A chain of an upsample layer with scale factor 2, a convolutional layer, + batch normalization, and ReLU + """ def __init__(self, ch_in, ch_out, conv_type: str = "standard"): + """ + Initialize the block + + Args: + ch_in: number of input channels + ch_out: number of output channels + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(up_conv, self).__init__() conv2d = get_conv_layer(conv_type) self.up = nn.Sequential( @@ -62,12 +104,32 @@ def __init__(self, ch_in, ch_out, conv_type: str = "standard"): ) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ x = self.up(x) return x class Recurrent_block(nn.Module): + """ + Recurrent block + + A repeated chain of a convolutional layer, batch normalization, and ReLU, + where the output of the previous step is added to the input for the next step. + """ def __init__(self, ch_out, t=2, conv_type: str = "standard"): + """ + Initialize the block + + Args: + ch_out: number of channels for input and output + t: number of steps. Default is 2. + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(Recurrent_block, self).__init__() conv2d = get_conv_layer(conv_type) self.t = t @@ -79,6 +141,12 @@ def __init__(self, ch_out, t=2, conv_type: str = "standard"): ) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ for i in range(self.t): if i == 0: @@ -89,7 +157,21 @@ def forward(self, x): class RRCNN_block(nn.Module): + """ + Recursive residual CNN block + + A chain of recurrent blocks with a skip connection of the input to the final output + """ def __init__(self, ch_in, ch_out, t=2, conv_type: str = "standard"): + """ + Initialize the block + + Args: + ch_out: number of input channels + ch_out: number of output channels + t: number of steps in the recurrent blocks. Default is 2. + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(RRCNN_block, self).__init__() conv2d = get_conv_layer(conv_type) self.RCNN = nn.Sequential( @@ -99,13 +181,32 @@ def __init__(self, ch_in, ch_out, t=2, conv_type: str = "standard"): self.Conv_1x1 = conv2d(ch_in, ch_out, kernel_size=1, stride=1, padding=0) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ x = self.Conv_1x1(x) x1 = self.RCNN(x) return x + x1 class single_conv(nn.Module): + """ + Convolutional block + + A chain of a convolutional layer, batch normalization, and ReLU + """ def __init__(self, ch_in, ch_out, conv_type: str = "standard"): + """ + Initialize the block + + Args: + ch_in: number of input channels + ch_out: number of output channels + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(single_conv, self).__init__() conv2d = get_conv_layer(conv_type) self.conv = nn.Sequential( @@ -115,6 +216,12 @@ def __init__(self, ch_in, ch_out, conv_type: str = "standard"): ) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ x = self.conv(x) return x diff --git a/satflow/models/layers/SpatioTemporalLSTMCell_memory_decoupling.py b/satflow/models/layers/SpatioTemporalLSTMCell_memory_decoupling.py index 63e26d14..0fff16f2 100644 --- a/satflow/models/layers/SpatioTemporalLSTMCell_memory_decoupling.py +++ b/satflow/models/layers/SpatioTemporalLSTMCell_memory_decoupling.py @@ -11,7 +11,19 @@ class SpatioTemporalLSTMCell(nn.Module): + """A SpatioTemporalLSTMCell. PredNN v2 adapted from https://github.com/thuml/predrnn-pytorch""" def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_norm): + """ + Initialize the module + + Args: + in_channel: number of input channels + num_hidden: size of hidden layer + width: width of image + filter_size: kernel size in convolutions + stride: stride in convolutions + layer_norm: whether to add LayerNorm after each convolution + """ super(SpatioTemporalLSTMCell, self).__init__() self.num_hidden = num_hidden @@ -108,6 +120,7 @@ def __init__(self, in_channel, num_hidden, width, filter_size, stride, layer_nor ) def forward(self, x_t, h_t, c_t, m_t): + """Compute the forward pass""" x_concat = self.conv_x(x_t) h_concat = self.conv_h(h_t) m_concat = self.conv_m(m_t) diff --git a/satflow/models/layers/TimeDistributed.py b/satflow/models/layers/TimeDistributed.py index a45b5a05..02cde120 100644 --- a/satflow/models/layers/TimeDistributed.py +++ b/satflow/models/layers/TimeDistributed.py @@ -1,25 +1,40 @@ +"""Apply a module over the time dimension identically for each step""" import torch import torch.nn as nn def _stack_tups(tuples, stack_dim=1): - "Stack tuple of tensors along `stack_dim`" + """Stack tuple of tensors along `stack_dim`""" return tuple( torch.stack([t[i] for t in tuples], dim=stack_dim) for i in list(range(len(tuples[0]))) ) class TimeDistributed(nn.Module): - "Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time." + """Applies `module` over `tdim` identically for each step, use `low_mem` to compute one at a time.""" def __init__(self, module, low_mem=False, tdim=1): + """ + Initialize the module + + Args: + module: a module to apply to the data + low_mem: if True, use a memory efficient implementation. default is False + tdim: The index of the input tensor that represents the time dimension. default is 1. + """ super().__init__() self.module = module self.low_mem = low_mem self.tdim = tdim def forward(self, *tensors, **kwargs): - "input x with shape:(bs,seq_len,channels,width,height)" + """ + Compute the forward pass + + Args: + tensors: shape(batch_size, seq_len, channels, width, height) + kwargs: key word arguments to the module being applied + """ if self.low_mem or self.tdim != 1: return self.low_mem_forward(*tensors, **kwargs) else: @@ -30,7 +45,13 @@ def forward(self, *tensors, **kwargs): return self.format_output(out, bs, seq_len) def low_mem_forward(self, *tensors, **kwargs): - "input x with shape:(bs,seq_len,channels,width,height)" + """ + Compute the forward pass with a memory efficient method + + Args: + tensors: shape(batch_size, seq_len, channels, width, height) + kwargs: key word arguments to the module being applied + """ seq_len = tensors[0].shape[self.tdim] args_split = [torch.unbind(x, dim=self.tdim) for x in tensors] out = [] @@ -41,10 +62,11 @@ def low_mem_forward(self, *tensors, **kwargs): return torch.stack(out, dim=self.tdim) def format_output(self, out, bs, seq_len): - "unstack from batchsize outputs" + """unstack from batchsize outputs""" if isinstance(out, tuple): return tuple(out_i.view(bs, seq_len, *out_i.shape[1:]) for out_i in out) return out.view(bs, seq_len, *out.shape[1:]) def __repr__(self): + """Print the name of the module being applied""" return f"TimeDistributed({self.module})" diff --git a/satflow/models/layers/__init__.py b/satflow/models/layers/__init__.py index 7aa14461..f0a7f633 100644 --- a/satflow/models/layers/__init__.py +++ b/satflow/models/layers/__init__.py @@ -1,3 +1,4 @@ +"""Different layers to be used in model architectures""" from .TimeDistributed import TimeDistributed from .ConvLSTM import ConvLSTMCell from .SpatioTemporalLSTMCell_memory_decoupling import SpatioTemporalLSTMCell diff --git a/satflow/models/perceiver.py b/satflow/models/perceiver.py index e8c25321..39b366bd 100644 --- a/satflow/models/perceiver.py +++ b/satflow/models/perceiver.py @@ -1,3 +1,4 @@ +"""General perception with iterative attention: https://arxiv.org/pdf/2103.03206.pdf""" from perceiver_pytorch import MultiPerceiver from perceiver_pytorch.modalities import InputModality from perceiver_pytorch.encoders import ImageEncoder @@ -32,6 +33,7 @@ @register_model class Perceiver(BaseModel): + """General perception with iterative attention: https://arxiv.org/pdf/2103.03206.pdf""" def __init__( self, input_channels: int = 22, @@ -71,6 +73,48 @@ def __init__( generate_fourier_features: bool = True, temporally_consistent_fourier_features: bool = False, ): + """ + Initialize the model + + Args: + input_channels: default is 22 + sat_channels: number of satellite channels. default is 12 + nwp_channels: default is 10 + base_channels: default is 1 + forecast_steps: number of timesteps to forecast. + history_steps: + input_size: + lr: learning_rate. default is 5e-4 + visualize: add a visualization step. default is True. + max_frequency: used in Fourier features. should be tuned based on how fine the data is + depth: depth of the network. default is 6 + num_latents: number of latents, or induced set points, or centroids. different papers giving it different names + cross_heads: number of heads for cross attention. default is 1. + latent_heads: number of heads for latent self attention. default is 8. + cross_dim_heads: number of dimensions per cross attention head. + latent_dim: latent dimension. default is 512. + weight_tie_layers: whether to weight tie layers. Default is False. + decoder_ff: use a feed-forward decoder. default is True + dim: dimension of sequence to be encoded + logits_dim: dimension of final logits + queries_dim: dimension of decoder queries + latent_dim_heads: number of dimensions per latent self attention head + loss: loss: name of the loss function or torch.nn.Module. Default is "mse" + sin_only: whether to only use sine for Fourier encoding + encoder_fourier: Whether to encode position with Fourier features + preprocessor_type: an optional preprocessing step. Default is None. + postprocessor_type: an optional preprocessing step. Default is None. + encoder_kwargs: arguments to pass to the ImageEncoder chosen in the preprocessing step + decoder_kwargs: arguments to pass to the ImageDecoder chosen in the postprocessing step + pretrained: Default is False. + predict_timesteps_together: predict future timesteps all at once. + Otherwise, iterate over each forecast step. Default is False + nwp_modality: Whether to add NWP data as an input. Default is False + datetime_modality: Whether to add datetime features as an input. Default is False. + use_learnable_query: Whether to use the LearnableQuery. Default is True + generate_fourier_features: whether to use fourier features in the LearnableQuery. Default is True. + temporally_consistent_fourier_features: Default is False. + """ super(BaseModel, self).__init__() self.forecast_steps = forecast_steps self.input_channels = input_channels @@ -266,6 +310,15 @@ def __init__( self.save_hyperparameters() def encode_inputs(self, x: dict) -> Dict[str, torch.Tensor]: + """ + Perform preprocessing steps (if specified) on input data + + Args: + x: a dictionary mapping keys to input data as tensors + + Returns: + A dictionary with processed tensors + """ video_inputs = x[SATELLITE_DATA] nwp_inputs = x.get(NWP_DATA, []) base_inputs = x.get( @@ -290,6 +343,13 @@ def encode_inputs(self, x: dict) -> Dict[str, torch.Tensor]: return x def add_timestep(self, batch_size: int, timestep: int = 1) -> torch.Tensor: + """ + Add a timestep input + + Args: + batch_size: dimension of of each batch + timestep: the index of future timestep that is being predicted + """ times = (torch.eye(self.forecast_steps)[timestep]).unsqueeze(-1).unsqueeze(-1) ones = torch.ones(1, 1, 1) timestep_input = times * ones @@ -299,6 +359,14 @@ def add_timestep(self, batch_size: int, timestep: int = 1) -> torch.Tensor: return timestep_input def _train_or_validate_step(self, batch, batch_idx, is_training: bool = True): + """ + Perform a step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + is_training: whether or not this is a training step. default is True. + """ x, y = batch batch_size = y[SATELLITE_DATA].size(0) # For each future timestep: @@ -338,6 +406,7 @@ def _train_or_validate_step(self, batch, batch_idx, is_training: bool = True): return loss def configure_optimizers(self): + """Get the optimizer and the learning rate scheduler for the initialized parameters""" # They use LAMB as the optimizer optimizer = optim.Lamb(self.parameters(), lr=self.lr, betas=(0.9, 0.999)) scheduler = LinearWarmupCosineAnnealingLR(optimizer, warmup_epochs=10, max_epochs=100) @@ -360,6 +429,12 @@ def configure_optimizers(self): return {"optimizer": optimizer, "lr_scheduler": lr_dict} def construct_query(self, x: dict): + """ + The LearnableQuery to pass to the model + + Args: + x: a dictionary mapping keys to input data as tensors + """ if self.use_learnable_query: if self.temporally_consistent_fourier_features: fourier_features = encode_position( @@ -395,4 +470,5 @@ def construct_query(self, x: dict): return y_query def forward(self, x, mask=None, query=None): + """A forward pass of the model""" return self.model.forward(x, mask=mask, queries=query) diff --git a/satflow/models/pix2pix.py b/satflow/models/pix2pix.py index 56e95cbb..e190b231 100644 --- a/satflow/models/pix2pix.py +++ b/satflow/models/pix2pix.py @@ -1,3 +1,4 @@ +"""A conditional GAN for image translation: https://arxiv.org/abs/1611.07004""" import torch import torchvision import numpy as np @@ -11,6 +12,7 @@ @register_model class Pix2Pix(pl.LightningModule): + """A conditional GAN for image translation: https://arxiv.org/abs/1611.07004""" def __init__( self, forecast_steps: int = 48, @@ -31,6 +33,29 @@ def __init__( channels_per_timestep: int = 12, pretrained: bool = False, ): + """ + Initialize the model + + Args: + forecast_steps: number of timesteps to forecast. default is 48. + input_channels: default is 12 + lr: learning rate. default is 0.0002 + beta1: first beta value for adam optimizer. default is 0.5 + beta2: second beta value for adam optimizer. default is 0.999 + num_filters: the number of filters in the last layer of the generator + and the first layer of the discriminator + generator_model: the architecture's name: resnet_9blocks | resnet_6blocks | unet_256 | unet_128 + norm: the name of normalization layers used in the network: batch | instance | none + use_dropout: whether to use dropout layers. default is False. + discriminator_model: the architecture's name: basic | n_layers | pixel + discriminator_layers: the number of conv layers in the discriminator; effective when discriminator_model=="n_layers" + loss: the type of GAN objective. It currently supports vanilla, lsgan, and wgangp. + scheduler: the method for the learning rate scheduler. One of "plateau" or "cosine" + lr_epochs: if scheduler == "cosine", this is the max number of iterations + lambda_l1: a penalty term that will be multiplied by the l1 loss. default is 100 + channels_per_timestep: used in the visualization step. the number of images per row. + pretrained: not implemented. default is False. + """ super().__init__() self.lr = lr self.b1 = beta1 @@ -64,9 +89,20 @@ def __init__( self.save_hyperparameters() def forward(self, x): + """A forward pass of the generator""" return self.generator(x) def visualize_step(self, x, y, y_hat, batch_idx, step): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + step: name of the step type. Default is "train" + """ # the logger you used (in this case tensorboard) tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time @@ -84,6 +120,17 @@ def visualize_step(self, x, y, y_hat, batch_idx, step): tensorboard.add_image(f"{step}/Generated_Image_Stack", image_grid, global_step=batch_idx) def training_step(self, batch, batch_idx, optimizer_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (images, future_images, future_masks) + batch_idx: used to visualize the results of the training step + optimizer_idx: the iteration number for the optimizer + + Returns: + A dictionary with information about the loss + """ images, future_images, future_masks = batch # train generator if optimizer_idx == 0: @@ -124,6 +171,16 @@ def training_step(self, batch, batch_idx, optimizer_idx): return output def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (images, future_images, future_masks) + batch_idx: used to visualize the results of the training step + + Returns: + A dictionary with information about the loss + """ images, future_images, future_masks = batch # generate images generated_images = self(images) @@ -158,6 +215,12 @@ def validation_step(self, batch, batch_idx): return output def configure_optimizers(self): + """ + Get the optimizers and the learning rate schedulers for the generator and discriminator + + Returns: + A tuple of [g_optimizer, d_optimizer], [g_scheduler, d_scheduler] + """ lr = self.lr b1 = self.b1 b2 = self.b2 diff --git a/satflow/models/pixel_cnn.py b/satflow/models/pixel_cnn.py index 4583fd4d..ab544d13 100644 --- a/satflow/models/pixel_cnn.py +++ b/satflow/models/pixel_cnn.py @@ -1,3 +1,4 @@ +"""A PixelCNN model: https://arxiv.org/pdf/1905.09272.pdf""" import torch import torch.nn.functional as F import pytorch_lightning as pl @@ -7,6 +8,7 @@ @register_model class PixelCNN(pl.LightningModule): + """A Pixel CNN model: https://arxiv.org/pdf/1905.09272.pdf""" def __init__( self, future_timesteps: int, @@ -16,6 +18,17 @@ def __init__( pretrained: bool = False, lr: float = 0.001, ): + """ + Initialize the model + + Args: + future_timesteps: not implemented + input_channels: default is 3 + num_layers: default is 5 + num_hidden: default is 64 + pretrained: not implemented. default is False + lr: learning rate. default is 0.001 + """ super(PixelCNN, self).__init__() self.lr = lr self.model = Pixcnn( @@ -24,6 +37,7 @@ def __init__( @classmethod def from_config(cls, config): + """Initialize PixelCNN model from configuration values""" return PixelCNN( future_timesteps=config.get("future_timesteps", 12), input_channels=config.get("in_channels", 12), @@ -34,14 +48,26 @@ def from_config(cls, config): ) def forward(self, x): + """A forward step of the model""" self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the training step + """ x, y = batch y_hat = self(x) # Generally only care about the center x crop, so the model can take into account the clouds in the area without @@ -51,6 +77,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch y_hat = self(x) val_loss = F.mse_loss(y_hat, y) @@ -58,6 +94,16 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch y_hat = self(x, self.forecast_steps) loss = F.mse_loss(y_hat, y) diff --git a/satflow/models/pl_metnet.py b/satflow/models/pl_metnet.py index 006b618e..55c51024 100644 --- a/satflow/models/pl_metnet.py +++ b/satflow/models/pl_metnet.py @@ -1,3 +1,4 @@ +"""A network for weather forecasting: https://arxiv.org/abs/2003.12140""" import einops import numpy as np import torch @@ -25,6 +26,7 @@ @register_model class LitMetNet(BaseModel): + """A network for weather forecasting: https://arxiv.org/abs/2003.12140""" def __init__( self, image_encoder: str = "downsampler", @@ -44,6 +46,27 @@ def __init__( visualize: bool = False, loss: str = "mse", ): + """ + Initialize the model + + Args: + image_encoder: the method for encoding the image. default is "downsampler" + input_channels: default is 12 + sat_channels: number of satellite channels + input_size: the size of the image to use for center crop. default is 256 + output_channels: default is 12 + hidden_dim: number of hidden dimensions. default is 64 + kernel_size: size of the convolutional kernel. default is 3 + num_layers: number of convolutional layers to have. default is 1 + num_att_layers: number of attention layers to ahve. default is 1 + head: the name of the final layer of the network. default is "identity" + forecast_steps: number of timesteps to forecast. default is 48. + temporal_dropout: dropout rate to use in the temporal encoder + lr: learning rate. default is 0.001 + pretrained: default is False. + visualize: add a visualization step. default is False + loss: the name of the loss function to use. default is "mse" + """ super(BaseModel, self).__init__() self.forecast_steps = forecast_steps self.input_channels = input_channels @@ -73,9 +96,11 @@ def __init__( self.save_hyperparameters() def forward(self, imgs, **kwargs) -> Any: + """Compute the forward pass""" return self.model(imgs) def configure_optimizers(self): + """Get the optimizer and the learning rate scheduler for the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() optimizer = torch.optim.Adam(self.parameters(), lr=self.lr) @@ -118,6 +143,14 @@ def _combine_data_sources(self, x: Dict[str, torch.Tensor]) -> torch.Tensor: return input_data def _train_or_validate_step(self, batch, batch_idx, is_training: bool = True): + """ + Perform a step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + is_training: whether or not this is a training step. default is True. + """ x, y = batch y[SATELLITE_DATA] = y[SATELLITE_DATA].float() diff --git a/satflow/models/runet.py b/satflow/models/runet.py index 81a70286..e063a87d 100644 --- a/satflow/models/runet.py +++ b/satflow/models/runet.py @@ -1,3 +1,4 @@ +"""A recurrent CNN (RUnet) and a recurrent residual CNN (R2Unet) based on U-Net""" import antialiased_cnns from satflow.models.layers.RUnetLayers import * import pytorch_lightning as pl @@ -10,6 +11,7 @@ @register_model class RUnet(pl.LightningModule): + """A recurrent CNN based on U-Net""" def __init__( self, input_channels: int = 12, @@ -21,6 +23,19 @@ def __init__( conv_type: str = "standard", pretrained: bool = False, ): + """ + Initialize the model + + Args: + input_channels: default is 12 + forecast_steps: number of timesteps to forecast. default is 48. + recurrent_steps: default is 2 + loss: name of the loss function or torch.nn.Module. Default is "mse" + lr: learning rate. default is 0.001 + visualize: add a visualization step. default is False + conv_type: one of "standard", "coord", "antialiased", or "3d" + pretrained: not implemented. default is False. + """ super().__init__() self.input_channels = input_channels self.forecast_steps = forecast_steps @@ -36,6 +51,7 @@ def __init__( @classmethod def from_config(cls, config): + """Initialize RUnet model from configuration values""" return RUnet( forecast_steps=config.get("forecast_steps", 12), input_channels=config.get("in_channels", 12), @@ -43,14 +59,26 @@ def from_config(cls, config): ) def forward(self, x): + """A forward step of the model""" return self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: used to visualize the results of the training step + + Returns: + The loss for the training step + """ x, y = batch x = x.float() y_hat = self(x) @@ -70,6 +98,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch x = x.float() y_hat = self(x) @@ -84,6 +122,16 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch x = x.float() y_hat = self(x) @@ -91,6 +139,16 @@ def test_step(self, batch, batch_idx): return loss def visualize_step(self, x, y, y_hat, batch_idx, step="train"): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + step: name of the step type. Default is "train" + """ tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time images = x[0].cpu().detach() @@ -108,7 +166,17 @@ def visualize_step(self, x, y, y_hat, batch_idx, step="train"): class R2U_Net(nn.Module): + """A recurrent residual CNN based on U-Net""" def __init__(self, img_ch=3, output_ch=1, t=2, conv_type: str = "standard"): + """ + Initialize the module + + Args: + img_ch: number of input channels. Default is 3 + output_ch: number of output channels. Default is 1 + t: number of recurrent steps. Default is 2 + conv_type: one of "standard", "coord", "antialiased", or "3d" + """ super(R2U_Net, self).__init__() if conv_type == "antialiased": self.Maxpool = nn.MaxPool2d(kernel_size=2, stride=1) @@ -147,6 +215,12 @@ def __init__(self, img_ch=3, output_ch=1, t=2, conv_type: str = "standard"): self.Conv_1x1 = nn.Conv2d(64, output_ch, kernel_size=1, stride=1, padding=0) def forward(self, x): + """ + Compute the forward pass + + Args: + x: shape(batch, channel, x_dim, y_dim) + """ # encoding path x1 = self.RRCNN1(x) diff --git a/satflow/models/unet.py b/satflow/models/unet.py index 29473c4d..944fae16 100644 --- a/satflow/models/unet.py +++ b/satflow/models/unet.py @@ -1,3 +1,4 @@ +"""A UNet CNN""" import torch import pytorch_lightning as pl from nowcasting_utils.models.base import register_model @@ -10,6 +11,7 @@ @register_model class Unet(pl.LightningModule): + """A UNet CNN""" def __init__( self, forecast_steps: int, @@ -22,6 +24,21 @@ def __init__( loss: Union[str, torch.nn.Module] = "mse", pretrained: bool = False, ): + """ + Initialize the model + + Args: + forecast_steps: number of timesteps to forecast. + input_channels: default is 3 + num_layers: default is 5 + hidden_dim: default is 64. + bilinear: Use bilinear interpolation for upsampling. + Default is False, which uses transposed convolutions. + lr: learning rate. default is 0.001 + visualize: add a visualization step. default is False + loss: loss: name of the loss function or torch.nn.Module. Default is "mse" + pretrained: Not implemented. Default is False + """ super(Unet, self).__init__() self.lr = lr self.input_channels = input_channels @@ -33,6 +50,7 @@ def __init__( @classmethod def from_config(cls, config): + """Initialize Unet model from configuration values""" return Unet( forecast_steps=config.get("forecast_steps", 12), input_channels=config.get("in_channels", 12), @@ -43,14 +61,26 @@ def from_config(cls, config): ) def forward(self, x): + """A forward step of the model""" return self.model.forward(x) def configure_optimizers(self): + """Get the optimizer with the initialized parameters""" # DeepSpeedCPUAdam provides 5x to 7x speedup over torch.optim.adam(w) # optimizer = torch.optim.adam() return torch.optim.Adam(self.parameters(), lr=self.lr) def training_step(self, batch, batch_idx): + """ + Perform a training step of the model + + Args: + batch: tuple of (x, y) + batch_idx: used to visualize the results of the training step + + Returns: + The loss for the training step + """ x, y = batch x = x.float() y_hat = self(x) @@ -70,6 +100,16 @@ def training_step(self, batch, batch_idx): return loss def validation_step(self, batch, batch_idx): + """ + Perform a validation step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the validation step + """ x, y = batch x = x.float() y_hat = self(x) @@ -84,6 +124,16 @@ def validation_step(self, batch, batch_idx): return val_loss def test_step(self, batch, batch_idx): + """ + Perform a testing step of the model + + Args: + batch: tuple of (x, y) + batch_idx: not implemented + + Returns: + The loss for the testing step + """ x, y = batch x = x.float() y_hat = self(x) @@ -91,6 +141,16 @@ def test_step(self, batch, batch_idx): return loss def visualize_step(self, x, y, y_hat, batch_idx, step="train"): + """ + Visualize the results of a step of the model + + Args: + x: input data + y: output + y_hat: prediction + batch_idx: (int) the global step to record for this batch + step: name of the step type. Default is "train" + """ tensorboard = self.logger.experiment[0] # Add all the different timesteps for a single prediction, 0.1% of the time images = x[0].cpu().detach() diff --git a/satflow/models/utils.py b/satflow/models/utils.py index 269de0af..4a4219de 100644 --- a/satflow/models/utils.py +++ b/satflow/models/utils.py @@ -1,3 +1,4 @@ +"""Utility functions for manipulating inputs""" import torch import einops import numpy as np @@ -5,6 +6,15 @@ def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: + """ + Get the desired convolutional layer + + Args: + conv_type: one of "standard", "coord", "antialiased", or "3d" + + Returns: + A convolutional layer + """ if conv_type == "standard": conv_layer = torch.nn.Conv2d elif conv_type == "coord": @@ -22,7 +32,17 @@ def get_conv_layer(conv_type: str = "standard") -> torch.nn.Module: def reverse_space_to_depth( frames: np.ndarray, temporal_block_size: int = 1, spatial_block_size: int = 1 ) -> np.ndarray: - """Reverse space to depth transform.""" + """ + Reverse space to depth transform. + + Args: + frames: input array + temporal_block_size: default is 1 + spatial_block_size: default is 1 + + Returns: + The transformed frames + """ if len(frames.shape) == 4: return einops.rearrange( frames, @@ -48,7 +68,17 @@ def reverse_space_to_depth( def space_to_depth( frames: np.ndarray, temporal_block_size: int = 1, spatial_block_size: int = 1 ) -> np.ndarray: - """Space to depth transform.""" + """ + Space to depth transform. + + Args: + frames: input array + temporal_block_size: default is 1 + spatial_block_size: default is 1 + + Returns: + The transformed frames + """ if len(frames.shape) == 4: return einops.rearrange( frames, diff --git a/satflow/run.py b/satflow/run.py index 828112d9..6f7729fb 100644 --- a/satflow/run.py +++ b/satflow/run.py @@ -1,3 +1,4 @@ +"""Command line entrypoint to train a satflow model from a config file""" import os os.environ["HYDRA_FULL_ERROR"] = "1" @@ -12,7 +13,17 @@ @hydra.main(config_path="configs/", config_name="config.yaml") def main(config: DictConfig): + """ + Train a satflow model + + https://hydra.cc/docs/intro/ + Args: + config: the configuration values will be provided by hydra based + on how the script is executed from the command line + + Returns: the output of model training + """ # Imports should be nested inside @hydra.main to optimize tab completion # Read more here: https://github.com/facebookresearch/hydra/issues/934 from satflow.core import utils diff --git a/satflow/version.py b/satflow/version.py index 260c070a..12c67b9f 100644 --- a/satflow/version.py +++ b/satflow/version.py @@ -1 +1,2 @@ +"""The version of the satflow package""" __version__ = "0.3.1" diff --git a/setup.py b/setup.py index 66816a72..b14cd8e4 100644 --- a/setup.py +++ b/setup.py @@ -1,3 +1,4 @@ +"""Setup file for the satflow package""" from distutils.core import setup from pathlib import Path