@@ -26,16 +26,20 @@ def __init__(
2626 self .pil_image = None
2727 self .has_latent = False
2828 w , h = image_shape
29- try :
30- comp_adjusted = TF .resize (comp .clone (), (h , w ))
31- except :
32- # comp_adjusted = comp.clone()
33- # Need to convert the latent to its image form
34- comp_adjusted = img_model .decode_tensor (comp .clone ())
29+ comp_adjusted = TF .resize (comp .clone (), (h , w ))
30+ # try:
31+ # comp_adjusted = TF.resize(comp.clone(), (h, w))
32+ # except:
33+ # # comp_adjusted = comp.clone()
34+ # # Need to convert the latent to its image form
35+ # comp_adjusted = img_model.decode_tensor(comp.clone())
3536 self .direct_loss = MSELoss (comp_adjusted , weight , stop , name , image_shape )
3637
3738 @torch .no_grad ()
3839 def set_comp (self , pil_image , device = DEVICE ):
40+ """
41+ sets the DIRECT loss anchor "comp" to the tensorized image.
42+ """
3943 logger .debug (type (pil_image ))
4044 self .pil_image = pil_image
4145 self .has_latent = False
@@ -47,6 +51,10 @@ def set_comp(self, pil_image, device=DEVICE):
4751
4852 @classmethod
4953 def convert_input (cls , input , img ):
54+ """
55+ Converts the input image tensor to the image representation of the image model.
56+ E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
57+ """
5058 logger .debug (type (input )) # pretty sure this is gonna be tensor
5159 # return input # this is the default MSE loss version
5260 return img .make_latent (input )
@@ -107,25 +115,62 @@ def get_loss(self, input, img):
107115 logger .debug (
108116 self .comp .shape
109117 ) # [1 1 1 1] -> from target image constructor when no input image provided
118+
119+ # why is the latent comp only set here? why not in the __init__ and set_comp?
110120 if not self .has_latent :
111121 # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
112122 # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
113123 latent = img .make_latent (self .pil_image )
114124 logger .debug (type (latent )) # EMAParametersDict
115125 logger .debug (type (self .comp )) # torch.Tensor
116126 with torch .no_grad ():
117- self .comp .set_ (latent .clone ())
127+ if type (latent ) == type (self .comp ):
128+ self .comp .set_ (latent .clone ())
129+ # else:
130+
118131 self .has_latent = True
132+
119133 l1 = super ().get_loss (img .get_latent_tensor (), img ) / 2
120134 l2 = self .direct_loss .get_loss (input , img ) / 10
121135 return l1 + l2
122136
123137
124138######################################################################
125139
140+ # fuck it, let's just make a dip latent loss from scratch.
141+
142+
143+ # The issue we're resolving here is that by inheriting from the MSELoss,
144+ # I can't easily set the comp to the parameters of the image model.
145+
146+ from pytti .LossAug .BaseLossClass import Loss
147+ from pytti .image_models .ema import EMAImage , EMAParametersDict
148+ from pytti .rotoscoper import Rotoscoper
149+
150+ import deep_image_prior
151+ import deep_image_prior .models
152+ from deep_image_prior .models import (
153+ get_hq_skip_net ,
154+ get_non_offset_params ,
155+ get_offset_params ,
156+ )
126157
127- class LatentLossGeneric (LatentLoss ):
128- # class LatentLoss(MSELoss):
158+
159+ def load_dip (input_depth , num_scales , offset_type , offset_groups , device ):
160+ dip_net = get_hq_skip_net (
161+ input_depth ,
162+ skip_n33d = 192 ,
163+ skip_n33u = 192 ,
164+ skip_n11 = 4 ,
165+ num_scales = num_scales ,
166+ offset_type = offset_type ,
167+ offset_groups = offset_groups ,
168+ ).to (device )
169+
170+ return dip_net
171+
172+
173+ class LatentLossDIP (Loss ):
129174 @torch .no_grad ()
130175 def __init__ (
131176 self ,
@@ -134,29 +179,109 @@ def __init__(
134179 stop = - math .inf ,
135180 name = "direct target loss" ,
136181 image_shape = None ,
182+ device = None ,
137183 ):
138- super ().__init__ (comp , weight , stop , name , image_shape )
184+ ##################################################################
185+ super ().__init__ (weight , stop , name , device )
186+ if image_shape is None :
187+ raise
188+ # height, width = comp.shape[-2:]
189+ # image_shape = (width, height)
190+ self .image_shape = image_shape
191+ self .register_buffer ("mask" , torch .ones (1 , 1 , 1 , 1 , device = self .device ))
192+ self .use_mask = False
193+ ##################################################################
139194 self .pil_image = None
140195 self .has_latent = False
141- w , h = image_shape
142- self .direct_loss = MSELoss (
143- TF .resize (comp .clone (), (h , w )), weight , stop , name , image_shape
196+ logger .debug (type (comp )) # inits to image tensor
197+ if comp is None :
198+ comp = self .default_comp ()
199+ if isinstance (comp , EMAParametersDict ):
200+ logger .debug ("initializing loss from latent" )
201+ self .register_module ("comp" , comp )
202+ self .has_latent = True
203+ else :
204+ w , h = image_shape
205+ comp_adjusted = TF .resize (comp .clone (), (h , w ))
206+ # try:
207+ # comp_adjusted = TF.resize(comp.clone(), (h, w))
208+ # except:
209+ # # comp_adjusted = comp.clone()
210+ # # Need to convert the latent to its image form
211+ # comp_adjusted = img_model.decode_tensor(comp.clone())
212+ self .direct_loss = MSELoss (comp_adjusted , weight , stop , name , image_shape )
213+
214+ ##################################################################
215+
216+ logger .debug (type (comp ))
217+
218+ @classmethod
219+ def default_comp (* args , ** kargs ):
220+ logger .debug ("default_comp" )
221+ device = kargs .get ("device" , "cuda" ) if torch .cuda .is_available () else "cpu"
222+ net = load_dip (
223+ input_depth = 32 ,
224+ num_scales = 7 ,
225+ offset_type = "none" ,
226+ offset_groups = 4 ,
227+ device = device ,
144228 )
229+ return EMAParametersDict (z = net , decay = 0.99 , device = device )
230+
231+ ###################################################################################
145232
146233 @torch .no_grad ()
147234 def set_comp (self , pil_image , device = DEVICE ):
235+ """
236+ sets the DIRECT loss anchor "comp" to the tensorized image.
237+ """
238+ logger .debug (type (pil_image ))
148239 self .pil_image = pil_image
149240 self .has_latent = False
150- self . direct_loss . set_comp (
151- pil_image . resize ( self .image_shape , Image .LANCZOS )
241+ im_resized = pil_image . resize (
242+ self .image_shape , Image .LANCZOS
152243 ) # to do: ResizeRight
244+ # self.direct_loss.set_comp(im_resized)
245+
246+ im_tensor = (
247+ TF .to_tensor (pil_image )
248+ .unsqueeze (0 )
249+ .to (device , memory_format = torch .channels_last )
250+ )
251+
252+ if hasattr (self , "direct_loss" ):
253+ self .direct_loss .set_comp (im_tensor )
254+ else :
255+ self .direct_loss = MSELoss (
256+ im_tensor , self .weight , self .stop , self .name , self .image_shape
257+ )
258+ # self.direct_loss.set_comp(im_resized)
259+
260+ @classmethod
261+ def convert_input (cls , input , img ):
262+ """
263+ Converts the input image tensor to the image representation of the image model.
264+ E.g. if img is VQGAN, then the input tensor is converted to the latent representation.
265+ """
266+ logger .debug (type (input )) # pretty sure this is gonna be tensor
267+ # return input # this is the default MSE loss version
268+ return img .make_latent (input )
153269
154270 @classmethod
155271 @vram_usage_mode ("Latent Image Loss" )
156272 @torch .no_grad ()
157273 def TargetImage (
158- cls , prompt_string , image_shape , pil_image = None , is_path = False , device = DEVICE
274+ cls ,
275+ prompt_string ,
276+ image_shape ,
277+ pil_image = None ,
278+ is_path = False ,
279+ device = DEVICE ,
280+ img_model = None ,
159281 ):
282+ logger .debug (
283+ type (pil_image )
284+ ) # None. emitted prior to do_run:559 but after parse_scenes:122. Why even use this constructor if no pil_image?
160285 text , weight , stop = parse (
161286 prompt_string , r"(?<!^http)(?<!s):|:(?!/)" , ["" , "1" , "-inf" ]
162287 )
@@ -168,24 +293,69 @@ def TargetImage(
168293 comp = (
169294 MSELoss .make_comp (pil_image )
170295 if pil_image is not None
171- else torch .zeros (1 , 1 , 1 , 1 , device = device )
296+ # else torch.zeros(1, 1, 1, 1, device=device)
297+ else cls .default_comp (img_model = img_model )
172298 )
173299 out = cls (comp , weight , stop , text + " (latent)" , image_shape )
174300 if pil_image is not None :
175301 out .set_comp (pil_image )
176- out .set_mask (mask )
302+ if (
303+ mask
304+ ): # this will break if there's no pil_image since the direct_loss won't be initialized
305+ out .set_mask (mask )
177306 return out
178307
179308 def set_mask (self , mask , inverted = False ):
180309 self .direct_loss .set_mask (mask , inverted )
181- super ().set_mask (mask , inverted )
310+ # super().set_mask(mask, inverted)
311+ # if device is None:
312+ device = self .device
313+ if isinstance (mask , str ) and mask != "" :
314+ if mask [0 ] == "-" :
315+ mask = mask [1 :]
316+ inverted = True
317+ if mask .strip ()[- 4 :] == ".mp4" :
318+ r = Rotoscoper (mask , self )
319+ r .update (0 )
320+ return
321+ mask = Image .open (fetch (mask )).convert ("L" )
322+ if isinstance (mask , Image .Image ):
323+ with vram_usage_mode ("Masks" ):
324+ mask = (
325+ TF .to_tensor (mask )
326+ .unsqueeze (0 )
327+ .to (device , memory_format = torch .channels_last )
328+ )
329+ if mask not in ["" , None ]:
330+ self .mask .set_ (mask if not inverted else (1 - mask ))
331+ self .use_mask = mask not in ["" , None ]
182332
183333 def get_loss (self , input , img ):
334+ logger .debug (type (input )) # Tensor
335+ logger .debug (input .shape ) # this is an image tensor
336+ logger .debug (type (img )) # DIPImage
337+ logger .debug (type (self .comp )) # EMAParametersDict
338+ # logger.debug(
339+ # self.comp.shape
340+ # ) # [1 1 1 1] -> from target image constructor when no input image provided
341+
342+ # why is the latent comp only set here? why not in the __init__ and set_comp?
184343 if not self .has_latent :
344+ raise
345+ # make_latent() encodes the image through a dummy class instance, returns the resulting fitted image representation
346+ # if get_image_tensor() is not implemented, then the returned 'latent' tensor is just the tensorized pil image
185347 latent = img .make_latent (self .pil_image )
348+ logger .debug (type (latent )) # EMAParametersDict
349+ logger .debug (type (self .comp )) # torch.Tensor
186350 with torch .no_grad ():
187- self .comp .set_ (latent .clone ())
351+ if type (latent ) == type (self .comp ):
352+ self .comp .set_ (latent .clone ())
353+ # else:
354+
188355 self .has_latent = True
356+
357+ estimated_image = self .comp .get_image_tensor ()
358+
189359 l1 = super ().get_loss (img .get_latent_tensor (), img ) / 2
190360 l2 = self .direct_loss .get_loss (input , img ) / 10
191361 return l1 + l2
0 commit comments