Skip to content

UNet

条件付きUNet関連のクラスを提供します。

ml_networks.torch.unet(PyTorch)とml_networks.jax.unet(JAX)の両方で提供されています。

ConditionalUnet2d

2D画像データ用の条件付きUNet。Diffusion Modelのノイズ予測ネットワークとして典型的に使用されます。

ConditionalUnet2d

ConditionalUnet2d(feature_dim, obs_shape, cfg)

Bases: Module

条件付きUNetモデル.

Args: feature_dim (int): 条件付き特徴量の次元数 obs_shape (tuple[int, int, int]): 観測データの形状 (チャンネル数, 高さ, 幅) cfg (UNetConfig): UNetの設定

Examples:

>>> from ml_networks.config import UNetConfig, ConvConfig, MLPConfig, LinearConfig
>>> cfg = UNetConfig(
...     channels=[64, 128, 256],
...     conv_cfg=ConvConfig(
...         kernel_size=3,
...         padding=1,
...         stride=1,
...         groups=1,
...         activation="ReLU",
...         dropout=0.0
...     ),
...     has_attn=True,
...     nhead=8,
...     cond_pred_scale=True
... )
>>> net = ConditionalUnet2d(feature_dim=32, obs_shape=(3, 64, 64), cfg=cfg)
>>> x = torch.randn(2, 3, 64, 64)
>>> cond = torch.randn(2, 32)
>>> out = net(x, cond)
>>> out.shape
torch.Size([2, 3, 64, 64])
Source code in src/ml_networks/torch/unet.py
def __init__(
    self,
    feature_dim: int,
    obs_shape: tuple[int, int, int],
    cfg: UNetConfig,
) -> None:
    super().__init__()
    all_dims = [obs_shape[0], *list(cfg.channels)]
    start_dim = cfg.channels[0]
    self.obs_shape = obs_shape

    in_out = list(pairwise(all_dims))

    mid_dim = all_dims[-1]
    self.mid_modules = nn.ModuleList([
        ConditionalResidualBlock2d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
        ),
        Attention2d(mid_dim, cfg.nhead) if cfg.has_attn and cfg.nhead is not None else nn.Identity(),
        ConditionalResidualBlock2d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
        ),
    ])

    down_modules = nn.ModuleList([])
    for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (len(in_out) - 1)
        down_modules.append(
            nn.ModuleList([
                ConditionalResidualBlock2d(
                    dim_in,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                ),
                Attention2d(dim_out, cfg.nhead) if cfg.has_attn and cfg.nhead is not None else nn.Identity(),
                ConditionalResidualBlock2d(
                    dim_out,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                ),
                Downsample2d(dim_out, cfg.use_shuffle) if not is_last else nn.Identity(),
            ]),
        )

    up_modules = nn.ModuleList([])
    for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
        is_last = ind >= (len(in_out) - 1)
        up_modules.append(
            nn.ModuleList([
                ConditionalResidualBlock2d(
                    dim_out * 2,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                ),
                Attention2d(dim_in, cfg.nhead) if cfg.has_attn and cfg.nhead is not None else nn.Identity(),
                ConditionalResidualBlock2d(
                    dim_in,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                ),
                Upsample2d(dim_in, cfg.use_shuffle) if not is_last else nn.Identity(),
            ]),
        )

    final_conv = nn.Sequential(
        ConvNormActivation(start_dim, start_dim, cfg.conv_cfg),
        nn.Conv2d(start_dim, obs_shape[0], 1),
    )

    # ModuleList は型情報を持たないので、mypy に対してはより具体的な
    # list[DownBlock] として扱うように cast する
    self.up_modules = cast("list[DownBlock]", up_modules)
    self.down_modules = cast("list[DownBlock]", down_modules)
    self.final_conv = final_conv

Attributes

down_modules instance-attribute

down_modules = cast('list[DownBlock]', down_modules)

final_conv instance-attribute

final_conv = final_conv

