Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

test_convtbc.py 1.8 KB

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
  1. # Copyright (c) 2017-present, Facebook, Inc.
  2. # All rights reserved.
  3. #
  4. # This source code is licensed under the license found in the LICENSE file in
  5. # the root directory of this source tree. An additional grant of patent rights
  6. # can be found in the PATENTS file in the same directory.
  7. import torch
  8. import unittest
  9. from fairseq.modules import ConvTBC
  10. import torch.nn as nn
  11. from torch.autograd import Variable
  12. class TestConvTBC(unittest.TestCase):
  13. def test_convtbc(self):
  14. # ksz, in_channels, out_channels
  15. conv_tbc = ConvTBC(4, 5, kernel_size=3, padding=1)
  16. # out_channels, in_channels, ksz
  17. conv1d = nn.Conv1d(4, 5, kernel_size=3, padding=1)
  18. conv_tbc.weight.data.copy_(conv1d.weight.data.transpose(0, 2))
  19. conv_tbc.bias.data.copy_(conv1d.bias.data)
  20. input_tbc = Variable(torch.randn(7, 2, 4), requires_grad=True)
  21. input1d = Variable(input_tbc.data.transpose(0, 1).transpose(1, 2), requires_grad=True)
  22. output_tbc = conv_tbc(input_tbc)
  23. output1d = conv1d(input1d)
  24. self.assertAlmostEqual(output_tbc.data.transpose(0, 1).transpose(1, 2), output1d.data)
  25. grad_tbc = torch.randn(output_tbc.size())
  26. grad1d = grad_tbc.transpose(0, 1).transpose(1, 2).contiguous()
  27. output_tbc.backward(grad_tbc)
  28. output1d.backward(grad1d)
  29. self.assertAlmostEqual(conv_tbc.weight.grad.data.transpose(0, 2), conv1d.weight.grad.data)
  30. self.assertAlmostEqual(conv_tbc.bias.grad.data, conv1d.bias.grad.data)
  31. self.assertAlmostEqual(input_tbc.grad.data.transpose(0, 1).transpose(1, 2), input1d.grad.data)
  32. def assertAlmostEqual(self, t1, t2):
  33. self.assertEqual(t1.size(), t2.size(), "size mismatch")
  34. self.assertLess((t1 - t2).abs().max(), 1e-4)
  35. if __name__ == '__main__':
  36. unittest.main()
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...