Skip to content

API リファレンス

ml-networksの完全なAPIリファレンスです。

モジュール一覧

共通モジュール

PyTorch (ml_networks.torch)

  • レイヤー - 基本的なレイヤー(MLP、Conv、Attention、Transformerなど)
  • ビジョン - ビジョン関連のモジュール(Encoder、Decoder、ConvNet、ResNet、ViTなど)
  • 分布 - 分布関連のクラスと関数
  • 損失関数 - 損失関数
  • 活性化関数 - カスタム活性化関数
  • UNet - 条件付きUNetクラス
  • その他 - HyperNet、ContrastiveLearning、BaseModule、ProgressBarCallback

JAX (ml_networks.jax)

  • JAX API - JAX(Flax NNX)実装のAPIリファレンス

主要なクラスと関数

レイヤー

MLPLayer

MLPLayer(input_dim, output_dim, cfg)

Bases: LightningModule

Multi-layer perceptron layer.

Parameters:

Name Type Description Default
input_dim int

Input dimension.

required
output_dim int

Output dimension.

required
cfg MLPConfig
required

Examples:

>>> cfg = MLPConfig(
...     hidden_dim=16,
...     n_layers=3,
...     output_activation="ReLU",
...     linear_cfg=LinearConfig(
...         activation="ReLU",
...         norm="layer",
...         norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
...         dropout=0.1,
...         norm_first=False,
...         bias=True
...     )
... )
>>> mlp = MLPLayer(32, 16, cfg)
>>> x = torch.randn(1, 32)
>>> output = mlp(x)
>>> output.shape
torch.Size([1, 16])
Source code in src/ml_networks/torch/layers.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    cfg: MLPConfig,
) -> None:
    super().__init__()
    self.cfg = deepcopy(cfg)
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.hidden_dim = cfg.hidden_dim
    self.n_layers = cfg.n_layers
    self.dense = self._build_dense()

Attributes

cfg instance-attribute

cfg = deepcopy(cfg)

dense instance-attribute

dense = _build_dense()

hidden_dim instance-attribute

hidden_dim = hidden_dim

input_dim instance-attribute

input_dim = input_dim

n_layers instance-attribute

n_layers = n_layers

output_dim instance-attribute

output_dim = output_dim

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (*, input_dim)

required

Returns:

Type Description
Tensor

Output tensor of shape (*, output_dim)

Source code in src/ml_networks/torch/layers.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (*, input_dim)

    Returns
    -------
    torch.Tensor
        Output tensor of shape (*, output_dim)

    """
    return self.dense(x)

LinearNormActivation

LinearNormActivation(input_dim, output_dim, cfg)

Bases: Module

Linear layer with normalization and activation, and dropouts.

Parameters:

Name Type Description Default
input_dim int

Input dimension.

required
output_dim int

Output dimension.

required
cfg LinearConfig

Linear layer configuration.

required
References

LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html RMSNorm: https://pytorch.org/docs/stable/generated/torch.nn.RMSNorm.html Linear: https://pytorch.org/docs/stable/generated/torch.nn.Linear.html Dropout: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html

Examples:

>>> cfg = LinearConfig(
...     activation="ReLU",
...     norm="layer",
...     norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
...     dropout=0.1,
...     norm_first=False,
...     bias=True
... )
>>> linear = LinearNormActivation(32, 16, cfg)
>>> linear
LinearNormActivation(
  (linear): Linear(in_features=32, out_features=16, bias=True)
  (norm): LayerNorm((16,), eps=1e-05, elementwise_affine=True)
  (activation): Activation(
    (activation): ReLU()
  )
  (dropout): Dropout(p=0.1, inplace=False)
)
>>> x = torch.randn(1, 32)
>>> output = linear(x)
>>> output.shape
torch.Size([1, 16])
>>> cfg = LinearConfig(
...     activation="SiGLU",
...     norm="none",
...     norm_cfg={},
...     dropout=0.0,
...     norm_first=True,
...     bias=True
... )
>>> linear = LinearNormActivation(32, 16, cfg)
>>> # If activation includes "glu", linear output_dim is doubled to adjust actual output_dim.
>>> linear
LinearNormActivation(
  (linear): Linear(in_features=32, out_features=32, bias=True)
  (norm): Identity()
  (activation): Activation(
    (activation): SiGLU()
  )
  (dropout): Identity()
)
>>> x = torch.randn(1, 32)
>>> output = linear(x)
>>> output.shape
torch.Size([1, 16])
Source code in src/ml_networks/torch/layers.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    cfg: LinearConfig,
) -> None:
    super().__init__()
    self.linear = nn.Linear(
        input_dim,
        output_dim * 2 if "glu" in cfg.activation.lower() else output_dim,
        bias=cfg.bias,
    )
    if cfg.norm_first:
        normalized_shape = input_dim
    else:
        normalized_shape = output_dim * 2 if "glu" in cfg.activation.lower() else output_dim

    norm_cfg = dict(cfg.norm_cfg)
    norm_cfg["normalized_shape"] = normalized_shape
    self.norm = get_norm(cfg.norm, **norm_cfg)
    self.activation = Activation(cfg.activation)
    self.dropout: nn.Module
    if cfg.dropout > 0:
        self.dropout = nn.Dropout(cfg.dropout)
    else:
        self.dropout = nn.Identity()
    self.norm_first = cfg.norm_first

Attributes

activation instance-attribute

activation = Activation(activation)

dropout instance-attribute

dropout

linear instance-attribute

linear = Linear(input_dim, output_dim * 2 if 'glu' in lower() else output_dim, bias=bias)

norm instance-attribute

norm = get_norm(norm, **norm_cfg)

norm_first instance-attribute

norm_first = norm_first

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (*, input_dim)

required

Returns:

Type Description
Tensor

Output tensor of shape (*, output_dim)

Source code in src/ml_networks/torch/layers.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (*, input_dim)

    Returns
    -------
    torch.Tensor
        Output tensor of shape (*, output_dim)
    """
    if self.norm_first:
        x = self.norm(x)
        x = self.linear(x)
        x = self.activation(x)
        x = self.dropout(x)
    else:
        x = self.linear(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.dropout(x)
    return x

ConvNormActivation

ConvNormActivation(in_channels, out_channels, cfg)

Bases: Module

Convolutional layer with normalization and activation, and dropouts.

Parameters:

Name Type Description Default
in_channels int

Input channels.

required
out_channels int

Output channels.

required
cfg ConvConfig

Convolutional layer configuration.

required
References

PixelShuffle: https://pytorch.org/docs/stable/generated/torch.nn.PixelShuffle.html PixelUnshuffle: https://pytorch.org/docs/stable/generated/torch.nn.PixelUnshuffle.html BatchNorm2d: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html GroupNorm: https://pytorch.org/docs/stable/generated/torch.nn.GroupNorm.html LayerNorm: https://pytorch.org/docs/stable/generated/torch.nn.LayerNorm.html InstanceNorm2d: https://pytorch.org/docs/stable/generated/torch.nn.InstanceNorm2d.html Conv2d: https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html Dropout: https://pytorch.org/docs/stable/generated/torch.nn.Dropout.html

Examples:

>>> cfg = ConvConfig(
...     activation="ReLU",
...     kernel_size=3,
...     stride=1,
...     padding=1,
...     dilation=1,
...     groups=1,
...     bias=True,
...     dropout=0.1,
...     norm="batch",
...     norm_cfg={"affine": True, "track_running_stats": True},
...     scale_factor=0
... )
>>> conv = ConvNormActivation(3, 16, cfg)
>>> x = torch.randn(1, 3, 32, 32)
>>> output = conv(x)
>>> output.shape
torch.Size([1, 16, 32, 32])
>>> cfg = ConvConfig(
...     activation="SiGLU",
...     kernel_size=3,
...     stride=1,
...     padding=1,
...     dilation=1,
...     groups=1,
...     bias=True,
...     dropout=0.0,
...     norm="none",
...     norm_cfg={},
...     scale_factor=2
... )
>>> conv = ConvNormActivation(3, 16, cfg)
>>> x = torch.randn(1, 3, 32, 32)
>>> output = conv(x)
>>> output.shape
torch.Size([1, 16, 64, 64])
Source code in src/ml_networks/torch/layers.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    cfg: ConvConfig,
) -> None:
    super().__init__()

    out_channels_ = out_channels
    if "glu" in cfg.activation.lower():
        out_channels_ *= 2
    if cfg.scale_factor > 0:
        out_channels_ *= abs(cfg.scale_factor) ** 2
    elif cfg.scale_factor < 0:
        out_channels_ //= abs(cfg.scale_factor) ** 2
    self.conv = nn.Conv2d(
        in_channels=in_channels,
        out_channels=out_channels_,
        kernel_size=cfg.kernel_size,
        stride=cfg.stride,
        padding=cfg.padding,
        dilation=cfg.dilation,
        groups=cfg.groups,
        bias=cfg.bias,
        padding_mode=cfg.padding_mode,
    )
    norm_cfg = dict(cfg.norm_cfg) if cfg.norm_cfg else {}
    if cfg.norm != "none" and cfg.norm != "group":
        norm_cfg["num_features"] = out_channels_
    elif cfg.norm == "group":
        norm_cfg["num_channels"] = in_channels if cfg.norm_first else out_channels_

    norm_type: Literal["layer", "rms", "group", "batch2d", "batch1d", "none"] = (
        "batch2d" if cfg.norm == "batch" else cfg.norm
    )  # type: ignore[assignment]
    self.norm = get_norm(norm_type, **norm_cfg)
    self.pixel_shuffle: nn.Module
    if cfg.scale_factor > 0:
        self.pixel_shuffle = nn.PixelShuffle(cfg.scale_factor)
    elif cfg.scale_factor < 0:
        self.pixel_shuffle = nn.PixelUnshuffle(abs(cfg.scale_factor))
    else:
        self.pixel_shuffle = nn.Identity()
    self.activation = Activation(cfg.activation, dim=-3)
    self.dropout: nn.Module = nn.Dropout(cfg.dropout) if cfg.dropout > 0 else nn.Identity()
    self.norm_first = cfg.norm_first