mid_modules instance-attribute

mid_modules = ModuleList([ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale), Attention2d(mid_dim, nhead) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale)])

obs_shape instance-attribute

obs_shape = obs_shape

up_modules instance-attribute

up_modules = cast('list[DownBlock]', up_modules)

Functions

forward

forward(base, cond)

Forward pass.

Parameters:

Name Type Description Default
base Tensor

Input tensor of shape (B, T, input_dim).

required
cond Tensor

Conditional tensor of shape (B, cond_dim).

required

Returns:

Type Description
Tensor

Output tensor of shape (B, T, input_dim).

Source code in src/ml_networks/torch/unet.py
def forward(
    self,
    base: torch.Tensor,
    cond: torch.Tensor,
) -> torch.Tensor:
    """Forward pass.

    Parameters
    ----------
    base : torch.Tensor
        Input tensor of shape (B, T, input_dim).
    cond : torch.Tensor
        Conditional tensor of shape (B, cond_dim).

    Returns
    -------
    torch.Tensor
        Output tensor of shape (B, T, input_dim).
    """
    batch_shape = base.shape[:-3]
    assert base.shape[-3:] == self.obs_shape, (
        f"Input shape {base.shape[-3:]} does not match expected shape {self.obs_shape}"
    )
    base = base.reshape(-1, *self.obs_shape)

    global_feature = cond.reshape(-1, cond.shape[-1])

    x = base
    h: list[torch.Tensor] = []
    for modules in self.down_modules:
        resnet, attn, resnet2, downsample = modules
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        h.append(x)
        x = downsample(x)

    for mid_module in self.mid_modules:
        x = mid_module(x) if isinstance(mid_module, nn.Identity) else mid_module(x, global_feature)

    for modules in self.up_modules:
        resnet, attn, resnet2, upsample = modules
        x = torch.cat((x, h.pop()), dim=1)
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        x = upsample(x)

    x = self.final_conv(x)

    return x.reshape(*batch_shape, *self.obs_shape)

ConditionalUnet1d

1D時系列データ用の条件付きUNet。

ConditionalUnet1d

ConditionalUnet1d(feature_dim, obs_shape, cfg)

Bases: Module

条件付き1D UNetモデル。.

Args: feature_dim (int): 条件付き特徴量の次元数 obs_shape (tuple[int, int]): 観測データの形状 (チャンネル数, 長さ) cfg (UNetConfig): UNetの設定

Examples:

