Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

REAL_DRCT_GAN_ONNX inference_fp16.py 5.5 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
  1. import argparse
  2. import cv2
  3. import glob
  4. import numpy as np
  5. import os
  6. import onnxruntime
  7. import time
  8. import math
  9. from tqdm import tqdm
  10. def main():
  11. parser = argparse.ArgumentParser()
  12. parser.add_argument('--model_path', type=str, default='model.onnx', help='Path to the ONNX model')
  13. parser.add_argument('--input', type=str, default='input', help='Input folder with images')
  14. parser.add_argument('--output', type=str, default='output', help='Output folder')
  15. parser.add_argument('--scale', type=int, default=4, help='Upscaling factor')
  16. parser.add_argument('--tile_size', type=int, default=512, help='Tile size for processing')
  17. parser.add_argument('--tile_pad', type=int, default=32, help='Padding around tiles')
  18. args = parser.parse_args()
  19. # Load the ONNX model with CUDA Execution Provider
  20. ort_session = onnxruntime.InferenceSession(args.model_path, providers=['CUDAExecutionProvider'])
  21. input_name = ort_session.get_inputs()[0].name
  22. # Create output folder if it doesn't exist
  23. os.makedirs(args.output, exist_ok=True)
  24. # Process each image in the input folder
  25. for image_path in tqdm(glob.glob(os.path.join(args.input, '*')), desc="Processing images", unit="image"):
  26. # Load image and normalize
  27. img = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.float32) / 255.0
  28. original_height, original_width = img.shape[:2]
  29. # Upscale image using tiling
  30. output_img = tile_process(img, ort_session, input_name, args.scale, args.tile_size, args.tile_pad)
  31. # Convert to uint8 and save the upscaled image
  32. output_img = (output_img * 255.0).round().astype(np.uint8)
  33. output_img = cv2.cvtColor(output_img, cv2.COLOR_BGR2RGB) # Convert BGR to RGB
  34. # Construct output filename with suffix and .png extension
  35. filename, _ = os.path.splitext(os.path.basename(image_path))
  36. output_filename = f"{filename}_REAL_GAN_DRCT.png"
  37. cv2.imwrite(os.path.join(args.output, output_filename), output_img)
  38. def tile_process(img, ort_session, input_name, scale, tile_size, tile_pad):
  39. """Processes the image in tiles to avoid OOM errors."""
  40. height, width = img.shape[:2]
  41. output_height = height * scale
  42. output_width = width * scale
  43. output_shape = (output_height, output_width, 3)
  44. # Start with black image
  45. output_img = np.zeros(output_shape, dtype=np.float32)
  46. # Calculate number of tiles
  47. tiles_x = math.ceil(width / tile_size)
  48. tiles_y = math.ceil(height / tile_size)
  49. # Loop over all tiles
  50. for y in range(tiles_y):
  51. for x in range(tiles_x):
  52. # Extract tile from input image
  53. ofs_x = x * tile_size
  54. ofs_y = y * tile_size
  55. input_start_x = ofs_x
  56. input_end_x = min(ofs_x + tile_size, width)
  57. input_start_y = ofs_y
  58. input_end_y = min(ofs_y + tile_size, height)
  59. # Input tile area on total image with padding
  60. input_start_x_pad = max(input_start_x - tile_pad, 0)
  61. input_end_x_pad = min(input_end_x + tile_pad, width)
  62. input_start_y_pad = max(input_start_y - tile_pad, 0)
  63. input_end_y_pad = min(input_end_y + tile_pad, height)
  64. # Input tile dimensions
  65. input_tile_width = input_end_x - input_start_x
  66. input_tile_height = input_end_y - input_start_y
  67. tile_idx = y * tiles_x + x + 1
  68. input_tile = img[input_start_y_pad:input_end_y_pad, input_start_x_pad:input_end_x_pad, :]
  69. # Pad tile to be divisible by scaling factor
  70. input_tile = pad_image(input_tile, 16)
  71. # Convert to BGR, transpose to CHW, and add batch dimension
  72. input_tile = np.transpose(input_tile[:, :, [2, 1, 0]], (2, 0, 1))
  73. input_tile = np.expand_dims(input_tile, axis=0).astype(np.float16)
  74. # Run inference
  75. output_tile = ort_session.run(None, {input_name: input_tile})[0]
  76. # Post-process the output tile
  77. output_tile = np.clip(output_tile, 0, 1)
  78. output_tile = np.transpose(output_tile[0, :, :, :], (1, 2, 0))
  79. # Output tile area on total image
  80. output_start_x = input_start_x * scale
  81. output_end_x = input_end_x * scale
  82. output_start_y = input_start_y * scale
  83. output_end_y = input_end_y * scale
  84. # Output tile area without padding
  85. output_start_x_tile = (input_start_x - input_start_x_pad) * scale
  86. output_end_x_tile = output_start_x_tile + input_tile_width * scale
  87. output_start_y_tile = (input_start_y - input_start_y_pad) * scale
  88. output_end_y_tile = output_start_y_tile + input_tile_height * scale
  89. # Put tile into output image
  90. output_img[output_start_y:output_end_y, output_start_x:output_end_x, :] = output_tile[
  91. output_start_y_tile:output_end_y_tile, output_start_x_tile:output_end_x_tile, :
  92. ]
  93. print(f'\tTile {tile_idx}/{tiles_x * tiles_y}')
  94. return output_img
  95. def pad_image(img, factor):
  96. """Pads the image to be divisible by the given factor using reflection padding."""
  97. height, width = img.shape[:2]
  98. pad_height = (factor - (height % factor)) % factor
  99. pad_width = (factor - (width % factor)) % factor
  100. return cv2.copyMakeBorder(img, 0, pad_height, 0, pad_width, cv2.BORDER_REFLECT_101)
  101. if __name__ == '__main__':
  102. main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...