-
Notifications
You must be signed in to change notification settings - Fork 7
/
Copy pathpredict.py
79 lines (62 loc) · 2.43 KB
/
predict.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import numpy as np
import argparse
from tqdm import tqdm
import yaml
from attrdict import AttrMap
import torch
from torch.autograd import Variable
from torch.utils.data import DataLoader
from data import TestDataset
from utils import gpu_manage, save_image
from models.gen.unet import UNet
def predict(config, args):
gpu_manage(args)
dataset = TestDataset(args.test_dir)
data_loader = DataLoader(dataset=dataset, num_workers=config.threads, batch_size=1, shuffle=False)
### MODELS LOAD ###
print('===> Loading models')
if config.gen_model == 'unet':
gen = UNet(in_ch=config.in_ch, out_ch=config.out_ch, gpu_ids=args.gpu_ids)
param = torch.load(args.pretrained)
gen.load_state_dict(param)
if args.cuda:
gen = gen.cuda(0)
with torch.no_grad():
for i, batch in enumerate(tqdm(data_loader)):
x = Variable(batch[0])
filename = batch[1][0]
if args.cuda:
x = x.cuda()
out = gen(x)
h = 1
w = 4
c = 3
p = config.size
allim = np.zeros((h, w, c, p, p))
x_ = x.cpu().numpy()[0]
out_ = out.cpu().numpy()[0]
in_rgb = x_[:3]
in_nir = x_[3]
out_rgb = np.clip(out_[:3], -1, 1)
out_cloud = np.clip(out_[3], -1, 1)
allim[0, 0, :] = np.repeat(in_nir[None, :, :], repeats=3, axis=0) * 127.5 + 127.5
allim[0, 1, :] = in_rgb * 127.5 + 127.5
allim[0, 2, :] = out_rgb * 127.5 + 127.5
allim[0, 3, :] = np.repeat(out_cloud[None, :, :], repeats=3, axis=0) * 127.5 + 127.5
allim = allim.transpose(0, 3, 1, 4, 2)
allim = allim.reshape((h*p, w*p, c))
save_image(args.out_dir, allim, i, 1, filename=filename)
if __name__ == '__main__':
parser = argparse.ArgumentParser()
parser.add_argument('--config', type=str, required=True)
parser.add_argument('--test_dir', type=str, required=True)
parser.add_argument('--out_dir', type=str, required=True)
parser.add_argument('--pretrained', type=str, required=True)
parser.add_argument('--cuda', action='store_true')
parser.add_argument('--gpu_ids', type=int, default=[0])
parser.add_argument('--manualSeed', type=int, default=0)
args = parser.parse_args()
with open(args.config, 'r') as f:
config = yaml.load(f)
config = AttrMap(config)
predict(config, args)