Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

init_Wi.py 2.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
  1. """Main"""
  2. import torch
  3. import torch.nn as nn
  4. import torch.optim as optim
  5. from torch.utils.data import DataLoader
  6. from datetime import datetime
  7. import matplotlib
  8. #matplotlib.use('agg')
  9. from matplotlib import pyplot as plt
  10. plt.ion()
  11. import os
  12. import sys
  13. from dataset.dataset_class import PreprocessDataset
  14. from dataset.video_extraction_conversion import *
  15. from loss.loss_discriminator import *
  16. from loss.loss_generator import *
  17. from network.blocks import *
  18. from network.model import *
  19. from tqdm import tqdm
  20. from params.params import K, path_to_chkpt, path_to_backup, path_to_Wi, batch_size, path_to_preprocess, frame_shape
  21. """Create dataset and net"""
  22. display_training = False
  23. device = torch.device("cuda:0")
  24. cpu = torch.device("cpu")
  25. dataset = PreprocessDataset(K=K, path_to_preprocess=path_to_preprocess, path_to_Wi=path_to_Wi)
  26. dataLoader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=16, pin_memory=True)
  27. E = nn.DataParallel(Embedder(frame_shape).to(device))
  28. """Training init"""
  29. epochCurrent = epoch = i_batch = 0
  30. lossesG = []
  31. lossesD = []
  32. i_batch_current = 0
  33. num_epochs = 1
  34. #initiate checkpoint if inexistant
  35. if not os.path.isfile(path_to_chkpt):
  36. print('Error loading model: file non-existant')
  37. sys.exit()
  38. """Loading from past checkpoint"""
  39. checkpoint = torch.load(path_to_chkpt, map_location=cpu)
  40. E.module.load_state_dict(checkpoint['E_state_dict'])
  41. num_vid = checkpoint['num_vid']
  42. E.train(False)
  43. #init W_i
  44. print('Initializing Discriminator weights')
  45. if not os.path.isdir(path_to_Wi):
  46. os.mkdir(path_to_Wi)
  47. for i in tqdm(range(num_vid)):
  48. if not os.path.isfile(path_to_Wi+'/W_'+str(i)+'/W_'+str(i)+'.tar'):
  49. w_i = torch.rand(512, 1)
  50. os.mkdir(path_to_Wi+'/W_'+str(i))
  51. torch.save({'W_i': w_i}, path_to_Wi+'/W_'+str(i)+'/W_'+str(i)+'.tar')
  52. """Training"""
  53. batch_start = datetime.now()
  54. pbar = tqdm(dataLoader, leave=True, initial=0)
  55. if not display_training:
  56. matplotlib.use('agg')
  57. with torch.no_grad():
  58. for epoch in range(num_epochs):
  59. if epoch > epochCurrent:
  60. i_batch_current = 0
  61. pbar = tqdm(dataLoader, leave=True, initial=0)
  62. pbar.set_postfix(epoch=epoch)
  63. for i_batch, (f_lm, x, g_y, i, W_i) in enumerate(pbar, start=0):
  64. f_lm = f_lm.to(device)
  65. #zero the parameter gradients
  66. #forward
  67. # Calculate average encoding vector for video
  68. f_lm_compact = f_lm.view(-1, f_lm.shape[-4], f_lm.shape[-3], f_lm.shape[-2], f_lm.shape[-1]) #BxK,2,3,224,224
  69. e_vectors = E(f_lm_compact[:,0,:,:,:], f_lm_compact[:,1,:,:,:]) #BxK,512,1
  70. e_vectors = e_vectors.view(-1, f_lm.shape[1], 512, 1) #B,K,512,1
  71. e_hat = e_vectors.mean(dim=1)
  72. for enum, idx in enumerate(i):
  73. torch.save({'W_i': e_hat[enum,:].unsqueeze(0)}, path_to_Wi+'/W_'+str(idx.item())+'/W_'+str(idx.item())+'.tar')
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...