Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

conds.py 2.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
  1. import torch
  2. import math
  3. import ldm_patched.modules.utils
  4. class CONDRegular:
  5. def __init__(self, cond):
  6. self.cond = cond
  7. def _copy_with(self, cond):
  8. return self.__class__(cond)
  9. def process_cond(self, batch_size, device, **kwargs):
  10. return self._copy_with(ldm_patched.modules.utils.repeat_to_batch_size(self.cond, batch_size).to(device))
  11. def can_concat(self, other):
  12. if self.cond.shape != other.cond.shape:
  13. return False
  14. return True
  15. def concat(self, others):
  16. conds = [self.cond]
  17. for x in others:
  18. conds.append(x.cond)
  19. return torch.cat(conds)
  20. class CONDNoiseShape(CONDRegular):
  21. def process_cond(self, batch_size, device, area, **kwargs):
  22. data = self.cond[:,:,area[2]:area[0] + area[2],area[3]:area[1] + area[3]]
  23. return self._copy_with(ldm_patched.modules.utils.repeat_to_batch_size(data, batch_size).to(device))
  24. class CONDCrossAttn(CONDRegular):
  25. def can_concat(self, other):
  26. s1 = self.cond.shape
  27. s2 = other.cond.shape
  28. if s1 != s2:
  29. if s1[0] != s2[0] or s1[2] != s2[2]: #these 2 cases should not happen
  30. return False
  31. mult_min = math.lcm(s1[1], s2[1])
  32. diff = mult_min // min(s1[1], s2[1])
  33. if diff > 4: #arbitrary limit on the padding because it's probably going to impact performance negatively if it's too much
  34. return False
  35. return True
  36. def concat(self, others):
  37. conds = [self.cond]
  38. crossattn_max_len = self.cond.shape[1]
  39. for x in others:
  40. c = x.cond
  41. crossattn_max_len = math.lcm(crossattn_max_len, c.shape[1])
  42. conds.append(c)
  43. out = []
  44. for c in conds:
  45. if c.shape[1] < crossattn_max_len:
  46. c = c.repeat(1, crossattn_max_len // c.shape[1], 1) #padding with repeat doesn't change result
  47. out.append(c)
  48. return torch.cat(out)
  49. class CONDConstant(CONDRegular):
  50. def __init__(self, cond):
  51. self.cond = cond
  52. def process_cond(self, batch_size, device, **kwargs):
  53. return self._copy_with(self.cond)
  54. def can_concat(self, other):
  55. if self.cond != other.cond:
  56. return False
  57. return True
  58. def concat(self, others):
  59. return self.cond
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...