1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
|
- from core.leras import nn
- tf = nn.tf
- class XSeg(nn.ModelBase):
-
- def on_build (self, in_ch, base_ch, out_ch):
-
- class ConvBlock(nn.ModelBase):
- def on_build(self, in_ch, out_ch):
- self.conv = nn.Conv2D (in_ch, out_ch, kernel_size=3, padding='SAME')
- self.frn = nn.FRNorm2D(out_ch)
- self.tlu = nn.TLU(out_ch)
- def forward(self, x):
- x = self.conv(x)
- x = self.frn(x)
- x = self.tlu(x)
- return x
- class UpConvBlock(nn.ModelBase):
- def on_build(self, in_ch, out_ch):
- self.conv = nn.Conv2DTranspose (in_ch, out_ch, kernel_size=3, padding='SAME')
- self.frn = nn.FRNorm2D(out_ch)
- self.tlu = nn.TLU(out_ch)
- def forward(self, x):
- x = self.conv(x)
- x = self.frn(x)
- x = self.tlu(x)
- return x
-
- self.base_ch = base_ch
- self.conv01 = ConvBlock(in_ch, base_ch)
- self.conv02 = ConvBlock(base_ch, base_ch)
- self.bp0 = nn.BlurPool (filt_size=4)
- self.conv11 = ConvBlock(base_ch, base_ch*2)
- self.conv12 = ConvBlock(base_ch*2, base_ch*2)
- self.bp1 = nn.BlurPool (filt_size=3)
- self.conv21 = ConvBlock(base_ch*2, base_ch*4)
- self.conv22 = ConvBlock(base_ch*4, base_ch*4)
- self.bp2 = nn.BlurPool (filt_size=2)
- self.conv31 = ConvBlock(base_ch*4, base_ch*8)
- self.conv32 = ConvBlock(base_ch*8, base_ch*8)
- self.conv33 = ConvBlock(base_ch*8, base_ch*8)
- self.bp3 = nn.BlurPool (filt_size=2)
- self.conv41 = ConvBlock(base_ch*8, base_ch*8)
- self.conv42 = ConvBlock(base_ch*8, base_ch*8)
- self.conv43 = ConvBlock(base_ch*8, base_ch*8)
- self.bp4 = nn.BlurPool (filt_size=2)
-
- self.conv51 = ConvBlock(base_ch*8, base_ch*8)
- self.conv52 = ConvBlock(base_ch*8, base_ch*8)
- self.conv53 = ConvBlock(base_ch*8, base_ch*8)
- self.bp5 = nn.BlurPool (filt_size=2)
-
- self.dense1 = nn.Dense ( 4*4* base_ch*8, 512)
- self.dense2 = nn.Dense ( 512, 4*4* base_ch*8)
-
- self.up5 = UpConvBlock (base_ch*8, base_ch*4)
- self.uconv53 = ConvBlock(base_ch*12, base_ch*8)
- self.uconv52 = ConvBlock(base_ch*8, base_ch*8)
- self.uconv51 = ConvBlock(base_ch*8, base_ch*8)
-
- self.up4 = UpConvBlock (base_ch*8, base_ch*4)
- self.uconv43 = ConvBlock(base_ch*12, base_ch*8)
- self.uconv42 = ConvBlock(base_ch*8, base_ch*8)
- self.uconv41 = ConvBlock(base_ch*8, base_ch*8)
- self.up3 = UpConvBlock (base_ch*8, base_ch*4)
- self.uconv33 = ConvBlock(base_ch*12, base_ch*8)
- self.uconv32 = ConvBlock(base_ch*8, base_ch*8)
- self.uconv31 = ConvBlock(base_ch*8, base_ch*8)
- self.up2 = UpConvBlock (base_ch*8, base_ch*4)
- self.uconv22 = ConvBlock(base_ch*8, base_ch*4)
- self.uconv21 = ConvBlock(base_ch*4, base_ch*4)
- self.up1 = UpConvBlock (base_ch*4, base_ch*2)
- self.uconv12 = ConvBlock(base_ch*4, base_ch*2)
- self.uconv11 = ConvBlock(base_ch*2, base_ch*2)
- self.up0 = UpConvBlock (base_ch*2, base_ch)
- self.uconv02 = ConvBlock(base_ch*2, base_ch)
- self.uconv01 = ConvBlock(base_ch, base_ch)
- self.out_conv = nn.Conv2D (base_ch, out_ch, kernel_size=3, padding='SAME')
-
-
- def forward(self, inp, pretrain=False):
- x = inp
- x = self.conv01(x)
- x = x0 = self.conv02(x)
- x = self.bp0(x)
- x = self.conv11(x)
- x = x1 = self.conv12(x)
- x = self.bp1(x)
- x = self.conv21(x)
- x = x2 = self.conv22(x)
- x = self.bp2(x)
- x = self.conv31(x)
- x = self.conv32(x)
- x = x3 = self.conv33(x)
- x = self.bp3(x)
- x = self.conv41(x)
- x = self.conv42(x)
- x = x4 = self.conv43(x)
- x = self.bp4(x)
- x = self.conv51(x)
- x = self.conv52(x)
- x = x5 = self.conv53(x)
- x = self.bp5(x)
-
- x = nn.flatten(x)
- x = self.dense1(x)
- x = self.dense2(x)
- x = nn.reshape_4D (x, 4, 4, self.base_ch*8 )
-
- x = self.up5(x)
- if pretrain:
- x5 = tf.zeros_like(x5)
- x = self.uconv53(tf.concat([x,x5],axis=nn.conv2d_ch_axis))
- x = self.uconv52(x)
- x = self.uconv51(x)
-
- x = self.up4(x)
- if pretrain:
- x4 = tf.zeros_like(x4)
- x = self.uconv43(tf.concat([x,x4],axis=nn.conv2d_ch_axis))
- x = self.uconv42(x)
- x = self.uconv41(x)
- x = self.up3(x)
- if pretrain:
- x3 = tf.zeros_like(x3)
- x = self.uconv33(tf.concat([x,x3],axis=nn.conv2d_ch_axis))
- x = self.uconv32(x)
- x = self.uconv31(x)
- x = self.up2(x)
- if pretrain:
- x2 = tf.zeros_like(x2)
- x = self.uconv22(tf.concat([x,x2],axis=nn.conv2d_ch_axis))
- x = self.uconv21(x)
- x = self.up1(x)
- if pretrain:
- x1 = tf.zeros_like(x1)
- x = self.uconv12(tf.concat([x,x1],axis=nn.conv2d_ch_axis))
- x = self.uconv11(x)
- x = self.up0(x)
- if pretrain:
- x0 = tf.zeros_like(x0)
- x = self.uconv02(tf.concat([x,x0],axis=nn.conv2d_ch_axis))
- x = self.uconv01(x)
- logits = self.out_conv(x)
- return logits, tf.nn.sigmoid(logits)
- nn.XSeg = XSeg
|