Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

mlp.py 1.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. import torch
  2. from torch import nn
  3. from model.base import ModelBase
  4. class Model(ModelBase):
  5. def __init__(self, vocab_size, embed_dim, hidden_size, dropout=0.1, *args, **kwargs):
  6. super(Model, self).__init__()
  7. self.embedding = nn.EmbeddingBag(vocab_size, embed_dim, sparse=False)
  8. self.dropout_layer = nn.Dropout(dropout)
  9. self.fc1 = nn.Linear(embed_dim, hidden_size)
  10. self.fc2 = nn.Linear(hidden_size, hidden_size)
  11. self.out = nn.Linear(hidden_size, 1)
  12. self.init_weights()
  13. def init_weights(self):
  14. nn.init.xavier_normal_(self.embedding.weight)
  15. nn.init.kaiming_normal_(self.fc1.weight, mode='fan_out', nonlinearity='relu')
  16. nn.init.kaiming_normal_(self.fc2.weight, mode='fan_out', nonlinearity='relu')
  17. self.fc1.bias.data.zero_()
  18. self.fc2.bias.data.zero_()
  19. nn.init.kaiming_normal_(self.out.weight, mode='fan_out', nonlinearity='sigmoid')
  20. nn.init.constant_(self.out.bias, 0)
  21. def forward(self, text, offsets):
  22. embedded = self.embedding(text, offsets)
  23. x = self.dropout_layer(self.fc1(embedded)).relu()
  24. x = self.dropout_layer(self.fc2(x)).relu()
  25. return self.out(x).sigmoid()
  26. def load_model(self, model_path):
  27. self.load_state_dict(torch.load(model_path))
  28. self.eval()
  29. def save_model(self, model_path):
  30. torch.save(self.state_dict(), model_path)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...