diff --git a/models_mae.py b/models_mae.py index 880e28f822..d973b1b4a9 100644 --- a/models_mae.py +++ b/models_mae.py @@ -27,6 +27,7 @@ def __init__(self, img_size=224, patch_size=16, in_chans=3, decoder_embed_dim=512, decoder_depth=8, decoder_num_heads=16, mlp_ratio=4., norm_layer=nn.LayerNorm, norm_pix_loss=False): super().__init__() + self.in_chans = in_chans # -------------------------------------------------------------------------- # MAE encoder specifics @@ -94,30 +95,30 @@ def _init_weights(self, m): def patchify(self, imgs): """ - imgs: (N, 3, H, W) - x: (N, L, patch_size**2 *3) + imgs: (N, in_chans, H, W) + x: (N, L, patch_size**2 * in_chans) """ p = self.patch_embed.patch_size[0] assert imgs.shape[2] == imgs.shape[3] and imgs.shape[2] % p == 0 h = w = imgs.shape[2] // p - x = imgs.reshape(shape=(imgs.shape[0], 3, h, p, w, p)) + x = imgs.reshape(shape=(imgs.shape[0], self.in_chans, h, p, w, p)) x = torch.einsum('nchpwq->nhwpqc', x) - x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * 3)) + x = x.reshape(shape=(imgs.shape[0], h * w, p**2 * self.in_chans)) return x def unpatchify(self, x): """ - x: (N, L, patch_size**2 *3) - imgs: (N, 3, H, W) + x: (N, L, patch_size**2 * in_chans) + imgs: (N, in_chans, H, W) """ p = self.patch_embed.patch_size[0] h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] - x = x.reshape(shape=(x.shape[0], h, w, p, p, 3)) + x = x.reshape(shape=(x.shape[0], h, w, p, p, self.in_chans)) x = torch.einsum('nhwpqc->nchpwq', x) - imgs = x.reshape(shape=(x.shape[0], 3, h * p, h * p)) + imgs = x.reshape(shape=(x.shape[0], self.in_chans, h * p, h * p)) return imgs def random_masking(self, x, mask_ratio): @@ -197,8 +198,8 @@ def forward_decoder(self, x, ids_restore): def forward_loss(self, imgs, pred, mask): """ - imgs: [N, 3, H, W] - pred: [N, L, p*p*3] + imgs: [N, in_chans, H, W] + pred: [N, L, p*p*in_chans] mask: [N, L], 0 is keep, 1 is remove, """ target = self.patchify(imgs) @@ -215,7 +216,7 @@ def forward_loss(self, imgs, pred, mask): def forward(self, imgs, mask_ratio=0.75): latent, mask, ids_restore = self.forward_encoder(imgs, mask_ratio) - pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*3] + pred = self.forward_decoder(latent, ids_restore) # [N, L, p*p*in_chans] loss = self.forward_loss(imgs, pred, mask) return loss, pred, mask