Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

utils.py 2.4 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
  1. import torch
  2. import numpy as np
  3. import os
  4. import wandb
  5. from dagshub import DAGsHubLogger
  6. def one_hot(y, num_class):
  7. return torch.zeros((len(y), num_class)).scatter_(1, y.unsqueeze(1), 1)
  8. def DBindex(cl_data_file):
  9. class_list = cl_data_file.keys()
  10. cl_num= len(class_list)
  11. cl_means = []
  12. stds = []
  13. DBs = []
  14. for cl in class_list:
  15. cl_means.append( np.mean(cl_data_file[cl], axis = 0) )
  16. stds.append( np.sqrt(np.mean( np.sum(np.square( cl_data_file[cl] - cl_means[-1]), axis = 1))))
  17. mu_i = np.tile( np.expand_dims( np.array(cl_means), axis = 0), (len(class_list),1,1) )
  18. mu_j = np.transpose(mu_i,(1,0,2))
  19. mdists = np.sqrt(np.sum(np.square(mu_i - mu_j), axis = 2))
  20. for i in range(cl_num):
  21. DBs.append( np.max([ (stds[i]+ stds[j])/mdists[i,j] for j in range(cl_num) if j != i ]) )
  22. return np.mean(DBs)
  23. def sparsity(cl_data_file):
  24. class_list = cl_data_file.keys()
  25. cl_sparsity = []
  26. for cl in class_list:
  27. cl_sparsity.append(np.mean([np.sum(x!=0) for x in cl_data_file[cl] ]) )
  28. return np.mean(cl_sparsity)
  29. class RunningAverage():
  30. def __init__(self):
  31. self.count = 0
  32. self.sum = 0
  33. def update(self, value, n_items = 1):
  34. self.sum += value * n_items
  35. self.count += n_items
  36. def __call__(self):
  37. return self.sum/self.count
  38. class Logger():
  39. wandb = None
  40. DAGsHub = None
  41. logger = None
  42. @staticmethod
  43. def init(wandb=True, DAGsHub=True, ckpt_dir=None):
  44. Logger.wandb = wandb
  45. Logger.DAGsHub = DAGsHub
  46. if DAGsHub:
  47. Logger.logger = DAGsHubLogger(metrics_path=os.path.join(ckpt_dir, "metrics.csv"), hparams_path=os.path.join(ckpt_dir,"params.yml"))
  48. @staticmethod
  49. def log(dict, step=None):
  50. if Logger.wandb:
  51. wandb.log(dict, step=step) if step else wandb.log(dict)
  52. if Logger.DAGsHub:
  53. Logger.logger.log_metrics(dict, step_num=step) if step else Logger.logger.log_metrics(dict)
  54. @staticmethod
  55. def log_hyperparams(params):
  56. if Logger.DAGsHub:
  57. Logger.logger.log_hyperparams(params)
  58. def wandb_restore_models(ckpt_dir):
  59. last_model_path = os.path.join(ckpt_dir, "last_model.tar")
  60. fn = wandb.restore(last_model_path)
  61. # best_model_path = os.path.join(ckpt_dir, "best_model.tar")
  62. # wandb.restore(best_model_path)
  63. print("Restored models")
  64. return fn.name
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...