| import os |
| import numpy as np |
| import torch |
| import torch.utils.data as data |
| from torch.utils.data import Dataset |
| from PIL import Image |
| from copy import deepcopy |
| import shutil |
| import json |
|
|
| def InfiniteSampler(n): |
| """Data sampler""" |
| i = n - 1 |
| order = np.random.permutation(n) |
| while True: |
| yield order[i] |
| i += 1 |
| if i >= n: |
| np.random.seed() |
| order = np.random.permutation(n) |
| i = 0 |
|
|
|
|
| class InfiniteSamplerWrapper(data.sampler.Sampler): |
| """Data sampler wrapper""" |
| def __init__(self, data_source): |
| self.num_samples = len(data_source) |
|
|
| def __iter__(self): |
| return iter(InfiniteSampler(self.num_samples)) |
|
|
| def __len__(self): |
| return 2 ** 31 |
|
|
|
|
| def copy_G_params(model): |
| flatten = deepcopy(list(p.data for p in model.parameters())) |
| return flatten |
| |
|
|
| def load_params(model, new_param): |
| for p, new_p in zip(model.parameters(), new_param): |
| p.data.copy_(new_p) |
|
|
|
|
| def get_dir(args): |
| task_name = 'train_results/' + args.name |
| saved_model_folder = os.path.join( task_name, 'models') |
| saved_image_folder = os.path.join( task_name, 'images') |
| |
| os.makedirs(saved_model_folder, exist_ok=True) |
| os.makedirs(saved_image_folder, exist_ok=True) |
|
|
| for f in os.listdir('./'): |
| if '.py' in f: |
| shutil.copy(f, task_name+'/'+f) |
| |
| with open( os.path.join(saved_model_folder, '../args.txt'), 'w') as f: |
| json.dump(args.__dict__, f, indent=2) |
|
|
| return saved_model_folder, saved_image_folder |
|
|
|
|
| class ImageFolder(Dataset): |
| """docstring for ArtDataset""" |
| def __init__(self, root, transform=None): |
| super( ImageFolder, self).__init__() |
| self.root = root |
|
|
| self.frame = self._parse_frame() |
| self.transform = transform |
|
|
| def _parse_frame(self): |
| frame = [] |
| img_names = os.listdir(self.root) |
| img_names.sort() |
| for i in range(len(img_names)): |
| image_path = os.path.join(self.root, img_names[i]) |
| if image_path[-4:] == '.jpg' or image_path[-4:] == '.png' or image_path[-5:] == '.jpeg': |
| frame.append(image_path) |
| return frame |
|
|
| def __len__(self): |
| return len(self.frame) |
|
|
| def __getitem__(self, idx): |
| file = self.frame[idx] |
| img = Image.open(file).convert('RGB') |
| |
| if self.transform: |
| img = self.transform(img) |
|
|
| return img |
|
|
|
|
|
|
| from io import BytesIO |
| import lmdb |
| from torch.utils.data import Dataset |
|
|
|
|
| class MultiResolutionDataset(Dataset): |
| def __init__(self, path, transform, resolution=256): |
| self.env = lmdb.open( |
| path, |
| max_readers=32, |
| readonly=True, |
| lock=False, |
| readahead=False, |
| meminit=False, |
| ) |
|
|
| if not self.env: |
| raise IOError('Cannot open lmdb dataset', path) |
|
|
| with self.env.begin(write=False) as txn: |
| self.length = int(txn.get('length'.encode('utf-8')).decode('utf-8')) |
|
|
| self.resolution = resolution |
| self.transform = transform |
|
|
| def __len__(self): |
| return self.length |
|
|
| def __getitem__(self, index): |
| with self.env.begin(write=False) as txn: |
| key = f'{self.resolution}-{str(index).zfill(5)}'.encode('utf-8') |
| img_bytes = txn.get(key) |
| |
| |
|
|
| buffer = BytesIO(img_bytes) |
| img = Image.open(buffer) |
| img = self.transform(img) |
|
|
| return img |
|
|
|
|