Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 3.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. import numpy as np
  5. def train_duq(model,epoch,device,train_loader,optimizer,lambda_):
  6. # train duq model for one epoch. applies gradient
  7. # penalty and updates centroids
  8. model.train()
  9. for i, (data,targets) in enumerate(train_loader):
  10. data, targets = data.to(device), targets.to(device)
  11. data.requires_grad_(True)
  12. targets = F.one_hot(targets,num_classes=10).float()
  13. optimizer.zero_grad()
  14. Wx = model(data)
  15. K = model.kernel(Wx)
  16. batch_loss = model.loss(K,targets)
  17. data_grad = lambda_ * grad_penalty(K,data)
  18. batch_loss += data_grad
  19. batch_loss.backward()
  20. optimizer.step()
  21. with torch.no_grad():
  22. model.update_centroids(Wx,targets)
  23. print('done epoch: {}'.format(epoch))
  24. def test_duq(model,epoch,batch_size,device,test_loader):
  25. # evaluate model for classification accuracy
  26. model.eval()
  27. correct = 0
  28. with torch.no_grad():
  29. for data,targets in test_loader:
  30. data, targets = data.to(device), targets.to(device)
  31. Wx = model(data)
  32. K = model.kernel(Wx)
  33. pred = K.argmax(dim = 1)
  34. correct += pred.eq(targets).sum().item()
  35. print('Test accuracy: {:.4f}'.
  36. format(correct/(len(test_loader)*batch_size)))
  37. def grad_penalty(K,data):
  38. # computes the gradient penalty of the sum of the
  39. # kernal distances wrt the data
  40. K_sum = K.sum(-1)
  41. dk_dx = torch.autograd.grad(
  42. outputs=K_sum,
  43. inputs=data,
  44. grad_outputs=torch.ones_like(K_sum),
  45. create_graph=True,
  46. retain_graph=True)[0]
  47. dk_dx = dk_dx.flatten(start_dim=1)
  48. grad_norm_sq = (dk_dx**2).sum(dim=1)
  49. grad_penalty = ((grad_norm_sq - 1) ** 2).mean()
  50. return grad_penalty
  51. def ood_detection_eval(model,device,ood_loader):
  52. # runs inference on duq model. scores
  53. # are defined as negative max kernel distance
  54. model.eval()
  55. eval_scores = []
  56. with torch.no_grad():
  57. for data,_ in ood_loader:
  58. data = data.to(device)
  59. Wx = model(data)
  60. K = model.kernel(Wx)
  61. scores = -torch.max(K,dim = 1)[0]
  62. eval_scores.append(scores.cpu().numpy())
  63. return np.concatenate(eval_scores)
  64. def train_standard(model,loss,epoch,device,train_loader,optimizer):
  65. # train one epoch for standard classification model
  66. model.train()
  67. running_loss = 0
  68. for i, (data,targets) in enumerate(train_loader):
  69. data, targets = data.to(device), targets.to(device)
  70. optimizer.zero_grad()
  71. output = model(data)
  72. batch_loss = loss(output,targets)
  73. batch_loss.backward()
  74. optimizer.step()
  75. running_loss += batch_loss.item()
  76. print('done epoch: {}'.format(epoch))
  77. def test_standard(model,loss,epoch,batch_size,device,test_loader):
  78. # evaluation for standard model
  79. model.eval()
  80. correct = 0
  81. with torch.no_grad():
  82. for data,targets in test_loader:
  83. data, targets = data.to(device), targets.to(device)
  84. output = model(data)
  85. pred = output.argmax(dim = 1)
  86. correct += pred.eq(targets).sum().item()
  87. print('Test accuracy: {:.4f}'.format(correct/(len(test_loader)*batch_size)))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...