Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

__init__.py 2.2 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import importlib
  8. import os
  9. from .fairseq_decoder import FairseqDecoder # noqa: F401
  10. from .fairseq_encoder import FairseqEncoder # noqa: F401
  11. from .fairseq_incremental_decoder import FairseqIncrementalDecoder # noqa: F401
  12. from .fairseq_model import FairseqModel # noqa: F401
  13. MODEL_REGISTRY = {}
  14. ARCH_MODEL_REGISTRY = {}
  15. ARCH_CONFIG_REGISTRY = {}
  16. def build_model(args, src_dict, dst_dict):
  17. return ARCH_MODEL_REGISTRY[args.arch].build_model(args, src_dict, dst_dict)
  18. def register_model(name):
  19. """Decorator to register a new model (e.g., LSTM)."""
  20. def register_model_cls(cls):
  21. if name in MODEL_REGISTRY:
  22. raise ValueError('Cannot register duplicate model ({})'.format(name))
  23. if not issubclass(cls, FairseqModel):
  24. raise ValueError('Model ({}: {}) must extend FairseqModel'.format(name, cls.__name__))
  25. MODEL_REGISTRY[name] = cls
  26. return cls
  27. return register_model_cls
  28. def register_model_architecture(model_name, arch_name):
  29. """Decorator to register a new model architecture (e.g., lstm_luong_wmt_en_de)."""
  30. def register_model_arch_fn(fn):
  31. if model_name not in MODEL_REGISTRY:
  32. raise ValueError('Cannot register model architecture for unknown model type ({})'.format(model_name))
  33. if arch_name in ARCH_MODEL_REGISTRY:
  34. raise ValueError('Cannot register duplicate model architecture ({})'.format(arch_name))
  35. if not callable(fn):
  36. raise ValueError('Model architecture must be callable ({})'.format(arch_name))
  37. ARCH_MODEL_REGISTRY[arch_name] = MODEL_REGISTRY[model_name]
  38. ARCH_CONFIG_REGISTRY[arch_name] = fn
  39. return fn
  40. return register_model_arch_fn
  41. # automatically import any Python files in the models/ directory
  42. for file in os.listdir(os.path.dirname(__file__)):
  43. if file.endswith('.py') and not file.startswith('_'):
  44. module = file[:file.find('.py')]
  45. importlib.import_module('fairseq.models.' + module)
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...