Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

BlurPool.py 1.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. import numpy as np
  2. from core.leras import nn
  3. tf = nn.tf
  4. class BlurPool(nn.LayerBase):
  5. def __init__(self, filt_size=3, stride=2, **kwargs ):
  6. if nn.data_format == "NHWC":
  7. self.strides = [1,stride,stride,1]
  8. else:
  9. self.strides = [1,1,stride,stride]
  10. self.filt_size = filt_size
  11. pad = [ int(1.*(filt_size-1)/2), int(np.ceil(1.*(filt_size-1)/2)) ]
  12. if nn.data_format == "NHWC":
  13. self.padding = [ [0,0], pad, pad, [0,0] ]
  14. else:
  15. self.padding = [ [0,0], [0,0], pad, pad ]
  16. if(self.filt_size==1):
  17. a = np.array([1.,])
  18. elif(self.filt_size==2):
  19. a = np.array([1., 1.])
  20. elif(self.filt_size==3):
  21. a = np.array([1., 2., 1.])
  22. elif(self.filt_size==4):
  23. a = np.array([1., 3., 3., 1.])
  24. elif(self.filt_size==5):
  25. a = np.array([1., 4., 6., 4., 1.])
  26. elif(self.filt_size==6):
  27. a = np.array([1., 5., 10., 10., 5., 1.])
  28. elif(self.filt_size==7):
  29. a = np.array([1., 6., 15., 20., 15., 6., 1.])
  30. a = a[:,None]*a[None,:]
  31. a = a / np.sum(a)
  32. a = a[:,:,None,None]
  33. self.a = a
  34. super().__init__(**kwargs)
  35. def build_weights(self):
  36. self.k = tf.constant (self.a, dtype=nn.floatx )
  37. def forward(self, x):
  38. k = tf.tile (self.k, (1,1,x.shape[nn.conv2d_ch_axis],1) )
  39. x = tf.pad(x, self.padding )
  40. x = tf.nn.depthwise_conv2d(x, k, self.strides, 'VALID', data_format=nn.data_format)
  41. return x
  42. nn.BlurPool = BlurPool
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...