Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

regularization_utils.py 837 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
  1. import torch
  2. from torch import nn
  3. class DropPath(nn.Module):
  4. """
  5. Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).
  6. Code taken from TIMM (https://github.com/rwightman/pytorch-image-models)
  7. Apache License 2.0
  8. """
  9. def __init__(self, drop_prob=None):
  10. super(DropPath, self).__init__()
  11. self.drop_prob = drop_prob
  12. def forward(self, x):
  13. if self.drop_prob == 0. or not self.training:
  14. return x
  15. keep_prob = 1 - self.drop_prob
  16. shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
  17. random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device)
  18. random_tensor.floor_() # binarize
  19. output = x.div(keep_prob) * random_tensor
  20. return output
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...