Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

datasets.py 3.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
  1. import os
  2. from typing import Dict, List, Tuple
  3. from ogb.nodeproppred import PygNodePropPredDataset
  4. import torch
  5. from torch_geometric.data import Data
  6. from torch_geometric.datasets import Amazon, Coauthor, PPI, WikiCS
  7. from torch_geometric import transforms as T
  8. from torch_geometric.utils import to_undirected
  9. from gssl import DATA_DIR
  10. def load_dataset(name: str) -> Tuple[Data, List[Dict[str, torch.Tensor]]]:
  11. ds_path = os.path.join(DATA_DIR, "datasets/", name)
  12. feature_norm = T.NormalizeFeatures()
  13. create_masks = T.AddTrainValTestMask(
  14. split="train_rest",
  15. num_splits=20,
  16. num_val=0.1,
  17. num_test=0.8,
  18. )
  19. if name == "WikiCS":
  20. data = WikiCS(
  21. root=ds_path,
  22. transform=feature_norm,
  23. )[0]
  24. elif name == "Amazon-CS":
  25. data = Amazon(
  26. root=ds_path,
  27. name="computers",
  28. transform=feature_norm,
  29. pre_transform=create_masks,
  30. )[0]
  31. elif name == "Amazon-Photo":
  32. data = Amazon(
  33. root=ds_path,
  34. name="photo",
  35. transform=feature_norm,
  36. pre_transform=create_masks,
  37. )[0]
  38. elif name == "Coauthor-CS":
  39. data = Coauthor(
  40. root=ds_path,
  41. name="cs",
  42. transform=feature_norm,
  43. pre_transform=create_masks,
  44. )[0]
  45. elif name == "Coauthor-Physics":
  46. data = Coauthor(
  47. root=ds_path,
  48. name="physics",
  49. transform=feature_norm,
  50. pre_transform=create_masks,
  51. )[0]
  52. elif name == "ogbn-arxiv":
  53. data = read_ogb_dataset(name=name, path=ds_path)
  54. data.edge_index = to_undirected(data.edge_index, data.num_nodes)
  55. elif name == "ogbn-products":
  56. data = read_ogb_dataset(name=name, path=ds_path)
  57. else:
  58. raise ValueError(f"Unknown dataset: {name}")
  59. if name in ("ogbn-arxiv", "ogbn-products"):
  60. masks = [
  61. {
  62. "train": data.train_mask,
  63. "val": data.val_mask,
  64. "test": data.test_mask,
  65. }
  66. ]
  67. else:
  68. masks = [
  69. {
  70. "train": data.train_mask[:, i],
  71. "val": data.val_mask[:, i],
  72. "test": (
  73. data.test_mask
  74. if name == "WikiCS"
  75. else data.test_mask[:, i]
  76. ),
  77. }
  78. for i in range(20)
  79. ]
  80. return data, masks
  81. def read_ogb_dataset(name: str, path: str) -> Data:
  82. dataset = PygNodePropPredDataset(root=path, name=name)
  83. split_idx = dataset.get_idx_split()
  84. data = dataset[0]
  85. data.train_mask = torch.zeros((data.num_nodes,), dtype=torch.bool)
  86. data.train_mask[split_idx["train"]] = True
  87. data.val_mask = torch.zeros((data.num_nodes,), dtype=torch.bool)
  88. data.val_mask[split_idx["valid"]] = True
  89. data.test_mask = torch.zeros((data.num_nodes,), dtype=torch.bool)
  90. data.test_mask[split_idx["test"]] = True
  91. data.y = data.y.squeeze(dim=-1)
  92. return data
  93. def load_ppi() -> Tuple[PPI, PPI, PPI]:
  94. ds_path = os.path.join(DATA_DIR, "datasets/PPI")
  95. feature_norm = T.NormalizeFeatures()
  96. train_ppi = PPI(root=ds_path, split="train", transform=feature_norm)
  97. val_ppi = PPI(root=ds_path, split="val", transform=feature_norm)
  98. test_ppi = PPI(root=ds_path, split="test", transform=feature_norm)
  99. return train_ppi, val_ppi, test_ppi
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...