Skip to content

活性化関数

カスタム活性化関数を提供します。

ml_networks.torch.activations(PyTorch)とml_networks.jax.activations(JAX)の両方で提供されています。

PyTorchに実装されている活性化関数に加えて、以下のカスタム活性化関数が使えます。

Activation

Activation

Activation(activation, **kwargs)

Bases: Module

Generic activation function.

Source code in src/ml_networks/torch/activations.py
def __init__(self, activation: str, **kwargs: Any) -> None:
    super().__init__()
    if "glu" not in activation.lower():
        kwargs.pop("dim", None)
    try:
        self.activation = getattr(nn, activation)(**kwargs)
    except AttributeError as err:
        if activation == "TanhExp":
            self.activation = TanhExp()
        elif activation == "REReLU":
            self.activation = REReLU(**kwargs)
        elif activation in {"SiGLU", "SwiGLU"}:
            self.activation = SiGLU(**kwargs)
        elif activation == "CRReLU":
            self.activation = CRReLU(**kwargs)
        elif activation == "L2Norm":
            self.activation = L2Norm()
        else:
            msg = f"Activation: '{activation}' is not implemented yet."
            raise NotImplementedError(msg) from err

Attributes

activation instance-attribute

activation = getattr(nn, activation)(**kwargs)

Functions

forward

forward(x)
Source code in src/ml_networks/torch/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return self.activation(x)

REReLU

Reparametrized ReLU: 逆伝播がGELU等になるReLU。See paper

REReLU

REReLU(reparametarize_fn='gelu')

Bases: Module

Reparametarized ReLU activation function. This backward pass is differentiable.

Parameters:

Name Type Description Default
reparametarize_fn str

Reparametarization function. Default is GELU.

'gelu'
References

https://openreview.net/forum?id=lNCnZwcH5Z

Examples:

>>> rerelu = REReLU()
>>> x = torch.randn(1, 3)
>>> output = rerelu(x)
>>> output.shape
torch.Size([1, 3])
Source code in src/ml_networks/torch/activations.py
def __init__(self, reparametarize_fn: str = "gelu") -> None:
    super().__init__()
    reparametarize_fn = reparametarize_fn.lower()
    self.reparametarize_fn = getattr(F, reparametarize_fn)

Attributes

reparametarize_fn instance-attribute

reparametarize_fn = getattr(functional, reparametarize_fn)

Functions

forward

forward(x)
Source code in src/ml_networks/torch/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return F.relu(x).detach() + self.reparametarize_fn(x) - self.reparametarize_fn(x).detach()

SiGLU

SiLU + GLU: SiLU(Swish)とGLUを組み合わせた活性化関数。See paper

SiGLU

SiGLU(dim=-1)

Bases: Module

SiGLU activation function.

This is equivalent to SwiGLU (Swish variant of Gated Linear Unit) activation function.

Parameters:

Name Type Description Default
dim int

Dimension to split the tensor. Default is -1.

-1
References

https://arxiv.org/abs/2102.11972

Examples:

>>> siglu = SiGLU()
>>> x = torch.randn(1, 4)
>>> output = siglu(x)
>>> output.shape
torch.Size([1, 2])
>>> siglu = SiGLU(dim=0)
>>> x = torch.randn(4, 1)
>>> output = siglu(x)
>>> output.shape
torch.Size([2, 1])
Source code in src/ml_networks/torch/activations.py
def __init__(self, dim: int = -1) -> None:
    super().__init__()
    self.dim = dim

Attributes

dim instance-attribute

dim = dim

Functions

forward

forward(x)
Source code in src/ml_networks/torch/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    x1, x2 = x.chunk(2, dim=self.dim)
    return x1 * F.silu(x2)

CRReLU

Correction Regularized ReLU: 正則化されたReLU。See paper

CRReLU

CRReLU(lr=0.01)

Bases: Module

Correction Regularized ReLU activation function. This is a variant of ReLU activation function.

Parameters:

Name Type Description Default
lr float

Learning rate. Default is 0.01.

0.01
References

https://openreview.net/forum?id=7TZYM6Hm9p

Examples:

>>> crrelu = CRReLU()
>>> x = torch.randn(1, 3)
>>> output = crrelu(x)
>>> output.shape
torch.Size([1, 3])
Source code in src/ml_networks/torch/activations.py
def __init__(self, lr: float = 0.01) -> None:
    super().__init__()
    self.lr = nn.Parameter(torch.tensor(lr).float())

Attributes

lr instance-attribute

lr = Parameter(float())

Functions

forward

forward(x)
Source code in src/ml_networks/torch/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return F.relu(x) + self.lr * x * torch.exp(-(x**2) / 2)

TanhExp

Mishの改善版という位置付け。See article

TanhExp

Bases: Module

TanhExp activation function.

Examples:

>>> tanhexp = TanhExp()
>>> x = torch.randn(1, 3)
>>> output = tanhexp(x)
>>> output.shape
torch.Size([1, 3])

Functions

forward

forward(x)
Source code in src/ml_networks/torch/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return TanhExpBase.apply(x)

L2Norm

L2正規化レイヤー。特徴量を単位超球上に射影します。

L2Norm

Bases: Module

L2 Normalization layer.

Examples:

>>> l2norm = L2Norm()
>>> x = torch.randn(2, 3)
>>> output = l2norm(x)
>>> output.shape
torch.Size([2, 3])

Functions

forward

forward(x)
Source code in src/ml_networks/torch/activations.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    return F.normalize(x, p=2, dim=-1)