Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

init.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. import torch.nn as nn
  2. import torch.nn.init as init
  3. def weight_init(m):
  4. '''
  5. Usage:
  6. model = Model()
  7. model.apply(weight_init)
  8. '''
  9. if isinstance(m, nn.Conv1d):
  10. init.normal_(m.weight.data)
  11. if m.bias is not None:
  12. init.normal_(m.bias.data)
  13. elif isinstance(m, nn.Conv2d):
  14. init.xavier_normal_(m.weight.data)
  15. if m.bias is not None:
  16. init.normal_(m.bias.data)
  17. elif isinstance(m, nn.Conv3d):
  18. init.xavier_normal_(m.weight.data)
  19. if m.bias is not None:
  20. init.normal_(m.bias.data)
  21. elif isinstance(m, nn.ConvTranspose1d):
  22. init.normal_(m.weight.data)
  23. if m.bias is not None:
  24. init.normal_(m.bias.data)
  25. elif isinstance(m, nn.ConvTranspose2d):
  26. init.xavier_normal_(m.weight.data)
  27. if m.bias is not None:
  28. init.normal_(m.bias.data)
  29. elif isinstance(m, nn.ConvTranspose3d):
  30. init.xavier_normal_(m.weight.data)
  31. if m.bias is not None:
  32. init.normal_(m.bias.data)
  33. elif isinstance(m, nn.BatchNorm1d):
  34. init.normal_(m.weight.data, mean=1, std=0.02)
  35. init.constant_(m.bias.data, 0)
  36. elif isinstance(m, nn.BatchNorm2d):
  37. init.normal_(m.weight.data, mean=1, std=0.02)
  38. init.constant_(m.bias.data, 0)
  39. elif isinstance(m, nn.BatchNorm3d):
  40. init.normal_(m.weight.data, mean=1, std=0.02)
  41. init.constant_(m.bias.data, 0)
  42. elif isinstance(m, nn.Linear):
  43. init.xavier_normal_(m.weight.data)
  44. init.normal_(m.bias.data)
  45. elif isinstance(m, nn.Parameter):
  46. init.uniform_(m, -0.01, 0.01)
  47. elif isinstance(m, nn.LSTM):
  48. for param in m.parameters():
  49. if len(param.shape) >= 2:
  50. init.orthogonal_(param.data)
  51. else:
  52. init.normal_(param.data)
  53. elif isinstance(m, nn.LSTMCell):
  54. for param in m.parameters():
  55. if len(param.shape) >= 2:
  56. init.orthogonal_(param.data)
  57. else:
  58. init.normal_(param.data)
  59. elif isinstance(m, nn.GRU):
  60. for param in m.parameters():
  61. if len(param.shape) >= 2:
  62. init.orthogonal_(param.data)
  63. else:
  64. init.normal_(param.data)
  65. elif isinstance(m, nn.GRUCell):
  66. for param in m.parameters():
  67. if len(param.shape) >= 2:
  68. init.orthogonal_(param.data)
  69. else:
  70. init.normal_(param.data)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...