From 93e82e791625e8af3a5fac1f9d5a0e785d74b293 Mon Sep 17 00:00:00 2001 From: Zhang Jiayuan Date: Tue, 13 May 2025 18:26:12 +0800 Subject: [PATCH] update window compatibility MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit changed a few lines of code to replace the original ‘/’ which does not work on windows --- tinyimagenetloader.py | 13 +++++++++++-- 1 file changed, 11 insertions(+), 2 deletions(-) diff --git a/tinyimagenetloader.py b/tinyimagenetloader.py index d5362f2..af5e59b 100644 --- a/tinyimagenetloader.py +++ b/tinyimagenetloader.py @@ -39,7 +39,13 @@ def __getitem__(self, idx): image = read_image(img_path) if image.shape[0] == 1: image = read_image(img_path,ImageReadMode.RGB) - label = self.id_dict[img_path.split('/')[4]] + # label = self.id_dict[img_path.split('/')[4]] + + norm_path = os.path.normpath(img_path) + path_parts = norm_path.split(os.path.sep) + class_id = path_parts[-3] + label = self.id_dict[class_id] + if self.transform: image = self.transform(image.type(torch.FloatTensor)) return image, label @@ -64,7 +70,10 @@ def __getitem__(self, idx): image = read_image(img_path) if image.shape[0] == 1: image = read_image(img_path,ImageReadMode.RGB) - label = self.cls_dic[img_path.split('/')[-1]] + # label = self.cls_dic[img_path.split('/')[-1]] + + filename = os.path.basename(img_path) + label = self.cls_dic[filename] if self.transform: image = self.transform(image.type(torch.FloatTensor)) return image, label