Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#609 Ci fix

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/infra-000_ci
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  1. import torch
  2. import torch.nn as nn
  3. def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False):
  4. """
  5. Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model
  6. :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed
  7. :param model: the target model
  8. :return: the number of fuses executed
  9. """
  10. children = list(model.named_children())
  11. counter = 0
  12. for i in range(len(children) - 1):
  13. if isinstance(children[i][1], torch.nn.Conv2d) and isinstance(children[i + 1][1], torch.nn.BatchNorm2d):
  14. setattr(model, children[i][0], torch.nn.utils.fuse_conv_bn_eval(children[i][1], children[i + 1][1]))
  15. if replace_bn_with_identity:
  16. setattr(model, children[i + 1][0], nn.Identity())
  17. else:
  18. delattr(model, children[i + 1][0])
  19. counter += 1
  20. for child_name, child in children:
  21. counter += fuse_conv_bn(child, replace_bn_with_identity)
  22. return counter
Discard
Tip!

Press p or to see the previous file or, n or to see the next file