Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

#842 fixed version

Merged
Ghost merged 1 commits into Deci-AI:master from deci-ai:bugfix/SG-000_fix_version
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
  1. from typing import Union, Optional
  2. from torch import nn
  3. from super_gradients.common import UpsampleMode
  4. from super_gradients.common.data_types.enum import DownSampleMode
  5. from super_gradients.modules.anti_alias import AntiAliasDownsample
  6. def make_upsample_module(scale_factor: int, upsample_mode: Union[str, UpsampleMode], align_corners: Optional[bool] = None):
  7. """
  8. Factory method for creating upsampling modules.
  9. :param scale_factor: upsample scale factor
  10. :param upsample_mode: see UpsampleMode for supported options.
  11. :return: nn.Module
  12. """
  13. upsample_mode = upsample_mode.value if isinstance(upsample_mode, UpsampleMode) else upsample_mode
  14. if upsample_mode == UpsampleMode.NEAREST.value:
  15. # Prevent ValueError when passing align_corners with nearest mode.
  16. module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode)
  17. elif upsample_mode in [UpsampleMode.BILINEAR.value, UpsampleMode.BICUBIC.value]:
  18. module = nn.Upsample(scale_factor=scale_factor, mode=upsample_mode, align_corners=align_corners)
  19. else:
  20. raise NotImplementedError(f"Upsample mode: `{upsample_mode}` is not supported.")
  21. return module
  22. def make_downsample_module(in_channels: int, stride: int, downsample_mode: Union[str, DownSampleMode]):
  23. """
  24. Factory method for creating down-sampling modules.
  25. :param downsample_mode: see DownSampleMode for supported options.
  26. :return: nn.Module
  27. """
  28. downsample_mode = downsample_mode.value if isinstance(downsample_mode, DownSampleMode) else downsample_mode
  29. if downsample_mode == DownSampleMode.ANTI_ALIAS.value:
  30. return AntiAliasDownsample(in_channels, stride)
  31. if downsample_mode == DownSampleMode.MAX_POOL.value:
  32. return nn.MaxPool2d(kernel_size=stride, stride=stride)
  33. raise NotImplementedError(f"DownSample mode: `{downsample_mode}` is not supported.")
Discard
Tip!

Press p or to see the previous file or, n or to see the next file