Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

lora.py 8.9 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
  1. import ldm_patched.modules.utils
  2. LORA_CLIP_MAP = {
  3. "mlp.fc1": "mlp_fc1",
  4. "mlp.fc2": "mlp_fc2",
  5. "self_attn.k_proj": "self_attn_k_proj",
  6. "self_attn.q_proj": "self_attn_q_proj",
  7. "self_attn.v_proj": "self_attn_v_proj",
  8. "self_attn.out_proj": "self_attn_out_proj",
  9. }
  10. def load_lora(lora, to_load):
  11. patch_dict = {}
  12. loaded_keys = set()
  13. for x in to_load:
  14. alpha_name = "{}.alpha".format(x)
  15. alpha = None
  16. if alpha_name in lora.keys():
  17. alpha = lora[alpha_name].item()
  18. loaded_keys.add(alpha_name)
  19. regular_lora = "{}.lora_up.weight".format(x)
  20. diffusers_lora = "{}_lora.up.weight".format(x)
  21. transformers_lora = "{}.lora_linear_layer.up.weight".format(x)
  22. A_name = None
  23. if regular_lora in lora.keys():
  24. A_name = regular_lora
  25. B_name = "{}.lora_down.weight".format(x)
  26. mid_name = "{}.lora_mid.weight".format(x)
  27. elif diffusers_lora in lora.keys():
  28. A_name = diffusers_lora
  29. B_name = "{}_lora.down.weight".format(x)
  30. mid_name = None
  31. elif transformers_lora in lora.keys():
  32. A_name = transformers_lora
  33. B_name ="{}.lora_linear_layer.down.weight".format(x)
  34. mid_name = None
  35. if A_name is not None:
  36. mid = None
  37. if mid_name is not None and mid_name in lora.keys():
  38. mid = lora[mid_name]
  39. loaded_keys.add(mid_name)
  40. patch_dict[to_load[x]] = ("lora", (lora[A_name], lora[B_name], alpha, mid))
  41. loaded_keys.add(A_name)
  42. loaded_keys.add(B_name)
  43. ######## loha
  44. hada_w1_a_name = "{}.hada_w1_a".format(x)
  45. hada_w1_b_name = "{}.hada_w1_b".format(x)
  46. hada_w2_a_name = "{}.hada_w2_a".format(x)
  47. hada_w2_b_name = "{}.hada_w2_b".format(x)
  48. hada_t1_name = "{}.hada_t1".format(x)
  49. hada_t2_name = "{}.hada_t2".format(x)
  50. if hada_w1_a_name in lora.keys():
  51. hada_t1 = None
  52. hada_t2 = None
  53. if hada_t1_name in lora.keys():
  54. hada_t1 = lora[hada_t1_name]
  55. hada_t2 = lora[hada_t2_name]
  56. loaded_keys.add(hada_t1_name)
  57. loaded_keys.add(hada_t2_name)
  58. patch_dict[to_load[x]] = ("loha", (lora[hada_w1_a_name], lora[hada_w1_b_name], alpha, lora[hada_w2_a_name], lora[hada_w2_b_name], hada_t1, hada_t2))
  59. loaded_keys.add(hada_w1_a_name)
  60. loaded_keys.add(hada_w1_b_name)
  61. loaded_keys.add(hada_w2_a_name)
  62. loaded_keys.add(hada_w2_b_name)
  63. ######## lokr
  64. lokr_w1_name = "{}.lokr_w1".format(x)
  65. lokr_w2_name = "{}.lokr_w2".format(x)
  66. lokr_w1_a_name = "{}.lokr_w1_a".format(x)
  67. lokr_w1_b_name = "{}.lokr_w1_b".format(x)
  68. lokr_t2_name = "{}.lokr_t2".format(x)
  69. lokr_w2_a_name = "{}.lokr_w2_a".format(x)
  70. lokr_w2_b_name = "{}.lokr_w2_b".format(x)
  71. lokr_w1 = None
  72. if lokr_w1_name in lora.keys():
  73. lokr_w1 = lora[lokr_w1_name]
  74. loaded_keys.add(lokr_w1_name)
  75. lokr_w2 = None
  76. if lokr_w2_name in lora.keys():
  77. lokr_w2 = lora[lokr_w2_name]
  78. loaded_keys.add(lokr_w2_name)
  79. lokr_w1_a = None
  80. if lokr_w1_a_name in lora.keys():
  81. lokr_w1_a = lora[lokr_w1_a_name]
  82. loaded_keys.add(lokr_w1_a_name)
  83. lokr_w1_b = None
  84. if lokr_w1_b_name in lora.keys():
  85. lokr_w1_b = lora[lokr_w1_b_name]
  86. loaded_keys.add(lokr_w1_b_name)
  87. lokr_w2_a = None
  88. if lokr_w2_a_name in lora.keys():
  89. lokr_w2_a = lora[lokr_w2_a_name]
  90. loaded_keys.add(lokr_w2_a_name)
  91. lokr_w2_b = None
  92. if lokr_w2_b_name in lora.keys():
  93. lokr_w2_b = lora[lokr_w2_b_name]
  94. loaded_keys.add(lokr_w2_b_name)
  95. lokr_t2 = None
  96. if lokr_t2_name in lora.keys():
  97. lokr_t2 = lora[lokr_t2_name]
  98. loaded_keys.add(lokr_t2_name)
  99. if (lokr_w1 is not None) or (lokr_w2 is not None) or (lokr_w1_a is not None) or (lokr_w2_a is not None):
  100. patch_dict[to_load[x]] = ("lokr", (lokr_w1, lokr_w2, alpha, lokr_w1_a, lokr_w1_b, lokr_w2_a, lokr_w2_b, lokr_t2))
  101. #glora
  102. a1_name = "{}.a1.weight".format(x)
  103. a2_name = "{}.a2.weight".format(x)
  104. b1_name = "{}.b1.weight".format(x)
  105. b2_name = "{}.b2.weight".format(x)
  106. if a1_name in lora:
  107. patch_dict[to_load[x]] = ("glora", (lora[a1_name], lora[a2_name], lora[b1_name], lora[b2_name], alpha))
  108. loaded_keys.add(a1_name)
  109. loaded_keys.add(a2_name)
  110. loaded_keys.add(b1_name)
  111. loaded_keys.add(b2_name)
  112. w_norm_name = "{}.w_norm".format(x)
  113. b_norm_name = "{}.b_norm".format(x)
  114. w_norm = lora.get(w_norm_name, None)
  115. b_norm = lora.get(b_norm_name, None)
  116. if w_norm is not None:
  117. loaded_keys.add(w_norm_name)
  118. patch_dict[to_load[x]] = ("diff", (w_norm,))
  119. if b_norm is not None:
  120. loaded_keys.add(b_norm_name)
  121. patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (b_norm,))
  122. diff_name = "{}.diff".format(x)
  123. diff_weight = lora.get(diff_name, None)
  124. if diff_weight is not None:
  125. patch_dict[to_load[x]] = ("diff", (diff_weight,))
  126. loaded_keys.add(diff_name)
  127. diff_bias_name = "{}.diff_b".format(x)
  128. diff_bias = lora.get(diff_bias_name, None)
  129. if diff_bias is not None:
  130. patch_dict["{}.bias".format(to_load[x][:-len(".weight")])] = ("diff", (diff_bias,))
  131. loaded_keys.add(diff_bias_name)
  132. for x in lora.keys():
  133. if x not in loaded_keys:
  134. print("lora key not loaded", x)
  135. return patch_dict
  136. def model_lora_keys_clip(model, key_map={}):
  137. sdk = model.state_dict().keys()
  138. text_model_lora_key = "lora_te_text_model_encoder_layers_{}_{}"
  139. clip_l_present = False
  140. for b in range(32): #TODO: clean up
  141. for c in LORA_CLIP_MAP:
  142. k = "clip_h.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
  143. if k in sdk:
  144. lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
  145. key_map[lora_key] = k
  146. lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c])
  147. key_map[lora_key] = k
  148. lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
  149. key_map[lora_key] = k
  150. k = "clip_l.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
  151. if k in sdk:
  152. lora_key = text_model_lora_key.format(b, LORA_CLIP_MAP[c])
  153. key_map[lora_key] = k
  154. lora_key = "lora_te1_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
  155. key_map[lora_key] = k
  156. clip_l_present = True
  157. lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
  158. key_map[lora_key] = k
  159. k = "clip_g.transformer.text_model.encoder.layers.{}.{}.weight".format(b, c)
  160. if k in sdk:
  161. if clip_l_present:
  162. lora_key = "lora_te2_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #SDXL base
  163. key_map[lora_key] = k
  164. lora_key = "text_encoder_2.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
  165. key_map[lora_key] = k
  166. else:
  167. lora_key = "lora_te_text_model_encoder_layers_{}_{}".format(b, LORA_CLIP_MAP[c]) #TODO: test if this is correct for SDXL-Refiner
  168. key_map[lora_key] = k
  169. lora_key = "text_encoder.text_model.encoder.layers.{}.{}".format(b, c) #diffusers lora
  170. key_map[lora_key] = k
  171. return key_map
  172. def model_lora_keys_unet(model, key_map={}):
  173. sdk = model.state_dict().keys()
  174. for k in sdk:
  175. if k.startswith("diffusion_model.") and k.endswith(".weight"):
  176. key_lora = k[len("diffusion_model."):-len(".weight")].replace(".", "_")
  177. key_map["lora_unet_{}".format(key_lora)] = k
  178. diffusers_keys = ldm_patched.modules.utils.unet_to_diffusers(model.model_config.unet_config)
  179. for k in diffusers_keys:
  180. if k.endswith(".weight"):
  181. unet_key = "diffusion_model.{}".format(diffusers_keys[k])
  182. key_lora = k[:-len(".weight")].replace(".", "_")
  183. key_map["lora_unet_{}".format(key_lora)] = unet_key
  184. diffusers_lora_prefix = ["", "unet."]
  185. for p in diffusers_lora_prefix:
  186. diffusers_lora_key = "{}{}".format(p, k[:-len(".weight")].replace(".to_", ".processor.to_"))
  187. if diffusers_lora_key.endswith(".to_out.0"):
  188. diffusers_lora_key = diffusers_lora_key[:-2]
  189. key_map[diffusers_lora_key] = unet_key
  190. return key_map
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...