diff --git a/minecraft_copilot_ml/data_loader.py b/minecraft_copilot_ml/data_loader.py index 89d88ea..59bbba9 100644 --- a/minecraft_copilot_ml/data_loader.py +++ b/minecraft_copilot_ml/data_loader.py @@ -47,7 +47,7 @@ "10220.schematic", "5096.schematic", "14191.schematic", - "10188.schematic" + "10188.schematic", ] @@ -188,8 +188,10 @@ class MinecraftSchematicsDataset(Dataset): def __init__( self, schematics_list_files: List[str], + unique_blocks_dict: Dict[str, int], ) -> None: self.schematics_list_files = schematics_list_files + self.unique_blocks_dict = unique_blocks_dict def __len__(self) -> int: return len(self.schematics_list_files) @@ -211,7 +213,8 @@ def __getitem__(self, idx: int) -> MinecraftSchematicsDatasetItemType: random_y_height_value : random_y_height_value + minimum_height, random_roll_z_value : random_roll_z_value + minimum_depth, ] = True - return block_map, block_map_mask + block_map_int = np.vectorize(self.unique_blocks_dict.get)(block_map) + return block_map_int, block_map_mask def list_schematic_files_in_folder(path_to_schematics: str) -> list[str]: diff --git a/minecraft_copilot_ml/metrics_graph.ipynb b/minecraft_copilot_ml/metrics_graph.ipynb index b140785..162c239 100644 --- a/minecraft_copilot_ml/metrics_graph.ipynb +++ b/minecraft_copilot_ml/metrics_graph.ipynb @@ -344,7 +344,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.12.3" + "version": "3.8.19" } }, "nbformat": 4, diff --git a/minecraft_copilot_ml/notebook.ipynb b/minecraft_copilot_ml/notebook.ipynb index ec2c1f1..013d574 100644 --- a/minecraft_copilot_ml/notebook.ipynb +++ b/minecraft_copilot_ml/notebook.ipynb @@ -219,7 +219,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.10.6" + "version": "3.10.12" } }, "nbformat": 4, diff --git a/minecraft_copilot_ml/vae.py b/minecraft_copilot_ml/vae.py index 6f7bccd..39de46e 100644 --- a/minecraft_copilot_ml/vae.py +++ b/minecraft_copilot_ml/vae.py @@ -1,299 +1,163 @@ -from typing import Tuple, Dict -from sklearn.model_selection import train_test_split +import os +from typing import List, Tuple + +import numpy as np import torch import torch.nn as nn -import torch.nn.functional as F -import numpy as np -import lightning as pl -from improved_diffusion.unet import UNetModel, ClassicResBlock, Downsample, Upsample, conv_nd # type: ignore[import-untyped] +import torch.optim as optim +from sklearn.model_selection import train_test_split +from torch.utils.data import DataLoader -from minecraft_copilot_ml.data_loader import MinecraftSchematicsDatasetItemType +from minecraft_copilot_ml.data_loader import ( + MinecraftSchematicsDataset, + MinecraftSchematicsDatasetItemType, + get_working_files_and_unique_blocks, + list_schematic_files_in_folder, +) +# Define the Encoder class Encoder(nn.Module): - def __init__(self, channels: int, hidden_channels: int): + def __init__(self, input_dim: int, latent_dim: int): super(Encoder, self).__init__() - self.input_conv = conv_nd( - dims=3, in_channels=channels, out_channels=hidden_channels, kernel_size=3, stride=1, padding=1 - ) - self.res_1 = ClassicResBlock(channels=hidden_channels, dropout=0.1, out_channels=hidden_channels * 2, dims=3) - self.down_1 = Downsample(channels=hidden_channels * 2, use_conv=True, dims=3) - self.res_2 = ClassicResBlock( - channels=hidden_channels * 2, dropout=0.1, out_channels=hidden_channels * 3, dims=3 - ) - self.down_2 = Downsample(hidden_channels * 3, True, dims=3) - self.res_3 = ClassicResBlock( - channels=hidden_channels * 3, dropout=0.1, out_channels=hidden_channels * 4, dims=3 - ) - self.down_3 = Downsample(hidden_channels * 4, True, dims=3) - self.mu_logvar = conv_nd( - dims=3, - in_channels=hidden_channels * 4, - out_channels=hidden_channels * 4 * 2, - kernel_size=3, - stride=1, - padding=1, + self.encoder = nn.Sequential( + nn.Conv3d(input_dim, 32, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.Conv3d(32, 64, kernel_size=3, stride=2, padding=1), + nn.ReLU(), + nn.Flatten(), ) + self.fc_mu = nn.Linear(64 * 4 * 4 * 4, latent_dim) # Adjust dimensions + self.fc_logvar = nn.Linear(64 * 4 * 4 * 4, latent_dim) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: - x = self.input_conv(x) - x = self.res_1(x) - x = self.down_1(x) - x = self.res_2(x) - x = self.down_2(x) - x = self.res_3(x) - x = self.down_3(x) - mu_logvar = self.mu_logvar(x) - mu, logvar = torch.chunk(mu_logvar, 2, dim=1) + x = self.encoder(x) + mu = self.fc_mu(x) + logvar = self.fc_logvar(x) return mu, logvar +# Define the Decoder class Decoder(nn.Module): - def __init__(self, channels: int, hidden_channels: int, activation: nn.Module = nn.Sigmoid()): + def __init__(self, latent_dim: int, output_dim: int): super(Decoder, self).__init__() - self.res_1 = ClassicResBlock( - channels=hidden_channels * 4, dropout=0.1, out_channels=hidden_channels * 3, dims=3 - ) - self.up_1 = Upsample(hidden_channels * 3, True, dims=3) - self.res_2 = ClassicResBlock( - channels=hidden_channels * 3, dropout=0.1, out_channels=hidden_channels * 2, dims=3 + self.fc = nn.Linear(latent_dim, 64 * 4 * 4 * 4) # Adjust dimensions + self.decoder = nn.Sequential( + nn.ConvTranspose3d(64, 32, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.ReLU(), + nn.ConvTranspose3d(32, output_dim, kernel_size=3, stride=2, padding=1, output_padding=1), + nn.Sigmoid(), ) - self.up_2 = Upsample(hidden_channels * 2, True, dims=3) - self.res_3 = ClassicResBlock(channels=hidden_channels * 2, dropout=0.1, out_channels=hidden_channels, dims=3) - self.up_3 = Upsample(hidden_channels, True, dims=3) - self.output_conv = conv_nd( - dims=3, - in_channels=hidden_channels, - out_channels=channels, - kernel_size=3, - stride=1, - padding=1, - ) - self.activation = activation - def forward(self, x: torch.Tensor) -> torch.Tensor: - x = self.res_1(x) - x = self.up_1(x) - x = self.res_2(x) - x = self.up_2(x) - x = self.res_3(x) - x = self.up_3(x) - x = self.output_conv(x) - x = self.activation(x) + def forward(self, z: torch.Tensor) -> torch.Tensor: + z = self.fc(z).view(-1, 64, 4, 4, 4) # Adjust dimensions + x = self.decoder(z) return x +# Define the VAE class VAE(nn.Module): - def __init__(self, channels: int, hidden_channels: int): + def __init__(self, input_dim: int, latent_dim: int): super(VAE, self).__init__() - self.encoder = Encoder(channels, hidden_channels) - self.decoder = Decoder(channels, hidden_channels) + self.encoder = Encoder(input_dim, latent_dim) + self.decoder = Decoder(latent_dim, input_dim) def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: mu, logvar = self.encoder(x) - z = self.reparameterize(mu, logvar) - reconstructed = self.decoder(z) - return reconstructed, mu, logvar - - def reparameterize(self, mu: torch.Tensor, logvar: torch.Tensor) -> torch.Tensor: std = torch.exp(0.5 * logvar) - eps = torch.randn_like(std) - return mu + eps * std - - -class VAETrainer(pl.LightningModule): - def __init__(self, model: VAE, unique_blocks_dict: Dict[str, int]): # type: ignore[no-any-unimported] - super(VAETrainer, self).__init__() - self.model = model - self.recon_loss = nn.CrossEntropyLoss(reduction="none") - self.unique_blocks_dict = unique_blocks_dict - - def forward(self, x: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: - return self.model(x) # type: ignore[no-any-return] - - def pre_process(self, x: np.ndarray) -> torch.Tensor: - vectorized_x = np.vectorize(lambda x: self.unique_blocks_dict.get(x, self.unique_blocks_dict["minecraft:air"]))( - x - ) - vectorized_x = vectorized_x.astype(np.int64) - x_tensor = torch.from_numpy(vectorized_x) - x_tensor = x_tensor.to("cuda" if torch.cuda.is_available() else "cpu") - x_tensor = F.one_hot(x_tensor, num_classes=len(self.unique_blocks_dict)).permute(0, 4, 1, 2, 3).float() - return x_tensor - - def configure_optimizers(self) -> torch.optim.Optimizer: - return torch.optim.AdamW(self.parameters(), lr=1e-4) - - def training_step(self, batch: MinecraftSchematicsDatasetItemType, batch_idx: int) -> torch.Tensor: - return self.step(batch, batch_idx, "train") - - def validation_step(self, batch: MinecraftSchematicsDatasetItemType, batch_idx: int) -> torch.Tensor: - return self.step(batch, batch_idx, "val") - - def step(self, batch: MinecraftSchematicsDatasetItemType, batch_idx: int, mode: str) -> torch.Tensor: - block_maps, block_map_masks = batch - pre_processed_block_maps = self.pre_process(block_maps) - recon_x, mu, logvar = self(pre_processed_block_maps) - BCE = self.recon_loss(recon_x, pre_processed_block_maps) - tensor_block_map_masks = torch.from_numpy(block_map_masks).to(self.device) - BCE = BCE * tensor_block_map_masks - BCE = BCE.mean() - KLD = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp(), dim=1) - KLD = KLD.mean() - loss = BCE + KLD - accuracy = (recon_x.argmax(dim=1) == pre_processed_block_maps.argmax(dim=1)).float() - accuracy = accuracy * tensor_block_map_masks - accuracy = accuracy.mean() - loss_dict = { - "loss": loss, - "loss_bce": BCE, - "loss_kld": KLD, - "accuracy": accuracy, - "learning_rate": self.trainer.optimizers[0].param_groups[0]["lr"], - } - for name, value in loss_dict.items(): - self.log( - f"{mode}_{name}", - value, - on_step=True, - on_epoch=True, - prog_bar=True, - logger=True, - batch_size=block_maps.shape[0], - ) - return loss # type: ignore[no-any-return] - - def on_train_start(self) -> None: - print(self) - -import argparse -import json -import os -import subprocess -from typing import List, Optional, Set, Tuple - -import boto3 -import lightning as pl -import numpy as np -import torch -from improved_diffusion.unet import UNetModel # type: ignore[import-untyped] -from lightning.pytorch.callbacks import ModelCheckpoint, TQDMProgressBar -from lightning.pytorch.loggers import CSVLogger -from loguru import logger -from torch.utils.data import DataLoader - -from minecraft_copilot_ml.data_loader import ( - MinecraftSchematicsDataset, - MinecraftSchematicsDatasetItemType, - get_working_files_and_unique_blocks, - list_schematic_files_in_folder, -) -from minecraft_copilot_ml.model import MinecraftCopilotTrainer - -if torch.cuda.is_available(): - device_name = torch.cuda.get_device_name() - if device_name is not None and device_name == "GeForce RTX 3090": - torch.set_float32_matmul_precision("medium") -else: - logger.warning("No CUDA device found.") - - -def main(argparser: argparse.ArgumentParser) -> None: - path_to_schematics: str = argparser.parse_args().path_to_schematics - path_to_output: str = argparser.parse_args().path_to_output - epochs: int = argparser.parse_args().epochs - batch_size: int = argparser.parse_args().batch_size - dataset_limit: Optional[int] = argparser.parse_args().dataset_limit - dataset_start: Optional[int] = argparser.parse_args().dataset_start - - if not os.path.exists(path_to_output): - os.makedirs(path_to_output) - - sync_command = f"aws s3 sync s3://minecraft-schematics-raw {path_to_schematics} --acl public-read --no-sign-request" - subprocess.run(sync_command, shell=True, check=True) - + z = mu + std * torch.randn_like(std) # Reparameterization trick + recon_x = self.decoder(z) + return recon_x, mu, logvar + + +# Define the masked loss function +def masked_vae_loss( + recon_x: torch.Tensor, + x: torch.Tensor, + mu: torch.Tensor, + logvar: torch.Tensor, + mask: torch.Tensor, +) -> torch.Tensor: + # Ensure mask matches recon_x dimensions + mask = mask.float() + recon_loss = ((recon_x - x) ** 2) * mask + recon_loss = recon_loss.sum() / mask.sum() # Normalize by the number of valid voxels + kl_loss = -0.5 * torch.sum(1 + logvar - mu.pow(2) - logvar.exp()) + return recon_loss + kl_loss + + +# Training loop +def train_vae( + model: VAE, + dataloader: DataLoader, + optimizer: optim.Adam, + epochs: int = 10, + device: str = "cuda", +) -> None: + model.to(device) + model.train() + for epoch in range(epochs): + total_loss: float = 0 + for batch in dataloader: + x, mask = batch # Ensure your DataLoader returns (data, mask) + x, mask = torch.from_numpy(x), torch.from_numpy(mask) + x, mask = x.to(device), mask.to(device) + x, mask = x.float(), mask.float() + optimizer.zero_grad() + recon_x, mu, logvar = model(x) + loss = masked_vae_loss(recon_x, x, mu, logvar, mask) + loss.backward() + optimizer.step() + total_loss += loss.item() + print(f"Epoch {epoch + 1}, Loss: {total_loss / len(dataloader)}") + + +# Main +if __name__ == "__main__": + latent_dim = 16 + batch_size = 32 + epochs = 10 + learning_rate = 1e-3 + + # Replace with your dataset + limit = 2 + path_to_schematics = "/home/mehdi/minecraft-copilot-ml/schematics_data_2_3_10" schematics_list_files = list_schematic_files_in_folder(path_to_schematics) schematics_list_files = sorted(schematics_list_files) - start = 0 - end = len(schematics_list_files) - if dataset_start is not None: - start = dataset_start - if dataset_limit is not None: - end = dataset_limit - schematics_list_files = schematics_list_files[start:end] - # Set the dictionary size to the number of unique blocks in the dataset. - # And also select the right files to load. unique_blocks_dict, loaded_schematic_files = get_working_files_and_unique_blocks(schematics_list_files) + input_dim = len(unique_blocks_dict) - logger.info(f"Unique blocks: {unique_blocks_dict}") - logger.info(f"Number of unique blocks: {len(unique_blocks_dict)}") - logger.info(f"Number of loaded schematics files: {len(loaded_schematic_files)}") - - train_loaded_schematic_files, test_loaded_schematic_files = train_test_split( - loaded_schematic_files, test_size=0.2, random_state=42 - ) - - train_schematics_dataset = MinecraftSchematicsDataset(train_loaded_schematic_files) - test_schematics_dataset = MinecraftSchematicsDataset(test_loaded_schematic_files) + dataset = MinecraftSchematicsDataset(schematics_list_files[:limit], unique_blocks_dict) + train_dataset, test_dataset = train_test_split(dataset, test_size=0.2, random_state=42) - def collate_fn(batch: List[MinecraftSchematicsDatasetItemType]) -> MinecraftSchematicsDatasetItemType: + def collate_fn( + batch: List[MinecraftSchematicsDatasetItemType], + ) -> MinecraftSchematicsDatasetItemType: block_map, block_map_mask = zip(*batch) return np.stack(block_map), np.stack(block_map_mask) - num_workers = os.cpu_count() - if num_workers is None: - num_workers = 0 - - train_schematics_dataloader = DataLoader( - train_schematics_dataset, + train_dataloader = DataLoader( + train_dataset, batch_size=batch_size, shuffle=True, + num_workers=0, collate_fn=collate_fn, - num_workers=num_workers, ) - test_schematics_dataloader = DataLoader( - test_schematics_dataset, + test_dataloader = DataLoader( + test_dataset, batch_size=batch_size, shuffle=False, + num_workers=0, collate_fn=collate_fn, - num_workers=num_workers, ) - unet_model = VAE( - channels=len(unique_blocks_dict), - hidden_channels=32, - ) - model = VAETrainer(unet_model, unique_blocks_dict) - csv_logger = CSVLogger(save_dir=path_to_output) - model_checkpoint = ModelCheckpoint(path_to_output, save_last=True, mode="min") - trainer = pl.Trainer( - logger=csv_logger, - callbacks=[model_checkpoint], - max_epochs=epochs, - log_every_n_steps=1, - accelerator="gpu" if torch.cuda.is_available() else "auto", - ) - trainer.fit(model, train_schematics_dataloader, test_schematics_dataloader) + vae = VAE(input_dim=input_dim, latent_dim=latent_dim) + optimizer = optim.Adam(vae.parameters(), lr=learning_rate) - # Save the best and last model locally - last_model = VAETrainer.load_from_checkpoint( - model_checkpoint.last_model_path, - unet_model=unet_model, - unique_blocks_dict=unique_blocks_dict, - save_dir=path_to_output, + train_vae( + vae, + train_dataloader, + optimizer, + epochs=epochs, + device="cuda" if torch.cuda.is_available() else "cpu", ) - torch.save(last_model, os.path.join(path_to_output, "last_model.pth")) - with open(os.path.join(path_to_output, "unique_blocks_dict.json"), "w") as f: - json.dump(unique_blocks_dict, f) - - -if __name__ == "__main__": - argparser = argparse.ArgumentParser() - argparser.add_argument("--path-to-schematics", type=str, required=True) - argparser.add_argument("--path-to-output", type=str, required=True) - argparser.add_argument("--epochs", type=int, required=True) - argparser.add_argument("--batch-size", type=int, required=True) - argparser.add_argument("--dataset-limit", type=int) - argparser.add_argument("--dataset-start", type=int) - - main(argparser)