Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

loss.py 649 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
  1. import torch
  2. EPS = 1e-15
  3. def barlow_twins_loss(
  4. z_a: torch.Tensor,
  5. z_b: torch.Tensor,
  6. ) -> torch.Tensor:
  7. batch_size = z_a.size(0)
  8. feature_dim = z_a.size(1)
  9. _lambda = 1 / feature_dim
  10. # Apply batch normalization
  11. z_a_norm = (z_a - z_a.mean(dim=0)) / (z_a.std(dim=0) + EPS)
  12. z_b_norm = (z_b - z_b.mean(dim=0)) / (z_b.std(dim=0) + EPS)
  13. # Cross-correlation matrix
  14. c = (z_a_norm.T @ z_b_norm) / batch_size
  15. # Loss function
  16. off_diagonal_mask = ~torch.eye(feature_dim).bool()
  17. loss = (
  18. (1 - c.diagonal()).pow(2).sum()
  19. + _lambda * c[off_diagonal_mask].pow(2).sum()
  20. )
  21. return loss
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...