Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

transductive_model_arxiv.py 1.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. from pl_bolts.optimizers.lr_scheduler import LinearWarmupCosineAnnealingLR
  2. import torch
  3. from torch import nn
  4. from torch_geometric import nn as tgnn
  5. from gssl.loss import get_loss
  6. from gssl.transductive_model import Model
  7. class ArxivModel(Model):
  8. def __init__(
  9. self,
  10. feature_dim: int,
  11. emb_dim: int,
  12. loss_name: str,
  13. p_x: float,
  14. p_e: float,
  15. lr_base: float,
  16. total_epochs: int,
  17. warmup_epochs: int,
  18. ):
  19. self._device = torch.device(
  20. "cuda" if torch.cuda.is_available() else "cpu"
  21. )
  22. self._encoder = ThreeLayerGCNEncoder(
  23. in_dim=feature_dim, out_dim=emb_dim
  24. ).to(self._device)
  25. self._loss_fn = get_loss(loss_name=loss_name)
  26. self._optimizer = torch.optim.AdamW(
  27. params=self._encoder.parameters(),
  28. lr=lr_base,
  29. weight_decay=1e-5,
  30. )
  31. self._scheduler = LinearWarmupCosineAnnealingLR(
  32. optimizer=self._optimizer,
  33. warmup_epochs=warmup_epochs,
  34. max_epochs=total_epochs,
  35. )
  36. self._p_x = p_x
  37. self._p_e = p_e
  38. self._total_epochs = total_epochs
  39. self._use_pytorch_eval_model = True
  40. class ThreeLayerGCNEncoder(nn.Module):
  41. def __init__(self, in_dim: int, out_dim: int):
  42. super().__init__()
  43. self._conv1 = tgnn.GCNConv(in_dim, out_dim)
  44. self._conv2 = tgnn.GCNConv(out_dim, out_dim)
  45. self._conv3 = tgnn.GCNConv(out_dim, out_dim)
  46. self._bn1 = nn.BatchNorm1d(out_dim, momentum=0.01)
  47. self._bn2 = nn.BatchNorm1d(out_dim, momentum=0.01)
  48. self._act1 = nn.PReLU()
  49. self._act2 = nn.PReLU()
  50. def forward(self, x, edge_index):
  51. x = self._conv1(x, edge_index)
  52. x = self._bn1(x)
  53. x = self._act1(x)
  54. x = self._conv2(x, edge_index)
  55. x = self._bn2(x)
  56. x = self._act2(x)
  57. x = self._conv3(x, edge_index)
  58. return x
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...