Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -164,4 +164,5 @@ cython_debug/
*.pth
data/
model_weights/
*lock
*lock
.DS_Store
52 changes: 34 additions & 18 deletions MAESTER/infer_engine.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,8 @@
import argparse
import os
import time
import torch
import sys
from model import *
from torch.nn import functional as F
import time
import numpy as np


def define_embed_idx(side_patch_num, central_patch_num):
Expand All @@ -14,8 +12,8 @@ def define_embed_idx(side_patch_num, central_patch_num):
mid_point = int(side_patch_num // 2)
mid_centeral_num = int(central_patch_num // 2)
selected_idx = reshape_idx[
mid_point - mid_centeral_num : mid_point + mid_centeral_num,
mid_point - mid_centeral_num : mid_point + mid_centeral_num,
mid_point - mid_centeral_num: mid_point + mid_centeral_num,
mid_point - mid_centeral_num: mid_point + mid_centeral_num,
]

return selected_idx.flatten()
Expand All @@ -33,12 +31,13 @@ def define_idx(central_patch_num, pix_size):
return index_h, index_w


def place_res(res_stack, target_tensor, anchor_h, anchor_w, index_d, index_h, index_w, embed_idx):
def place_res(
res_stack, target_tensor, anchor_h, anchor_w, index_d, index_h, index_w,
embed_idx):
# exclude cls token and only select the central patch
res_stack = res_stack[:, 1:][:, embed_idx.long()]
col_num = res_stack.shape[0]


h = index_h + anchor_h
h = h.repeat(col_num, 1)
w = index_w + anchor_w
Expand All @@ -52,7 +51,7 @@ def place_res(res_stack, target_tensor, anchor_h, anchor_w, index_d, index_h, in

def split(a, n):
k, m = divmod(len(a), n)
return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n))
return (a[i * k + min(i, m): (i + 1) * k + min(i + 1, m)] for i in range(n))


def run_inference(
Expand All @@ -77,7 +76,7 @@ def run_inference(
cfg["DATASET"]["patch_size"] + 1
if cfg["DATASET"]["patch_size"] % 2 == 0
else cfg["DATASET"]["patch_size"]
)
)
patch_size = cfg["DATASET"]["patch_size"]
OFFSET = patch_size // 2
PAD = int(
Expand All @@ -90,8 +89,10 @@ def run_inference(

scr = F.pad(scr, [PAD, PAD, PAD, PAD, 0, 0], "constant", 0).to(device)

t1 = time.time()
print("preparing tensor storage...")
feature_storage = torch.FloatTensor(
torch.FloatStorage.from_file(
torch.from_file(
embedding_storage_path,
shared=True,
size=orgD * (orgH+PAD) * (orgW+PAD) * cfg["MODEL"]["embed"],
Expand All @@ -100,31 +101,41 @@ def run_inference(

_, adjH, adjW = scr.shape


with torch.no_grad():
index_h, index_w = define_idx(central_patch, pix_size)
iter_list = list(split(range(0, orgD), ngpus_per_node))[rank]
embed_idx = define_embed_idx(target_size // patch_size, patch_size)

# print(f"rank {rank} ====> getting partition: [{iter_list[0]}, {iter_list[-1]}]")

iter_list = torch.tensor(iter_list)
d_list = iter_list.split(int(cfg["DATASET"]["batch_size"] * 2))

for i_d in d_list:
_win = vol_size // patch_size
h_patch = np.floor(adjH - _win + 1) - _win
w_patch = np.floor(adjW - _win + 1) - _win
total_patch = int((h_patch * w_patch + 2 * vol_size + 2 * _win) / _win)

for b, i_d in enumerate(d_list):
print(f"\ninferring batch #{b + 1} of {len(d_list)}")
h_count = 0
i_h = 0
patch_count = 1

while i_h < adjH - vol_size + 1:
i_w = 0
w_count = 0

while i_w < adjW - vol_size + 1:
print(f"\t patch #{patch_count} of ~ {total_patch}", end="\r")
sample = scr[
i_d, i_h : i_h + vol_size, i_w : i_w + vol_size
i_d, i_h: i_h + vol_size, i_w: i_w + vol_size
].to(device)
sample = sample.unsqueeze(1)
rep = model.infer_latent(sample)[:, 1:, :]
place_res(rep, feature_storage, i_h, i_w, i_d, index_h, index_w, embed_idx)
place_res(
rep, feature_storage, i_h, i_w, i_d, index_h, index_w, embed_idx
)
w_count += 1
patch_count += 1

del sample
del rep
Expand All @@ -148,4 +159,9 @@ def run_inference(
else:
i_h += 1

print(f"Rank {rank}: Inference finished!")
total_patch = patch_count # ;D

print(
f"\n\nRank {rank}: Inference finished in "
f"{(time.time() - t1) / 60: .2f} minutes."
)
64 changes: 33 additions & 31 deletions examples/inference_demo.ipynb

Large diffs are not rendered by default.