Skip to content

ビジョン

ビジョン関連のモジュール(Encoder、Decoderなど)を提供します。

ml_networks.torch.vision(PyTorch)とml_networks.jax.vision(JAX)の両方で提供されています。

Encoder

Encoder

Encoder(feature_dim, obs_shape, backbone_cfg, fc_cfg=None)

Bases: BaseModule

Encoder with various architectures.

Parameters:

Name Type Description Default
feature_dim int | tuple[int, int, int]

Dimension of the feature tensor. If int, Encoder includes full connection layer to downsample the feature tensor. Otherwise, Encoder does not include full connection layer and directly process with backbone network.

required
obs_shape tuple[int, int, int]

shape of the input tensor

required
backbone_cfg ViTConfig | ConvNetConfig | ResNetConfig

configuration of the network

required
fc_cfg MLPConfig | LinearConfig | SpatialSoftmaxConfig | None

configuration of the full connection layer. If feature_dim is tuple, fc_cfg is ignored. If feature_dim is int, fc_cfg must be provided. Default is None.

None

Examples:

>>> feature_dim = 128
>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[16, 32, 64],
...     conv_cfgs=[
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> fc_cfg = LinearConfig(
...     activation="ReLU",
...     bias=True
... )
>>> encoder = Encoder(feature_dim, obs_shape, cfg, fc_cfg)
>>> x = torch.randn(2, *obs_shape)
>>> y = encoder(x)
>>> y.shape
torch.Size([2, 128])
>>> encoder
Encoder(
  (encoder): ConvNet(
    (conv): Sequential(
      (0): ConvNormActivation(
        (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pixel_shuffle): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (1): ConvNormActivation(
        (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pixel_shuffle): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (2): ConvNormActivation(
        (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (pixel_shuffle): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
    )
  )
  (fc): Sequential(
    (0): Flatten(start_dim=1, end_dim=-1)
    (1): LinearNormActivation(
      (linear): Linear(in_features=4096, out_features=128, bias=True)
      (norm): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
  )
)
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    feature_dim: int | tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    backbone_cfg: ViTConfig | ConvNetConfig | ResNetConfig,
    fc_cfg: MLPConfig | LinearConfig | SpatialSoftmaxConfig | None = None,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape

    self.encoder: nn.Module
    self.fc: nn.Module
    self._is_vit = isinstance(backbone_cfg, ViTConfig)
    if isinstance(backbone_cfg, ViTConfig):
        self.encoder = ViT(obs_shape, backbone_cfg)
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.encoder = ConvNet(obs_shape, backbone_cfg)
    elif isinstance(backbone_cfg, ResNetConfig):
        self.encoder = ResNetPixUnshuffle(obs_shape, backbone_cfg)
    else:
        msg = f"{type(backbone_cfg)} is not implemented"
        raise NotImplementedError(msg)

    self.feature_dim = feature_dim

    if self._is_vit:
        # ViT encoder returns the CLS token (B, d_model); skip the convolutional reshape.
        assert isinstance(backbone_cfg, ViTConfig)
        d_model = backbone_cfg.transformer_cfg.d_model
        self.last_channel = d_model
        self.conved_shape = (1, 1)
        self.conved_size = d_model
        assert isinstance(feature_dim, int), "feature_dim must be int when using ViTConfig backbone"
        if isinstance(fc_cfg, MLPConfig):
            self.fc = MLPLayer(d_model, feature_dim, fc_cfg)
        elif isinstance(fc_cfg, LinearConfig):
            self.fc = LinearNormActivation(d_model, feature_dim, fc_cfg)
        elif fc_cfg is None:
            assert d_model == feature_dim, (
                f"feature_dim must equal transformer d_model when fc_cfg is None, got {feature_dim} vs {d_model}"
            )
            self.fc = nn.Identity()
        else:
            msg = f"fc_cfg type {type(fc_cfg)} is not supported with ViTConfig backbone"
            raise NotImplementedError(msg)
        return

    # 型情報を補うために明示的にキャスト
    self.conved_size = cast("int", self.encoder.conved_size)
    self.conved_shape = cast("tuple[int, int]", self.encoder.conved_shape)
    self.last_channel = cast("int", self.encoder.last_channel)

    if isinstance(feature_dim, int):
        assert fc_cfg is not None, "fc_cfg must be provided if feature_dim is provided"
    else:
        assert feature_dim == (self.last_channel, *self.conved_shape), (
            f"{feature_dim} != {(self.last_channel, *self.conved_shape)}"
        )
    if isinstance(fc_cfg, MLPConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using MLPConfig"
        self.fc = nn.Sequential(
            nn.Flatten(),
            MLPLayer(self.conved_size, feature_dim, fc_cfg),
        )
    elif isinstance(fc_cfg, LinearConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using LinearConfig"
        self.fc = nn.Sequential(
            nn.Flatten(),
            LinearNormActivation(self.conved_size, feature_dim, fc_cfg),
        )
    elif isinstance(fc_cfg, AdaptiveAveragePoolingConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using AdaptiveAveragePoolingConfig"
        self.fc = nn.Sequential(
            nn.AdaptiveAvgPool2d(fc_cfg.output_size),
            nn.Flatten(),
            LinearNormActivation(
                int(self.last_channel * np.prod(fc_cfg.output_size)),
                feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                LinearConfig,
            )
            else MLPLayer(
                int(self.last_channel * np.prod(fc_cfg.output_size)),
                feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                MLPConfig,
            )
            else nn.Identity(),
        )
        if fc_cfg.additional_layer is None:
            self.feature_dim = (
                self.last_channel * (fc_cfg.output_size**2)
                if isinstance(
                    fc_cfg.output_size,
                    int,
                )
                else self.last_channel * np.prod(fc_cfg.output_size)
            )

    elif isinstance(fc_cfg, SpatialSoftmaxConfig):
        assert isinstance(self.feature_dim, int), "feature_dim must be int when using SpatialSoftmaxConfig"
        self.fc = nn.Sequential(
            SpatialSoftmax(fc_cfg),
            nn.Flatten(),
            LinearNormActivation(
                self.last_channel * 2,
                self.feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                LinearConfig,
            )
            else MLPLayer(
                self.last_channel * 2,
                self.feature_dim,
                fc_cfg.additional_layer,
            )
            if isinstance(
                fc_cfg.additional_layer,
                MLPConfig,
            )
            else nn.Identity(),
        )
        if fc_cfg.additional_layer is None:
            self.feature_dim = self.last_channel * 2
    else:
        self.fc = nn.Identity()

Attributes

conved_shape instance-attribute

conved_shape = cast('tuple[int, int]', conved_shape)

conved_size instance-attribute

conved_size = cast('int', conved_size)

encoder instance-attribute

encoder

fc instance-attribute

fc

feature_dim instance-attribute

feature_dim = feature_dim

last_channel instance-attribute

last_channel = cast('int', last_channel)

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *obs_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, *feature_dim)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *obs_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, *feature_dim)
    """
    batch_shape = x.shape[:-3]

    x = x.reshape([-1, *self.obs_shape])
    x = self.encoder(x)
    if self._is_vit:
        # ViT encoder returns the CLS token (B, d_model); feed it directly into fc.
        x = self.fc(x)
    else:
        x = x.view(-1, self.last_channel, *self.conved_shape)
        x = self.fc(x)
    return x.reshape([*batch_shape, *x.shape[1:]])

Decoder

Decoder

Decoder(feature_dim, obs_shape, backbone_cfg, fc_cfg=None)

Bases: BaseModule

Decoder with various architectures.

Parameters:

Name Type Description Default
feature_dim int | tuple[int, int, int]

dimension of the feature tensor, if int, Decoder includes full connection layer to upsample the feature tensor. Otherwise, Decoder does not include full connection layer and directly process with backbone network.

required
obs_shape tuple[int, int, int]

shape of the output tensor

required
backbone_cfg ConvNetConfig | ViTConfig | ResNetConfig

configuration of the network

required
fc_cfg MLPConfig | LinearConfig | None

configuration of the full connection layer. If feature_dim is tuple, fc_cfg is ignored. If feature_dim is int, fc_cfg must be provided. Default is None.

None

Examples:

>>> feature_dim = 128
>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[64, 32, 16],
...     conv_cfgs=[
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> fc_cfg = MLPConfig(
...     hidden_dim=256,
...     n_layers=2,
...     output_activation= "ReLU",
...     linear_cfg= LinearConfig(
...         activation= "ReLU",
...         bias= True
...     )
... )
>>> decoder = Decoder(feature_dim, obs_shape, cfg, fc_cfg)
>>> x = torch.randn(2, feature_dim)
>>> y = decoder(x)
>>> y.shape
torch.Size([2, 3, 64, 64])
>>> decoder
Decoder(
  (fc): MLPLayer(
    (dense): Sequential(
      (0): LinearNormActivation(
        (linear): Linear(in_features=128, out_features=256, bias=True)
        (norm): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (1): LinearNormActivation(
        (linear): Linear(in_features=256, out_features=256, bias=True)
        (norm): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (2): LinearNormActivation(
        (linear): Linear(in_features=256, out_features=1024, bias=True)
        (norm): Identity()
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
    )
  )
  (decoder): ConvTranspose(
    (first_conv): Conv2d(16, 64, kernel_size=(1, 1), stride=(1, 1))
    (conv): Sequential(
      (0): ConvTransposeNormActivation(
        (conv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (1): ConvTransposeNormActivation(
        (conv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
      (2): ConvTransposeNormActivation(
        (conv): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
        (norm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
        (activation): Activation(
          (activation): ReLU()
        )
        (dropout): Identity()
      )
    )
  )
)
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    feature_dim: int | tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    backbone_cfg: ConvNetConfig | ViTConfig | ResNetConfig,
    fc_cfg: MLPConfig | LinearConfig | None = None,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape
    self.feature_dim = feature_dim
    self._is_vit = isinstance(backbone_cfg, ViTConfig)
    self.fc: nn.Module
    self.input_shape: tuple[int, ...]
    self.decoder: nn.Module

    if self._is_vit:
        assert isinstance(backbone_cfg, ViTConfig)
        d_model = backbone_cfg.transformer_cfg.d_model
        assert isinstance(feature_dim, int), "feature_dim must be int when using ViTConfig backbone"
        # ViT decoder consumes a CLS token (B, d_model); fc maps feature_dim -> d_model.
        if isinstance(fc_cfg, MLPConfig):
            self.fc = MLPLayer(feature_dim, d_model, fc_cfg)
        elif isinstance(fc_cfg, LinearConfig):
            self.fc = LinearNormActivation(feature_dim, d_model, fc_cfg)
        elif fc_cfg is None:
            assert feature_dim == d_model, (
                f"feature_dim must equal transformer d_model when fc_cfg is None, got {feature_dim} vs {d_model}"
            )
            self.fc = nn.Identity()
        else:
            msg = f"fc_cfg type {type(fc_cfg)} is not supported with ViTConfig backbone"
            raise NotImplementedError(msg)
        self.input_shape = (d_model,)
        self.has_fc = True
        self.decoder = ViT(in_shape=(d_model,), obs_shape=obs_shape, cfg=backbone_cfg)
        return

    in_shape3: tuple[int, int, int]
    if isinstance(backbone_cfg, ConvNetConfig):
        in_shape3 = cast(
            "tuple[int, int, int]",
            ConvTranspose.get_input_shape(obs_shape, backbone_cfg),
        )
    elif isinstance(backbone_cfg, ResNetConfig):
        in_shape3 = cast(
            "tuple[int, int, int]",
            ResNetPixShuffle.get_input_shape(obs_shape, backbone_cfg),
        )
    else:
        msg = f"{type(backbone_cfg)} is not implemented"
        raise NotImplementedError(msg)
    self.input_shape = in_shape3
    if isinstance(feature_dim, int):
        assert fc_cfg is not None, "fc_cfg must be provided if feature_dim is provided"
        self.has_fc = True
    else:
        assert feature_dim == in_shape3, f"{feature_dim} != {in_shape3}"
        self.has_fc = False

    if isinstance(fc_cfg, MLPConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using MLPConfig"
        self.fc = MLPLayer(feature_dim, int(np.prod(in_shape3)), fc_cfg)
    elif isinstance(fc_cfg, LinearConfig):
        assert isinstance(feature_dim, int), "feature_dim must be int when using LinearConfig"
        self.fc = LinearNormActivation(feature_dim, int(np.prod(in_shape3)), fc_cfg)
    else:
        self.fc = nn.Identity()

    if isinstance(backbone_cfg, ConvNetConfig):
        self.decoder = ConvTranspose(in_shape=in_shape3, obs_shape=obs_shape, cfg=backbone_cfg)
    elif isinstance(backbone_cfg, ResNetConfig):
        self.decoder = ResNetPixShuffle(in_shape=in_shape3, obs_shape=obs_shape, cfg=backbone_cfg)

Attributes

decoder instance-attribute

decoder

fc instance-attribute

fc

feature_dim instance-attribute

feature_dim = feature_dim

has_fc instance-attribute

has_fc = True

input_shape instance-attribute

input_shape = in_shape3

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *feature_dim)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, *obs_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *feature_dim)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, *obs_shape)

    """
    if self._is_vit:
        batch_shape, data_shape = x.shape[:-1], x.shape[-1:]
        x = x.reshape([-1, *data_shape])
        x = self.fc(x)  # (B, d_model)
        x = self.decoder(x)  # (B, *obs_shape)
        return x.reshape([*batch_shape, *self.obs_shape])

    if self.has_fc:
        batch_shape, data_shape = x.shape[:-1], x.shape[-1:]
    else:
        batch_shape, data_shape = x.shape[:-3], x.shape[-3:]
    x = x.reshape([-1, *data_shape])
    x = self.fc(x)
    x = x.reshape([-1, *self.input_shape])
    x = self.decoder(x)

    return x.reshape([*batch_shape, *self.obs_shape])

ConvNet

ConvNet

ConvNet(obs_shape, cfg)

Bases: Module

Convolutional Neural Network for Encoder.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

shape of input tensor

required
cfg ConvNetConfig

configuration of the network

required

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[16, 32, 64],
...     conv_cfgs=[
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> encoder = ConvNet(obs_shape, cfg)
>>> encoder
ConvNet(
  (conv): Sequential(
    (0): ConvNormActivation(
      (conv): Conv2d(3, 16, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pixel_shuffle): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
    (1): ConvNormActivation(
      (conv): Conv2d(16, 32, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pixel_shuffle): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
    (2): ConvNormActivation(
      (conv): Conv2d(32, 64, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (pixel_shuffle): Identity()
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
  )
)
>>> x = torch.randn(2, *obs_shape)
>>> y = encoder(x)
>>> y.shape
torch.Size([2, 64, 8, 8])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    obs_shape: tuple[int, int, int],
    cfg: ConvNetConfig,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape
    self.channels = [obs_shape[0], *cfg.channels]
    self.cfg = cfg

    self.conv = self._build_conv()

    self.last_channel = self.channels[-1]

Attributes

cfg instance-attribute

cfg = cfg

channels instance-attribute

channels = [obs_shape[0], *(channels)]

conv instance-attribute

conv = _build_conv()

conved_shape property

conved_shape

Get the shape of the output tensor after convolutional layers.

Returns:

Type Description
tuple[int, int]

shape of the output tensor

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[64, 32, 16],
...     conv_cfgs=[
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=3, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> encoder = ConvNet(obs_shape, cfg)
>>> encoder.conved_shape
(8, 8)

conved_size property

conved_size

Get the size of the output tensor after convolutional layers.

Returns:

Type Description
int

size of the output tensor

last_channel instance-attribute

last_channel = channels[-1]

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *obs_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *obs_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

    """
    return self.conv(x)

ConvTranspose

ConvTranspose

ConvTranspose(in_shape, obs_shape, cfg)

Bases: Module

Convolutional Transpose Network for Decoder.

Parameters:

Name Type Description Default
in_shape tuple[int, int, int]

shape of input tensor

required
obs_shape tuple[int, int, int]

shape of output tensor

required
cfg ConvNetConfig

configuration of the network

required

Examples:

>>> in_shape = (128, 8, 8)
>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[64, 32, 16],
...     conv_cfgs=[
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> decoder = ConvTranspose(in_shape, obs_shape, cfg)
>>> decoder
ConvTranspose(
  (first_conv): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1))
  (conv): Sequential(
    (0): ConvTransposeNormActivation(
      (conv): ConvTranspose2d(64, 32, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
    (1): ConvTransposeNormActivation(
      (conv): ConvTranspose2d(32, 16, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
    (2): ConvTransposeNormActivation(
      (conv): ConvTranspose2d(16, 3, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1))
      (norm): BatchNorm2d(3, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (activation): Activation(
        (activation): ReLU()
      )
      (dropout): Identity()
    )
  )
)
>>> x = torch.randn(2, *in_shape)
>>> y = decoder(x)
>>> y.shape
torch.Size([2, 3, 64, 64])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    in_shape: tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    cfg: ConvNetConfig,
) -> None:
    super().__init__()
    self.in_shape = in_shape
    self.obs_shape = obs_shape
    self.conv_out_shapes = []
    self.cfg = cfg
    self.channels = [*cfg.channels, obs_shape[0]]
    assert len(cfg.channels) == len(cfg.conv_cfgs)
    if self.in_shape[0] != cfg.channels[0]:
        self.first_conv = nn.Conv2d(in_shape[0], cfg.channels[0], kernel_size=1, stride=1, padding=0)
        self.init_channel = cfg.channels[0]
        self.have_first_conv = True
    else:
        self.init_channel = in_shape[0]
        self.have_first_conv = False

    prev_shape: tuple[int, int] = tuple(in_shape[1:])  # type: ignore[assignment]
    for conv_cfg in cfg.conv_cfgs:
        padding, kernel, stride, dilation = (
            conv_cfg.padding,
            conv_cfg.kernel_size,
            conv_cfg.stride,
            conv_cfg.dilation,
        )
        prev_shape = tuple(conv_transpose_out_shape(prev_shape, padding, kernel, stride, dilation))  # type: ignore[assignment]
        self.conv_out_shapes += [prev_shape]
    assert self.conv_out_shapes[-1] == obs_shape[1:], f"{self.conv_out_shapes[-1]} != {obs_shape[1:]}"

    self.conv = self._build_conv()

Attributes

cfg instance-attribute

cfg = cfg

channels instance-attribute

channels = [*(channels), obs_shape[0]]

conv instance-attribute

conv = _build_conv()

conv_out_shapes instance-attribute

conv_out_shapes = []

first_conv instance-attribute

first_conv = Conv2d(in_shape[0], channels[0], kernel_size=1, stride=1, padding=0)

have_first_conv instance-attribute

have_first_conv = True

in_shape instance-attribute

in_shape = in_shape

init_channel instance-attribute

init_channel = channels[0]

obs_shape instance-attribute

obs_shape = obs_shape

Functions

forward

forward(z)

Forward pass.

Parameters:

Name Type Description Default
z Tensor

input tensor of shape (batch_size, *in_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, *obs_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, z: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    z: torch.Tensor
        input tensor of shape (batch_size, *in_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, *obs_shape)


    """
    if self.have_first_conv:
        z = self.first_conv(z)
    return self.conv(z)

get_input_shape staticmethod

get_input_shape(obs_shape, cfg)

Get input shape of the decoder.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

shape of the output tensor

required
cfg ConvNetConfig

configuration of the network

required

Returns:

Type Description
tuple[int, int, int]

shape of the input tensor

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ConvNetConfig(
...     channels=[64, 32, 16],
...     conv_cfgs=[
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
...     ]
... )
>>> ConvTranspose.get_input_shape(obs_shape, cfg)
(16, 8, 8)
Source code in src/ml_networks/torch/vision.py
@staticmethod
def get_input_shape(obs_shape: tuple[int, int, int], cfg: ConvNetConfig) -> tuple[int, ...]:
    """
    Get input shape of the decoder.

    Parameters
    ----------
    obs_shape: tuple[int, int, int]
        shape of the output tensor
    cfg: ConvNetConfig
        configuration of the network

    Returns
    -------
    tuple[int, int, int]
        shape of the input tensor

    Examples
    --------
    >>> obs_shape = (3, 64, 64)
    >>> cfg = ConvNetConfig(
    ...     channels=[64, 32, 16],
    ...     conv_cfgs=[
    ...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
    ...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
    ...         ConvConfig(kernel_size=4, stride=2, padding=1, activation="ReLU", norm="batch", dropout=0.0),
    ...     ]
    ... )
    >>> ConvTranspose.get_input_shape(obs_shape, cfg)
    (16, 8, 8)
    """
    in_shape: tuple[int, int] = tuple(obs_shape[1:])  # type: ignore[assignment]
    for conv_cfg in reversed(cfg.conv_cfgs):
        padding, kernel, stride, dilation = (
            conv_cfg.padding,
            conv_cfg.kernel_size,
            conv_cfg.stride,
            conv_cfg.dilation,
        )
        in_shape = tuple(conv_transpose_in_shape(in_shape, padding, kernel, stride, dilation))  # type: ignore[assignment]
    return (cfg.init_channel, *in_shape)

ResNetPixUnshuffle

ResNetPixUnshuffle

ResNetPixUnshuffle(obs_shape, cfg)

Bases: Module

ResNet with PixelUnshuffle for Encoder.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

shape of input tensor

required
cfg ResNetConfig

configuration of the network

required

Examples:

>>> obs_shape = (3, 64, 64)
>>> cfg = ResNetConfig(
...     conv_channel=64,
...     conv_kernel=3,
...     f_kernel=3,
...     conv_activation="ReLU",
...     out_activation="ReLU",
...     n_res_blocks=2,
...     scale_factor=2,
...     n_scaling=3,
...     norm="batch",
...     norm_cfg={},
...     dropout=0.0
... )
>>> encoder = ResNetPixUnshuffle(obs_shape, cfg)
>>> x = torch.randn(2, *obs_shape)
>>> y = encoder(x)
>>> y.shape
torch.Size([2, 64, 8, 8])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    obs_shape: tuple[int, int, int],
    cfg: ResNetConfig,
) -> None:
    super().__init__()

    self.obs_shape = obs_shape
    self.cfg = cfg

    first_cfg = ConvConfig(
        activation=cfg.conv_activation,
        kernel_size=cfg.f_kernel,
        stride=1,
        padding=cfg.f_kernel // 2,
        dilation=1,
        groups=1,
        bias=True,
        dropout=cfg.dropout,
        norm=cfg.norm,
        norm_cfg=cfg.norm_cfg,
        padding_mode=cfg.padding_mode,
    )
    # First layer
    self.conv1 = ConvNormActivation(self.obs_shape[0], cfg.conv_channel, first_cfg)

    # downsampling
    downsample: list[nn.Module] = []
    downsample_cfg = first_cfg
    downsample_cfg.kernel_size = cfg.conv_kernel
    downsample_cfg.padding = cfg.conv_kernel // 2
    downsample_cfg.scale_factor = -cfg.scale_factor
    for _ in range(cfg.n_scaling):
        downsample += [
            ConvNormActivation(cfg.conv_channel, cfg.conv_channel, downsample_cfg),
        ]
    self.downsample = nn.Sequential(*downsample)

    # Residual blocks
    res_blocks: list[nn.Module] = []
    for _ in range(cfg.n_res_blocks):
        res_blocks += [
            ResidualBlock(
                cfg.conv_channel,
                cfg.conv_kernel,
                cfg.conv_activation,
                cfg.norm,
                cfg.norm_cfg,
                cfg.dropout,
                cfg.padding_mode,
            ),
        ]
        if cfg.attention is not None:
            res_blocks += [Attention2d(cfg.conv_channel, nhead=None, attn_cfg=cfg.attention)]

    self.res_blocks = nn.Sequential(*res_blocks)

    cov2_cfg = first_cfg
    cov2_cfg.kernel_size = cfg.conv_kernel
    cov2_cfg.padding = cfg.conv_kernel // 2
    cov2_cfg.scale_factor = 0

    # Second conv layer post residual blocks
    self.conv2 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, cov2_cfg)

    # Final output layer
    final_cfg = first_cfg
    final_cfg.kernel_size = cfg.conv_kernel
    final_cfg.padding = cfg.conv_kernel // 2

    self.conv3 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, final_cfg)
    self.last_channel = cfg.conv_channel

Attributes

cfg instance-attribute

cfg = cfg

conv1 instance-attribute

conv1 = ConvNormActivation(obs_shape[0], conv_channel, first_cfg)

conv2 instance-attribute

conv2 = ConvNormActivation(conv_channel, conv_channel, cov2_cfg)

conv3 instance-attribute

conv3 = ConvNormActivation(conv_channel, conv_channel, final_cfg)

conved_shape property

conved_shape

Get the shape of the output tensor after convolutional layers.

Returns:

Type Description
tuple[int, int]

shape of the output tensor

conved_size property

conved_size

Get the size of the output tensor after convolutional layers.

Returns:

Type Description
int

size of the output tensor

downsample instance-attribute

downsample = Sequential(*downsample)

last_channel instance-attribute

last_channel = conv_channel

obs_shape instance-attribute

obs_shape = obs_shape

res_blocks instance-attribute

res_blocks = Sequential(*res_blocks)

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *obs_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *obs_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, self.last_channel, *self.conved_shape)

    """
    out = self.conv1(x)
    out1 = self.downsample(out)
    out_res = self.res_blocks(out1)
    out2 = self.conv2(out_res)
    out = torch.add(out1, out2)
    return self.conv3(out)

ResNetPixShuffle

ResNetPixShuffle

ResNetPixShuffle(in_shape, obs_shape, cfg)

Bases: Module

ResNet with PixelShuffle.

Parameters:

Name Type Description Default
in_shape tuple[int, int, int]

shape of input tensor

required
obs_shape tuple[int, int, int]

shape of output tensor

required
cfg ResNetConfig

configuration of the network

required

Examples:

>>> in_shape = (128, 16, 16)
>>> obs_shape = (3, 64, 64)
>>> cfg = ResNetConfig(
...     conv_channel=64,
...     conv_kernel=3,
...     f_kernel=3,
...     conv_activation="ReLU",
...     out_activation="ReLU",
...     n_res_blocks=2,
...     scale_factor=2,
...     n_scaling=2,
...     norm="batch",
...     norm_cfg={},
...     dropout=0.0
... )
>>> decoder = ResNetPixShuffle(in_shape, obs_shape, cfg)
>>> x = torch.randn(2, *in_shape)
>>> y = decoder(x)
>>> y.shape
torch.Size([2, 3, 64, 64])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    in_shape: tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    cfg: ResNetConfig,
) -> None:
    super().__init__()

    self.in_shape = in_shape
    self.obs_shape = obs_shape
    self.conv_channel = cfg.conv_channel
    self.conv_kernel = cfg.conv_kernel
    self.final_kernel = cfg.f_kernel
    self.conv_activation = cfg.conv_activation
    self.out_activation = cfg.out_activation
    self.n_res_blocks = cfg.n_res_blocks
    self.upscale_factor = cfg.scale_factor
    self.n_upsampling = cfg.n_scaling
    self.norm = cfg.norm
    self.norm_cfg = cfg.norm_cfg
    self.dropout = cfg.dropout

    self._scaling_factor = self.upscale_factor**self.n_upsampling

    height = obs_shape[1]
    width = obs_shape[2]

    out_channels = obs_shape[0]
    self.input_height, self.input_width = height // self._scaling_factor, width // self._scaling_factor
    assert self.input_height == in_shape[1], f"{self.input_height} != {in_shape[1]}"
    assert self.input_width == in_shape[2], f"{self.input_width} != {in_shape[2]}"

    conv_cfg = ConvConfig(
        activation=self.conv_activation,
        kernel_size=self.conv_kernel,
        stride=1,
        padding=self.conv_kernel // 2,
        dilation=1,
        groups=1,
        bias=True,
        dropout=self.dropout,
        norm=cfg.norm,
        norm_cfg=cfg.norm_cfg,
        padding_mode=cfg.padding_mode,
    )

    # First layer
    self.conv1 = ConvNormActivation(in_shape[0], self.conv_channel, conv_cfg)

    # Residual blocks

    res_blocks: list[nn.Module] = []
    for _ in range(self.n_res_blocks):
        res_blocks += [
            ResidualBlock(
                self.conv_channel,
                self.conv_kernel,
                self.conv_activation,
                self.norm,
                self.norm_cfg,
                self.dropout,
                cfg.padding_mode,
            ),
        ]
        if cfg.attention is not None:
            res_blocks += [Attention2d(self.conv_channel, nhead=None, attn_cfg=cfg.attention)]
    self.res_blocks = nn.Sequential(*res_blocks)

    # Second conv layer post residual blocks
    self.conv2 = ConvNormActivation(self.conv_channel, self.conv_channel, conv_cfg)

    upscale_cfg = conv_cfg
    upscale_cfg.scale_factor = self.upscale_factor

    # Upsampling layers
    upsampling: list[nn.Module] = []
    for _ in range(self.n_upsampling):
        upsampling += [
            ConvNormActivation(self.conv_channel, self.conv_channel, upscale_cfg),
        ]
    self.upsampling = nn.Sequential(*upsampling)

    final_cfg = conv_cfg
    final_cfg.kernel_size = self.final_kernel
    final_cfg.padding = self.final_kernel // 2
    final_cfg.activation = self.out_activation
    final_cfg.norm = "none"
    final_cfg.norm_cfg = {}
    final_cfg.dropout = 0.0
    final_cfg.scale_factor = 0
    # Final output layer
    self.conv3 = ConvNormActivation(self.conv_channel, out_channels, final_cfg)

Attributes

conv1 instance-attribute

conv1 = ConvNormActivation(in_shape[0], conv_channel, conv_cfg)

conv2 instance-attribute

conv2 = ConvNormActivation(conv_channel, conv_channel, conv_cfg)

conv3 instance-attribute

conv3 = ConvNormActivation(conv_channel, out_channels, final_cfg)

conv_activation instance-attribute

conv_activation = conv_activation

conv_channel instance-attribute

conv_channel = conv_channel

conv_kernel instance-attribute

conv_kernel = conv_kernel

dropout instance-attribute

dropout = dropout

final_kernel instance-attribute

final_kernel = f_kernel

in_shape instance-attribute

in_shape = in_shape

n_res_blocks instance-attribute

n_res_blocks = n_res_blocks

n_upsampling instance-attribute

n_upsampling = n_scaling

norm instance-attribute

norm = norm

norm_cfg instance-attribute

norm_cfg = norm_cfg

obs_shape instance-attribute

obs_shape = obs_shape

out_activation instance-attribute

out_activation = out_activation

res_blocks instance-attribute

res_blocks = Sequential(*res_blocks)

upsampling instance-attribute

upsampling = Sequential(*upsampling)

upscale_factor instance-attribute

upscale_factor = scale_factor

Functions

forward

forward(x)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

input tensor of shape (batch_size, *in_shape)

required

Returns:

Type Description
Tensor

output tensor of shape (batch_size, *obs_shape)

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x: torch.Tensor
        input tensor of shape (batch_size, *in_shape)

    Returns
    -------
    torch.Tensor
        output tensor of shape (batch_size, *obs_shape)

    """
    out1 = self.conv1(x)
    out = self.res_blocks(out1)
    out2 = self.conv2(out)
    out = torch.add(out1, out2)
    out = self.upsampling(out)
    return self.conv3(out)

get_input_shape staticmethod

get_input_shape(obs_shape, cfg)

Get input shape of the decoder.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

shape of the output tensor

required
cfg ResNetConfig

configuration of the network

required

Returns:

Type Description
tuple[int, int, int]

shape of the input tensor

Source code in src/ml_networks/torch/vision.py
@staticmethod
def get_input_shape(obs_shape: tuple[int, int, int], cfg: ResNetConfig) -> tuple[int, int, int]:
    """
    Get input shape of the decoder.

    Parameters
    ----------
    obs_shape: tuple[int, int, int]
        shape of the output tensor
    cfg: ConvNetConfig
        configuration of the network

    Returns
    -------
    tuple[int, int, int]
        shape of the input tensor

    """
    return (
        cfg.init_channel,
        obs_shape[1] // (cfg.scale_factor**cfg.n_scaling),
        obs_shape[2] // (cfg.scale_factor**cfg.n_scaling),
    )

ViT

ViT

ViT(in_shape, cfg, obs_shape=None)

Bases: Module

Vision Transformer for Encoder and Decoder.

The encoder mode (obs_shape is None) follows the DETR convention: a learnable per-patch positional embedding is added to the query and key tensors of every self-attention layer (rather than being added once at the input). When cfg.cls_token is True a CLS token with its own learnable positional embedding is prepended, and the forward pass returns the CLS token of shape (B, d_model).

The decoder mode (obs_shape is not None) takes a CLS token of shape (B, d_model) or (B, 1, d_model) and reconstructs an image. The CLS token is projected to a hidden dimension and used as the key/value of cross-attention. A fixed set of P = (H // p) * (W // p) learnable query tokens interacts with this representation through several cross-attention layers with residual MLP blocks. Each query is then linearly projected to p * p * C pixels and rearranged into a (B, C, H, W) image.

Parameters:

Name Type Description Default
in_shape tuple[int, ...]

Encoder mode: image shape (C, H, W). Decoder mode: (d_model,); the CLS token is the input.

required
cfg ViTConfig

Network configuration.

required
obs_shape tuple[int, int, int] | None

Output image shape for decoder mode. None selects encoder mode.

None

Examples:

>>> from ml_networks.config import TransformerConfig
>>> in_shape = (3, 64, 64)
>>> cfg = ViTConfig(
...     patch_size=8,
...     cls_token=True,
...     transformer_cfg=TransformerConfig(
...         d_model=64,
...         nhead=8,
...         dim_ff=256,
...         n_layers=2,
...         dropout=0.0,
...         hidden_activation="GELU",
...         output_activation="GELU",
...     ),
...     init_channel=3,
... )
>>> encoder = ViT(in_shape, cfg)
>>> x = torch.randn(2, *in_shape)
>>> cls = encoder(x)
>>> cls.shape
torch.Size([2, 64])
>>> decoder = ViT(in_shape=(64,), cfg=cfg, obs_shape=(3, 64, 64))
>>> y = decoder(cls)
>>> y.shape
torch.Size([2, 3, 64, 64])
Source code in src/ml_networks/torch/vision.py
def __init__(
    self,
    in_shape: tuple[int, ...],
    cfg: ViTConfig,
    obs_shape: tuple[int, int, int] | None = None,
) -> None:
    super().__init__()

    self.in_shape = in_shape
    self.cfg = cfg
    self.patch_size = cfg.patch_size
    self.transformer_cfg = cfg.transformer_cfg
    self.is_encoder = obs_shape is None
    self.obs_shape: tuple[int, int, int] = (
        obs_shape if obs_shape is not None else cast("tuple[int, int, int]", in_shape)
    )

    d_model = self.transformer_cfg.d_model
    self.d_model = d_model

    if self.is_encoder:
        self._build_encoder()
        self.last_channel = d_model
        self.out_patch_dim = d_model
    else:
        self._build_decoder()
        self.out_patch_dim = self.patch_size**2 * self.obs_shape[0]
        self.last_channel = self.out_patch_dim

Attributes

cfg instance-attribute

cfg = cfg

conved_shape property

conved_shape

conved_size property

conved_size

d_model instance-attribute

d_model = d_model

in_shape instance-attribute

in_shape = in_shape

is_encoder instance-attribute

is_encoder = obs_shape is None

last_channel instance-attribute

last_channel = d_model

obs_shape instance-attribute

obs_shape = obs_shape if obs_shape is not None else cast('tuple[int, int, int]', in_shape)

out_patch_dim instance-attribute

out_patch_dim = d_model

patch_size instance-attribute

patch_size = patch_size

transformer_cfg instance-attribute

transformer_cfg = transformer_cfg

Functions

forward

forward(x, return_cls_token=False)

Forward pass.

Parameters:

Name Type Description Default
x Tensor

Encoder mode: image of shape (B, C, H, W). Decoder mode: CLS token of shape (B, d_model) or (B, 1, d_model).

required
return_cls_token bool

Retained for backward compatibility; the encoder always returns the CLS token.

False

Returns:

Type Description
Tensor

Encoder mode: CLS token of shape (B, d_model). Decoder mode: reconstructed image of shape (B, C, H, W).

Source code in src/ml_networks/torch/vision.py
def forward(self, x: torch.Tensor, return_cls_token: bool = False) -> torch.Tensor:
    """
    Forward pass.

    Parameters
    ----------
    x : torch.Tensor
        Encoder mode: image of shape ``(B, C, H, W)``.
        Decoder mode: CLS token of shape ``(B, d_model)`` or ``(B, 1, d_model)``.
    return_cls_token : bool
        Retained for backward compatibility; the encoder always returns the CLS token.

    Returns
    -------
    torch.Tensor
        Encoder mode: CLS token of shape ``(B, d_model)``.
        Decoder mode: reconstructed image of shape ``(B, C, H, W)``.
    """
    del return_cls_token
    if self.is_encoder:
        return self._forward_encoder(x)
    return self._forward_decoder(x)

get_input_shape staticmethod

get_input_shape(obs_shape, cfg)

Input shape consumed by the ViT decoder: the CLS token has dimension d_model.

Source code in src/ml_networks/torch/vision.py
@staticmethod
def get_input_shape(obs_shape: tuple[int, int, int], cfg: ViTConfig) -> tuple[int, ...]:
    """Input shape consumed by the ViT decoder: the CLS token has dimension ``d_model``."""
    del obs_shape
    return (cfg.transformer_cfg.d_model,)

get_n_patches

get_n_patches(obs_shape)
Source code in src/ml_networks/torch/vision.py
def get_n_patches(self, obs_shape: tuple[int, int, int]) -> int:
    return (obs_shape[1] // self.patch_size) * (obs_shape[2] // self.patch_size)

get_patch_dim

get_patch_dim(obs_shape)
Source code in src/ml_networks/torch/vision.py
def get_patch_dim(self, obs_shape: tuple[int, int, int]) -> int:
    return self.patch_size**2 * obs_shape[0]

patchify

patchify(imgs)

画像をパッチに分割する.

Source code in src/ml_networks/torch/vision.py
def patchify(self, imgs: torch.Tensor) -> torch.Tensor:
    """画像をパッチに分割する."""
    p = self.patch_size
    assert imgs.shape[-1] % p == 0
    assert imgs.shape[-2] % p == 0
    return rearrange(imgs, "n c (h p1) (w p2) -> n (h w) (p1 p2 c)", p1=p, p2=p)

unpatchify

unpatchify(x)

パッチを画像に戻す.

Source code in src/ml_networks/torch/vision.py
def unpatchify(self, x: torch.Tensor) -> torch.Tensor:
    """パッチを画像に戻す."""
    p = self.patch_size
    h = self.obs_shape[1] // p
    w = self.obs_shape[2] // p
    assert h * w == x.shape[1], (
        f"{h * w} != {x.shape[1]}, please check the shape {x.shape} and obs_shape {self.obs_shape}"
    )
    return rearrange(x, "n (h w) (p1 p2 c) -> n c (h p1) (w p2)", h=h, w=w, p1=p, p2=p)