Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

sinusoidal_positional_embedding.py 3.0 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import math
  8. import torch
  9. from torch.autograd import Variable
  10. import torch.nn as nn
  11. from fairseq import utils
  12. class SinusoidalPositionalEmbedding(nn.Module):
  13. """This module produces sinusoidal positional embeddings of any length.
  14. Padding symbols are ignored, but it is necessary to specify whether padding
  15. is added on the left side (left_pad=True) or right side (left_pad=False).
  16. """
  17. def __init__(self, embedding_dim, padding_idx, left_pad, init_size=1024):
  18. super().__init__()
  19. self.embedding_dim = embedding_dim
  20. self.padding_idx = padding_idx
  21. self.left_pad = left_pad
  22. self.weights = SinusoidalPositionalEmbedding.get_embedding(
  23. init_size,
  24. embedding_dim,
  25. padding_idx,
  26. )
  27. self.register_buffer('_float_tensor', torch.FloatTensor())
  28. @staticmethod
  29. def get_embedding(num_embeddings, embedding_dim, padding_idx=None):
  30. """Build sinusoidal embeddings.
  31. This matches the implementation in tensor2tensor, but differs slightly
  32. from the description in Section 3.5 of "Attention Is All You Need".
  33. """
  34. half_dim = embedding_dim // 2
  35. emb = math.log(10000) / (half_dim - 1)
  36. emb = torch.exp(torch.arange(half_dim) * -emb)
  37. emb = torch.arange(num_embeddings).unsqueeze(1) * emb.unsqueeze(0)
  38. emb = torch.cat([torch.sin(emb), torch.cos(emb)], dim=1).view(num_embeddings, -1)
  39. if embedding_dim % 2 == 1:
  40. # zero pad
  41. emb = torch.cat([emb, torch.zeros(num_embeddings, 1)], dim=1)
  42. if padding_idx is not None:
  43. emb[padding_idx, :] = 0
  44. return emb
  45. def forward(self, input, incremental_state=None):
  46. """Input is expected to be of size [bsz x seqlen]."""
  47. # recompute/expand embeddings if needed
  48. bsz, seq_len = input.size()
  49. max_pos = self.padding_idx + 1 + seq_len
  50. if max_pos > self.weights.size(0):
  51. self.weights = SinusoidalPositionalEmbedding.get_embedding(
  52. max_pos,
  53. self.embedding_dim,
  54. self.padding_idx,
  55. ).type_as(self.weights)
  56. self.weights = self.weights.type_as(self._float_tensor)
  57. weights = Variable(self.weights)
  58. if incremental_state is not None:
  59. # positions is the same for every token when decoding a single step
  60. return weights[self.padding_idx + seq_len, :].expand(bsz, 1, -1)
  61. positions = Variable(utils.make_positions(input.data, self.padding_idx, self.left_pad))
  62. return weights.index_select(0, positions.view(-1)).view(bsz, seq_len, -1)
  63. def max_positions(self):
  64. """Maximum number of supported positions."""
  65. return int(1e5) # an arbitrary large number
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...