diff --git a/depth_anything_v2/dpt.py b/depth_anything_v2/dpt.py index 18d3e6f..c44b535 100644 --- a/depth_anything_v2/dpt.py +++ b/depth_anything_v2/dpt.py @@ -215,7 +215,6 @@ def image2tensor(self, raw_image, input_size=518): image = transform({'image': image})['image'] image = torch.from_numpy(image).unsqueeze(0) - DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' - image = image.to(DEVICE) + image = image.to(next(self.parameters()).device) return image, (h, w) diff --git a/metric_depth/depth_anything_v2/dpt.py b/metric_depth/depth_anything_v2/dpt.py index 6541304..9918b93 100644 --- a/metric_depth/depth_anything_v2/dpt.py +++ b/metric_depth/depth_anything_v2/dpt.py @@ -216,7 +216,6 @@ def image2tensor(self, raw_image, input_size=518): image = transform({'image': image})['image'] image = torch.from_numpy(image).unsqueeze(0) - DEVICE = 'cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else 'cpu' - image = image.to(DEVICE) + image = image.to(next(self.parameters()).device) return image, (h, w)