Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

export_utils.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
  1. import torch
  2. import torch.nn as nn
  3. import torch.nn.functional as F
  4. class ExportableHardswish(nn.Module):
  5. '''
  6. Export-friendly version of nn.Hardswish()
  7. '''
  8. @staticmethod
  9. def forward(x):
  10. return x * F.hardtanh(x + 3, 0., 6.) / 6. # for torchscript, CoreML and ONNX
  11. def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
  12. """
  13. Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model
  14. :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed
  15. :param model: the target model
  16. :return: the number of fuses executed
  17. """
  18. children = list(model.named_children())
  19. counter = 0
  20. for i in range(len(children) - 1):
  21. if isinstance(children[i][1], torch.nn.Conv2d) and isinstance(children[i + 1][1], torch.nn.BatchNorm2d):
  22. setattr(model, children[i][0], torch.nn.utils.fuse_conv_bn_eval(children[i][1], children[i + 1][1]))
  23. if replace_bn_with_identity:
  24. setattr(model, children[i + 1][0], nn.Identity())
  25. else:
  26. delattr(model, children[i + 1][0])
  27. counter += 1
  28. for child_name, child in children:
  29. counter += fuse_conv_bn(child, replace_bn_with_identity)
  30. return counter
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...