Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

lstm.py 666 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  1. import yaml
  2. import torch
  3. from training.base import TrainerBase
  4. with open('params.yaml', 'r') as f:
  5. PARAMS = yaml.safe_load(f)
  6. class Trainer(TrainerBase):
  7. def __init__(self, model, method, mode):
  8. super(Trainer, self).__init__(model, method, mode)
  9. self._optimizer = torch.optim.RMSprop(self._model.parameters(), lr=float(PARAMS[mode]['optimizer']['lr']),
  10. weight_decay=float(PARAMS[mode]['optimizer']['weight_decay']))
  11. self._scheduler = torch.optim.lr_scheduler.StepLR(
  12. self._optimizer, PARAMS[mode]['optimizer']['step_lr'], gamma=PARAMS[mode]['optimizer']['gamma']
  13. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...