Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

fsdp_patch.py 1.3 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
  1. """
  2. Monkeypatch to fix fsdp set state when no previous state was set
  3. https://github.com/OpenAccess-AI-Collective/axolotl/pull/400/files
  4. """
  5. import contextlib
  6. from typing import Generator, Optional
  7. import torch
  8. from torch import nn
  9. from torch.distributed.fsdp.api import (
  10. OptimStateDictConfig,
  11. StateDictConfig,
  12. StateDictType,
  13. )
  14. from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel
  15. @staticmethod
  16. @contextlib.contextmanager
  17. def state_dict_type_patch(
  18. module: nn.Module,
  19. state_dict_type: StateDictType,
  20. state_dict_config: Optional[StateDictConfig] = None,
  21. optim_state_dict_config: Optional[OptimStateDictConfig] = None,
  22. ) -> Generator:
  23. prev_state_dict_settings = FullyShardedDataParallel.set_state_dict_type(
  24. module,
  25. state_dict_type,
  26. state_dict_config,
  27. optim_state_dict_config,
  28. )
  29. yield
  30. if prev_state_dict_settings.state_dict_type:
  31. FullyShardedDataParallel.set_state_dict_type(
  32. module,
  33. prev_state_dict_settings.state_dict_type,
  34. prev_state_dict_settings.state_dict_config,
  35. prev_state_dict_settings.optim_state_dict_config,
  36. )
  37. def replace_fsdp_state_dict_type():
  38. torch.distributed.fsdp.fully_sharded_data_parallel.FullyShardedDataParallel.state_dict_type = (
  39. state_dict_type_patch
  40. )
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...