Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

Saveable.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
  1. import pickle
  2. from pathlib import Path
  3. from core import pathex
  4. import numpy as np
  5. from core.leras import nn
  6. tf = nn.tf
  7. class Saveable():
  8. def __init__(self, name=None):
  9. self.name = name
  10. #override
  11. def get_weights(self):
  12. #return tf tensors that should be initialized/loaded/saved
  13. return []
  14. #override
  15. def get_weights_np(self):
  16. weights = self.get_weights()
  17. if len(weights) == 0:
  18. return []
  19. return nn.tf_sess.run (weights)
  20. def set_weights(self, new_weights):
  21. weights = self.get_weights()
  22. if len(weights) != len(new_weights):
  23. raise ValueError ('len of lists mismatch')
  24. tuples = []
  25. for w, new_w in zip(weights, new_weights):
  26. if len(w.shape) != new_w.shape:
  27. new_w = new_w.reshape(w.shape)
  28. tuples.append ( (w, new_w) )
  29. nn.batch_set_value (tuples)
  30. def save_weights(self, filename, force_dtype=None):
  31. d = {}
  32. weights = self.get_weights()
  33. if self.name is None:
  34. raise Exception("name must be defined.")
  35. name = self.name
  36. for w in weights:
  37. w_val = nn.tf_sess.run (w).copy()
  38. w_name_split = w.name.split('/', 1)
  39. if name != w_name_split[0]:
  40. raise Exception("weight first name != Saveable.name")
  41. if force_dtype is not None:
  42. w_val = w_val.astype(force_dtype)
  43. d[ w_name_split[1] ] = w_val
  44. d_dumped = pickle.dumps (d, 4)
  45. pathex.write_bytes_safe ( Path(filename), d_dumped )
  46. def load_weights(self, filename):
  47. """
  48. returns True if file exists
  49. """
  50. filepath = Path(filename)
  51. if filepath.exists():
  52. result = True
  53. d_dumped = filepath.read_bytes()
  54. d = pickle.loads(d_dumped)
  55. else:
  56. return False
  57. weights = self.get_weights()
  58. if self.name is None:
  59. raise Exception("name must be defined.")
  60. try:
  61. tuples = []
  62. for w in weights:
  63. w_name_split = w.name.split('/')
  64. if self.name != w_name_split[0]:
  65. raise Exception("weight first name != Saveable.name")
  66. sub_w_name = "/".join(w_name_split[1:])
  67. w_val = d.get(sub_w_name, None)
  68. if w_val is None:
  69. #io.log_err(f"Weight {w.name} was not loaded from file {filename}")
  70. tuples.append ( (w, w.initializer) )
  71. else:
  72. w_val = np.reshape( w_val, w.shape.as_list() )
  73. tuples.append ( (w, w_val) )
  74. nn.batch_set_value(tuples)
  75. except:
  76. return False
  77. return True
  78. def init_weights(self):
  79. nn.init_weights(self.get_weights())
  80. nn.Saveable = Saveable
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...