-
Notifications
You must be signed in to change notification settings - Fork 3
Expand file tree
/
Copy pathutils.py
More file actions
72 lines (52 loc) · 2.69 KB
/
utils.py
File metadata and controls
72 lines (52 loc) · 2.69 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import torch
from transformers import PreTrainedTokenizer
def find_substring_token_indices(prompt: str, substr: str, tokenizer: PreTrainedTokenizer, model="pixart"):
assert model in ["pixart", "sana", "flux"], f"Model {model} not supported"
prompt_tokens = tokenizer(prompt).input_ids
if model == "pixart":
substr_tokens = tokenizer(substr).input_ids[:-1]
elif model == "sana":
if "in the style of" in prompt: # TODO: Clean code [for sana, " X" and "X" resul in different tokens]
substr = " " + substr
else:
substr = substr
substr_tokens = tokenizer(substr).input_ids[1:]
elif model == "flux":
substr_tokens = tokenizer(substr).input_ids[1:-1]
else:
raise ValueError(f"Model {model} not recognized.")
start_idx = -1
for i in range(len(prompt_tokens) - len(substr_tokens) + 1):
if prompt_tokens[i:i+len(substr_tokens)] == substr_tokens:
start_idx = i
break
assert start_idx != -1, "substr_tokens not found in tokens"
token_indices = list(range(start_idx, start_idx + len(substr_tokens)))
if tokenizer.decode([prompt_tokens[token_idx] for token_idx in token_indices]) != substr:
print("============================ Warning ============================")
print("[Warning] tokenizer.decode([prompt_tokens[token_idx] for token_idx in token_indices]) != substr")
print(f"[Warning] Decoded: {tokenizer.decode([prompt_tokens[token_idx] for token_idx in token_indices])}")
print(f"[Warning] Expected: {substr}")
print("=================================================================")
return token_indices
def latents_to_images(pipe, latents):
latents = latents.to(pipe.vae.dtype)
with torch.no_grad():
images = pipe.vae.decode(latents / pipe.vae.config.scaling_factor, return_dict=False)[0]
images = pipe.image_processor.postprocess(images, output_type="pil")
return images
def get_worker_list_chunk(arr, num_workers, worker_idx, print_log=True):
arr_len = len(arr)
chunk_size = (arr_len + num_workers - 1) // num_workers
start_index = chunk_size * worker_idx
end_index = min((worker_idx + 1) * chunk_size, arr_len)
if print_log:
print(f"Choosing chunk ({start_index}:{end_index})")
print(f"First prompt of the chunk: \"{arr[start_index]}\"")
print(f"Last prompt of the chunk: \"{arr[end_index-1]}\"")
return arr[start_index:end_index]
def print_arguments(args):
print("===================== Arguments =====================")
for key, value in vars(args).items():
print(f"{key}: {value}")
print("=====================================================")