Register
Login
Resources
Docs Blog Datasets Glossary Case Studies Tutorials & Webinars
Product
Data Engine LLMs Platform Enterprise
Pricing Explore
Connect to our Discord channel

DenseNorm.py 485 B

You have to be logged in to leave a comment. Sign In
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
  1. from core.leras import nn
  2. tf = nn.tf
  3. class DenseNorm(nn.LayerBase):
  4. def __init__(self, dense=False, eps=1e-06, dtype=None, **kwargs):
  5. self.dense = dense
  6. if dtype is None:
  7. dtype = nn.floatx
  8. self.eps = tf.constant(eps, dtype=dtype, name="epsilon")
  9. super().__init__(**kwargs)
  10. def __call__(self, x):
  11. return x * tf.rsqrt(tf.reduce_mean(tf.square(x), axis=-1, keepdims=True) + self.eps)
  12. nn.DenseNorm = DenseNorm
Tip!

Press p or to see the previous file or, n or to see the next file

Comments

Loading...