Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 2.1 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
  1. from typing import *
  2. import os
  3. import torch
  4. import dgl
  5. import random
  6. import numpy as np
  7. def split_dataset(dataset, split_mode, *args, **kwargs):
  8. assert split_mode in ['rand', 'ogb', 'wikics', 'preload']
  9. if split_mode == 'rand':
  10. assert 'train_ratio' in kwargs and 'test_ratio' in kwargs
  11. train_ratio = kwargs['train_ratio']
  12. test_ratio = kwargs['test_ratio']
  13. num_samples = dataset.x.size(0)
  14. train_size = int(num_samples * train_ratio)
  15. test_size = int(num_samples * test_ratio)
  16. indices = torch.randperm(num_samples)
  17. return {
  18. 'train': indices[:train_size],
  19. 'val': indices[train_size: test_size + train_size],
  20. 'test': indices[test_size + train_size:]
  21. }
  22. elif split_mode == 'ogb':
  23. return dataset.get_idx_split()
  24. elif split_mode == 'wikics':
  25. assert 'split_idx' in kwargs
  26. split_idx = kwargs['split_idx']
  27. return {
  28. 'train': dataset.train_mask[:, split_idx],
  29. 'test': dataset.test_mask,
  30. 'val': dataset.val_mask[:, split_idx]
  31. }
  32. elif split_mode == 'preload':
  33. assert 'preload_split' in kwargs
  34. assert kwargs['preload_split'] is not None
  35. train_mask, test_mask, val_mask = kwargs['preload_split']
  36. return {
  37. 'train': train_mask,
  38. 'test': test_mask,
  39. 'val': val_mask
  40. }
  41. def seed_everything(seed):
  42. random.seed(seed)
  43. os.environ['PYTHONHASHSEED'] = str(seed)
  44. np.random.seed(seed)
  45. torch.backends.cudnn.benchmark = False
  46. torch.backends.cudnn.deterministic = True
  47. torch.manual_seed(seed)
  48. torch.cuda.manual_seed_all(seed)
  49. def normalize(s):
  50. return (s.max() - s) / (s.max() - s.mean())
  51. def build_dgl_graph(edge_index: torch.Tensor) -> dgl.DGLGraph:
  52. row, col = edge_index
  53. return dgl.graph((row, col))
  54. def batchify_dict(dicts: List[dict], aggr_func=lambda x: x):
  55. res = dict()
  56. for d in dicts:
  57. for k, v in d.items():
  58. if k not in res:
  59. res[k] = [v]
  60. else:
  61. res[k].append(v)
  62. res = {k: aggr_func(v) for k, v in res.items()}
  63. return res
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...