Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

video_inference.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
  1. import torch
  2. import cv2
  3. from matplotlib import pyplot as plt
  4. from loss.loss_discriminator import *
  5. from loss.loss_generator import *
  6. from network.blocks import *
  7. from network.model import *
  8. from webcam_demo.webcam_extraction_conversion import *
  9. from params.params import path_to_chkpt
  10. from tqdm import tqdm
  11. """Init"""
  12. #Paths
  13. path_to_model_weights = 'finetuned_model.tar'
  14. path_to_embedding = 'e_hat_video.tar'
  15. path_to_mp4 = 'test_vid2.webm'
  16. device = torch.device("cuda:0")
  17. cpu = torch.device("cpu")
  18. checkpoint = torch.load(path_to_model_weights, map_location=cpu)
  19. e_hat = torch.load(path_to_embedding, map_location=cpu)
  20. e_hat = e_hat['e_hat'].to(device)
  21. G = Generator(256, finetuning=True, e_finetuning=e_hat)
  22. G.eval()
  23. """Training Init"""
  24. G.load_state_dict(checkpoint['G_state_dict'])
  25. G.to(device)
  26. """Main"""
  27. print('PRESS Q TO EXIT')
  28. cap = cv2.VideoCapture(path_to_mp4)
  29. n_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
  30. fps = int(cap.get(cv2.CAP_PROP_FPS))
  31. ret = True
  32. i = 0
  33. size = (256*3,256)
  34. #out = cv2.VideoWriter('project.mp4',cv2.VideoWriter_fourcc('M','P','4','2'), 30, size)
  35. video = cv2.VideoWriter('project.mp4',cv2.VideoWriter_fourcc(*'DIVX'), fps, size)
  36. with torch.no_grad():
  37. while ret:
  38. x, g_y, ret = generate_landmarks(cap=cap, device=device, pad=50)
  39. if ret:
  40. g_y = g_y.unsqueeze(0)/255
  41. x = x.unsqueeze(0)/255
  42. #forward
  43. # Calculate average encoding vector for video
  44. #f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxK,2,3,224,224
  45. #train G
  46. x_hat = G(g_y, e_hat)
  47. plt.clf()
  48. out1 = x_hat.transpose(1,3)[0]
  49. #for img_no in range(1,x_hat.shape[0]):
  50. # out1 = torch.cat((out1, x_hat.transpose(1,3)[img_no]), dim = 1)
  51. out1 = out1.to(cpu).numpy()
  52. #plt.imshow(out1)
  53. #plt.show()
  54. #plt.clf()
  55. out2 = x.transpose(1,3)[0]
  56. #for img_no in range(1,x.shape[0]):
  57. # out2 = torch.cat((out2, x.transpose(1,3)[img_no]), dim = 1)
  58. out2 = out2.to(cpu).numpy()
  59. #plt.imshow(out2)
  60. #plt.show()
  61. #plt.clf()
  62. out3 = g_y.transpose(1,3)[0]
  63. #for img_no in range(1,g_y.shape[0]):
  64. # out3 = torch.cat((out3, g_y.transpose(1,3)[img_no]), dim = 1)
  65. out3 = out3.to(cpu).numpy()
  66. #plt.imshow(out3)
  67. #plt.show()
  68. fake = cv2.cvtColor(out1*255, cv2.COLOR_BGR2RGB)
  69. me = cv2.cvtColor(out2*255, cv2.COLOR_BGR2RGB)
  70. landmark = cv2.cvtColor(out3*255, cv2.COLOR_BGR2RGB)
  71. img = np.concatenate((me, landmark, fake), axis=1)
  72. img = img.astype('uint8')
  73. video.write(img)
  74. i+=1
  75. print(i,'/',n_frames)
  76. cap.release()
  77. video.release()
  78. """cv2.destroyAllWindows()"""
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...