Thank you! We'll be in touch ASAP.
Something went wrong, please try again or contact us directly at contact@dagshub.com
Deci-AI:master
deci-ai:bugfix/infra-000_ci
import torch import torch.nn as nn def fuse_conv_bn(model: nn.Module, replace_bn_with_identity: bool = False): """ Fuses consecutive nn.Conv2d and nn.BatchNorm2d layers recursively inplace in all of the model :param replace_bn_with_identity: if set to true, bn will be replaced with identity. otherwise, bn will be removed :param model: the target model :return: the number of fuses executed """ children = list(model.named_children()) counter = 0 for i in range(len(children) - 1): if isinstance(children[i][1], torch.nn.Conv2d) and isinstance(children[i + 1][1], torch.nn.BatchNorm2d): setattr(model, children[i][0], torch.nn.utils.fuse_conv_bn_eval(children[i][1], children[i + 1][1])) if replace_bn_with_identity: setattr(model, children[i + 1][0], nn.Identity()) else: delattr(model, children[i + 1][0]) counter += 1 for child_name, child in children: counter += fuse_conv_bn(child, replace_bn_with_identity) return counter
Press p or to see the previous file or, n or to see the next file