Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

augment.py 2.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. from typing import List, Optional, Tuple, Union
  2. import torch
  3. from torch_geometric.data import Data
  4. from torch_geometric.data.sampler import EdgeIndex
  5. class GraphAugmentor:
  6. """Masks node features (same for all nodes) and drops edges."""
  7. def __init__(
  8. self,
  9. p_x_1: float,
  10. p_e_1: float,
  11. p_x_2: Optional[float] = None,
  12. p_e_2: Optional[float] = None,
  13. ):
  14. self._p_x_1 = p_x_1
  15. self._p_e_1 = p_e_1
  16. self._p_x_2 = p_x_2 if p_x_2 is not None else p_x_1
  17. self._p_e_2 = p_e_2 if p_e_2 is not None else p_e_1
  18. def __call__(self, data: Data):
  19. """Augment full-batch graph."""
  20. x_a = mask_features(data.x, p=self._p_x_1)
  21. x_b = mask_features(data.x, p=self._p_x_2)
  22. edge_index_a = drop_edges(data.edge_index, p=self._p_e_1)
  23. edge_index_b = drop_edges(data.edge_index, p=self._p_e_2)
  24. return (x_a, edge_index_a), (x_b, edge_index_b)
  25. def augment_batch(
  26. self,
  27. x: torch.Tensor,
  28. adjs: List[EdgeIndex],
  29. ):
  30. """Augment batch from NeighborSampler."""
  31. x_a = mask_features(x, p=self._p_x_1)
  32. x_b = mask_features(x, p=self._p_x_2)
  33. edge_indexes_a = [
  34. drop_edges(adj.edge_index, p=self._p_e_1)
  35. for adj in adjs
  36. ]
  37. edge_indexes_b = [
  38. drop_edges(adj.edge_index, p=self._p_e_2)
  39. for adj in adjs
  40. ]
  41. return (x_a, edge_indexes_a), (x_b, edge_indexes_b)
  42. def mask_features(x: torch.Tensor, p: float) -> torch.Tensor:
  43. num_features = x.size(-1)
  44. device = x.device
  45. return bernoulli_mask(size=(1, num_features), prob=p).to(device) * x
  46. def drop_edges(edge_index: torch.Tensor, p: float) -> torch.Tensor:
  47. num_edges = edge_index.size(-1)
  48. device = edge_index.device
  49. mask = bernoulli_mask(size=num_edges, prob=p).to(device) == 1.
  50. return edge_index[:, mask]
  51. def bernoulli_mask(size: Union[int, Tuple[int, ...]], prob: float):
  52. return torch.bernoulli((1 - prob) * torch.ones(size))
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...