Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

multistage.py 2.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
  1. from typing import Optional
  2. import torch
  3. from deadtrees.utils import utils
  4. from pytorch_lightning import Callback
  5. log = utils.get_logger(__name__)
  6. class MultiStage(Callback):
  7. def __init__(
  8. self,
  9. *,
  10. unfreeze_epoch: int,
  11. lr_reduce_epoch: Optional[int] = None,
  12. lr_reduce_fraction: Optional[float] = None,
  13. ):
  14. super().__init__()
  15. self.unfreeze_epoch = unfreeze_epoch # epoch when to unfreeze encoder
  16. self.lr_reduce_epoch = lr_reduce_epoch # epoch when to reduce learning rate
  17. self.lr_reduce_fraction = lr_reduce_fraction # reduce learning rate by fraction
  18. def on_train_epoch_start(self, trainer, pl_module):
  19. if trainer.current_epoch == 0:
  20. if pl_module.encoder_weights is None:
  21. log.error(
  22. "No encoder weights given but MultiStage encoder freeze requested"
  23. )
  24. exit()
  25. else:
  26. log.info(
  27. f"Using pre-trained encoder weights: {pl_module.encoder_weights}"
  28. )
  29. # freeze encoder
  30. log.info(f"NEW STAGE (epoch: {trainer.current_epoch}): Freeze encoder")
  31. pl_module.model.encoder.eval()
  32. for m in pl_module.model.encoder.modules():
  33. m.requires_grad_ = False
  34. if trainer.current_epoch == self.unfreeze_epoch:
  35. # unfreeze encoder, keep default learning rate
  36. log.info(f"NEW STAGE (epoch: {trainer.current_epoch}): Unfreeze encoder")
  37. pl_module.model.encoder.train()
  38. for m in pl_module.model.encoder.modules():
  39. m.requires_grad_ = True
  40. if self.lr_reduce_epoch:
  41. # we need a lr_reduce_fraction here!
  42. assert self.lr_reduce_fraction is not None
  43. if trainer.current_epoch == self.lr_reduce_epoch:
  44. # also use unfrozen encoder, lower learning rate
  45. log.info(
  46. f"NEW STAGE (epoch: {trainer.current_epoch}): Lower LR rate by factor {self.lr_reduce_fraction}"
  47. )
  48. new_optimizer = torch.optim.Adam(
  49. pl_module.parameters(),
  50. lr=pl_module.hparams.training.learning_rate
  51. / self.lr_reduce_fraction,
  52. )
  53. new_scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(
  54. new_optimizer, T_max=pl_module.hparams.training.cosineannealing_tmax
  55. )
  56. trainer.optimizers = [new_optimizer]
  57. trainer.lr_schedulers = trainer._configure_schedulers(
  58. [new_scheduler], monitor=None, is_manual_optimization=False
  59. )
  60. trainer.optimizer_frequencies = (
  61. []
  62. ) # or optimizers frequencies if you have any
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...