Skip to content
121 changes: 62 additions & 59 deletions minecraft_copilot_ml/metrics_graph.ipynb

Large diffs are not rendered by default.

50 changes: 23 additions & 27 deletions minecraft_copilot_ml/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,20 @@
import torch.nn.functional as F


class ResBlock(nn.Module):
def __init__(self, channels_count: int):
super(ResBlock, self).__init__()
self.channels_count = channels_count
self.block = nn.Sequential(
nn.Conv3d(channels_count, channels_count, kernel_size=3, padding=1),
nn.BatchNorm3d(channels_count),
nn.SiLU(),
)

def forward(self, x: torch.Tensor) -> torch.Tensor:
return x + self.block(x)


class VAE(pl.LightningModule):
def __init__(self, unique_blocks_dict, unique_counts_coefficients=None, latent_dim=64):
super(VAE, self).__init__()
Expand All @@ -25,18 +39,10 @@ def __init__(self, unique_blocks_dict, unique_counts_coefficients=None, latent_d
nn.Conv3d(1, len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
# lets go
nn.Flatten(),
Expand All @@ -53,21 +59,11 @@ def __init__(self, unique_blocks_dict, unique_counts_coefficients=None, latent_d
nn.LeakyReLU(),
nn.Unflatten(1, (len(unique_blocks_dict), 16, 16, 16)),
# lets go
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
nn.BatchNorm3d(len(unique_blocks_dict)),
nn.LeakyReLU(),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
ResBlock(len(unique_blocks_dict)),
nn.Conv3d(len(unique_blocks_dict), len(unique_blocks_dict), kernel_size=3, padding=1),
)

Expand Down