| | import torch |
| | from torch.utils.data import Dataset |
| | import os |
| | from natsort import natsorted |
| | import cv2 |
| | import glob |
| | import numpy as np |
| | from PIL import Image |
| | from skimage import io as img |
| |
|
| | class ImageAndMaskData(Dataset): |
| |
|
| | def __init__(self, img_dir, mask_dir, transform=None): |
| |
|
| | |
| | self.images = natsorted(glob.glob(img_dir + "/*")) |
| | self.masks = natsorted(glob.glob(mask_dir + "/*")) |
| |
|
| | self.imgs_and_masks = list(zip(self.images, self.masks)) |
| |
|
| | self.transform = transform |
| |
|
| | def __len__(self): |
| |
|
| | return len(self.imgs_and_masks) |
| |
|
| | def __getitem__(self, idx): |
| |
|
| | data = self.imgs_and_masks[idx] |
| |
|
| | img_path = data[0] |
| | mask_path = data[1] |
| |
|
| | |
| | img = np.array(Image.open(img_path)) |
| | mask = np.array(Image.open(mask_path))[:,:,0:1] |
| | |
| | |
| |
|
| | sample = np.concatenate((img, mask), axis=2) |
| | |
| |
|
| | |
| |
|
| | sample = Image.fromarray(sample) |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | if self.transform: |
| | sample = self.transform(sample) |
| | |
| |
|
| |
|
| | return sample |
| |
|
| |
|
| | |
| |
|
| | def make_4_chs_img(image_path, mask_path): |
| | im = img.imread(image_path) |
| | mask = img.imread(mask_path) |
| |
|
| | |
| | mask = (mask > 127)*255 |
| | |
| | |
| |
|
| | return np.concatenate((im, mask[:,:,0:1]), axis=2) |
| |
|
| | def norm(x): |
| | out = (x -0.5) *2 |
| | return out.clamp(-1, 1) |
| |
|
| | def denorm(x): |
| | out = (x + 1) / 2 |
| | return out.clamp(0, 1) |
| |
|
| | def np2torch(x): |
| | |
| | x = x[:,:,:] |
| | x = x.transpose((2, 0, 1))/255 |
| | |
| | x = torch.from_numpy(x) |
| | |
| | |
| | |
| | x = x.type(torch.FloatTensor) |
| | |
| | x = norm(x) |
| | return x |
| |
|
| |
|
| |
|
| | class ImageAndMaskDataFromSinGAN(Dataset): |
| |
|
| | def __init__(self, img_dir, mask_dir, transform=None): |
| |
|
| | |
| | self.images = natsorted(glob.glob(img_dir + "/*")) |
| | self.masks = natsorted(glob.glob(mask_dir + "/*")) |
| |
|
| | self.imgs_and_masks = list(zip(self.images, self.masks)) |
| |
|
| | self.transform = transform |
| |
|
| | def __len__(self): |
| |
|
| | return len(self.imgs_and_masks) |
| |
|
| | def __getitem__(self, idx): |
| |
|
| | data = self.imgs_and_masks[idx] |
| |
|
| | image_path = data[0] |
| | mask_path = data[1] |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| |
|
| | |
| |
|
| | sample = make_4_chs_img(image_path, mask_path) |
| |
|
| | sample = np2torch(sample) |
| |
|
| | sample = sample[0:4,:,:] |
| | |
| | |
| |
|
| | |
| | |
| |
|
| |
|
| | |
| |
|
| | |
| | |
| | if self.transform: |
| | sample = self.transform(sample) |
| | |
| |
|
| |
|
| | return sample |
| |
|
| |
|
| |
|
| |
|
| | if __name__ == "__main__": |
| |
|
| | dataset = ImageAndMaskDataFromSinGAN("/work/vajira/DATA/kvasir_seg/real_images_root/real_images", |
| | "/work/vajira/DATA/kvasir_seg/real_masks_root/real_masks") |
| |
|
| | print(dataset[1].shape) |
| |
|
| | |
| |
|
| |
|
| |
|