Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

embedder_inference.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
  1. """Main"""
  2. import torch
  3. from dataset.video_extraction_conversion import select_frames, select_images_frames, generate_cropped_landmarks
  4. from network.blocks import *
  5. from network.model import Embedder
  6. import face_alignment
  7. import numpy as np
  8. from params.params import path_to_chkpt
  9. """Hyperparameters and config"""
  10. device = torch.device("cuda:0")
  11. cpu = torch.device("cpu")
  12. path_to_e_hat_video = 'e_hat_video.tar'
  13. path_to_e_hat_images = 'e_hat_images.tar'
  14. path_to_video = 'test_vid.mp4'
  15. path_to_images = 'examples/fine_tuning/test_images'
  16. T = 32
  17. face_aligner = face_alignment.FaceAlignment(face_alignment.LandmarksType._2D, flip_input=False, device ='cuda:0')
  18. """Loading Embedder input"""
  19. frame_mark_video = select_frames(path_to_video , T)
  20. frame_mark_video = generate_cropped_landmarks(frame_mark_video, pad=50, face_aligner=face_aligner)
  21. frame_mark_video = torch.from_numpy(np.array(frame_mark_video)).type(dtype = torch.float) #T,2,256,256,3
  22. frame_mark_video = frame_mark_video.transpose(2,4).to(device)/255 #T,2,3,256,256
  23. f_lm_video = frame_mark_video.unsqueeze(0) #1,T,2,3,256,256
  24. frame_mark_images = select_images_frames(path_to_images)
  25. frame_mark_images = generate_cropped_landmarks(frame_mark_images, pad=50, face_aligner=face_aligner)
  26. frame_mark_images = torch.from_numpy(np.array(frame_mark_images)).type(dtype = torch.float) #T,2,256,256,3
  27. frame_mark_images = frame_mark_images.transpose(2,4).to(device)/255 #T,2,3,256,256
  28. f_lm_images = frame_mark_images.unsqueeze(0) #1,T,2,3,256,256
  29. E = Embedder(256).to(device)
  30. E.eval()
  31. """Loading from past checkpoint"""
  32. checkpoint = torch.load(path_to_chkpt, map_location=cpu)
  33. E.load_state_dict(checkpoint['E_state_dict'])
  34. """Inference"""
  35. with torch.no_grad():
  36. #forward
  37. # Calculate average encoding vector for video
  38. f_lm = f_lm_video
  39. f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224
  40. e_vectors = E(f_lm_compact[:,0,:,:,:], f_lm_compact[:,1,:,:,:]) #BxT,512,1
  41. e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1
  42. e_hat_video = e_vectors.mean(dim=1)
  43. f_lm = f_lm_images
  44. f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxT,2,3,224,224
  45. e_vectors = E(f_lm_compact[:,0,:,:,:], f_lm_compact[:,1,:,:,:]) #BxT,512,1
  46. e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,T,512,1
  47. e_hat_images = e_vectors.mean(dim=1)
  48. print('Saving e_hat...')
  49. torch.save({
  50. 'e_hat': e_hat_video
  51. }, path_to_e_hat_video)
  52. torch.save({
  53. 'e_hat': e_hat_images
  54. }, path_to_e_hat_images)
  55. print('...Done saving')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...