-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathdata.py
72 lines (49 loc) · 2.11 KB
/
data.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
import glob
import cv2
import random
import numpy as np
import pickle
import os
from torch.utils import data
class Dataset(data.Dataset):
def __init__(self, config):
super().__init__()
self.config = config
with open(config.imlist, 'rb') as f:
self.imlist = pickle.load(f, encoding='latin-1')
cloud_files = glob.glob(os.path.join(config.cloud_dir, '*.png'))
self.cloud_files = cloud_files
self.n_cloud = len(cloud_files)
def __getitem__(self, index):
rgb = cv2.imread(os.path.join(self.config.rgbnir_dir, 'RGB', str(self.imlist[index])), 1).astype(np.float32)
nir = cv2.imread(os.path.join(self.config.rgbnir_dir, 'NIR', str(self.imlist[index])), 0).astype(np.float32)
cloud = cv2.imread(self.cloud_files[random.randrange(self.n_cloud)], -1).astype(np.float32)
alpha = cloud[:, :, 3] / 255.
alpha = np.broadcast_to(alpha[:, :, None], alpha.shape + (3,))
cloud_rgb = (1. - alpha) * rgb + alpha * cloud[:, :, :3]
cloud_rgb = np.clip(cloud_rgb, 0., 255.)
cloud_mask = cloud[:, :, 3]
x = np.concatenate((cloud_rgb, nir[:, :, None]), axis=2)
t = np.concatenate((rgb, cloud_mask[:, :, None]), axis=2)
x = x / 127.5 - 1
t = t / 127.5 - 1
x = x.transpose(2, 0, 1)
t = t.transpose(2, 0, 1)
return x, t
def __len__(self):
return len(self.imlist)
class TestDataset(data.Dataset):
def __init__(self, test_dir):
super().__init__()
self.test_dir = test_dir
self.test_files = glob.glob(os.path.join(test_dir, 'RGB', '*.png'))
def __getitem__(self, index):
filename = os.path.basename(self.test_files[index])
cloud_rgb = cv2.imread(os.path.join(self.test_dir, 'RGB', filename), 1).astype(np.float32)
nir = cv2.imread(os.path.join(self.test_dir, 'NIR', filename), 0).astype(np.float32)
x = np.concatenate((cloud_rgb, nir[:, :, None]), axis=2)
x = x / 127.5 - 1
x = x.transpose(2, 0, 1)
return x, filename
def __len__(self):
return len(self.test_files)