活性化関数¶
カスタム活性化関数を提供します。
ml_networks.torch.activations(PyTorch)とml_networks.jax.activations(JAX)の両方で提供されています。
PyTorchに実装されている活性化関数に加えて、以下のカスタム活性化関数が使えます。
Activation¶
Activation ¶
Bases: Module
Generic activation function.
Source code in src/ml_networks/torch/activations.py
REReLU¶
Reparametrized ReLU: 逆伝播がGELU等になるReLU。See paper
REReLU ¶
Bases: Module
Reparametarized ReLU activation function. This backward pass is differentiable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reparametarize_fn
|
str
|
Reparametarization function. Default is GELU. |
'gelu'
|
References
https://openreview.net/forum?id=lNCnZwcH5Z
Examples:
>>> rerelu = REReLU()
>>> x = torch.randn(1, 3)
>>> output = rerelu(x)
>>> output.shape
torch.Size([1, 3])
Source code in src/ml_networks/torch/activations.py
SiGLU¶
SiLU + GLU: SiLU(Swish)とGLUを組み合わせた活性化関数。See paper
SiGLU ¶
Bases: Module
SiGLU activation function.
This is equivalent to SwiGLU (Swish variant of Gated Linear Unit) activation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to split the tensor. Default is -1. |
-1
|
References
https://arxiv.org/abs/2102.11972
Examples:
>>> siglu = SiGLU()
>>> x = torch.randn(1, 4)
>>> output = siglu(x)
>>> output.shape
torch.Size([1, 2])
>>> siglu = SiGLU(dim=0)
>>> x = torch.randn(4, 1)
>>> output = siglu(x)
>>> output.shape
torch.Size([2, 1])
Source code in src/ml_networks/torch/activations.py
CRReLU¶
Correction Regularized ReLU: 正則化されたReLU。See paper
CRReLU ¶
Bases: Module
Correction Regularized ReLU activation function. This is a variant of ReLU activation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
float
|
Learning rate. Default is 0.01. |
0.01
|
References
https://openreview.net/forum?id=7TZYM6Hm9p
Examples:
>>> crrelu = CRReLU()
>>> x = torch.randn(1, 3)
>>> output = crrelu(x)
>>> output.shape
torch.Size([1, 3])
Source code in src/ml_networks/torch/activations.py
TanhExp¶
Mishの改善版という位置付け。See article
TanhExp ¶
L2Norm¶
L2正規化レイヤー。特徴量を単位超球上に射影します。