Attributes

activation instance-attribute

activation = Activation(activation, dim=-3)

conv instance-attribute

conv = Conv2d(in_channels=in_channels, out_channels=out_channels_, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias, padding_mode=padding_mode)

dropout instance-attribute

dropout = Dropout(dropout) if dropout > 0 else Identity()

norm instance-attribute

norm = get_norm(norm_type, **norm_cfg)

norm_first instance-attribute

norm_first = norm_first

pixel_shuffle instance-attribute

pixel_shuffle

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, in_channels, H, W) or (in_channels, H, W)

required

Returns:

Type Description
Tensor

Output tensor of shape (B, out_channels, H', W') or (out_channels, H', W')

H' and W' are calculated as follows:
H' = (H + 2*padding - dilation * (kernel_size - 1) - 1) // stride + 1
H' = H' * scale_factor if scale_factor > 0 else H' // abs(scale_factor) if scale_factor < 0 else H'
W' = (W + 2*padding - dilation * (kernel_size - 1) - 1) // stride + 1
W' = W' * scale_factor if scale_factor > 0 else W' // abs(scale_factor) if scale_factor < 0 else W'
Source code in src/ml_networks/torch/layers.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (B, in_channels, H, W) or (in_channels, H, W)

    Returns
    -------
    torch.Tensor
        Output tensor of shape (B, out_channels, H', W') or (out_channels, H', W')
    H' and W' are calculated as follows:
    H' = (H + 2*padding - dilation * (kernel_size - 1) - 1) // stride + 1
    H' = H' * scale_factor if scale_factor > 0 else H' // abs(scale_factor) if scale_factor < 0 else H'
    W' = (W + 2*padding - dilation * (kernel_size - 1) - 1) // stride + 1
    W' = W' * scale_factor if scale_factor > 0 else W' // abs(scale_factor) if scale_factor < 0 else W'

    """
    if self.norm_first:
        x = self.norm(x)
        x = self.conv(x)
        x = self.pixel_shuffle(x)
        x = self.activation(x)
        x = self.dropout(x)
    else:
        x = self.conv(x)
        x = self.norm(x)
        x = self.pixel_shuffle(x)
        x = self.activation(x)
        x = self.dropout(x)
    return x

ConvTransposeNormActivation

ConvTransposeNormActivation(in_channels, out_channels, cfg)

Bases: Module

Transposed convolutional layer with normalization and activation, and dropouts.

Parameters:

Name Type Description Default
in_channels int

Input channels.

required
out_channels int

Output channels.

required
cfg ConvConfig

Convolutional layer configuration.

required

Examples:

>>> cfg = ConvConfig(
...     activation="ReLU",
...     kernel_size=3,
...     stride=1,
...     padding=1,
...     output_padding=0,
...     dilation=1,
...     groups=1,
...     bias=True,
...     dropout=0.1,
...     norm="batch",
...     norm_cfg={"affine": True, "track_running_stats": True}
... )
>>> conv = ConvTransposeNormActivation(3, 16, cfg)
>>> x = torch.randn(1, 3, 32, 32)
>>> output = conv(x)
>>> output.shape
torch.Size([1, 16, 32, 32])
>>> cfg = ConvConfig(
...     activation="SiGLU",
...     kernel_size=3,
...     stride=1,
...     padding=1,
...     output_padding=0,
...     dilation=1,
...     groups=1,
...     bias=True,
...     dropout=0.0,
...     norm="none",
...     norm_cfg={}
... )
>>> conv = ConvTransposeNormActivation(3, 16, cfg)
>>> x = torch.randn(1, 3, 32, 32)
>>> output = conv(x)
>>> output.shape
torch.Size([1, 16, 32, 32])
Source code in src/ml_networks/torch/layers.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    cfg: ConvConfig,
) -> None:
    super().__init__()

    self.conv = nn.ConvTranspose2d(
        in_channels,
        out_channels * 2 if "glu" in cfg.activation.lower() else out_channels,
        cfg.kernel_size,
        cfg.stride,
        cfg.padding,
        cfg.output_padding,
        cfg.groups,
        bias=cfg.bias,
        dilation=cfg.dilation,
    )
    norm_cfg = dict(cfg.norm_cfg) if cfg.norm_cfg else {}
    if cfg.norm not in {"none", "group"}:
        norm_cfg["num_features"] = out_channels * 2 if "glu" in cfg.activation.lower() else out_channels
    elif cfg.norm == "group":
        norm_cfg["num_channels"] = out_channels * 2 if "glu" in cfg.activation.lower() else out_channels
    norm_type: Literal["layer", "rms", "group", "batch2d", "batch1d", "none"] = (
        "batch2d" if cfg.norm == "batch" else cfg.norm
    )  # type: ignore[assignment]
    self.norm = get_norm(norm_type, **norm_cfg)
    self.activation = Activation(cfg.activation, dim=-3)
    self.dropout: nn.Module = nn.Dropout(cfg.dropout) if cfg.dropout > 0 else nn.Identity()

Attributes

activation instance-attribute

activation = Activation(activation, dim=-3)

conv instance-attribute

conv = ConvTranspose2d(in_channels, out_channels * 2 if 'glu' in lower() else out_channels, kernel_size, stride, padding, output_padding, groups, bias=bias, dilation=dilation)

dropout instance-attribute

dropout = Dropout(dropout) if dropout > 0 else Identity()

norm instance-attribute

norm = get_norm(norm_type, **norm_cfg)

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Input tensor of shape (B, in_channels, H, W) or (in_channels, H, W)

required

Returns:

Type Description
Tensor

Output tensor of shape (B, out_channels, H', W') or (out_channels, H', W')

H' and W' are calculated as follows:
H' = (H - 1) * stride - 2 * padding + kernel_size + output_padding
W' = (W - 1) * stride - 2 * padding + kernel_size + output_padding
Source code in src/ml_networks/torch/layers.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor of shape (B, in_channels, H, W) or (in_channels, H, W)

    Returns
    -------
    torch.Tensor
        Output tensor of shape (B, out_channels, H', W') or (out_channels, H', W')
    H' and W' are calculated as follows:
    H' = (H - 1) * stride - 2 * padding + kernel_size + output_padding
    W' = (W - 1) * stride - 2 * padding + kernel_size + output_padding
    """
    x = self.conv(x)
    x = self.norm(x)
    x = self.activation(x)
    return self.dropout(x)

ビジョン

Encoder

Encoder(feature_dim, obs_shape, backbone_cfg, fc_cfg=None)

Bases: BaseModule

Encoder with various architectures.

Parameters:

Name Type Description Default
feature_dim int | tuple[int, int, int]

