Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_utils.py 2.6 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import unittest
  8. import torch
  9. from torch.autograd import Variable
  10. from fairseq import utils
  11. class TestUtils(unittest.TestCase):
  12. def test_convert_padding_direction(self):
  13. pad = 1
  14. left_pad = torch.LongTensor([
  15. [2, 3, 4, 5, 6],
  16. [1, 7, 8, 9, 10],
  17. [1, 1, 1, 11, 12],
  18. ])
  19. right_pad = torch.LongTensor([
  20. [2, 3, 4, 5, 6],
  21. [7, 8, 9, 10, 1],
  22. [11, 12, 1, 1, 1],
  23. ])
  24. self.assertAlmostEqual(
  25. right_pad,
  26. utils.convert_padding_direction(
  27. left_pad,
  28. pad,
  29. left_to_right=True,
  30. ),
  31. )
  32. self.assertAlmostEqual(
  33. left_pad,
  34. utils.convert_padding_direction(
  35. right_pad,
  36. pad,
  37. right_to_left=True,
  38. ),
  39. )
  40. def test_make_positions(self):
  41. pad = 1
  42. left_pad_input = torch.LongTensor([
  43. [9, 9, 9, 9, 9],
  44. [1, 9, 9, 9, 9],
  45. [1, 1, 1, 9, 9],
  46. ])
  47. left_pad_output = torch.LongTensor([
  48. [2, 3, 4, 5, 6],
  49. [1, 2, 3, 4, 5],
  50. [1, 1, 1, 2, 3],
  51. ])
  52. right_pad_input = torch.LongTensor([
  53. [9, 9, 9, 9, 9],
  54. [9, 9, 9, 9, 1],
  55. [9, 9, 1, 1, 1],
  56. ])
  57. right_pad_output = torch.LongTensor([
  58. [2, 3, 4, 5, 6],
  59. [2, 3, 4, 5, 1],
  60. [2, 3, 1, 1, 1],
  61. ])
  62. self.assertAlmostEqual(
  63. left_pad_output,
  64. utils.make_positions(left_pad_input, pad, left_pad=True),
  65. )
  66. self.assertAlmostEqual(
  67. right_pad_output,
  68. utils.make_positions(right_pad_input, pad, left_pad=False),
  69. )
  70. def test_make_variable(self):
  71. t = [{'k': torch.rand(5, 5)}]
  72. v = utils.make_variable(t)[0]['k']
  73. self.assertTrue(isinstance(v, Variable))
  74. self.assertFalse(v.data.is_cuda)
  75. v = utils.make_variable(t, cuda=True)[0]['k']
  76. self.assertEqual(v.data.is_cuda, torch.cuda.is_available())
  77. def assertAlmostEqual(self, t1, t2):
  78. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  79. self.assertLess(utils.item((t1 - t2).abs().max()), 1e-4)
  80. if __name__ == '__main__':
  81. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...