>>> from ml_networks.config import UNetConfig, ConvConfig, MLPConfig, LinearConfig
>>> cfg = UNetConfig(
...     channels=[64, 128, 256],
...     conv_cfg=ConvConfig(
...         kernel_size=3,
...         padding=1,
...         stride=1,
...         groups=1,
...         activation="ReLU",
...         dropout=0.0
...     ),
...     has_attn=True,
...     nhead=8,
...     cond_pred_scale=True
... )
>>> net = ConditionalUnet1d(feature_dim=32, obs_shape=(3, 64), cfg=cfg)
>>> x = torch.randn(2, 3, 64)
>>> cond = torch.randn(2, 32)
>>> out = net(x, cond)
>>> out.shape
torch.Size([2, 3, 64])
Source code in src/ml_networks/torch/unet.py
def __init__(
    self,
    feature_dim: int,
    obs_shape: tuple[int, int],
    cfg: UNetConfig,
) -> None:
    super().__init__()
    all_dims = [obs_shape[0], *list(cfg.channels)]
    start_dim = cfg.channels[0]
    self.obs_shape = obs_shape

    in_out = list(pairwise(all_dims))

    mid_dim = all_dims[-1]
    self.mid_modules = nn.ModuleList([
        ConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
        )
        if not cfg.use_hypernet
        else HyperConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            hyper_mlp_cfg=cfg.hyper_mlp_cfg,
        ),
        Attention1d(mid_dim, cfg.nhead) if cfg.has_attn and cfg.nhead is not None else nn.Identity(),
        ConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
        )
        if not cfg.use_hypernet
        else HyperConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            hyper_mlp_cfg=cfg.hyper_mlp_cfg,
        ),
    ])

    down_modules = nn.ModuleList([])
    for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (len(in_out) - 1)
        down_modules.append(
            nn.ModuleList([
                ConditionalResidualBlock1d(
                    dim_in,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_in,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                ),
                Attention1d(dim_out, cfg.nhead) if cfg.has_attn and cfg.nhead is not None else nn.Identity(),
                ConditionalResidualBlock1d(
                    dim_out,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_out,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                ),
                Downsample1d(dim_out, cfg.use_shuffle) if not is_last else nn.Identity(),
            ]),
        )

    up_modules = nn.ModuleList([])
    for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
        is_last = ind >= (len(in_out) - 1)
        up_modules.append(
            nn.ModuleList([
                ConditionalResidualBlock1d(
                    dim_out * 2,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_out * 2,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                ),
                Attention1d(dim_in, cfg.nhead) if cfg.has_attn and cfg.nhead is not None else nn.Identity(),
                ConditionalResidualBlock1d(
                    dim_in,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_in,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                ),
                Upsample1d(dim_in, cfg.use_shuffle) if not is_last else nn.Identity(),
            ]),
        )
    final_conv = nn.Sequential(
        ConvNormActivation1d(start_dim, start_dim, cfg.conv_cfg),
        nn.Conv1d(start_dim, obs_shape[0], 1),
    )

    # ModuleList は型情報を持たないので、mypy に対してはより具体的な
    # list[DownBlock] として扱うように cast する
    self.up_modules = cast("list[DownBlock]", up_modules)
    self.down_modules = cast("list[DownBlock]", down_modules)
    self.final_conv = final_conv

Attributes

down_modules instance-attribute

down_modules = cast('list[DownBlock]', down_modules)

final_conv instance-attribute

final_conv = final_conv

mid_modules instance-attribute

mid_modules = ModuleList([ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg), Attention1d(mid_dim, nhead) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg)])

obs_shape instance-attribute

obs_shape = obs_shape

up_modules instance-attribute

up_modules = cast('list[DownBlock]', up_modules)

Functions

forward

forward(base, cond)

Forward pass.

Parameters:

Name Type Description Default
base Tensor

Input tensor of shape (B, input_dim, T).

required
cond Tensor

Conditional tensor of shape (B, cond_dim).

required

Returns:

Type Description
Tensor

Output tensor of shape (B, T, input_dim).

Source code in src/ml_networks/torch/unet.py
def forward(
    self,
    base: torch.Tensor,
    cond: torch.Tensor,
) -> torch.Tensor:
    """Forward pass.

    Parameters
    ----------
    base : torch.Tensor
        Input tensor of shape (B, input_dim, T).
    cond : torch.Tensor
        Conditional tensor of shape (B, cond_dim).

    Returns
    -------
    torch.Tensor
        Output tensor of shape (B, T, input_dim).
    """
    batch_shape = base.shape[:-2]
    assert base.shape[-2:] == self.obs_shape, (
        f"Input shape {base.shape[-2:]} does not match expected shape {self.obs_shape}"
    )
    base = base.reshape(-1, *self.obs_shape)

    global_feature = cond.reshape(-1, cond.shape[-1])

    x = base
    h: list[torch.Tensor] = []
    for modules in self.down_modules:
        resnet, attn, resnet2, downsample = modules
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        h.append(x)
        x = downsample(x)

    for mid_module in self.mid_modules:
        x = mid_module(x) if isinstance(mid_module, nn.Identity) else mid_module(x, global_feature)

    for modules in self.up_modules:
        resnet, attn, resnet2, upsample = modules
        x = torch.cat((x, h.pop()), dim=1)
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        x = upsample(x)

    x = self.final_conv(x)

    return x.reshape(*batch_shape, *self.obs_shape)