Dimension of the feature tensor. If int, Encoder includes full connection layer to downsample the feature tensor. Otherwise, Encoder does not include full connection layer and directly process with backbone network.

required
obs_shape tuple[int, int, int]

shape of the input tensor

required
backbone_cfg ViTConfig | ConvNetConfig | ResNetConfig

configuration of the network

required
fc_cfg MLPConfig | LinearConfig | SpatialSoftmaxConfig | None

configuration of the full connection layer. If feature_dim is tuple, fc_cfg is ignored. If feature_dim is int, fc_cfg must be provided. Default is None.

None

Examples:

>>> feature_dim = 128
>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[16, 32, 64],
...     conv_cfgs=[
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> fc_cfg = LinearConfig(
...     activation="ReLU",
...     bias=True
... )
>>> encoder = Encoder(feature_dim, obs_shape, cfg, fc_cfg)
>>> x = torch.randn(2, *obs_shape)
>>> y = encoder(x)
>>> y.shape
torch.Size([2, 128])
>>> encoder
Encoder(
  (encoder): ConvNet(
    (conv): Sequential(
      (0): ConvNormActivation(
        (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pixel_shuffle): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (1): ConvNormActivation(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pixel_shuffle): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (2): ConvNormActivation(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pixel_shuffle): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
    )
  )
  (fc): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): LinearNormActivation(
      (linear): Linear(in_features=4096, out_features=128, bias=True)
      (norm): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
  )
)
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    feature_dim: int | tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    backbone_cfg: ViTConfig | ConvNetConfig | ResNetConfig,
    fc_cfg: MLPConfig | LinearConfig | SpatialSoftmaxConfig | None = None,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape

    self.encoder: nn.Module
    if isinstance(backbone_cfg, ViTConfig):
        self.encoder = ViT(obs_shape, backbone_cfg)
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.encoder = ConvNet(obs_shape, backbone_cfg)
    elif isinstance(backbone_cfg, ResNetConfig):
        self.encoder = ResNetPixUnshuffle(obs_shape, backbone_cfg)
    else:
        msg = f"{type(backbone_cfg)} is not implemented"
        raise NotImplementedError(msg)

    self.feature_dim = feature_dim
    # 型情報を補うために明示的にキャスト
    self.conved_size = cast("int", self.encoder.conved_size)
    self.conved_shape = cast("tuple[int, int]", self.encoder.conved_shape)
    self.last_channel = cast("int", self.encoder.last_channel)

    if isinstance(feature_dim, int):
        assert fc_cfg is not None, "fc_cfg must be provided if feature_dim is provided"
    else:
        assert feature_dim == (self.last_channel, *self.conved_shape), (
            f"{feature_dim} != {(self.last_channel, *self.conved_shape)}"
        )
    self.fc: nn.Module
    if isinstance(fc_cfg, MLPConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using MLPConfig"
        self.fc = nn.Sequential(
            nn.Flatten(),
            MLPLayer(self.conved_size, feature_dim, fc_cfg),
        )
    elif isinstance(fc_cfg, LinearConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using LinearConfig"
        self.fc = nn.Sequential(
            nn.Flatten(),
            LinearNormActivation(self.conved_size, feature_dim, fc_cfg),
        )
    elif isinstance(fc_cfg, AdaptiveAveragePoolingConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using AdaptiveAveragePoolingConfig"
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(fc_cfg.output_size),
            nn.Flatten(),
            LinearNormActivation(
                int(self.last_channel * np.prod(fc_cfg.output_size)),
                feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                LinearConfig,
            )
            else MLPLayer(
                int(self.last_channel * np.prod(fc_cfg.output_size)),
                feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                MLPConfig,
            )
            else nn.Identity(),
        )
        if fc_cfg.additional_layer is None:
            self.feature_dim = (
                self.last_channel * (fc_cfg.output_size**2)
                if isinstance(
                    fc_cfg.output_size,
                    int,
                )
                else self.last_channel * np.prod(fc_cfg.output_size)
            )

    elif isinstance(fc_cfg, SpatialSoftmaxConfig):
        assert isinstance(self.feature_dim, int), "feature_dim must be int when using SpatialSoftmaxConfig"
        self.fc = nn.Sequential(
            SpatialSoftmax(fc_cfg),
            nn.Flatten(),
            LinearNormActivation(
                self.last_channel * 2,
                self.feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                LinearConfig,
            )
            else MLPLayer(
                self.last_channel * 2,
                self.feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                MLPConfig,
            )
            else nn.Identity(),
        )
        if fc_cfg.additional_layer is None:
            self.feature_dim = self.last_channel * 2
    else:
        self.fc = nn.Identity()

Attributes

conved_shape instance-attribute

conved_shape = cast('tuple[int, int]', conved_shape)

conved_size instance-attribute

conved_size = cast('int', conved_size)

encoder instance-attribute

encoder

fc instance-attribute

fc

feature_dim instance-attribute

feature_dim = feature_dim

last_channel instance-attribute

last_channel = cast('int', last_channel)

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *obs_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, *feature_dim)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *obs_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, *feature_dim)
    """
    batch_shape = x.shape[:-3]

    x = x.reshape([-1, *self.obs_shape])
    x = self.encoder(x)
    x = x.view(-1, self.last_channel, *self.conved_shape)
    x = self.fc(x)
    return x.reshape([*batch_shape, *x.shape[1:]])

Decoder

Decoder(feature_dim, obs_shape, backbone_cfg, fc_cfg=None)

Bases: BaseModule

Decoder with various architectures.

Parameters:

Name Type Description Default
feature_dim int | tuple[int, int, int]

dimension of the feature tensor, if int, Decoder includes full connection layer to upsample the feature tensor. Otherwise, Decoder does not include full connection layer and directly process with backbone network.

required
obs_shape tuple[int, int, int]

shape of the output tensor

required
backbone_cfg ConvNetConfig | ViTConfig | ResNetConfig

configuration of the network

required
fc_cfg MLPConfig | LinearConfig | None

configuration of the full connection layer. If feature_dim is tuple, fc_cfg is ignored. If feature_dim is int, fc_cfg must be provided. Default is None.

None

Examples:

>>> feature_dim = 128
>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[64, 32, 16],
...     conv_cfgs=[
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> fc_cfg = MLPConfig(
...     hidden_dim=256,
...     n_layers=2,
...     output_activation= "ReLU",
...     linear_cfg= LinearConfig(
...         activation= "ReLU",
...         bias= True
...     )
... )
>>> decoder = Decoder(feature_dim, obs_shape, cfg, fc_cfg)
>>> x = torch.randn(2, feature_dim)
>>> y = decoder(x)
>>> y.shape
torch.Size([2, 3, 64, 64])
>>> decoder
Decoder(
  (fc): MLPLayer(
    (dense): Sequential(
      (0): LinearNormActivation(
        (linear): Linear(in_features=128, out_features=256, bias=True)
        (norm): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (1): LinearNormActivation(
        (linear): Linear(in_features=256, out_features=256, bias=True)
        (norm): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (2): LinearNormActivation(
        (linear): Linear(in_features=256, out_features=1024, bias=True)
        (norm): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
    )
  )
  (decoder): ConvTranspose(
    (first_conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv): Sequential(
      (0): ConvTransposeNormActivation(
        (conv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (1): ConvTransposeNormActivation(
        (conv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (2): ConvTransposeNormActivation(
        (conv): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
    )
  )
)
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    feature_dim: int | tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    backbone_cfg: ConvNetConfig | ViTConfig | ResNetConfig,
    fc_cfg: MLPConfig | LinearConfig | None = None,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape
    self.feature_dim = feature_dim

    self.input_shape: tuple[int, int, int]
    if isinstance(backbone_cfg, ViTConfig):
        self.input_shape = ViT.get_input_shape(obs_shape, backbone_cfg)
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.input_shape = cast(
            "tuple[int, int, int]",
            ConvTranspose.get_input_shape(obs_shape, backbone_cfg),
        )
    elif isinstance(backbone_cfg, ResNetConfig):
        self.input_shape = cast(
            "tuple[int, int, int]",
            ResNetPixShuffle.get_input_shape(obs_shape, backbone_cfg),
        )
    else:
        msg = f"{type(backbone_cfg)} is not implemented"
        raise NotImplementedError(msg)
    if isinstance(feature_dim, int):
        assert fc_cfg is not None, "fc_cfg must be provided if feature_dim is provided"
        self.has_fc = True
    else:
        assert feature_dim == self.input_shape, f"{feature_dim} != {self.input_shape}"
        self.has_fc = False

    if isinstance(fc_cfg, MLPConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using MLPConfig"
        self.fc: nn.Module = MLPLayer(feature_dim, int(np.prod(self.input_shape)), fc_cfg)
    elif isinstance(fc_cfg, LinearConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using LinearConfig"
        self.fc = LinearNormActivation(feature_dim, int(np.prod(self.input_shape)), fc_cfg)
    else:
        self.fc = nn.Identity()

    if isinstance(backbone_cfg, ViTConfig):
        self.decoder: nn.Module = ViT(in_shape=self.input_shape, obs_shape=obs_shape, cfg=backbone_cfg)
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.decoder = ConvTranspose(in_shape=self.input_shape, obs_shape=obs_shape, cfg=backbone_cfg)
    elif isinstance(backbone_cfg, ResNetConfig):
        self.decoder = ResNetPixShuffle(in_shape=self.input_shape, obs_shape=obs_shape, cfg=backbone_cfg)

Attributes

decoder instance-attribute

decoder = ViT(in_shape=input_shape, obs_shape=obs_shape, cfg=backbone_cfg)

fc instance-attribute

fc = MLPLayer(feature_dim, int(prod(input_shape)), fc_cfg)

feature_dim instance-attribute

feature_dim = feature_dim

has_fc instance-attribute

has_fc = True

input_shape instance-attribute

input_shape

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *feature_dim)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, *obs_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *feature_dim)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, *obs_shape)

    """
    if self.has_fc:
        batch_shape, data_shape = x.shape[:-1], x.shape[-1:]
    else:
        batch_shape, data_shape = x.shape[:-3], x.shape[-3:]
    x = x.reshape([-1, *data_shape])
    x = self.fc(x)
    x = x.reshape([-1, *self.input_shape])
    x = self.decoder(x)

    return x.reshape([*batch_shape, *self.obs_shape])

ConvNet

ConvNet(obs_shape, cfg)

Bases: Module

Convolutional Neural Network for Encoder.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

shape of input tensor

required
cfg ConvNetConfig

configuration of the network

required

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[16, 32, 64],
...     conv_cfgs=[
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> encoder = ConvNet(obs_shape, cfg)
>>> encoder
ConvNet(
  (conv): Sequential(
    (0): ConvNormActivation(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pixel_shuffle): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
    (1): ConvNormActivation(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pixel_shuffle): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
    (2): ConvNormActivation(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pixel_shuffle): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
  )
)
>>> x = torch.randn(2, *obs_shape)
>>> y = encoder(x)
>>> y.shape
torch.Size([2, 64, 8, 8])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    obs_shape: tuple[int, int, int],
    cfg: ConvNetConfig,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape
    self.channels = [obs_shape[0], *cfg.channels]
    self.cfg = cfg

    self.conv = self._build_conv()

    self.last_channel = self.channels[-1]

Attributes

cfg instance-attribute

cfg = cfg

channels instance-attribute

channels = [obs_shape[0], *(channels)]

conv instance-attribute

conv = _build_conv()

conved_shape property

conved_shape

Get the shape of the output tensor after convolutional layers.

Returns:

Type Description
tuple[int, int]

shape of the output tensor

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[64, 32, 16],
...     conv_cfgs=[
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> encoder = ConvNet(obs_shape, cfg)
>>> encoder.conved_shape
(8, 8)

conved_size property

conved_size

Get the size of the output tensor after convolutional layers.

Returns:

Type Description
int

size of the output tensor

last_channel instance-attribute

last_channel = channels[-1]

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *obs_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *obs_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

    """
    return self.conv(x)

ResNetPixUnshuffle

ResNetPixUnshuffle(obs_shape, cfg)

Bases: Module

ResNet with PixelUnshuffle for Encoder.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

shape of input tensor

required
cfg ResNetConfig

configuration of the network

required

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ResNetConfig(
...     conv_channel=64,
...     conv_kernel=3,
...     f_kernel=3,
...     conv_activation="ReLU",
...     out_activation="ReLU",
...     n_res_blocks=2,
...     scale_factor=2,
...     n_scaling=3,
...     norm="batch",
...     norm_cfg={},
...     dropout=0.0
... )
>>> encoder = ResNetPixUnshuffle(obs_shape, cfg)
>>> x = torch.randn(2, *obs_shape)
>>> y = encoder(x)
>>> y.shape
torch.Size([2, 64, 8, 8])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    obs_shape: tuple[int, int, int],
    cfg: ResNetConfig,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape
    self.cfg = cfg

    first_cfg = ConvConfig(
        activation=cfg.conv_activation,
        kernel_size=cfg.f_kernel,
        stride=1,
        padding=cfg.f_kernel // 2,
        dilation=1,
        groups=1,
        bias=True,
        dropout=cfg.dropout,
        norm=cfg.norm,
        norm_cfg=cfg.norm_cfg,
        padding_mode=cfg.padding_mode,
    )
    # First layer
    self.conv1 = ConvNormActivation(self.obs_shape[0], cfg.conv_channel, first_cfg)

    # downsampling
    downsample: list[nn.Module] = []
    downsample_cfg = first_cfg
    downsample_cfg.kernel_size = cfg.conv_kernel
    downsample_cfg.padding = cfg.conv_kernel // 2
    downsample_cfg.scale_factor = -cfg.scale_factor
    for _ in range(cfg.n_scaling):
        downsample += [
            ConvNormActivation(cfg.conv_channel, cfg.conv_channel, downsample_cfg),
        ]
    self.downsample = nn.Sequential(*downsample)

    # Residual blocks
    res_blocks: list[nn.Module] = []
    for _ in range(cfg.n_res_blocks):
        res_blocks += [
            ResidualBlock(
                cfg.conv_channel,
                cfg.conv_kernel,
                cfg.conv_activation,
                cfg.norm,
                cfg.norm_cfg,
                cfg.dropout,
                cfg.padding_mode,
            ),
        ]
        if cfg.attention is not None:
            res_blocks += [Attention2d(cfg.conv_channel, nhead=None, attn_cfg=cfg.attention)]

    self.res_blocks = nn.Sequential(*res_blocks)

    cov2_cfg = first_cfg
    cov2_cfg.kernel_size = cfg.conv_kernel
    cov2_cfg.padding = cfg.conv_kernel // 2
    cov2_cfg.scale_factor = 0

    # Second conv layer post residual blocks
    self.conv2 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, cov2_cfg)

    # Final output layer
    final_cfg = first_cfg
    final_cfg.kernel_size = cfg.conv_kernel
    final_cfg.padding = cfg.conv_kernel // 2

    self.conv3 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, final_cfg)
    self.last_channel = cfg.conv_channel

Attributes

cfg instance-attribute

cfg = cfg

conv1 instance-attribute

conv1 = ConvNormActivation(obs_shape[0], conv_channel, first_cfg)

conv2 instance-attribute

conv2 = ConvNormActivation(conv_channel, conv_channel, cov2_cfg)

conv3 instance-attribute

conv3 = ConvNormActivation(conv_channel, conv_channel, final_cfg)

conved_shape property

conved_shape

Get the shape of the output tensor after convolutional layers.

Returns:

Type Description
tuple[int, int]

shape of the output tensor

conved_size property

conved_size

Get the size of the output tensor after convolutional layers.

Returns:

Type Description
int

size of the output tensor

downsample instance-attribute

downsample = Sequential(*downsample)

last_channel instance-attribute

last_channel = conv_channel

obs_shape instance-attribute

obs_shape = obs_shape

res_blocks instance-attribute

res_blocks = Sequential(*res_blocks)

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *obs_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *obs_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

    """
    out = self.conv1(x)
    out1 = self.downsample(out)
    out_res = self.res_blocks(out1)
    out2 = self.conv2(out_res)
    out = torch.add(out1, out2)
    return self.conv3(out)

分布

Distribution

Distribution(in_dim, dist, n_groups=1, spherical=False)

Bases: Module

A distribution function.

Parameters:

Name Type Description Default
in_dim int

Input dimension.

required
dist Literal['normal', 'categorical', 'bernoulli']

Distribution type.

required
n_groups int

Number of groups. Default is 1. This is used for the categorical and Bernoulli distributions.

1
spherical bool

Whether to project samples to the unit sphere. Default is False. This is used for the categorical and Bernoulli distributions. If True and dist=="categorical", the samples are projected from {0, 1} to {-1, 1}. If True and dist=="bernoulli", the samples are projected from {0, 1} to the unit sphere.

refer to https://arxiv.org/abs/2406.07548

False

Examples:

>>> dist = Distribution(10, "normal")
>>> data = torch.randn(2, 20)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'NormalStoch'
>>> posterior.shape
NormalShape(mean=torch.Size([2, 10]), std=torch.Size([2, 10]), stoch=torch.Size([2, 10]))
>>> dist = Distribution(10, "categorical", n_groups=2)
>>> data = torch.randn(2, 10)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'CategoricalStoch'
>>> posterior.shape
CategoricalShape(logits=torch.Size([2, 2, 5]), probs=torch.Size([2, 2, 5]), stoch=torch.Size([2, 10]))
>>> dist = Distribution(10, "bernoulli", n_groups=2)
>>> data = torch.randn(2, 10)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'BernoulliStoch'
>>> posterior.shape
BernoulliShape(logits=torch.Size([2, 2, 5]), probs=torch.Size([2, 2, 5]), stoch=torch.Size([2, 10]))
Source code in src/ml_networks/torch/distributions.py
def __init__(
    self,
    in_dim: int,
    dist: Literal["normal", "categorical", "bernoulli"],
    n_groups: int = 1,
    spherical: bool = False,
) -> None:
    super().__init__()

    self.dist = dist
    self.spherical = spherical
    self.n_class = in_dim // n_groups
    self.in_dim = in_dim
    self.n_groups = n_groups

    if dist == "normal":
        self.posterior = self.normal  # type: ignore[assignment]
    elif dist == "categorical":
        self.posterior = self.categorical  # type: ignore[assignment]
    elif dist == "bernoulli":
        self.posterior = self.bernoulli  # type: ignore[assignment]
    else:
        raise NotImplementedError

    if spherical:
        self.codebook = BSQCodebook(self.n_class)

Attributes

codebook instance-attribute

codebook = BSQCodebook(n_class)

dist instance-attribute

dist = dist

in_dim instance-attribute

in_dim = in_dim

n_class instance-attribute

n_class = in_dim // n_groups

n_groups instance-attribute

n_groups = n_groups

posterior instance-attribute

posterior = normal

spherical instance-attribute

spherical = spherical

Functions

bernoulli

bernoulli(logits, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/torch/distributions.py
def bernoulli(self, logits: torch.Tensor, deterministic: bool = False, inv_tmp: float = 1.0) -> BernoulliStoch:
    batch_shape = logits.shape[:-1]
    chunked_logits = torch.chunk(logits, self.n_groups, dim=-1)
    logits = torch.stack(chunked_logits, dim=-2)
    logits = logits * inv_tmp
    probs = torch.sigmoid(logits)

    dist = BernoulliStraightThrough(probs=probs)
    posterior_dist = D.Independent(dist, 1)

    sample = posterior_dist.rsample()

    if self.spherical:
        sample = self.codebook.bits_to_codes(sample)

    if deterministic:
        sample = (
            torch.where(sample > 0.5, torch.ones_like(sample), torch.zeros_like(sample)) + probs - probs.detach()
        )

    return BernoulliStoch(
        logits,
        probs,
        sample.reshape([*batch_shape, -1]),
    )

categorical

categorical(logits, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/torch/distributions.py
def categorical(self, logits: torch.Tensor, deterministic: bool = False, inv_tmp: float = 1.0) -> CategoricalStoch:
    batch_shape = logits.shape[:-1]
    logits_chunk = torch.chunk(logits, self.n_groups, dim=-1)
    logits = torch.stack(logits_chunk, dim=-2)
    logits = logits
    probs = softmax(logits, dim=-1, temperature=1 / inv_tmp)
    dist = D.OneHotCategoricalStraightThrough(probs=probs)
    posterior_dist = D.Independent(dist, 1)

    sample = posterior_dist.rsample()

    if self.spherical:
        sample = sample * 2 - 1

    return CategoricalStoch(
        logits,
        probs,
        sample.reshape([*batch_shape, -1])
        if not deterministic
        else self.deterministic_onehot(probs).reshape([*batch_shape, -1]),
    )

deterministic_onehot

deterministic_onehot(input)

Compute the one-hot vector by argmax.

Parameters:

Name Type Description Default
input Tensor

Input tensor.

required

Returns:

Type Description
Tensor

One-hot vector.

Examples:

>>> input = torch.arange(6).reshape(2, 3) / 5.0
>>> dist = Distribution(3, "categorical")
>>> onehot = dist.deterministic_onehot(input)
>>> onehot
tensor([[0., 0., 1.],
        [0., 0., 1.]])
Source code in src/ml_networks/torch/distributions.py
def deterministic_onehot(self, input: torch.Tensor) -> torch.Tensor:
    """
    Compute the one-hot vector by argmax.

    Parameters
    ----------
    input : torch.Tensor
        Input tensor.

    Returns
    -------
    torch.Tensor
        One-hot vector.

    Examples
    --------
    >>> input = torch.arange(6).reshape(2, 3) / 5.0
    >>> dist = Distribution(3, "categorical")
    >>> onehot = dist.deterministic_onehot(input)
    >>> onehot
    tensor([[0., 0., 1.],
            [0., 0., 1.]])
    """
    return F.one_hot(input.argmax(dim=-1), num_classes=self.n_class) + input - input.detach()

forward

forward(x, deterministic=False, inv_tmp=1.0)

Compute the posterior distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
deterministic bool

Whether to use the deterministic mode. Default is False. if True and dist=="normal", the mean is returned. if True and dist=="categorical", the one-hot vector computed by argmax is returned. if True and dist=="bernoulli", 1 is returned if x > 0.5 or 0 is returned if x <= 0.5.

False
inv_tmp float

Inverse temperature. Default is 1.0. This is used for the categorical and Bernoulli distributions.

1.0

Returns:

Type Description
StochState

Posterior distribution.

Source code in src/ml_networks/torch/distributions.py
def forward(
    self,
    x: torch.Tensor,
    deterministic: bool = False,
    inv_tmp: float = 1.0,
) -> StochState:
    """
    Compute the posterior distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    deterministic : bool, optional
        Whether to use the deterministic mode. Default is False.
        if True and dist=="normal", the mean is returned.
        if True and dist=="categorical", the one-hot vector computed by argmax is returned.
        if True and dist=="bernoulli", 1 is returned if x > 0.5 or 0 is returned if x <= 0.5.

    inv_tmp : float, optional
        Inverse temperature. Default is 1.0.
        This is used for the categorical and Bernoulli distributions.

    Returns
    -------
    StochState
        Posterior distribution.


    """
    return self.posterior(x, deterministic=deterministic, inv_tmp=inv_tmp)

normal

normal(mu_std, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/torch/distributions.py
def normal(self, mu_std: torch.Tensor, deterministic: bool = False, inv_tmp: float = 1.0) -> NormalStoch:
    assert mu_std.shape[-1] == self.in_dim * 2, (
        f"mu_std.shape[-1] {mu_std.shape[-1]} and in_dim {self.in_dim} must be the same."
    )

    mu, std = torch.chunk(mu_std, 2, dim=-1)
    std = F.softplus(std) + 1e-6

    normal_dist = D.Normal(mu, std)
    posterior_dist = D.Independent(normal_dist, 1)

    sample = posterior_dist.rsample() if not deterministic else mu

    return NormalStoch(mu, std, sample if not deterministic else mu)

NormalStoch dataclass

NormalStoch(mean, std, stoch)

Parameters of a normal distribution and its stochastic sample.

Attributes:

Name Type Description
mean Tensor

Mean of the normal distribution.

std Tensor

Standard deviation of the normal distribution.

stoch Tensor

sample from the normal distribution with reparametrization trick.

Attributes

mean instance-attribute

mean

shape property

shape

mean, std, stoch の shape をタプルで返す.

std instance-attribute

std

stoch instance-attribute

stoch

Functions

__getattr__

__getattr__(name)

torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

Parameters:

Name Type Description Default
name str

メソッド名。

required

Returns:

Type Description
callable

torch.Tensorのメソッドを各メンバに適用する関数。

Raises:

Type Description
AttributeError

指定された名前がtorch.Tensorのメソッドでない場合。

Source code in src/ml_networks/torch/distributions.py
def __getattr__(self, name: str) -> Any:
    """torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

    例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

    Parameters
    ----------
    name : str
        メソッド名。

    Returns
    -------
    callable
        torch.Tensorのメソッドを各メンバに適用する関数。

    Raises
    ------
    AttributeError
        指定された名前がtorch.Tensorのメソッドでない場合。
    """
    if hasattr(torch.Tensor, name):  # torch.Tensor のメソッドか確認

        def method(*args: Any, **kwargs: Any) -> NormalStoch:
            return NormalStoch(
                getattr(self.mean, name)(*args, **kwargs),
                getattr(self.std, name)(*args, **kwargs),
                getattr(self.stoch, name)(*args, **kwargs),
            )

        return method
    msg = f"'{self.__class__.__name__}' object has no attribute '{name}'"
    raise AttributeError(msg)

__getitem__

__getitem__(idx)

インデックスアクセス.

Parameters:

Name Type Description Default
idx int or slice or tuple

インデックス指定。

required

Returns:

Type Description
NormalStoch

指定されたインデックスに対応するNormalStoch

Source code in src/ml_networks/torch/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> NormalStoch:
    """インデックスアクセス.

    Parameters
    ----------
    idx : int or slice or tuple
        インデックス指定。

    Returns
    -------
    NormalStoch
        指定されたインデックスに対応する`NormalStoch`。
    """
    return NormalStoch(self.mean[idx], self.std[idx], self.stoch[idx])

__len__

__len__()

長さを返す.

Returns:

Type Description
int

バッチ次元の長さ。

Source code in src/ml_networks/torch/distributions.py
def __len__(self) -> int:
    """長さを返す.

    Returns
    -------
    int
        バッチ次元の長さ。
    """
    return self.stoch.shape[0]

__post_init__

__post_init__()

初期化後の処理.

Raises:

Type Description
ValueError

meanstd のshapeが異なる場合、またはstdに負の値が含まれる場合。

Source code in src/ml_networks/torch/distributions.py
def __post_init__(self) -> None:
    """初期化後の処理.

    Raises
    ------
    ValueError
        `mean` と `std` のshapeが異なる場合、または`std`に負の値が含まれる場合。
    """
    if self.mean.shape != self.std.shape:
        msg = f"mean.shape {self.mean.shape} and std.shape {self.std.shape} must be the same."
        raise ValueError(msg)
    if (self.std < 0).any():
        msg = "std must be non-negative."
        raise ValueError(msg)

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/torch/distributions.py
def get_distribution(self, independent: int = 1) -> D.Independent:
    return D.Independent(D.Normal(self.mean, self.std), independent)

save

save(path)

Save the parameters of the normal distribution to the specified path.

Parameters:

Name Type Description Default
path str

Path to save the parameters.

required
Source code in src/ml_networks/torch/distributions.py
def save(self, path: str) -> None:
    """
    Save the parameters of the normal distribution to the specified path.

    Parameters
    ----------
    path : str
        Path to save the parameters.

    """
    os.makedirs(path, exist_ok=True)

    save_blosc2(f"{path}/mean.blosc2", self.mean.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/std.blosc2", self.std.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/stoch.blosc2", self.stoch.detach().clone().cpu().numpy())

squeeze

squeeze(dim)

Squeeze the parameters of the normal distribution.

Parameters:

Name Type Description Default
dim int

Dimension to squeeze.

required

Returns:

Type Description
NormalStoch

Squeezed normal distribution.

Source code in src/ml_networks/torch/distributions.py
def squeeze(self, dim: int) -> NormalStoch:
    """
    Squeeze the parameters of the normal distribution.

    Parameters
    ----------
    dim : int
        Dimension to squeeze.

    Returns
    -------
    NormalStoch
        Squeezed normal distribution.

    """
    return NormalStoch(
        self.mean.squeeze(dim),
        self.std.squeeze(dim),
        self.stoch.squeeze(dim),
    )

unsqueeze

unsqueeze(dim)

Unsqueeze the parameters of the normal distribution.

Parameters:

Name Type Description Default
dim int

Dimension to unsqueeze.

required

Returns:

Type Description
NormalStoch

Unsqueezed normal distribution.

Source code in src/ml_networks/torch/distributions.py
def unsqueeze(self, dim: int) -> NormalStoch:
    """
    Unsqueeze the parameters of the normal distribution.

    Parameters
    ----------
    dim : int
        Dimension to unsqueeze.

    Returns
    -------
    NormalStoch
        Unsqueezed normal distribution.

    """
    return NormalStoch(
        self.mean.unsqueeze(dim),
        self.std.unsqueeze(dim),
        self.stoch.unsqueeze(dim),
    )

CategoricalStoch dataclass

CategoricalStoch(logits, probs, stoch)

Parameters of a categorical distribution and its stochastic sample.

Attributes:

Name Type Description
logits Tensor

Logits of the categorical distribution.

probs Tensor

Probabilities of the categorical distribution.

stoch Tensor

sample from the categorical distribution with Straight-Through Estimator.

Attributes

logits instance-attribute

logits

probs instance-attribute

probs

shape property

shape

mean, std, stoch の shape をタプルで返す.

stoch instance-attribute

stoch

Functions

__getattr__

__getattr__(name)

torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

Parameters:

Name Type Description Default
name str

メソッド名。

required

Returns:

Type Description
callable

torch.Tensorのメソッドを各メンバに適用する関数。

Raises:

Type Description
AttributeError

指定された名前がtorch.Tensorのメソッドでない場合。

Source code in src/ml_networks/torch/distributions.py
def __getattr__(self, name: str) -> Any:
    """torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

    例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

    Parameters
    ----------
    name : str
        メソッド名。

    Returns
    -------
    callable
        torch.Tensorのメソッドを各メンバに適用する関数。

    Raises
    ------
    AttributeError
        指定された名前がtorch.Tensorのメソッドでない場合。
    """
    if hasattr(torch.Tensor, name):  # torch.Tensor のメソッドか確認

        def method(*args: Any, **kwargs: Any) -> CategoricalStoch:
            return CategoricalStoch(
                getattr(self.logits, name)(*args, **kwargs),
                getattr(self.probs, name)(*args, **kwargs),
                getattr(self.stoch, name)(*args, **kwargs),
            )

        return method
    msg = f"'{self.__class__.__name__}' object has no attribute '{name}'"
    raise AttributeError(msg)

__getitem__

__getitem__(idx)

インデックスアクセス.

Parameters:

Name Type Description Default
idx int or slice or tuple

インデックス指定。

required

Returns:

Type Description
CategoricalStoch

指定されたインデックスに対応するCategoricalStoch

Source code in src/ml_networks/torch/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> CategoricalStoch:
    """インデックスアクセス.

    Parameters
    ----------
    idx : int or slice or tuple
        インデックス指定。

    Returns
    -------
    CategoricalStoch
        指定されたインデックスに対応する`CategoricalStoch`。
    """
    return CategoricalStoch(self.logits[idx], self.probs[idx], self.stoch[idx])

__len__

__len__()

長さを返す.

Returns:

Type Description
int

バッチ次元の長さ。

Source code in src/ml_networks/torch/distributions.py
def __len__(self) -> int:
    """長さを返す.

    Returns
    -------
    int
        バッチ次元の長さ。
    """
    return self.stoch.shape[0]

__post_init__

__post_init__()

初期化後の処理.

Raises:

Type Description
ValueError

logitsprobs のshapeが異なる場合、 あるいはprobsが[0, 1]の範囲外、または和が1から大きくずれている場合。

Source code in src/ml_networks/torch/distributions.py
def __post_init__(self) -> None:
    """初期化後の処理.

    Raises
    ------
    ValueError
        `logits` と `probs` のshapeが異なる場合、
        あるいは`probs`が[0, 1]の範囲外、または和が1から大きくずれている場合。
    """
    if self.logits.shape != self.probs.shape:
        msg = f"logits.shape {self.logits.shape} and probs.shape {self.probs.shape} must be the same."
        raise ValueError(msg)
    if (self.probs < 0).any() or (self.probs > 1).any():
        msg = "probs must be in the range [0, 1]."
        raise ValueError(msg)
    if (self.probs.sum(dim=-1) - 1).abs().max() > 1e-6:
        msg = "probs must sum to 1."
        raise ValueError(msg)

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/torch/distributions.py
def get_distribution(self, independent: int = 1) -> D.Independent:
    return D.Independent(D.OneHotCategoricalStraightThrough(self.probs), independent)

save

save(path)

Save the parameters of the categorical distribution to the specified path.

Parameters:

Name Type Description Default
path str

Path to save the parameters.

required
Source code in src/ml_networks/torch/distributions.py
def save(self, path: str) -> None:
    """
    Save the parameters of the categorical distribution to the specified path.

    Parameters
    ----------
    path : str
        Path to save the parameters.
    """
    os.makedirs(path, exist_ok=True)

    save_blosc2(f"{path}/logits.blosc2", self.logits.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/probs.blosc2", self.probs.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/stoch.blosc2", self.stoch.detach().clone().cpu().numpy())

squeeze

squeeze(dim)

Squeeze the parameters of the categorical distribution.

Parameters:

Name Type Description Default
dim int

Dimension to squeeze.

required

Returns:

Type Description
CategoricalStoch

Squeezed categorical distribution.

Source code in src/ml_networks/torch/distributions.py
def squeeze(self, dim: int) -> CategoricalStoch:
    """
    Squeeze the parameters of the categorical distribution.

    Parameters
    ----------
    dim : int
        Dimension to squeeze.

    Returns
    -------
    CategoricalStoch
        Squeezed categorical distribution.

    """
    return CategoricalStoch(
        self.logits.squeeze(dim),
        self.probs.squeeze(dim),
        self.stoch.squeeze(dim),
    )

unsqueeze

unsqueeze(dim)

Unsqueeze the parameters of the categorical distribution.

Parameters:

Name Type Description Default
dim int

Dimension to unsqueeze.

required

Returns:

Type Description
CategoricalStoch

Unsqueezed categorical distribution.

Source code in src/ml_networks/torch/distributions.py
def unsqueeze(self, dim: int) -> CategoricalStoch:
    """
    Unsqueeze the parameters of the categorical distribution.

    Parameters
    ----------
    dim : int
        Dimension to unsqueeze.

    Returns
    -------
    CategoricalStoch
        Unsqueezed categorical distribution.

    """
    return CategoricalStoch(
        self.logits.unsqueeze(dim),
        self.probs.unsqueeze(dim),
        self.stoch.unsqueeze(dim),
    )

損失関数

focal_loss

focal_loss(prediction, target, gamma=2.0, sum_dim=-1)

Focal loss function. Mainly for multi-class classification.

Reference

Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002

Parameters:

Name Type Description Default
prediction Tensor

The predicted tensor. This should be before softmax.

required
target Tensor

The target tensor.

required
gamma float

The gamma parameter. Default is 2.0.

2.0
sum_dim int

The dimension to sum the loss. Default is -1.

-1

Returns:

Type Description
Tensor

The focal loss.

Source code in src/ml_networks/torch/loss.py
def focal_loss(
    prediction: torch.Tensor,
    target: torch.Tensor,
    gamma: float = 2.0,
    sum_dim: int = -1,
) -> torch.Tensor:
    """
    Focal loss function. Mainly for multi-class classification.

    Reference
    ---------
    Focal Loss for Dense Object Detection
    https://arxiv.org/abs/1708.02002

    Parameters
    ----------
    prediction : torch.Tensor
        The predicted tensor. This should be before softmax.
    target : torch.Tensor
        The target tensor.
    gamma : float
        The gamma parameter. Default is 2.0.
    sum_dim : int
        The dimension to sum the loss. Default is -1.

    Returns
    -------
    torch.Tensor
        The focal loss.

    """
    prediction = prediction.unsqueeze(1).transpose(sum_dim, 1).squeeze(-1)
    if gamma:
        log_prob = F.log_softmax(prediction, dim=1)
        prob = torch.exp(log_prob)
        loss = F.nll_loss((1 - prob) ** gamma * log_prob, target, reduction="none")
    else:
        loss = F.cross_entropy(prediction, target, reduction="none")
    return loss.mean(0).sum()

charbonnier

charbonnier(prediction, target, epsilon=0.001, alpha=1, sum_dim=None)

Charbonnier loss function.

Reference

A General and Adaptive Robust Loss Function http://arxiv.org/abs/1701.03077

Parameters:

Name Type Description Default
prediction Tensor

The predicted tensor.

required
target Tensor

The target tensor.

required
epsilon float

A small value to avoid division by zero. Default is 1e-3.

0.001
alpha float

The alpha parameter. Default is 1.

1
sum_dim int | list[int] | tuple[int, ...] | None

The dimension to sum the loss. Default is None (sums over [-1, -2, -3]).

None

Returns:

Type Description
Tensor

The Charbonnier loss.

Source code in src/ml_networks/torch/loss.py
def charbonnier(
    prediction: torch.Tensor,
    target: torch.Tensor,
    epsilon: float = 1e-3,
    alpha: float = 1,
    sum_dim: int | list[int] | tuple[int, ...] | None = None,
) -> torch.Tensor:
    """
    Charbonnier loss function.

    Reference
    ---------
    A General and Adaptive Robust Loss Function
    http://arxiv.org/abs/1701.03077

    Parameters
    ----------
    prediction : torch.Tensor
        The predicted tensor.
    target : torch.Tensor
        The target tensor.
    epsilon : float
        A small value to avoid division by zero. Default is 1e-3.
    alpha : float
        The alpha parameter. Default is 1.
    sum_dim : int | list[int] | tuple[int, ...] | None
        The dimension to sum the loss. Default is None (sums over [-1, -2, -3]).

    Returns
    -------
    torch.Tensor
        The Charbonnier loss.

    """
    if sum_dim is None:
        sum_dim = [-1, -2, -3]
    x = prediction - target
    loss = (x**2 + epsilon**2) ** (alpha / 2)
    return torch.sum(loss, dim=sum_dim)

FocalFrequencyLoss

FocalFrequencyLoss(loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False)

The torch.nn.Module class that implements focal frequency loss.

A frequency domain loss function for optimizing generative models.

Reference

Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021. https://arxiv.org/pdf/2012.12821.pdf

Parameters:

Name Type Description Default
loss_weight float

weight for focal frequency loss. Default: 1.0

1.0
alpha float

the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0

1.0
patch_factor int

the factor to crop image patches for patch-based focal frequency loss. Default: 1

1
ave_spectrum bool

whether to use minibatch average spectrum. Default: False

False
log_matrix bool

whether to adjust the spectrum weight matrix by logarithm. Default: False

False
batch_matrix bool

whether to calculate the spectrum weight matrix using batch-based statistics. Default: False

False
Source code in src/ml_networks/torch/loss.py
def __init__(
    self,
    loss_weight: float = 1.0,
    alpha: float = 1.0,
    patch_factor: int = 1,
    ave_spectrum: bool = False,
    log_matrix: bool = False,
    batch_matrix: bool = False,
) -> None:
    self.loss_weight = loss_weight
    self.alpha = alpha
    self.patch_factor = patch_factor
    self.ave_spectrum = ave_spectrum
    self.log_matrix = log_matrix
    self.batch_matrix = batch_matrix

Attributes

alpha instance-attribute

alpha = alpha

ave_spectrum instance-attribute

ave_spectrum = ave_spectrum

batch_matrix instance-attribute

batch_matrix = batch_matrix

log_matrix instance-attribute

log_matrix = log_matrix

loss_weight instance-attribute

loss_weight = loss_weight

patch_factor instance-attribute

patch_factor = patch_factor

Functions

__call__

__call__(pred, target, matrix=None, mean_batch=True)

Forward function to calculate focal frequency loss.

Parameters:

Name Type Description Default
pred Tensor

of shape (N, C, H, W). Predicted tensor.

required
target Tensor

of shape (N, C, H, W). Target tensor.

required
matrix Tensor | None

Default: None (If set to None: calculated online, dynamic).

None
mean_batch bool

Whether to average over batch dimension.

True

Returns:

Type Description
Tensor

The focal frequency loss.

Source code in src/ml_networks/torch/loss.py
def __call__(
    self,
    pred: torch.Tensor,
    target: torch.Tensor,
    matrix: torch.Tensor | None = None,
    mean_batch: bool = True,
) -> torch.Tensor:
    """Forward function to calculate focal frequency loss.

    Parameters
    ----------
    pred: torch.Tensor
        of shape (N, C, H, W). Predicted tensor.
    target: torch.Tensor
        of shape (N, C, H, W). Target tensor.
    matrix: torch.Tensor | None
        Default: None (If set to None: calculated online, dynamic).
    mean_batch: bool
        Whether to average over batch dimension.

    Returns
    -------
    torch.Tensor
        The focal frequency loss.
    """
    if target.shape != pred.shape:
        target = target.expand_as(pred)
    if pred.ndim == 5:
        batch_shape = pred.shape[:2]
        pred = pred.flatten(0, 1)
        target = target.flatten(0, 1)
        flattened = True
    else:
        flattened = False

    pred_freq = self.tensor2freq(pred)
    target_freq = self.tensor2freq(target)

    # whether to use minibatch average spectrum
    if self.ave_spectrum:
        pred_freq = torch.mean(pred_freq, 0, keepdim=True)
        target_freq = torch.mean(target_freq, 0, keepdim=True)

    # calculate focal frequency loss
    loss = self.loss_formulation(pred_freq, target_freq, matrix, mean_batch) * self.loss_weight
    if flattened and not mean_batch:
        loss = loss.reshape(batch_shape)
    return loss

loss_formulation

loss_formulation(recon_freq, real_freq, matrix=None, mean_batch=True)
Source code in src/ml_networks/torch/loss.py
def loss_formulation(
    self,
    recon_freq: torch.Tensor,
    real_freq: torch.Tensor,
    matrix: torch.Tensor | None = None,
    mean_batch: bool = True,
) -> torch.Tensor:
    # spectrum weight matrix
    if matrix is not None:
        # if the matrix is predefined
        weight_matrix = matrix.detach()
    else:
        # if the matrix is calculated online: continuous, dynamic, based on current Euclidean distance
        matrix_tmp = (recon_freq - real_freq) ** 2
        matrix_tmp = torch.sqrt(matrix_tmp[..., 0] + matrix_tmp[..., 1]) ** self.alpha

        # whether to adjust the spectrum weight matrix by logarithm
        if self.log_matrix:
            matrix_tmp = torch.log(matrix_tmp + 1.0)

        # whether to calculate the spectrum weight matrix using batch-based statistics
        if self.batch_matrix:
            matrix_tmp = matrix_tmp / matrix_tmp.max()
        else:
            matrix_tmp = matrix_tmp / matrix_tmp.max(-1).values.max(-1).values[:, :, :, None, None]

        matrix_tmp[torch.isnan(matrix_tmp)] = 0.0
        matrix_tmp = torch.clamp(matrix_tmp, min=0.0, max=1.0)
        weight_matrix = matrix_tmp.clone().detach()

    min_val = weight_matrix.min().item()
    max_val = weight_matrix.max().item()
    assert min_val >= 0, f"The values of spectrum weight matrix should be >= 0, but got Min: {min_val:.10f}"
    assert max_val <= 1, f"The values of spectrum weight matrix should be <= 1, but got Max: {max_val:.10f}"

    # frequency distance using (squared) Euclidean distance
    tmp = (recon_freq - real_freq) ** 2
    freq_distance = tmp[..., 0] + tmp[..., 1]

    # dynamic spectrum weighting (Hadamard product)
    loss = weight_matrix * freq_distance
    loss = loss.sum(dim=[-1, -2, -3])
    if mean_batch:
        loss = loss.mean()
    return loss

tensor2freq

tensor2freq(x)
Source code in src/ml_networks/torch/loss.py
def tensor2freq(self, x: torch.Tensor) -> torch.Tensor:
    # crop image patches
    patch_factor = self.patch_factor
    _, _, h, w = x.shape
    assert h % patch_factor == 0, "Patch factor should be divisible by image height"
    assert w % patch_factor == 0, "Patch factor should be divisible by image width"
    patch_h = h // patch_factor
    patch_w = w // patch_factor
    patch_list: list[torch.Tensor] = [
        x[:, :, i * patch_h : (i + 1) * patch_h, j * patch_w : (j + 1) * patch_w]
        for i in range(patch_factor)
        for j in range(patch_factor)
    ]

    # stack to patch tensor
    y = torch.stack(patch_list, 1)

    # perform 2D DFT (real-to-complex, orthonormalization)
    if IS_HIGH_VERSION:
        freq = torch.fft.fft2(y, norm="ortho")
        freq = torch.stack([freq.real, freq.imag], -1)
    else:
        freq = torch.rfft(y, 2, onesided=False, normalized=True)  # type: ignore[attr-defined]
    return freq

ユーティリティ

get_optimizer

get_optimizer(param, name, **kwargs)

Get optimizer from torch.optim or pytorch_optimizer.

Args:

param : Iterator[nn.Parameter] Parameters of models to optimize. name : str Optimizer name. kwargs : dict Optimizer arguments(settings).

Returns:

Type Description
Optimizer

Examples:

>>> get_optimizer([nn.Parameter(torch.randn(1, 3))], "Adam", lr=0.01)
Adam (
Parameter Group 0
    amsgrad: False
    betas: (0.9, 0.999)
    capturable: False
    differentiable: False
    eps: 1e-08
    foreach: None
    fused: None
    lr: 0.01
    maximize: False
    weight_decay: 0
)
Source code in src/ml_networks/torch/torch_utils.py
def get_optimizer(
    param: Iterator[nn.Parameter],
    name: str,
    **kwargs: float | str | bool,
) -> torch.optim.Optimizer:
    """
    Get optimizer from torch.optim or pytorch_optimizer.

    Args:
    -----
    param : Iterator[nn.Parameter]
        Parameters of models to optimize.
    name : str
        Optimizer name.
    kwargs : dict
        Optimizer arguments(settings).

    Returns
    -------
    torch.optim.Optimizer

    Examples
    --------
    >>> get_optimizer([nn.Parameter(torch.randn(1, 3))], "Adam", lr=0.01)
    Adam (
    Parameter Group 0
        amsgrad: False
        betas: (0.9, 0.999)
        capturable: False
        differentiable: False
        eps: 1e-08
        foreach: None
        fused: None
        lr: 0.01
        maximize: False
        weight_decay: 0
    )
    """
    if hasattr(schedulefree, name):
        optimizer = getattr(schedulefree, name)
    elif hasattr(torch.optim, name):
        optimizer = getattr(torch.optim, name)
    elif hasattr(pytorch_optimizer, name):
        optimizer = getattr(pytorch_optimizer, name)
    else:
        msg = f"Optimizer {name} is not implemented in torch.optim or pytorch_optimizer, schedulefree. "
        msg += "Please check the name and capitalization."
        raise NotImplementedError(msg)
    return optimizer(param, **kwargs)

torch_fix_seed

torch_fix_seed(seed=42)

乱数を固定する関数.

References
  • https://qiita.com/north_redwing/items/1e153139125d37829d2d
Source code in src/ml_networks/torch/torch_utils.py
def torch_fix_seed(seed: int = 42) -> None:
    """
    乱数を固定する関数.

    References
    ----------
    - https://qiita.com/north_redwing/items/1e153139125d37829d2d
    """
    random.seed(seed)
    pl.seed_everything(seed, workers=True)
    torch.set_float32_matmul_precision("medium")
    torch.backends.cudnn.benchmark = False
    torch.backends.cudnn.deterministic = True

save_blosc2

save_blosc2(path, x)

Save numpy array with blosc2 compression.

Args:

path : str Path to save. x : np.ndarray Numpy array to save.

Examples:

>>> save_blosc2("test.blosc2", np.random.randn(10, 10))
Source code in src/ml_networks/utils.py
def save_blosc2(path: str, x: np.ndarray) -> None:
    """Save numpy array with blosc2 compression.

    Args:
    -----
    path : str
        Path to save.
    x : np.ndarray
        Numpy array to save.

    Examples
    --------
    >>> save_blosc2("test.blosc2", np.random.randn(10, 10))

    """
    Path(path).write_bytes(blosc2.pack_array2(x))

load_blosc2

load_blosc2(path)

Load numpy array with blosc2 compression.

Args:

path : str Path to load.

Returns:

Type Description
ndarray

Numpy array.

Examples:

>>> data = load_blosc2("test.blosc2")
>>> type(data)
<class 'numpy.ndarray'>
Source code in src/ml_networks/utils.py
def load_blosc2(path: str) -> np.ndarray:
    """Load numpy array with blosc2 compression.

    Args:
    -----
    path : str
        Path to load.

    Returns
    -------
    np.ndarray
        Numpy array.

    Examples
    --------
    >>> data = load_blosc2("test.blosc2")
    >>> type(data)
    <class 'numpy.ndarray'>
    """
    return blosc2.unpack_array2(Path(path).read_bytes())