Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

inference_fp16.py 4.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
  1. import argparse
  2. import cv2
  3. import glob
  4. import numpy as np
  5. import os
  6. import torch
  7. from drct.archs.DRCT_arch import *
  8. #from drct.data import *
  9. #from drct.models import *
  10. def main():
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument(
  13. '--model_path',
  14. type=str,
  15. default= # noqa: E251
  16. "/work/u1657859/DRCT/experiments/train_DRCT-L_SRx4_finetune_from_ImageNet_pretrain/models/DRCT-L.pth" # noqa: E501
  17. )
  18. parser.add_argument('--input', type=str, default='datasets/Set14/LRbicx4', help='input test image folder')
  19. parser.add_argument('--output', type=str, default='results/DRCT-L', help='output folder')
  20. parser.add_argument('--scale', type=int, default=4, help='scale factor: 1, 2, 3, 4')
  21. #parser.add_argument('--window_size', type=int, default=16, help='16')
  22. parser.add_argument('--tile', type=int, default=None, help='Tile size, None for no tile during testing (testing as a whole)')
  23. parser.add_argument('--tile_overlap', type=int, default=32, help='Overlapping of different tiles')
  24. args = parser.parse_args()
  25. device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
  26. # set up model (DRCT-L)
  27. model = DRCT(upscale=4, in_chans=3, img_size= 64, window_size= 16, compress_ratio= 3,squeeze_factor= 30,
  28. conv_scale= 0.01, overlap_ratio= 0.5, img_range= 1., depths= [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6],
  29. embed_dim= 180, num_heads= [6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6, 6], gc= 32,
  30. mlp_ratio= 2, upsampler= 'pixelshuffle', resi_connection= '1conv')
  31. model.load_state_dict(torch.load(args.model_path)['params'], strict=True)
  32. model.eval()
  33. model = model.to(device).half()
  34. print(model)
  35. window_size = 16
  36. os.makedirs(args.output, exist_ok=True)
  37. for idx, path in enumerate(sorted(glob.glob(os.path.join(args.input, '*')))):
  38. imgname = os.path.splitext(os.path.basename(path))[0]
  39. print('Testing', idx, imgname)
  40. # read image
  41. img = cv2.imread(path, cv2.IMREAD_COLOR).astype(np.float32) / 255.
  42. img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
  43. #img = torch.from_numpy(np.transpose(img[:, :, [2, 1, 0]], (2, 0, 1))).float()
  44. img = img.unsqueeze(0).half().to(device)
  45. #print(img.shape)
  46. # inference
  47. try:
  48. with torch.no_grad():
  49. #output = model(img)
  50. _, _, h_old, w_old = img.size()
  51. h_pad = (h_old // window_size + 1) * window_size - h_old
  52. w_pad = (w_old // window_size + 1) * window_size - w_old
  53. img = torch.cat([img, torch.flip(img, [2])], 2)[:, :, :h_old + h_pad, :]
  54. img = torch.cat([img, torch.flip(img, [3])], 3)[:, :, :, :w_old + w_pad]
  55. output = test(img, model, args, window_size)
  56. output = output[..., :h_old * args.scale, :w_old * args.scale]
  57. except Exception as error:
  58. print('Error', error, imgname)
  59. else:
  60. # save image
  61. output = output.data.squeeze().float().cpu().clamp_(0, 1).numpy()
  62. output = np.transpose(output[[2, 1, 0], :, :], (1, 2, 0))
  63. output = (output * 255.0).round().astype(np.uint8)
  64. cv2.imwrite(os.path.join(args.output, f'{imgname}_DRCT-L_X4.png'), output)
  65. def test(img_lq, model, args, window_size):
  66. if args.tile is None:
  67. # test the image as a whole
  68. output = model(img_lq)
  69. else:
  70. # test the image tile by tile
  71. b, c, h, w = img_lq.size()
  72. tile = min(args.tile, h, w)
  73. assert tile % window_size == 0, "tile size should be a multiple of window_size"
  74. tile_overlap = args.tile_overlap
  75. sf = args.scale
  76. stride = tile - tile_overlap
  77. h_idx_list = list(range(0, h-tile, stride)) + [h-tile]
  78. w_idx_list = list(range(0, w-tile, stride)) + [w-tile]
  79. E = torch.zeros(b, c, h*sf, w*sf).type_as(img_lq)
  80. W = torch.zeros_like(E)
  81. for h_idx in h_idx_list:
  82. for w_idx in w_idx_list:
  83. in_patch = img_lq[..., h_idx:h_idx+tile, w_idx:w_idx+tile]
  84. out_patch = model(in_patch)
  85. out_patch_mask = torch.ones_like(out_patch)
  86. E[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch)
  87. W[..., h_idx*sf:(h_idx+tile)*sf, w_idx*sf:(w_idx+tile)*sf].add_(out_patch_mask)
  88. output = E.div_(W)
  89. return output
  90. if __name__ == '__main__':
  91. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...