Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#433 Feature/SG 143 black formatter

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:feature/SG-143-black-formatter
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
  1. import torch
  2. from torch import nn, Tensor
  3. import torch.nn.functional as F
  4. class SEBlock(nn.Module):
  5. """
  6. Spatial Squeeze and Channel Excitation Block (cSE).
  7. Figure 1, Variant a from https://arxiv.org/abs/1808.08127v1
  8. """
  9. def __init__(self, in_channels: int, internal_neurons: int):
  10. super(SEBlock, self).__init__()
  11. self.down = nn.Conv2d(in_channels=in_channels, out_channels=internal_neurons, kernel_size=1, stride=1, bias=True)
  12. self.up = nn.Conv2d(in_channels=internal_neurons, out_channels=in_channels, kernel_size=1, stride=1, bias=True)
  13. self.input_channels = in_channels
  14. def forward(self, inputs: Tensor) -> Tensor:
  15. x = F.avg_pool2d(inputs, kernel_size=inputs.size(3))
  16. x = self.down(x)
  17. x = F.relu(x)
  18. x = self.up(x)
  19. x = torch.sigmoid(x)
  20. x = x.view(-1, self.input_channels, 1, 1)
  21. return inputs * x
Discard
Tip!

Press p or to see the previous file or, n or to see the next file