forked from leeruibin/RORem
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathinference_RORem.py
More file actions
134 lines (123 loc) · 4.21 KB
/
inference_RORem.py
File metadata and controls
134 lines (123 loc) · 4.21 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
from diffusers import AutoPipelineForInpainting
import torch
import os
from diffusers import UNet2DConditionModel
import argparse
from myutils.img_util import dilate_mask
from diffusers.utils import load_image
def parse_args():
parser = argparse.ArgumentParser(description="Simple example of a training script.")
parser.add_argument(
"--pretrained_model",
type=str,
default="diffusers/stable-diffusion-xl-1.0-inpainting-0.1",
help="Path to pretrained LDM teacher model or model identifier from huggingface.co/models.",
)
parser.add_argument(
"--RORem_unet",
type=str,
default=None,
required=True,
help="Path to pretrain RORem Unet",
)
parser.add_argument(
"--image_path",
type=str,
default=None,
help="Path to the input image.",
)
parser.add_argument(
"--mask_path",
type=str,
default=None,
help="Path to the mask image.",
)
parser.add_argument(
"--save_path",
type=str,
default=None,
help="Path to save the removal result.",
)
parser.add_argument(
"--inference_steps",
type=int,
default=25,
)
parser.add_argument(
"--resolution",
default=512,
type=int
)
parser.add_argument(
"--dilate_size",
default=20,
type=int,
help="dilate the mask"
)
parser.add_argument(
"--use_CFG",
type=lambda x: x.lower() == 'true',
default=True,
help="whether to enable CFG, can reduce the artifacts in the mask region in our final test"
)
args = parser.parse_args()
return args
def main(args):
if args.pretrained_model is None:
pretrain_path = "diffusers/stable-diffusion-xl-1.0-inpainting-0.1"
else:
pretrain_path = args.pretrained_model
# load pretrained SDXL-inpainting model
pipe_edit = AutoPipelineForInpainting.from_pretrained(
pretrain_path,
torch_dtype=torch.float16,
)
# load RORem Unet
unet = UNet2DConditionModel.from_pretrained(args.RORem_unet).to("cuda",dtype=torch.float16)
print(f"Finish loading unet from {args.RORem_unet}!!")
pipe_edit.unet = unet
pipe_edit.to("cuda")
height = width = args.resolution
image_name = args.image_path.split("/")[-1]
if args.save_path is None:
save_folder = "removal_result"
os.makedirs(save_folder,exist_ok=True)
args.save_path = f"{save_folder}/{image_name}"
else:
save_folder = os.path.dirname(args.save_path)
os.makedirs(save_folder,exist_ok=True)
input_image = load_image(args.input_path).resize((args.resolution,args.resolution))
input_mask = load_image(args.mask_path).resize((args.resolution,args.resolution))
if args.dilate_size != 0:
mask_image = dilate_mask(mask_image,args.dilate_size)
if not args.use_CFG:
prompts = ""
Removal_result = pipe_edit(
prompt=prompts,
height=height,
width=width,
image=input_image,
mask_image=input_mask,
guidance_scale=1.,
num_inference_steps=50, # steps between 15 and 30 also work well
strength=0.99, # make sure to use `strength` below 1.0
).images[0]
else:
# we also find by adding these prompt, the model can work even better
prompts = "4K, high quality, masterpiece, Highly detailed, Sharp focus, Professional, photorealistic, realistic"
negative_prompts = "low quality, worst, bad proportions, blurry, extra finger, Deformed, disfigured, unclear background"
Removal_result = pipe_edit(
prompt=prompts,
negative_prompt=negative_prompts,
height=height,
width=width,
image=input_image,
mask_image=input_mask,
guidance_scale=1.,
num_inference_steps=50, # steps between 15 and 30 also work well
strength=0.99, # make sure to use `strength` below 1.0
).images[0]
Removal_result.save(save_folder)
if __name__ == "__main__":
args = parse_args()
main(args)