-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathdata.py
84 lines (67 loc) · 2.74 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# data.py
# by Umair Khan
# CS 410 - Spring 2020
# Define a custom PyTorch dataset incorporating the
# channel manipulation and transformations described
# in the original paper.
# Imports
import glob
import torchvision.transforms as transforms
from torch.utils.data import Dataset
from PIL import Image, ImageFilter
# Sub-image crop size, as defined in paper
CROP = 33
# Class definition
class SRDataset(Dataset):
# Dataset setup.
# Arguments:
# - dirname - directory to generate dataset from
# - zoom - zoom factor for the dataset
def __init__(self, dirname, zoom):
# Initialize superclass
super(SRDataset, self).__init__()
# Derive a valid crop from the zoom factor
crop = CROP - (CROP % zoom)
# Get a list of filepaths for images in folder and initialize image list
self.files = glob.glob(dirname + "*")
self.images = []
# Define transforms for network inputs
# (subsample and interpolate to synthesize low-resolution)
self.tf_in = transforms.Compose([transforms.CenterCrop(crop),
transforms.Resize(crop // zoom),
transforms.Resize(crop, interpolation = Image.BICUBIC),
transforms.ToTensor()])
# Define transforms from network outputs
# (no manipulation needed)
self.tf_out = transforms.Compose([transforms.CenterCrop(crop),
transforms.ToTensor()])
# Loop through filepaths, load images, and decompose into patches
for filename in self.files:
raw = self.load(filename)
for i in range(0, raw.height - crop + 1, 14):
for j in range(0, raw.width - crop + 1, 14):
self.images.append(raw.crop((j, i, j + crop, i + crop)))
# Function to load the Y channel (YCbCr) of the image from the given filepath.
# Arguments:
# - filepath - path to image file
def load(self, filepath):
raw = Image.open(filepath).convert("YCbCr")
y, cb, cr = raw.split()
return y
# Hook function to retrieve item from dataset.
# Arguments:
# - i - index to retrieve
def __getitem__(self, i):
# Get the image
input_img = self.images[i]
output_img = input_img.copy()
# Apply transformations
# (paper specifies Gaussian blur on input)
input_img = input_img.filter(ImageFilter.GaussianBlur(1))
input_img = self.tf_in(input_img)
output_img = self.tf_out(output_img)
# Return results
return input_img, output_img
# Hook function to retreive dataset length.
def __len__(self):
return len(self.images)