Skip to content

JAX API リファレンス

JAX(Flax NNX)バックエンドのAPIリファレンスです。

PyTorchバックエンドと同一のインターフェースを提供しています。詳細は各PyTorch APIリファレンスページを参照してください。

レイヤー (ml_networks.jax.layers)

MLPLayer

MLPLayer(input_dim, output_dim, cfg, *, rngs)

Bases: Module

Multi-layer perceptron layer.

Parameters:

Name Type Description Default
input_dim int

Input dimension.

required
output_dim int

Output dimension.

required
cfg MLPConfig
required
rngs Rngs

Random number generators.

required

Examples:

>>> rngs = nnx.Rngs(0)
>>> cfg = MLPConfig(
...     hidden_dim=16,
...     n_layers=3,
...     output_activation="ReLU",
...     linear_cfg=LinearConfig(
...         activation="ReLU",
...         norm="layer",
...         norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
...         dropout=0.1,
...         norm_first=False,
...         bias=True
...     )
... )
>>> mlp = MLPLayer(32, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32))
>>> output = mlp(x)
>>> output.shape
(1, 16)
Source code in src/ml_networks/jax/layers.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    cfg: MLPConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.cfg = deepcopy(cfg)
    self.input_dim = input_dim
    self.output_dim = output_dim
    self.hidden_dim = cfg.hidden_dim
    self.n_layers = cfg.n_layers
    self.layers = nnx.List(self._build_dense(rngs=rngs))

Attributes

cfg instance-attribute

cfg = deepcopy(cfg)

hidden_dim instance-attribute

hidden_dim = hidden_dim

input_dim instance-attribute

input_dim = input_dim

layers instance-attribute

layers = List(_build_dense(rngs=rngs))

n_layers instance-attribute

n_layers = n_layers

output_dim instance-attribute

output_dim = output_dim

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (*, input_dim)

required

Returns:

Type Description
Array

Output tensor of shape (*, output_dim)

Source code in src/ml_networks/jax/layers.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (*, input_dim)

    Returns
    -------
    jax.Array
        Output tensor of shape (*, output_dim)
    """
    for layer in self.layers:
        x = layer(x)
    return x

LinearNormActivation

LinearNormActivation(input_dim, output_dim, cfg, *, rngs)

Bases: Module

Linear layer with normalization and activation, and dropouts.

Parameters:

Name Type Description Default
input_dim int

Input dimension.

required
output_dim int

Output dimension.

required
cfg LinearConfig

Linear layer configuration.

required
rngs Rngs

Random number generators.

required

Examples:

>>> rngs = nnx.Rngs(0)
>>> cfg = LinearConfig(
...     activation="ReLU",
...     norm="layer",
...     norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
...     dropout=0.1,
...     norm_first=False,
...     bias=True
... )
>>> linear = LinearNormActivation(32, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32))
>>> output = linear(x)
>>> output.shape
(1, 16)
Source code in src/ml_networks/jax/layers.py
def __init__(
    self,
    input_dim: int,
    output_dim: int,
    cfg: LinearConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    out_features = output_dim * 2 if "glu" in cfg.activation.lower() else output_dim

    self.linear = nnx.Linear(input_dim, out_features, use_bias=cfg.bias, rngs=rngs)

    normalized_shape = input_dim if cfg.norm_first else out_features

    norm_cfg = dict(cfg.norm_cfg)
    norm_cfg["normalized_shape"] = normalized_shape
    self.norm = get_norm(cfg.norm, rngs=rngs, **norm_cfg)
    self.activation = Activation(cfg.activation)
    self.dropout: nnx.Module
    if cfg.dropout > 0:
        self.dropout = nnx.Dropout(rate=cfg.dropout, rngs=rngs)
    else:
        self.dropout = Identity()
    self.norm_first = cfg.norm_first

Attributes

activation instance-attribute

activation = Activation(activation)

dropout instance-attribute

dropout

linear instance-attribute

linear = Linear(input_dim, out_features, use_bias=bias, rngs=rngs)

norm instance-attribute

norm = get_norm(norm, rngs=rngs, **norm_cfg)

norm_first instance-attribute

norm_first = norm_first

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (*, input_dim)

required

Returns:

Type Description
Array

Output tensor of shape (*, output_dim)

Source code in src/ml_networks/jax/layers.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (*, input_dim)

    Returns
    -------
    jax.Array
        Output tensor of shape (*, output_dim)
    """
    if self.norm_first:
        x = self.norm(x)
        x = self.linear(x)
        x = self.activation(x)
        x = self.dropout(x)
    else:
        x = self.linear(x)
        x = self.norm(x)
        x = self.activation(x)
        x = self.dropout(x)
    return x

ConvNormActivation

ConvNormActivation(in_channels, out_channels, cfg, *, rngs)

Bases: Module

Convolutional layer with normalization and activation, and dropouts.

Uses NHWC (channels-last) format.

Parameters:

Name Type Description Default
in_channels int

Input channels.

required
out_channels int

Output channels.

required
cfg ConvConfig

Convolutional layer configuration.

required
rngs Rngs

Random number generators.

required

Examples:

>>> rngs = nnx.Rngs(0)
>>> cfg = ConvConfig(
...     activation="ReLU",
...     kernel_size=3,
...     stride=1,
...     padding=1,
...     dilation=1,
...     groups=1,
...     bias=True,
...     dropout=0.1,
...     norm="batch",
...     norm_cfg={"affine": True, "track_running_stats": True},
...     scale_factor=0
... )
>>> conv = ConvNormActivation(3, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32, 32, 3))
>>> output = conv(x)
>>> output.shape
(1, 32, 32, 16)
Source code in src/ml_networks/jax/layers.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    cfg: ConvConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    out_channels_ = out_channels
    if "glu" in cfg.activation.lower():
        out_channels_ *= 2
    if cfg.scale_factor > 0:
        out_channels_ *= abs(cfg.scale_factor) ** 2
    elif cfg.scale_factor < 0:
        out_channels_ //= abs(cfg.scale_factor) ** 2

    # Handle padding mode
    self.padding_mode = cfg.padding_mode
    self.manual_padding = cfg.padding if cfg.padding_mode != "zeros" else 0
    conv_padding: Any
    if cfg.padding_mode != "zeros":
        conv_padding = "VALID"
    else:
        conv_padding = ((cfg.padding, cfg.padding), (cfg.padding, cfg.padding))

    self.conv = nnx.Conv(
        in_features=in_channels,
        out_features=out_channels_,
        kernel_size=(cfg.kernel_size, cfg.kernel_size),
        strides=(cfg.stride, cfg.stride),
        padding=conv_padding,
        kernel_dilation=(cfg.dilation, cfg.dilation),
        feature_group_count=cfg.groups,
        use_bias=cfg.bias,
        rngs=rngs,
    )

    norm_cfg = dict(cfg.norm_cfg) if cfg.norm_cfg else {}
    if cfg.norm not in {"none", "group"}:
        norm_cfg["num_features"] = out_channels_
    elif cfg.norm == "group":
        norm_cfg["num_channels"] = in_channels if cfg.norm_first else out_channels_

    norm_type: Literal["layer", "rms", "group", "batch2d", "batch1d", "none"] = (
        "batch2d" if cfg.norm == "batch" else cfg.norm
    )  # type: ignore[assignment]
    self.norm = get_norm(norm_type, rngs=rngs, **norm_cfg)

    self.scale_factor = cfg.scale_factor
    self.activation = Activation(cfg.activation, dim=-1)
    self.dropout: nnx.Module = nnx.Dropout(rate=cfg.dropout, rngs=rngs) if cfg.dropout > 0 else Identity()
    self.norm_first = cfg.norm_first

Attributes

activation instance-attribute

activation = Activation(activation, dim=-1)

conv instance-attribute

conv = Conv(in_features=in_channels, out_features=out_channels_, kernel_size=(kernel_size, kernel_size), strides=(stride, stride), padding=conv_padding, kernel_dilation=(dilation, dilation), feature_group_count=groups, use_bias=bias, rngs=rngs)

dropout instance-attribute

dropout = Dropout(rate=dropout, rngs=rngs) if dropout > 0 else Identity()

manual_padding instance-attribute

manual_padding = padding if padding_mode != 'zeros' else 0

norm instance-attribute

norm = get_norm(norm_type, rngs=rngs, **norm_cfg)

norm_first instance-attribute

norm_first = norm_first

padding_mode instance-attribute

padding_mode = padding_mode

scale_factor instance-attribute

scale_factor = scale_factor

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels)

required

Returns:

Type Description
Array

Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels)

Source code in src/ml_networks/jax/layers.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels)

    Returns
    -------
    jax.Array
        Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels)
    """
    if self.norm_first:
        x = self.norm(x)
        x = _pad_input(x, self.manual_padding, self.padding_mode, n_spatial_dims=2)
        x = self.conv(x)
        x = self._apply_shuffle(x)
        x = self.activation(x)
        x = self.dropout(x)
    else:
        x = _pad_input(x, self.manual_padding, self.padding_mode, n_spatial_dims=2)
        x = self.conv(x)
        x = self.norm(x)
        x = self._apply_shuffle(x)
        x = self.activation(x)
        x = self.dropout(x)
    return x

ConvTransposeNormActivation

ConvTransposeNormActivation(in_channels, out_channels, cfg, *, rngs)

Bases: Module

Transposed convolutional layer with normalization and activation, and dropouts.

Uses NHWC (channels-last) format.

Parameters:

Name Type Description Default
in_channels int

Input channels.

required
out_channels int

Output channels.

required
cfg ConvConfig

Convolutional layer configuration.

required
rngs Rngs

Random number generators.

required

Examples:

>>> rngs = nnx.Rngs(0)
>>> cfg = ConvConfig(
...     activation="ReLU",
...     kernel_size=3,
...     stride=1,
...     padding=1,
...     output_padding=0,
...     dilation=1,
...     groups=1,
...     bias=True,
...     dropout=0.1,
...     norm="batch",
...     norm_cfg={"affine": True, "track_running_stats": True}
... )
>>> conv = ConvTransposeNormActivation(3, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32, 32, 3))
>>> output = conv(x)
>>> output.shape
(1, 32, 32, 16)
Source code in src/ml_networks/jax/layers.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    cfg: ConvConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    out_features = out_channels * 2 if "glu" in cfg.activation.lower() else out_channels

    self.conv = nnx.ConvTranspose(
        in_features=in_channels,
        out_features=out_features,
        kernel_size=(cfg.kernel_size, cfg.kernel_size),
        strides=(cfg.stride, cfg.stride),
        padding=((cfg.padding, cfg.padding + cfg.output_padding), (cfg.padding, cfg.padding + cfg.output_padding)),
        kernel_dilation=(cfg.dilation, cfg.dilation),
        use_bias=cfg.bias,
        rngs=rngs,
    )

    norm_cfg = dict(cfg.norm_cfg) if cfg.norm_cfg else {}
    if cfg.norm not in {"none", "group"}:
        norm_cfg["num_features"] = out_features
    elif cfg.norm == "group":
        norm_cfg["num_channels"] = out_features
    norm_type: Literal["layer", "rms", "group", "batch2d", "batch1d", "none"] = (
        "batch2d" if cfg.norm == "batch" else cfg.norm
    )  # type: ignore[assignment]
    self.norm = get_norm(norm_type, rngs=rngs, **norm_cfg)
    self.activation = Activation(cfg.activation, dim=-1)
    self.dropout: nnx.Module = nnx.Dropout(rate=cfg.dropout, rngs=rngs) if cfg.dropout > 0 else Identity()

Attributes

activation instance-attribute

activation = Activation(activation, dim=-1)

conv instance-attribute

conv = ConvTranspose(in_features=in_channels, out_features=out_features, kernel_size=(kernel_size, kernel_size), strides=(stride, stride), padding=((padding, padding + output_padding), (padding, padding + output_padding)), kernel_dilation=(dilation, dilation), use_bias=bias, rngs=rngs)

dropout instance-attribute

dropout = Dropout(rate=dropout, rngs=rngs) if dropout > 0 else Identity()

norm instance-attribute

norm = get_norm(norm_type, rngs=rngs, **norm_cfg)

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels)

required

Returns:

Type Description
Array

Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels)

Source code in src/ml_networks/jax/layers.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels)

    Returns
    -------
    jax.Array
        Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels)
    """
    x = self.conv(x)
    x = self.norm(x)
    x = self.activation(x)
    return self.dropout(x)

ビジョン (ml_networks.jax.vision)

Encoder

Encoder(feature_dim, obs_shape, backbone_cfg, fc_cfg=None, *, rngs)

Bases: Module

Image encoder module (NHWC format).

Parameters:

Name Type Description Default
feature_dim int | tuple[int, int, int]

Output feature dimension. If int, a fully-connected layer flattens and projects the backbone output. If tuple, the backbone output is returned directly (fc is identity).

required
obs_shape tuple[int, int, int]

Observation shape in (H, W, C) format.

required
backbone_cfg ViTConfig | ConvNetConfig | ResNetConfig

Backbone configuration.

required
fc_cfg MLPConfig | LinearConfig | SpatialSoftmaxConfig | AdaptiveAveragePoolingConfig | None

Fully-connected layer configuration. Required when feature_dim is int.

None
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    feature_dim: int | tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    backbone_cfg: ViTConfig | ConvNetConfig | ResNetConfig,
    fc_cfg: MLPConfig | LinearConfig | SpatialSoftmaxConfig | AdaptiveAveragePoolingConfig | None = None,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.obs_shape = obs_shape
    self.feature_dim = feature_dim

    self.encoder: nnx.Module
    if isinstance(backbone_cfg, ViTConfig):
        self.encoder = ViT(obs_shape, backbone_cfg, rngs=rngs)
        self.last_channel: int = self.encoder.last_channel
        self.conved_size: int = cast("int", self.encoder.conved_size)
        self.conved_shape: tuple[int, ...] = cast("tuple[int, ...]", self.encoder.conved_shape)
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.encoder = ConvNet(obs_shape, backbone_cfg, rngs=rngs)
        self.last_channel = self.encoder.last_channel
        self.conved_size = cast("int", self.encoder.conved_size)
        self.conved_shape = cast("tuple[int, ...]", self.encoder.conved_shape)
    elif isinstance(backbone_cfg, ResNetConfig):
        self.encoder = ResNetPixUnshuffle(obs_shape, backbone_cfg, rngs=rngs)
        self.last_channel = self.encoder.last_channel
        self.conved_size = cast("int", self.encoder.conved_size)
        self.conved_shape = cast("tuple[int, ...]", self.encoder.conved_shape)
    else:
        msg = f"{type(backbone_cfg)} is not implemented"
        raise NotImplementedError(msg)

    if isinstance(feature_dim, int):
        assert fc_cfg is not None, "fc_cfg must be provided if feature_dim is int"
    else:
        assert feature_dim == (self.last_channel, *self.conved_shape), (
            f"{feature_dim} != {(self.last_channel, *self.conved_shape)}"
        )

    self.fc: nnx.Module
    if isinstance(fc_cfg, MLPConfig):
        assert isinstance(feature_dim, int)
        self.fc = MLPLayer(self.conved_size, feature_dim, fc_cfg, rngs=rngs)
    elif isinstance(fc_cfg, LinearConfig):
        assert isinstance(feature_dim, int)
        self.fc = LinearNormActivation(self.conved_size, feature_dim, fc_cfg, rngs=rngs)
    elif isinstance(fc_cfg, AdaptiveAveragePoolingConfig):
        assert isinstance(feature_dim, int)
        output_size = fc_cfg.output_size
        pooled_size = int(self.last_channel * np.prod(output_size))
        if isinstance(fc_cfg.additional_layer, LinearConfig):
            self.fc = LinearNormActivation(pooled_size, feature_dim, fc_cfg.additional_layer, rngs=rngs)
        elif isinstance(fc_cfg.additional_layer, MLPConfig):
            self.fc = MLPLayer(pooled_size, feature_dim, fc_cfg.additional_layer, rngs=rngs)
        else:
            self.fc = Identity()
        if fc_cfg.additional_layer is None:
            self.feature_dim = pooled_size
        self._adaptive_pool_output_size = output_size
    elif isinstance(fc_cfg, SpatialSoftmaxConfig):
        assert isinstance(feature_dim, int)
        if isinstance(fc_cfg.additional_layer, LinearConfig):
            self.fc = LinearNormActivation(
                self.last_channel * 2,
                feature_dim,
                fc_cfg.additional_layer,
                rngs=rngs,
            )
        elif isinstance(fc_cfg.additional_layer, MLPConfig):
            self.fc = MLPLayer(self.last_channel * 2, feature_dim, fc_cfg.additional_layer, rngs=rngs)
        else:
            self.fc = Identity()
        if fc_cfg.additional_layer is None:
            self.feature_dim = self.last_channel * 2
        self._spatial_softmax = SpatialSoftmax(fc_cfg)
    else:
        self.fc = Identity()
    self._fc_cfg = fc_cfg

Attributes

conved_shape instance-attribute

conved_shape = cast('tuple[int, ...]', conved_shape)

conved_size instance-attribute

conved_size = cast('int', conved_size)

encoder instance-attribute

encoder

fc instance-attribute

fc

feature_dim instance-attribute

feature_dim = feature_dim

last_channel instance-attribute

last_channel = last_channel

obs_shape instance-attribute

obs_shape = obs_shape

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (*, H, W, C) in NHWC format.

required

Returns:

Type Description
Array

Encoded tensor of shape (*, feature_dim).

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (*, H, W, C) in NHWC format.

    Returns
    -------
    jax.Array
        Encoded tensor of shape (*, feature_dim).
    """
    batch_shape = x.shape[:-3]
    x = x.reshape(-1, *self.obs_shape)
    x = self.encoder(x)
    if isinstance(self._fc_cfg, AdaptiveAveragePoolingConfig):
        # NHWC adaptive average pooling: (B, H, W, C) -> (B, oh, ow, C)
        pool_size = self._adaptive_pool_output_size
        assert isinstance(pool_size, tuple)
        oh, ow = pool_size
        b, h, w, c = x.shape
        # Reshape to windows and average
        x = x.reshape(b, oh, h // oh, ow, w // ow, c)
        x = x.mean(axis=(2, 4))
        x = x.reshape(b, -1)
        x = self.fc(x)
    elif isinstance(self._fc_cfg, SpatialSoftmaxConfig):
        x = self._spatial_softmax(x)
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
    else:
        x = x.reshape(x.shape[0], -1)
        x = self.fc(x)
    return x.reshape(*batch_shape, *x.shape[1:])

Decoder

Decoder(feature_dim, obs_shape, backbone_cfg, fc_cfg=None, *, rngs)

Bases: Module

Image decoder module (NHWC format).

Parameters:

Name Type Description Default
feature_dim int | tuple[int, int, int]

Input feature dimension. If int, a fully-connected layer projects and reshapes input before the backbone. If tuple, input is passed directly to the backbone.

required
obs_shape tuple[int, int, int]

Output observation shape in (H, W, C) format.

required
backbone_cfg ConvNetConfig | ViTConfig | ResNetConfig

Backbone configuration.

required
fc_cfg MLPConfig | LinearConfig | None

Fully-connected layer configuration. Required when feature_dim is int.

None
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    feature_dim: int | tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    backbone_cfg: ConvNetConfig | ViTConfig | ResNetConfig,
    fc_cfg: MLPConfig | LinearConfig | None = None,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.obs_shape = obs_shape
    self.feature_dim = feature_dim

    self.input_shape: tuple[int, int, int]
    if isinstance(backbone_cfg, ViTConfig):
        self.input_shape = ViT.get_input_shape(obs_shape, backbone_cfg)
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.input_shape = cast(
            "tuple[int, int, int]",
            ConvTranspose.get_input_shape(obs_shape, backbone_cfg),
        )
    elif isinstance(backbone_cfg, ResNetConfig):
        self.input_shape = ResNetPixShuffle.get_input_shape(obs_shape, backbone_cfg)
    else:
        msg = f"{type(backbone_cfg)} is not implemented"
        raise NotImplementedError(msg)

    if isinstance(feature_dim, int):
        assert fc_cfg is not None, "fc_cfg must be provided if feature_dim is int"
        self.has_fc = True
    else:
        assert feature_dim == self.input_shape, f"{feature_dim} != {self.input_shape}"
        self.has_fc = False

    input_size = int(np.prod(self.input_shape))
    self.fc: nnx.Module
    if isinstance(fc_cfg, MLPConfig):
        assert isinstance(feature_dim, int)
        self.fc = MLPLayer(feature_dim, input_size, fc_cfg, rngs=rngs)
    elif isinstance(fc_cfg, LinearConfig):
        assert isinstance(feature_dim, int)
        self.fc = LinearNormActivation(feature_dim, input_size, fc_cfg, rngs=rngs)
    else:
        self.fc = Identity()

    if isinstance(backbone_cfg, ViTConfig):
        self.decoder: nnx.Module = ViT(
            in_shape=self.input_shape,
            cfg=backbone_cfg,
            obs_shape=obs_shape,
            rngs=rngs,
        )
    elif isinstance(backbone_cfg, ConvNetConfig):
        self.decoder = ConvTranspose(
            in_shape=self.input_shape,
            obs_shape=obs_shape,
            cfg=backbone_cfg,
            rngs=rngs,
        )
    elif isinstance(backbone_cfg, ResNetConfig):
        self.decoder = ResNetPixShuffle(
            in_shape=self.input_shape,
            obs_shape=obs_shape,
            cfg=backbone_cfg,
            rngs=rngs,
        )

Attributes

decoder instance-attribute

decoder = ViT(in_shape=input_shape, cfg=backbone_cfg, obs_shape=obs_shape, rngs=rngs)

fc instance-attribute

fc

feature_dim instance-attribute

feature_dim = feature_dim

has_fc instance-attribute

has_fc = True

input_shape instance-attribute

input_shape

obs_shape instance-attribute

obs_shape = obs_shape

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (*, feature_dim).

required

Returns:

Type Description
Array

Decoded tensor of shape (*, H, W, C) in NHWC format.

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (*, feature_dim).

    Returns
    -------
    jax.Array
        Decoded tensor of shape (*, H, W, C) in NHWC format.
    """
    if self.has_fc:
        batch_shape, data_shape = x.shape[:-1], x.shape[-1:]
    else:
        batch_shape, data_shape = x.shape[:-3], x.shape[-3:]
    x = x.reshape(-1, *data_shape)
    x = self.fc(x)
    x = x.reshape(-1, *self.input_shape)
    x = self.decoder(x)
    return x.reshape(*batch_shape, *self.obs_shape)

ConvNet

ConvNet(obs_shape, cfg, *, rngs)

Bases: Module

Convolutional network (NHWC format).

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

Observation shape in (H, W, C) format.

required
cfg ConvNetConfig

Configuration.

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    obs_shape: tuple[int, int, int],
    cfg: ConvNetConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.cfg = cfg
    self.obs_shape = obs_shape
    in_channels = obs_shape[2]  # NHWC
    self.channels = [in_channels, *cfg.channels]

    layers: list[nnx.Module] = []
    attn_layers: list[nnx.Module] = []
    spatial_shape: tuple[int, ...] = (obs_shape[0], obs_shape[1])

    for ch, conv_cfg_i in zip(cfg.channels, cfg.conv_cfgs, strict=True):
        layers.append(ConvNormActivation(in_channels, ch, conv_cfg_i, rngs=rngs))

        if cfg.attention is not None:
            attn_layers.append(
                Attention2d(ch, nhead=None, attn_cfg=cfg.attention, rngs=rngs),
            )
        else:
            attn_layers.append(Identity())

        spatial_shape = conv_out_shape(
            spatial_shape,
            padding=conv_cfg_i.padding,
            kernel_size=conv_cfg_i.kernel_size,
            stride=conv_cfg_i.stride,
            dilation=conv_cfg_i.dilation,
        )
        in_channels = ch

    self.conv_layers = nnx.List(layers)
    self.attn_layers = nnx.List(attn_layers)
    self.output_spatial_shape = spatial_shape
    self.output_channels = in_channels
    self.last_channel = in_channels
    self.output_dim = in_channels * int(np.prod(spatial_shape))

Attributes

attn_layers instance-attribute

attn_layers = List(attn_layers)

cfg instance-attribute

cfg = cfg

channels instance-attribute

channels = [in_channels, *(channels)]

conv_layers instance-attribute

conv_layers = List(layers)

conved_shape property

conved_shape

Get the spatial shape of the output after convolutional layers.

conved_size property

conved_size

Get the flattened size of the output after convolutional layers.

last_channel instance-attribute

last_channel = in_channels

obs_shape instance-attribute

obs_shape = obs_shape

output_channels instance-attribute

output_channels = in_channels

output_dim instance-attribute

output_dim = in_channels * int(prod(spatial_shape))

output_spatial_shape instance-attribute

output_spatial_shape = spatial_shape

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, C) in NHWC format.

required

Returns:

Type Description
Array

Flattened tensor of shape (B, output_dim).

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, C) in NHWC format.

    Returns
    -------
    jax.Array
        Flattened tensor of shape (B, output_dim).
    """
    for conv, attn in zip(self.conv_layers, self.attn_layers, strict=False):
        x = conv(x)
        x = attn(x)
    return x.reshape(x.shape[0], -1)

ConvTranspose

ConvTranspose(in_shape, obs_shape, cfg, *, rngs)

Bases: Module

Transposed convolutional network (NHWC format).

Parameters:

Name Type Description Default
in_shape tuple[int, int, int]

Input shape in (H, W, C) format.

required
obs_shape tuple[int, int, int]

Output observation shape in (H, W, C) format.

required
cfg ConvNetConfig

Configuration (channels are in decode order).

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    in_shape: tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    cfg: ConvNetConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.in_shape = in_shape
    self.obs_shape = obs_shape
    self.cfg = cfg
    # channels: cfg.channels -> obs_shape[2] (output channels)
    self.channels = [*cfg.channels, obs_shape[2]]

    assert len(cfg.channels) == len(cfg.conv_cfgs)

    # first_conv if input channels != first cfg channel
    self.have_first_conv = in_shape[2] != cfg.channels[0]
    if self.have_first_conv:
        first_conv_cfg = ConvConfig(
            activation="Identity",
            kernel_size=1,
            stride=1,
            padding=0,
            norm="none",
        )
        self.first_conv = ConvNormActivation(
            in_shape[2],
            cfg.channels[0],
            first_conv_cfg,
            rngs=rngs,
        )

    layers: list[nnx.Module] = []
    for i, conv_cfg_i in enumerate(cfg.conv_cfgs):
        if cfg.attention is not None:
            layers.append(
                Attention2d(self.channels[i], nhead=None, attn_cfg=cfg.attention, rngs=rngs),
            )
        layers.append(
            ConvTransposeNormActivation(self.channels[i], self.channels[i + 1], conv_cfg_i, rngs=rngs),
        )

    self.conv_layers = nnx.List(layers)

Attributes

cfg instance-attribute

cfg = cfg

channels instance-attribute

channels = [*(channels), obs_shape[2]]

conv_layers instance-attribute

conv_layers = List(layers)

first_conv instance-attribute

first_conv = ConvNormActivation(in_shape[2], channels[0], first_conv_cfg, rngs=rngs)

have_first_conv instance-attribute

have_first_conv = in_shape[2] != channels[0]

in_shape instance-attribute

in_shape = in_shape

obs_shape instance-attribute

obs_shape = obs_shape

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, C) in NHWC format.

required

Returns:

Type Description
Array

Output tensor of shape (B, H', W', out_C) in NHWC format.

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, C) in NHWC format.

    Returns
    -------
    jax.Array
        Output tensor of shape (B, H', W', out_C) in NHWC format.
    """
    if self.have_first_conv:
        x = self.first_conv(x)
    for layer in self.conv_layers:
        x = layer(x)
    return x

get_input_shape staticmethod

get_input_shape(obs_shape, cfg)

Get the required input shape for a given output shape and config.

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

Output shape in (H, W, C) format.

required
cfg ConvNetConfig

Configuration.

required

Returns:

Type Description
tuple[int, ...]

Required input shape in (H, W, C) format.

Source code in src/ml_networks/jax/vision.py
@staticmethod
def get_input_shape(obs_shape: tuple[int, int, int], cfg: ConvNetConfig) -> tuple[int, ...]:
    """Get the required input shape for a given output shape and config.

    Parameters
    ----------
    obs_shape : tuple[int, int, int]
        Output shape in (H, W, C) format.
    cfg : ConvNetConfig
        Configuration.

    Returns
    -------
    tuple[int, ...]
        Required input shape in (H, W, C) format.
    """
    # NHWC: spatial dims are [0], [1]
    in_spatial: tuple[int, ...] = obs_shape[:2]
    for conv_cfg_i in reversed(cfg.conv_cfgs):
        in_spatial = conv_transpose_in_shape(
            in_spatial,
            padding=conv_cfg_i.padding,
            kernel_size=conv_cfg_i.kernel_size,
            stride=conv_cfg_i.stride,
            dilation=conv_cfg_i.dilation,
        )
    return (*in_spatial, cfg.init_channel)

ResNetPixUnshuffle

ResNetPixUnshuffle(obs_shape, cfg, *, rngs)

Bases: Module

ResNet with PixelUnshuffle downsampling (NHWC format).

Parameters:

Name Type Description Default
obs_shape tuple[int, int, int]

Input observation shape in (H, W, C) format.

required
cfg ResNetConfig

Configuration.

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    obs_shape: tuple[int, int, int],
    cfg: ResNetConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.cfg = cfg
    self.obs_shape = obs_shape

    first_cfg = ConvConfig(
        activation=cfg.conv_activation,
        kernel_size=cfg.f_kernel,
        stride=1,
        padding=cfg.f_kernel // 2,
        dilation=1,
        groups=1,
        bias=True,
        dropout=cfg.dropout,
        norm=cfg.norm,
        norm_cfg=cfg.norm_cfg,
        padding_mode=cfg.padding_mode,
    )
    # First layer: input channels -> conv_channel
    self.conv1 = ConvNormActivation(obs_shape[2], cfg.conv_channel, first_cfg, rngs=rngs)

    # Downsampling layers
    downsample_cfg = deepcopy(first_cfg)
    downsample_cfg.kernel_size = cfg.conv_kernel
    downsample_cfg.padding = cfg.conv_kernel // 2
    downsample_cfg.scale_factor = -cfg.scale_factor
    downsample_layers = [
        ConvNormActivation(cfg.conv_channel, cfg.conv_channel, downsample_cfg, rngs=rngs)
        for _ in range(cfg.n_scaling)
    ]
    self.downsample = nnx.List(downsample_layers)

    # Residual blocks
    res_blocks: list[nnx.Module] = []
    for _ in range(cfg.n_res_blocks):
        res_blocks.append(
            ResidualBlock(
                cfg.conv_channel,
                cfg.conv_kernel,
                cfg.conv_activation,
                cfg.norm,
                cfg.norm_cfg,
                cfg.dropout,
                cfg.padding_mode,
                rngs=rngs,
            ),
        )
        if cfg.attention is not None:
            res_blocks.append(
                Attention2d(cfg.conv_channel, nhead=None, attn_cfg=cfg.attention, rngs=rngs),
            )
    self.res_blocks = nnx.List(res_blocks)

    # Post-residual conv
    conv_cfg = deepcopy(first_cfg)
    conv_cfg.kernel_size = cfg.conv_kernel
    conv_cfg.padding = cfg.conv_kernel // 2
    conv_cfg.scale_factor = 0
    self.conv2 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, conv_cfg, rngs=rngs)

    # Final conv
    self.conv3 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, conv_cfg, rngs=rngs)
    self.last_channel = cfg.conv_channel

Attributes

cfg instance-attribute

cfg = cfg

conv1 instance-attribute

conv1 = ConvNormActivation(obs_shape[2], conv_channel, first_cfg, rngs=rngs)

conv2 instance-attribute

conv2 = ConvNormActivation(conv_channel, conv_channel, conv_cfg, rngs=rngs)

conv3 instance-attribute

conv3 = ConvNormActivation(conv_channel, conv_channel, conv_cfg, rngs=rngs)

conved_shape property

conved_shape

Get the spatial shape after downsampling.

conved_size property

conved_size

Get the flattened size after downsampling.

downsample instance-attribute

downsample = List(downsample_layers)

last_channel instance-attribute

last_channel = conv_channel

obs_shape instance-attribute

obs_shape = obs_shape

res_blocks instance-attribute

res_blocks = List(res_blocks)

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, C) in NHWC format.

required

Returns:

Type Description
Array

Downsampled output.

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, C) in NHWC format.

    Returns
    -------
    jax.Array
        Downsampled output.
    """
    out = self.conv1(x)
    for layer in self.downsample:
        out = layer(out)
    out1 = out
    for layer in self.res_blocks:
        out = layer(out)
    out2 = self.conv2(out)
    out = out1 + out2
    return self.conv3(out)

ResNetPixShuffle

ResNetPixShuffle(in_shape, obs_shape, cfg, *, rngs)

Bases: Module

ResNet with PixelShuffle upsampling (NHWC format).

Parameters:

Name Type Description Default
in_shape tuple[int, int, int]

Input shape in (H, W, C) format.

required
obs_shape tuple[int, int, int]

Output observation shape in (H, W, C) format.

required
cfg ResNetConfig

Configuration.

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    in_shape: tuple[int, int, int],
    obs_shape: tuple[int, int, int],
    cfg: ResNetConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.cfg = cfg
    self.in_shape = in_shape
    self.obs_shape = obs_shape
    out_channels = obs_shape[2]  # NHWC

    conv_cfg = ConvConfig(
        activation=cfg.conv_activation,
        kernel_size=cfg.conv_kernel,
        stride=1,
        padding=cfg.conv_kernel // 2,
        dilation=1,
        groups=1,
        bias=True,
        dropout=cfg.dropout,
        norm=cfg.norm,
        norm_cfg=cfg.norm_cfg,
        padding_mode=cfg.padding_mode,
    )

    # First layer
    self.conv1 = ConvNormActivation(in_shape[2], cfg.conv_channel, conv_cfg, rngs=rngs)

    # Residual blocks
    res_blocks: list[nnx.Module] = []
    for _ in range(cfg.n_res_blocks):
        res_blocks.append(
            ResidualBlock(
                cfg.conv_channel,
                cfg.conv_kernel,
                cfg.conv_activation,
                cfg.norm,
                cfg.norm_cfg,
                cfg.dropout,
                cfg.padding_mode,
                rngs=rngs,
            ),
        )
        if cfg.attention is not None:
            res_blocks.append(
                Attention2d(cfg.conv_channel, nhead=None, attn_cfg=cfg.attention, rngs=rngs),
            )
    self.res_blocks = nnx.List(res_blocks)

    # Second conv layer post residual blocks
    self.conv2 = ConvNormActivation(cfg.conv_channel, cfg.conv_channel, conv_cfg, rngs=rngs)

    # Upsampling layers
    upscale_cfg = deepcopy(conv_cfg)
    upscale_cfg.scale_factor = cfg.scale_factor
    upsample_layers = [
        ConvNormActivation(cfg.conv_channel, cfg.conv_channel, upscale_cfg, rngs=rngs) for _ in range(cfg.n_scaling)
    ]
    self.upsampling = nnx.List(upsample_layers)

    # Final output layer
    final_cfg = ConvConfig(
        activation=cfg.out_activation,
        kernel_size=cfg.f_kernel,
        stride=1,
        padding=cfg.f_kernel // 2,
        norm="none",
        norm_cfg={},
        dropout=0.0,
    )
    self.conv3 = ConvNormActivation(cfg.conv_channel, out_channels, final_cfg, rngs=rngs)

Attributes

cfg instance-attribute

cfg = cfg

conv1 instance-attribute

conv1 = ConvNormActivation(in_shape[2], conv_channel, conv_cfg, rngs=rngs)

conv2 instance-attribute

conv2 = ConvNormActivation(conv_channel, conv_channel, conv_cfg, rngs=rngs)

conv3 instance-attribute

conv3 = ConvNormActivation(conv_channel, out_channels, final_cfg, rngs=rngs)

in_shape instance-attribute

in_shape = in_shape

obs_shape instance-attribute

obs_shape = obs_shape

res_blocks instance-attribute

res_blocks = List(res_blocks)

upsampling instance-attribute

upsampling = List(upsample_layers)

Functions

__call__

__call__(x)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, C) in NHWC format.

required

Returns:

Type Description
Array

Upsampled output of shape (B, H', W', C').

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, C) in NHWC format.

    Returns
    -------
    jax.Array
        Upsampled output of shape (B, H', W', C').
    """
    out1 = self.conv1(x)
    out = out1
    for layer in self.res_blocks:
        out = layer(out)
    out2 = self.conv2(out)
    out = out1 + out2
    for layer in self.upsampling:
        out = layer(out)
    return self.conv3(out)

get_input_shape staticmethod

get_input_shape(obs_shape, cfg)

Get the required input shape for a given output shape and config.

Source code in src/ml_networks/jax/vision.py
@staticmethod
def get_input_shape(obs_shape: tuple[int, int, int], cfg: ResNetConfig) -> tuple[int, int, int]:
    """Get the required input shape for a given output shape and config."""
    scaling = cfg.scale_factor**cfg.n_scaling
    return (
        obs_shape[0] // scaling,
        obs_shape[1] // scaling,
        cfg.init_channel,
    )

ViT

ViT(in_shape, cfg, obs_shape=None, *, rngs)

Bases: Module

Vision Transformer for Encoder and Decoder (NHWC format).

Parameters:

Name Type Description Default
in_shape tuple[int, int, int]

Input shape in (H, W, C) format.

required
cfg ViTConfig

ViT configuration.

required
obs_shape tuple[int, int, int] | None

Output shape in (H, W, C) format. If None, acts as encoder.

None
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/vision.py
def __init__(
    self,
    in_shape: tuple[int, int, int],
    cfg: ViTConfig,
    obs_shape: tuple[int, int, int] | None = None,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.cfg = cfg
    self.in_shape = in_shape
    self.obs_shape = obs_shape if obs_shape is not None else in_shape
    self.patch_size = cfg.patch_size

    t_cfg = cfg.transformer_cfg
    self.transformer_cfg = t_cfg
    # NHWC: (H, W, C) -> patch_dim = patch_size^2 * C
    self.in_patch_dim = self.get_patch_dim(in_shape)
    self.out_patch_dim = self.get_patch_dim(obs_shape) if obs_shape is not None else t_cfg.d_model

    self.positional_encoding = PositionalEncoding(
        self.in_patch_dim,
        t_cfg.dropout,
        max_len=self.get_n_patches(in_shape),
        rngs=rngs,
    )
    self.vit = TransformerLayer(self.in_patch_dim, self.out_patch_dim, t_cfg, rngs=rngs)

    self.is_encoder = obs_shape is None
    if self.is_encoder:
        self.n_patches = self.get_n_patches(in_shape)
        self.patch_embed = PatchEmbed(
            emb_dim=self.in_patch_dim,
            patch_size=self.patch_size,
            obs_shape=in_shape,
            rngs=rngs,
        )

    self.should_unpatchify = cfg.unpatchify
    if cfg.cls_token:
        self._cls_token = nnx.Param(jax.random.normal(rngs(), (1, 1, self.in_patch_dim)) * 0.02)
    self.last_channel = self.get_n_patches(in_shape)
    self.output_dim = self.out_patch_dim

Attributes

cfg instance-attribute

cfg = cfg

conved_shape property

conved_shape

Get the output shape after transformer.

conved_size property

conved_size

Get the flattened output size.

in_patch_dim instance-attribute

in_patch_dim = get_patch_dim(in_shape)

in_shape instance-attribute

in_shape = in_shape

is_encoder instance-attribute

is_encoder = obs_shape is None

last_channel instance-attribute

last_channel = get_n_patches(in_shape)

n_patches instance-attribute

n_patches = get_n_patches(in_shape)

obs_shape instance-attribute

obs_shape = obs_shape if obs_shape is not None else in_shape

out_patch_dim instance-attribute

out_patch_dim = get_patch_dim(obs_shape) if obs_shape is not None else d_model

output_dim instance-attribute

output_dim = out_patch_dim

patch_embed instance-attribute

patch_embed = PatchEmbed(emb_dim=in_patch_dim, patch_size=patch_size, obs_shape=in_shape, rngs=rngs)

patch_size instance-attribute

patch_size = patch_size

positional_encoding instance-attribute

positional_encoding = PositionalEncoding(in_patch_dim, dropout, max_len=get_n_patches(in_shape), rngs=rngs)

should_unpatchify instance-attribute

should_unpatchify = unpatchify

transformer_cfg instance-attribute

transformer_cfg = t_cfg

vit instance-attribute

vit = TransformerLayer(in_patch_dim, out_patch_dim, t_cfg, rngs=rngs)

Functions

__call__

__call__(x, *, return_cls_token=False)

Forward pass.

Parameters:

Name Type Description Default
x Array

Input tensor of shape (B, H, W, C) in NHWC format.

required
return_cls_token bool

Whether to return CLS token only. Default is False.

False

Returns:

Type Description
Array

Output tensor.

Source code in src/ml_networks/jax/vision.py
def __call__(self, x: jax.Array, *, return_cls_token: bool = False) -> jax.Array:
    """
    Forward pass.

    Parameters
    ----------
    x : jax.Array
        Input tensor of shape (B, H, W, C) in NHWC format.
    return_cls_token : bool
        Whether to return CLS token only. Default is False.

    Returns
    -------
    jax.Array
        Output tensor.
    """
    x = self.patch_embed(x) if self.is_encoder else self.patchify(x)
    x = self.positional_encoding(x)
    if hasattr(self, "_cls_token"):
        cls_token = jnp.broadcast_to(self._cls_token.value, (x.shape[0], 1, x.shape[-1]))
        x = jnp.concatenate([cls_token, x], axis=1)
    x = self.vit(x)
    if hasattr(self, "_cls_token"):
        cls_token = x[:, 0]
        x = x[:, 1:]
    if self.should_unpatchify:
        x = self.unpatchify(x)
    if return_cls_token and hasattr(self, "_cls_token"):
        return cls_token
    return x

get_input_shape staticmethod

get_input_shape(obs_shape, cfg)

Get the required input shape (NHWC: H, W, C).

Source code in src/ml_networks/jax/vision.py
@staticmethod
def get_input_shape(obs_shape: tuple[int, int, int], cfg: ViTConfig) -> tuple[int, int, int]:
    """Get the required input shape (NHWC: H, W, C)."""
    return (obs_shape[0], obs_shape[1], cfg.init_channel)

get_n_patches

get_n_patches(obs_shape)

Get number of patches for a given shape (NHWC: H, W, C).

Source code in src/ml_networks/jax/vision.py
def get_n_patches(self, obs_shape: tuple[int, int, int]) -> int:
    """Get number of patches for a given shape (NHWC: H, W, C)."""
    return (obs_shape[0] // self.patch_size) * (obs_shape[1] // self.patch_size)

get_patch_dim

get_patch_dim(obs_shape)

Get patch dimension for a given shape (NHWC: H, W, C).

Source code in src/ml_networks/jax/vision.py
def get_patch_dim(self, obs_shape: tuple[int, int, int]) -> int:
    """Get patch dimension for a given shape (NHWC: H, W, C)."""
    return self.patch_size**2 * obs_shape[2]

patchify

patchify(imgs)

Split images into patches.

Parameters:

Name Type Description Default
imgs Array

Input images of shape (N, H, W, C) in NHWC format.

required

Returns:

Type Description
Array

Patchified images of shape (N, L, patch_size**2 * C).

Source code in src/ml_networks/jax/vision.py
def patchify(self, imgs: jax.Array) -> jax.Array:
    """Split images into patches.

    Parameters
    ----------
    imgs : jax.Array
        Input images of shape (N, H, W, C) in NHWC format.

    Returns
    -------
    jax.Array
        Patchified images of shape (N, L, patch_size**2 * C).
    """
    p = self.patch_size
    return rearrange(imgs, "n (h p1) (w p2) c -> n (h w) (p1 p2 c)", p1=p, p2=p)

unpatchify

unpatchify(x)

Reconstruct images from patches.

Parameters:

Name Type Description Default
x Array

Input of shape (N, L, patch_size**2 * C).

required

Returns:

Type Description
Array

Images of shape (N, H, W, C) in NHWC format.

Source code in src/ml_networks/jax/vision.py
def unpatchify(self, x: jax.Array) -> jax.Array:
    """Reconstruct images from patches.

    Parameters
    ----------
    x : jax.Array
        Input of shape (N, L, patch_size**2 * C).

    Returns
    -------
    jax.Array
        Images of shape (N, H, W, C) in NHWC format.
    """
    p = self.patch_size
    h = self.obs_shape[0] // p
    w = self.obs_shape[1] // p
    assert h * w == x.shape[1], (
        f"{h * w} != {x.shape[1]}, please check the shape {x.shape} and obs_shape {self.obs_shape}"
    )
    return rearrange(x, "n (h w) (p1 p2 c) -> n (h p1) (w p2) c", h=h, w=w, p1=p, p2=p)

分布 (ml_networks.jax.distributions)

Distribution

Distribution(in_dim, dist, n_groups=1, spherical=False, *, rngs)

Bases: Module

A distribution function.

Parameters:

Name Type Description Default
in_dim int

Input dimension.

required
dist Literal['normal', 'categorical', 'bernoulli']

Distribution type.

required
n_groups int

Number of groups. Default is 1.

1
spherical bool

Whether to project samples to the unit sphere. Default is False.

False
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/distributions.py
def __init__(
    self,
    in_dim: int,
    dist: Literal["normal", "categorical", "bernoulli"],
    n_groups: int = 1,
    spherical: bool = False,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.dist_type = dist
    self.spherical = spherical
    self.n_class = in_dim // n_groups
    self.in_dim = in_dim
    self.n_groups = n_groups
    self.rngs = rngs

    if spherical:
        self.codebook = BSQCodebook(self.n_class)

Attributes

codebook instance-attribute

codebook = BSQCodebook(n_class)

dist_type instance-attribute

dist_type = dist

in_dim instance-attribute

in_dim = in_dim

n_class instance-attribute

n_class = in_dim // n_groups

n_groups instance-attribute

n_groups = n_groups

rngs instance-attribute

rngs = rngs

spherical instance-attribute

spherical = spherical

Functions

__call__

__call__(x, deterministic=False, inv_tmp=1.0)

Compute the posterior distribution.

Parameters:

Name Type Description Default
x Array

Input tensor.

required
deterministic bool

Whether to use the deterministic mode. Default is False.

False
inv_tmp float

Inverse temperature. Default is 1.0.

1.0

Returns:

Type Description
StochState

Posterior distribution.

Source code in src/ml_networks/jax/distributions.py
def __call__(
    self,
    x: jax.Array,
    deterministic: bool = False,
    inv_tmp: float = 1.0,
) -> StochState:
    """
    Compute the posterior distribution.

    Parameters
    ----------
    x : jax.Array
        Input tensor.
    deterministic : bool, optional
        Whether to use the deterministic mode. Default is False.
    inv_tmp : float, optional
        Inverse temperature. Default is 1.0.

    Returns
    -------
    StochState
        Posterior distribution.
    """
    if self.dist_type == "normal":
        return self.normal(x, deterministic=deterministic, inv_tmp=inv_tmp)
    if self.dist_type == "categorical":
        return self.categorical(x, deterministic=deterministic, inv_tmp=inv_tmp)
    if self.dist_type == "bernoulli":
        return self.bernoulli(x, deterministic=deterministic, inv_tmp=inv_tmp)
    raise NotImplementedError

bernoulli

bernoulli(logits, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/jax/distributions.py
def bernoulli(self, logits: jax.Array, deterministic: bool = False, inv_tmp: float = 1.0) -> BernoulliStoch:
    batch_shape = logits.shape[:-1]
    chunked_logits = jnp.split(logits, self.n_groups, axis=-1)
    logits_stacked = jnp.stack(chunked_logits, axis=-2)
    logits_stacked = logits_stacked * inv_tmp
    probs = jax.nn.sigmoid(logits_stacked)

    key = self.rngs()
    # Bernoulli sampling with straight-through
    u = jax.random.uniform(key, probs.shape)
    sample = (u < probs).astype(jnp.float32)
    # Straight-through estimator
    sample = sample + probs - jax.lax.stop_gradient(probs)

    if self.spherical:
        sample = self.codebook.bits_to_codes(sample)

    if deterministic:
        sample = (
            jnp.where(sample > 0.5, jnp.ones_like(sample), jnp.zeros_like(sample))
            + probs
            - jax.lax.stop_gradient(probs)
        )

    return BernoulliStoch(
        logits_stacked,
        probs,
        sample.reshape(*batch_shape, -1),
    )

categorical

categorical(logits, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/jax/distributions.py
def categorical(self, logits: jax.Array, deterministic: bool = False, inv_tmp: float = 1.0) -> CategoricalStoch:
    batch_shape = logits.shape[:-1]
    logits_chunks = jnp.split(logits, self.n_groups, axis=-1)
    logits_stacked = jnp.stack(logits_chunks, axis=-2)
    probs = softmax(logits_stacked, axis=-1, temperature=1 / inv_tmp)

    key = self.rngs()
    # Sample using Gumbel-max trick for one-hot with straight-through
    gumbel_noise = -jnp.log(-jnp.log(jax.random.uniform(key, probs.shape, minval=1e-10, maxval=1.0)))
    y = jnp.log(probs + 1e-10) + gumbel_noise
    hard = jax.nn.one_hot(jnp.argmax(y, axis=-1), self.n_class)
    sample = hard + probs - jax.lax.stop_gradient(probs)

    if self.spherical:
        sample = sample * 2 - 1

    if deterministic:
        stoch = self.deterministic_onehot(probs).reshape(*batch_shape, -1)
    else:
        stoch = sample.reshape(*batch_shape, -1)

    return CategoricalStoch(logits_stacked, probs, stoch)

deterministic_onehot

deterministic_onehot(input)

Compute the one-hot vector by argmax with straight-through.

Source code in src/ml_networks/jax/distributions.py
def deterministic_onehot(self, input: jax.Array) -> jax.Array:
    """Compute the one-hot vector by argmax with straight-through."""
    hard = jax.nn.one_hot(jnp.argmax(input, axis=-1), self.n_class)
    return hard + input - jax.lax.stop_gradient(input)

normal

normal(mu_std, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/jax/distributions.py
def normal(self, mu_std: jax.Array, deterministic: bool = False, inv_tmp: float = 1.0) -> NormalStoch:
    assert mu_std.shape[-1] == self.in_dim * 2, (
        f"mu_std.shape[-1] {mu_std.shape[-1]} and in_dim {self.in_dim} must be the same."
    )

    mu, std = jnp.split(mu_std, 2, axis=-1)
    std = jax.nn.softplus(std) + 1e-6

    if deterministic:
        sample = mu
    else:
        key = self.rngs()
        sample = mu + std * jax.random.normal(key, mu.shape)

    return NormalStoch(mu, std, sample)

NormalStoch dataclass

NormalStoch(mean, std, stoch)

Parameters of a normal distribution and its stochastic sample.

Attributes

mean instance-attribute

mean

shape property

shape

std instance-attribute

std

stoch instance-attribute

stoch

Functions

__getitem__

__getitem__(idx)
Source code in src/ml_networks/jax/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> NormalStoch:
    return NormalStoch(self.mean[idx], self.std[idx], self.stoch[idx])

__len__

__len__()
Source code in src/ml_networks/jax/distributions.py
def __len__(self) -> int:
    return self.stoch.shape[0]

__post_init__

__post_init__()
Source code in src/ml_networks/jax/distributions.py
def __post_init__(self) -> None:
    if self.mean.shape != self.std.shape:
        msg = f"mean.shape {self.mean.shape} and std.shape {self.std.shape} must be the same."
        raise ValueError(msg)
    if (self.std < 0).any():
        msg = "std must be non-negative."
        raise ValueError(msg)

detach

detach()

Stop gradient equivalent.

Source code in src/ml_networks/jax/distributions.py
def detach(self) -> NormalStoch:
    """Stop gradient equivalent."""
    return NormalStoch(
        jax.lax.stop_gradient(self.mean),
        jax.lax.stop_gradient(self.std),
        jax.lax.stop_gradient(self.stoch),
    )

flatten

flatten(start_dim=0, end_dim=-1)

Flatten along specified dimensions.

Source code in src/ml_networks/jax/distributions.py
def flatten(self, start_dim: int = 0, end_dim: int = -1) -> NormalStoch:
    """Flatten along specified dimensions."""
    ndim = self.mean.ndim
    if end_dim < 0:
        end_dim = ndim + end_dim
    new_shape = (*self.mean.shape[:start_dim], -1, *self.mean.shape[end_dim + 1 :])
    return self.reshape(*new_shape)

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/jax/distributions.py
def get_distribution(self, independent: int = 1) -> distrax.Distribution:
    return distrax.Independent(distrax.Normal(self.mean, self.std), independent)

reshape

reshape(*shape)
Source code in src/ml_networks/jax/distributions.py
def reshape(self, *shape: int) -> NormalStoch:
    return NormalStoch(
        self.mean.reshape(*shape),
        self.std.reshape(*shape),
        self.stoch.reshape(*shape),
    )

save

save(path)
Source code in src/ml_networks/jax/distributions.py
def save(self, path: str) -> None:
    os.makedirs(path, exist_ok=True)
    save_blosc2(f"{path}/mean.blosc2", np.asarray(self.mean))
    save_blosc2(f"{path}/std.blosc2", np.asarray(self.std))
    save_blosc2(f"{path}/stoch.blosc2", np.asarray(self.stoch))

CategoricalStoch dataclass

CategoricalStoch(logits, probs, stoch)

Parameters of a categorical distribution and its stochastic sample.

Attributes

logits instance-attribute

logits

probs instance-attribute

probs

shape property

shape

stoch instance-attribute

stoch

Functions

__getitem__

__getitem__(idx)
Source code in src/ml_networks/jax/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> CategoricalStoch:
    return CategoricalStoch(self.logits[idx], self.probs[idx], self.stoch[idx])

__len__

__len__()
Source code in src/ml_networks/jax/distributions.py
def __len__(self) -> int:
    return self.stoch.shape[0]

__post_init__

__post_init__()
Source code in src/ml_networks/jax/distributions.py
def __post_init__(self) -> None:
    if self.logits.shape != self.probs.shape:
        msg = f"logits.shape {self.logits.shape} and probs.shape {self.probs.shape} must be the same."
        raise ValueError(msg)

detach

detach()

Stop gradient equivalent.

Source code in src/ml_networks/jax/distributions.py
def detach(self) -> CategoricalStoch:
    """Stop gradient equivalent."""
    return CategoricalStoch(
        jax.lax.stop_gradient(self.logits),
        jax.lax.stop_gradient(self.probs),
        jax.lax.stop_gradient(self.stoch),
    )

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/jax/distributions.py
def get_distribution(self, independent: int = 1) -> distrax.Distribution:
    return distrax.Independent(
        distrax.OneHotCategorical(probs=self.probs),
        independent,
    )

reshape

reshape(*shape)
Source code in src/ml_networks/jax/distributions.py
def reshape(self, *shape: int) -> CategoricalStoch:
    return CategoricalStoch(
        self.logits.reshape(*shape),
        self.probs.reshape(*shape),
        self.stoch.reshape(*shape),
    )

save

save(path)
Source code in src/ml_networks/jax/distributions.py
def save(self, path: str) -> None:
    os.makedirs(path, exist_ok=True)
    save_blosc2(f"{path}/logits.blosc2", np.asarray(self.logits))
    save_blosc2(f"{path}/probs.blosc2", np.asarray(self.probs))
    save_blosc2(f"{path}/stoch.blosc2", np.asarray(self.stoch))

squeeze

squeeze(axis)
Source code in src/ml_networks/jax/distributions.py
def squeeze(self, axis: int) -> CategoricalStoch:
    return CategoricalStoch(
        jnp.squeeze(self.logits, axis=axis),
        jnp.squeeze(self.probs, axis=axis),
        jnp.squeeze(self.stoch, axis=axis),
    )

BernoulliStoch dataclass

BernoulliStoch(logits, probs, stoch)

Parameters of a Bernoulli distribution and its stochastic sample.

Attributes

logits instance-attribute

logits

probs instance-attribute

probs

shape property

shape

stoch instance-attribute

stoch

Functions

__getitem__

__getitem__(idx)
Source code in src/ml_networks/jax/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> BernoulliStoch:
    return BernoulliStoch(self.logits[idx], self.probs[idx], self.stoch[idx])

__len__

__len__()
Source code in src/ml_networks/jax/distributions.py
def __len__(self) -> int:
    return self.stoch.shape[0]

__post_init__

__post_init__()
Source code in src/ml_networks/jax/distributions.py
def __post_init__(self) -> None:
    if self.logits.shape != self.probs.shape:
        msg = f"logits.shape {self.logits.shape} and probs.shape {self.probs.shape} must be the same."
        raise ValueError(msg)

detach

detach()

Stop gradient equivalent.

Source code in src/ml_networks/jax/distributions.py
def detach(self) -> BernoulliStoch:
    """Stop gradient equivalent."""
    return BernoulliStoch(
        jax.lax.stop_gradient(self.logits),
        jax.lax.stop_gradient(self.probs),
        jax.lax.stop_gradient(self.stoch),
    )

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/jax/distributions.py
def get_distribution(self, independent: int = 1) -> distrax.Distribution:
    return distrax.Independent(distrax.Bernoulli(probs=self.probs), independent)

reshape

reshape(*shape)
Source code in src/ml_networks/jax/distributions.py
def reshape(self, *shape: int) -> BernoulliStoch:
    return BernoulliStoch(
        self.logits.reshape(*shape),
        self.probs.reshape(*shape),
        self.stoch.reshape(*shape),
    )

save

save(path)
Source code in src/ml_networks/jax/distributions.py
def save(self, path: str) -> None:
    os.makedirs(path, exist_ok=True)
    save_blosc2(f"{path}/logits.blosc2", np.asarray(self.logits))
    save_blosc2(f"{path}/probs.blosc2", np.asarray(self.probs))
    save_blosc2(f"{path}/stoch.blosc2", np.asarray(self.stoch))

squeeze

squeeze(axis)
Source code in src/ml_networks/jax/distributions.py
def squeeze(self, axis: int) -> BernoulliStoch:
    return BernoulliStoch(
        jnp.squeeze(self.logits, axis=axis),
        jnp.squeeze(self.probs, axis=axis),
        jnp.squeeze(self.stoch, axis=axis),
    )

損失関数 (ml_networks.jax.loss)

focal_loss

focal_loss(prediction, target, gamma=2.0, sum_axis=-1)

Focal loss function. Mainly for multi-class classification.

Reference

Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002

Parameters:

Name Type Description Default
prediction Array

The predicted tensor. This should be before softmax.

required
target Array

The target tensor (integer class labels).

required
gamma float

The gamma parameter. Default is 2.0.

2.0
sum_axis int

The axis to sum the loss. Default is -1.

-1

Returns:

Type Description
Array

The focal loss.

Source code in src/ml_networks/jax/loss.py
def focal_loss(
    prediction: jax.Array,
    target: jax.Array,
    gamma: float = 2.0,
    sum_axis: int = -1,
) -> jax.Array:
    """
    Focal loss function. Mainly for multi-class classification.

    Reference
    ---------
    Focal Loss for Dense Object Detection
    https://arxiv.org/abs/1708.02002

    Parameters
    ----------
    prediction : jax.Array
        The predicted tensor. This should be before softmax.
    target : jax.Array
        The target tensor (integer class labels).
    gamma : float
        The gamma parameter. Default is 2.0.
    sum_axis : int
        The axis to sum the loss. Default is -1.

    Returns
    -------
    jax.Array
        The focal loss.
    """
    # Rearrange: unsqueeze(1), transpose sum_axis with 1, squeeze(-1)
    prediction = jnp.expand_dims(prediction, axis=1)
    prediction = jnp.moveaxis(prediction, sum_axis, 1)
    prediction = jnp.squeeze(prediction, axis=-1)

    if gamma:
        log_prob = jax.nn.log_softmax(prediction, axis=1)
        prob = jnp.exp(log_prob)
        # nll_loss equivalent: -log_prob[target]
        n_classes = prediction.shape[1]
        target_one_hot = jax.nn.one_hot(target, n_classes)
        loss = -jnp.sum(((1 - prob) ** gamma) * log_prob * target_one_hot, axis=1)
    else:
        # Cross entropy
        n_classes = prediction.shape[1]
        target_one_hot = jax.nn.one_hot(target, n_classes)
        log_prob = jax.nn.log_softmax(prediction, axis=1)
        loss = -jnp.sum(log_prob * target_one_hot, axis=1)
    return loss.mean(axis=0).sum()

binary_focal_loss

binary_focal_loss(prediction, target, gamma=2.0, sum_axis=-1)

Binary focal loss function. Mainly for binary classification.

Reference

Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002

Parameters:

Name Type Description Default
prediction Array

The predicted tensor. This should be before sigmoid.

required
target Array

The target tensor.

required
gamma float

The gamma parameter. Default is 2.0.

2.0
sum_axis int

The axis to sum the loss. Default is -1.

-1

Returns:

Type Description
Array

The binary focal loss.

Source code in src/ml_networks/jax/loss.py
def binary_focal_loss(
    prediction: jax.Array,
    target: jax.Array,
    gamma: float = 2.0,
    sum_axis: int = -1,
) -> jax.Array:
    """
    Binary focal loss function. Mainly for binary classification.

    Reference
    ---------
    Focal Loss for Dense Object Detection
    https://arxiv.org/abs/1708.02002

    Parameters
    ----------
    prediction : jax.Array
        The predicted tensor. This should be before sigmoid.
    target : jax.Array
        The target tensor.
    gamma : float
        The gamma parameter. Default is 2.0.
    sum_axis : int
        The axis to sum the loss. Default is -1.

    Returns
    -------
    jax.Array
        The binary focal loss.
    """
    if gamma:
        log_probs = jax.nn.log_sigmoid(prediction)
        neg_log_probs = jax.nn.log_sigmoid(-prediction)
        probs = jax.nn.sigmoid(prediction)
        focal_weight = jnp.where(target == 1, (1 - probs) ** gamma, probs**gamma)
        loss = jnp.where(target == 1, -log_probs, -neg_log_probs)
        loss = focal_weight * loss
    else:
        # Binary cross entropy with logits
        loss = jnp.maximum(prediction, 0) - prediction * target + jnp.log(1 + jnp.exp(-jnp.abs(prediction)))
    return loss.sum(axis=sum_axis)

charbonnier

charbonnier(prediction, target, epsilon=0.001, alpha=1, sum_axes=None)

Charbonnier loss function.

Reference

A General and Adaptive Robust Loss Function http://arxiv.org/abs/1701.03077

Parameters:

Name Type Description Default
prediction Array

The predicted tensor.

required
target Array

The target tensor.

required
epsilon float

A small value to avoid division by zero. Default is 1e-3.

0.001
alpha float

The alpha parameter. Default is 1.

1
sum_axes int | list[int] | tuple[int, ...] | None

The axes to sum the loss. Default is None (sums over [-1, -2, -3]).

None

Returns:

Type Description
Array

The Charbonnier loss.

Source code in src/ml_networks/jax/loss.py
def charbonnier(
    prediction: jax.Array,
    target: jax.Array,
    epsilon: float = 1e-3,
    alpha: float = 1,
    sum_axes: int | list[int] | tuple[int, ...] | None = None,
) -> jax.Array:
    """
    Charbonnier loss function.

    Reference
    ---------
    A General and Adaptive Robust Loss Function
    http://arxiv.org/abs/1701.03077

    Parameters
    ----------
    prediction : jax.Array
        The predicted tensor.
    target : jax.Array
        The target tensor.
    epsilon : float
        A small value to avoid division by zero. Default is 1e-3.
    alpha : float
        The alpha parameter. Default is 1.
    sum_axes : int | list[int] | tuple[int, ...] | None
        The axes to sum the loss. Default is None (sums over [-1, -2, -3]).

    Returns
    -------
    jax.Array
        The Charbonnier loss.
    """
    if sum_axes is None:
        sum_axes = [-1, -2, -3]
    x = prediction - target
    loss = (x**2 + epsilon**2) ** (alpha / 2)
    return jnp.sum(loss, axis=sum_axes)

kl_divergence

kl_divergence(posterior, prior)

KL divergence between two StochState distributions.

Parameters:

Name Type Description Default
posterior StochState

The posterior distribution.

required
prior StochState

The prior distribution.

required

Returns:

Type Description
Array

The KL divergence.

Source code in src/ml_networks/jax/loss.py
def kl_divergence(posterior: StochState, prior: StochState) -> jax.Array:
    """KL divergence between two StochState distributions.

    Parameters
    ----------
    posterior : StochState
        The posterior distribution.
    prior : StochState
        The prior distribution.

    Returns
    -------
    jax.Array
        The KL divergence.
    """
    return posterior.get_distribution().kl_divergence(prior.get_distribution())

活性化関数 (ml_networks.jax.activations)

Activation

Activation(activation, **kwargs)

Bases: Module

Generic activation function.

Source code in src/ml_networks/jax/activations.py
def __init__(self, activation: str, **kwargs: Any) -> None:
    if "glu" not in activation.lower():
        kwargs.pop("dim", None)

    builtin_activations: dict[str, Any] = {
        "ReLU": jax.nn.relu,
        "GELU": jax.nn.gelu,
        "GeLU": jax.nn.gelu,
        "SiLU": jax.nn.silu,
        "Tanh": jnp.tanh,
        "Sigmoid": jax.nn.sigmoid,
        "ELU": jax.nn.elu,
        "LeakyReLU": jax.nn.leaky_relu,
        "Mish": lambda x: x * jnp.tanh(jax.nn.softplus(x)),
        "Softplus": jax.nn.softplus,
        "Identity": lambda x: x,
    }

    if activation in builtin_activations:
        self._fn = builtin_activations[activation]
        self._module: nnx.Module | None = None
    elif activation == "TanhExp":
        self._fn = None
        self._module = TanhExp()
    elif activation == "REReLU":
        self._fn = None
        self._module = REReLU(**kwargs)
    elif activation in {"SiGLU", "SwiGLU"}:
        self._fn = None
        self._module = SiGLU(**kwargs)
    elif activation == "CRReLU":
        self._fn = None
        self._module = CRReLU(**kwargs)
    elif activation == "L2Norm":
        self._fn = None
        self._module = L2Norm()
    else:
        msg = f"Activation: '{activation}' is not implemented yet."
        raise NotImplementedError(msg)

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/activations.py
def __call__(self, x: jax.Array) -> jax.Array:
    if self._module is not None:
        return self._module(x)
    return self._fn(x)

REReLU

REReLU(reparametarize_fn='gelu')

Bases: Module

Reparametarized ReLU activation function. This backward pass is differentiable.

Parameters:

Name Type Description Default
reparametarize_fn str

Reparametarization function. Default is GELU.

'gelu'
References

https://openreview.net/forum?id=lNCnZwcH5Z

Examples:

>>> rerelu = REReLU()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = rerelu(x)
>>> output.shape
(1, 3)
Source code in src/ml_networks/jax/activations.py
def __init__(self, reparametarize_fn: str = "gelu") -> None:
    reparam_fns: dict[str, Any] = {
        "gelu": jax.nn.gelu,
        "relu": jax.nn.relu,
        "silu": jax.nn.silu,
        "elu": jax.nn.elu,
    }
    reparametarize_fn = reparametarize_fn.lower()
    if reparametarize_fn not in reparam_fns:
        msg = f"Reparametarization function '{reparametarize_fn}' is not supported."
        raise ValueError(msg)
    self.reparametarize_fn = reparam_fns[reparametarize_fn]

Attributes

reparametarize_fn instance-attribute

reparametarize_fn = reparam_fns[reparametarize_fn]

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/activations.py
def __call__(self, x: jax.Array) -> jax.Array:
    return (
        jax.lax.stop_gradient(jax.nn.relu(x))
        + self.reparametarize_fn(x)
        - jax.lax.stop_gradient(self.reparametarize_fn(x))
    )

SiGLU

SiGLU(dim=-1)

Bases: Module

SiGLU activation function.

This is equivalent to SwiGLU (Swish variant of Gated Linear Unit) activation function.

Parameters:

Name Type Description Default
dim int

Dimension to split the tensor. Default is -1.

-1
References

https://arxiv.org/abs/2102.11972

Examples:

>>> siglu = SiGLU()
>>> x = jnp.ones((1, 4))
>>> output = siglu(x)
>>> output.shape
(1, 2)
>>> siglu = SiGLU(dim=0)
>>> x = jnp.ones((4, 1))
>>> output = siglu(x)
>>> output.shape
(2, 1)
Source code in src/ml_networks/jax/activations.py
def __init__(self, dim: int = -1) -> None:
    self.dim = dim

Attributes

dim instance-attribute

dim = dim

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/activations.py
def __call__(self, x: jax.Array) -> jax.Array:
    x1, x2 = jnp.split(x, 2, axis=self.dim)
    return x1 * jax.nn.silu(x2)

CRReLU

CRReLU(lr=0.01)

Bases: Module

Correction Regularized ReLU activation function. This is a variant of ReLU activation function.

Parameters:

Name Type Description Default
lr float

Learning rate. Default is 0.01.

0.01
References

https://openreview.net/forum?id=7TZYM6Hm9p

Examples:

>>> crrelu = CRReLU()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = crrelu(x)
>>> output.shape
(1, 3)
Source code in src/ml_networks/jax/activations.py
def __init__(self, lr: float = 0.01) -> None:
    self.lr = nnx.Param(jnp.array(lr, dtype=jnp.float32))

Attributes

lr instance-attribute

lr = Param(array(lr, dtype=float32))

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/activations.py
def __call__(self, x: jax.Array) -> jax.Array:
    return jax.nn.relu(x) + self.lr.value * x * jnp.exp(-(x**2) / 2)

TanhExp

Bases: Module

TanhExp activation function.

Examples:

>>> tanhexp = TanhExp()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = tanhexp(x)
>>> output.shape
(1, 3)

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/activations.py
def __call__(self, x: jax.Array) -> jax.Array:
    return _tanhexp(x)

UNet (ml_networks.jax.unet)

ConditionalUnet2d

ConditionalUnet2d(feature_dim, obs_shape, cfg, *, rngs)

Bases: Module

条件付きUNetモデル (NHWC format).

Parameters:

Name Type Description Default
feature_dim int

条件付き特徴量の次元数

required
obs_shape tuple[int, int, int]

観測データの形状 (H, W, C) in NHWC format.

required
cfg UNetConfig

UNetの設定

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/unet.py
def __init__(
    self,
    feature_dim: int,
    obs_shape: tuple[int, int, int],
    cfg: UNetConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    in_channels = obs_shape[2]  # NHWC: C is last
    all_dims = [in_channels, *list(cfg.channels)]
    start_dim = cfg.channels[0]
    self.obs_shape = obs_shape

    in_out = list(pairwise(all_dims))

    mid_dim = all_dims[-1]
    self.mid_modules = nnx.List([
        ConditionalResidualBlock2d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
            rngs=rngs,
        ),
        Attention2d(mid_dim, cfg.nhead, rngs=rngs) if cfg.has_attn and cfg.nhead is not None else Identity(),
        ConditionalResidualBlock2d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
            rngs=rngs,
        ),
    ])

    down_modules = []
    for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (len(in_out) - 1)
        down_modules.append(
            nnx.List([
                ConditionalResidualBlock2d(
                    dim_in,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                ),
                Attention2d(dim_out, cfg.nhead, rngs=rngs)
                if cfg.has_attn and cfg.nhead is not None
                else Identity(),
                ConditionalResidualBlock2d(
                    dim_out,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                ),
                Downsample2d(dim_out, cfg.use_shuffle, rngs=rngs) if not is_last else Identity(),
            ])
        )
    self.down_modules = nnx.List(down_modules)

    up_modules = []
    for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
        is_last = ind >= (len(in_out) - 1)
        up_modules.append(
            nnx.List([
                ConditionalResidualBlock2d(
                    dim_out * 2,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                ),
                Attention2d(dim_in, cfg.nhead, rngs=rngs) if cfg.has_attn and cfg.nhead is not None else Identity(),
                ConditionalResidualBlock2d(
                    dim_in,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                ),
                Upsample2d(dim_in, cfg.use_shuffle, rngs=rngs) if not is_last else Identity(),
            ])
        )
    self.up_modules = nnx.List(up_modules)

    self.final_conv1 = ConvNormActivation(start_dim, start_dim, cfg.conv_cfg, rngs=rngs)
    self.final_conv2 = nnx.Conv(
        in_features=start_dim,
        out_features=in_channels,
        kernel_size=(1, 1),
        rngs=rngs,
    )

Attributes

down_modules instance-attribute

down_modules = List(down_modules)

final_conv1 instance-attribute

final_conv1 = ConvNormActivation(start_dim, start_dim, conv_cfg, rngs=rngs)

final_conv2 instance-attribute

final_conv2 = Conv(in_features=start_dim, out_features=in_channels, kernel_size=(1, 1), rngs=rngs)

mid_modules instance-attribute

mid_modules = List([ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs), Attention2d(mid_dim, nhead, rngs=rngs) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs)])

obs_shape instance-attribute

obs_shape = obs_shape

up_modules instance-attribute

up_modules = List(up_modules)

Functions

__call__

__call__(base, cond)

Forward pass.

Parameters:

Name Type Description Default
base Array

Input tensor of shape (B, H, W, C) in NHWC format.

required
cond Array

Conditional tensor of shape (B, cond_dim).

required

Returns:

Type Description
Array

Output tensor of shape (B, H, W, C).

Source code in src/ml_networks/jax/unet.py
def __call__(self, base: jax.Array, cond: jax.Array) -> jax.Array:
    """Forward pass.

    Parameters
    ----------
    base : jax.Array
        Input tensor of shape (B, H, W, C) in NHWC format.
    cond : jax.Array
        Conditional tensor of shape (B, cond_dim).

    Returns
    -------
    jax.Array
        Output tensor of shape (B, H, W, C).
    """
    batch_shape = base.shape[:-3]
    assert base.shape[-3:] == self.obs_shape, (
        f"Input shape {base.shape[-3:]} does not match expected shape {self.obs_shape}"
    )
    base = base.reshape(-1, *self.obs_shape)
    global_feature = cond.reshape(-1, cond.shape[-1])

    x = base
    h: list[jax.Array] = []
    for modules in self.down_modules:
        resnet, attn, resnet2, downsample = modules[0], modules[1], modules[2], modules[3]
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        h.append(x)
        x = downsample(x)

    for mid_module in self.mid_modules:
        x = mid_module(x) if isinstance(mid_module, Identity) else mid_module(x, global_feature)

    for modules in self.up_modules:
        resnet, attn, resnet2, upsample = modules[0], modules[1], modules[2], modules[3]
        x = jnp.concatenate((x, h.pop()), axis=-1)  # NHWC: concat on C axis
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        x = upsample(x)

    x = self.final_conv1(x)
    x = self.final_conv2(x)

    return x.reshape(*batch_shape, *self.obs_shape)

ConditionalUnet1d

ConditionalUnet1d(feature_dim, obs_shape, cfg, *, rngs)

Bases: Module

条件付き1D UNetモデル (NLC format).

Parameters:

Name Type Description Default
feature_dim int

条件付き特徴量の次元数

required
obs_shape tuple[int, int]

観測データの形状 (L, C) in NLC format.

required
cfg UNetConfig

UNetの設定

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/unet.py
def __init__(
    self,
    feature_dim: int,
    obs_shape: tuple[int, int],
    cfg: UNetConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    in_channels = obs_shape[1]  # NLC: C is last
    all_dims = [in_channels, *list(cfg.channels)]
    start_dim = cfg.channels[0]
    self.obs_shape = obs_shape

    in_out = list(pairwise(all_dims))

    mid_dim = all_dims[-1]
    self.mid_modules = nnx.List([
        ConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
            rngs=rngs,
        )
        if not cfg.use_hypernet
        else HyperConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            hyper_mlp_cfg=cfg.hyper_mlp_cfg,
            rngs=rngs,
        ),
        Attention1d(mid_dim, cfg.nhead, rngs=rngs) if cfg.has_attn and cfg.nhead is not None else Identity(),
        ConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            cond_predict_scale=cfg.cond_pred_scale,
            rngs=rngs,
        )
        if not cfg.use_hypernet
        else HyperConditionalResidualBlock1d(
            mid_dim,
            mid_dim,
            cond_dim=feature_dim,
            conv_cfg=cfg.conv_cfg,
            hyper_mlp_cfg=cfg.hyper_mlp_cfg,
            rngs=rngs,
        ),
    ])

    down_modules = []
    for ind, (dim_in, dim_out) in enumerate(in_out):
        is_last = ind >= (len(in_out) - 1)
        down_modules.append(
            nnx.List([
                ConditionalResidualBlock1d(
                    dim_in,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_in,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                    rngs=rngs,
                ),
                Attention1d(dim_out, cfg.nhead, rngs=rngs)
                if cfg.has_attn and cfg.nhead is not None
                else Identity(),
                ConditionalResidualBlock1d(
                    dim_out,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_out,
                    dim_out,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                    rngs=rngs,
                ),
                Downsample1d(dim_out, cfg.use_shuffle, rngs=rngs) if not is_last else Identity(),
            ])
        )
    self.down_modules = nnx.List(down_modules)

    up_modules = []
    for ind, (dim_in, dim_out) in enumerate(reversed(in_out[1:])):
        is_last = ind >= (len(in_out) - 1)
        up_modules.append(
            nnx.List([
                ConditionalResidualBlock1d(
                    dim_out * 2,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_out * 2,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                    rngs=rngs,
                ),
                Attention1d(dim_in, cfg.nhead, rngs=rngs) if cfg.has_attn and cfg.nhead is not None else Identity(),
                ConditionalResidualBlock1d(
                    dim_in,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    cond_predict_scale=cfg.cond_pred_scale,
                    rngs=rngs,
                )
                if not cfg.use_hypernet
                else HyperConditionalResidualBlock1d(
                    dim_in,
                    dim_in,
                    cond_dim=feature_dim,
                    conv_cfg=cfg.conv_cfg,
                    hyper_mlp_cfg=cfg.hyper_mlp_cfg,
                    rngs=rngs,
                ),
                Upsample1d(dim_in, cfg.use_shuffle, rngs=rngs) if not is_last else Identity(),
            ])
        )
    self.up_modules = nnx.List(up_modules)

    self.final_conv1 = ConvNormActivation1d(start_dim, start_dim, cfg.conv_cfg, rngs=rngs)
    self.final_conv2 = nnx.Conv(
        in_features=start_dim,
        out_features=in_channels,
        kernel_size=(1,),
        rngs=rngs,
    )

Attributes

down_modules instance-attribute

down_modules = List(down_modules)

final_conv1 instance-attribute

final_conv1 = ConvNormActivation1d(start_dim, start_dim, conv_cfg, rngs=rngs)

final_conv2 instance-attribute

final_conv2 = Conv(in_features=start_dim, out_features=in_channels, kernel_size=(1,), rngs=rngs)

mid_modules instance-attribute

mid_modules = List([ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg, rngs=rngs), Attention1d(mid_dim, nhead, rngs=rngs) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg, rngs=rngs)])

obs_shape instance-attribute

obs_shape = obs_shape

up_modules instance-attribute

up_modules = List(up_modules)

Functions

__call__

__call__(base, cond)

Forward pass.

Parameters:

Name Type Description Default
base Array

Input tensor of shape (B, L, C) in NLC format.

required
cond Array

Conditional tensor of shape (B, cond_dim).

required

Returns:

Type Description
Array

Output tensor of shape (B, L, C).

Source code in src/ml_networks/jax/unet.py
def __call__(self, base: jax.Array, cond: jax.Array) -> jax.Array:
    """Forward pass.

    Parameters
    ----------
    base : jax.Array
        Input tensor of shape (B, L, C) in NLC format.
    cond : jax.Array
        Conditional tensor of shape (B, cond_dim).

    Returns
    -------
    jax.Array
        Output tensor of shape (B, L, C).
    """
    batch_shape = base.shape[:-2]
    assert base.shape[-2:] == self.obs_shape, (
        f"Input shape {base.shape[-2:]} does not match expected shape {self.obs_shape}"
    )
    base = base.reshape(-1, *self.obs_shape)
    global_feature = cond.reshape(-1, cond.shape[-1])

    x = base
    h: list[jax.Array] = []
    for modules in self.down_modules:
        resnet, attn, resnet2, downsample = modules[0], modules[1], modules[2], modules[3]
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        h.append(x)
        x = downsample(x)

    for mid_module in self.mid_modules:
        x = mid_module(x) if isinstance(mid_module, Identity) else mid_module(x, global_feature)

    for modules in self.up_modules:
        resnet, attn, resnet2, upsample = modules[0], modules[1], modules[2], modules[3]
        x = jnp.concatenate((x, h.pop()), axis=-1)  # NLC: concat on C axis
        x = resnet(x, global_feature)
        x = attn(x)
        x = resnet2(x, global_feature)
        x = upsample(x)

    x = self.final_conv1(x)
    x = self.final_conv2(x)

    return x.reshape(*batch_shape, *self.obs_shape)

ユーティリティ (ml_networks.jax.jax_utils)

get_optimizer

get_optimizer(name, **kwargs)

Get optimizer from optax.

Parameters:

Name Type Description Default
name str

Optimizer name (e.g. "adam", "sgd", "adamw", "lamb", "rmsprop").

required
kwargs dict

Optimizer arguments (e.g. learning_rate=0.01).

{}

Returns:

Type Description
GradientTransformation

Examples:

>>> opt = get_optimizer("adam", learning_rate=0.01)
Source code in src/ml_networks/jax/jax_utils.py
def get_optimizer(
    name: str,
    **kwargs: Any,
) -> optax.GradientTransformation:
    """
    Get optimizer from optax.

    Parameters
    ----------
    name : str
        Optimizer name (e.g. "adam", "sgd", "adamw", "lamb", "rmsprop").
    kwargs : dict
        Optimizer arguments (e.g. learning_rate=0.01).

    Returns
    -------
    optax.GradientTransformation

    Examples
    --------
    >>> opt = get_optimizer("adam", learning_rate=0.01)
    """
    # Map common PyTorch-style names to optax equivalents
    name_map: dict[str, str] = {
        "Adam": "adam",
        "AdamW": "adamw",
        "SGD": "sgd",
        "RMSprop": "rmsprop",
        "Lamb": "lamb",
        "Lars": "lars",
        "Adagrad": "adagrad",
    }

    optax_name = name_map.get(name, name)

    # Map common PyTorch kwargs to optax kwargs
    mapped_kwargs = dict(kwargs)
    if "lr" in mapped_kwargs:
        mapped_kwargs["learning_rate"] = mapped_kwargs.pop("lr")

    if hasattr(optax, optax_name):
        optimizer_fn = getattr(optax, optax_name)
    else:
        msg = f"Optimizer {name} is not implemented in optax. "
        msg += "Please check the name."
        raise NotImplementedError(msg)
    return optimizer_fn(**mapped_kwargs)

jax_fix_seed

jax_fix_seed(seed=42)

乱数を固定する関数.

Parameters:

Name Type Description Default
seed int

Random seed.

42

Returns:

Type Description
Array

JAX PRNG key.

Source code in src/ml_networks/jax/jax_utils.py
def jax_fix_seed(seed: int = 42) -> jax.Array:
    """
    乱数を固定する関数.

    Parameters
    ----------
    seed : int
        Random seed.

    Returns
    -------
    jax.Array
        JAX PRNG key.
    """
    random.seed(seed)
    np.random.seed(seed)
    return jax.random.PRNGKey(seed)

MinMaxNormalize

MinMaxNormalize(min_val, max_val, old_min=0.0, old_max=1.0)

MinMax 正規化変換.

JAX/NumPy版。入力の値域 [old_min, old_max] を [min_val, max_val] に変換する。

Source code in src/ml_networks/jax/jax_utils.py
def __init__(self, min_val: float, max_val: float, old_min: float = 0.0, old_max: float = 1.0) -> None:
    self.min_val = min_val
    self.max_val = max_val
    self.scale = (max_val - min_val) / (old_max - old_min)
    self.shift = min_val - old_min * self.scale

Attributes

max_val instance-attribute

max_val = max_val

min_val instance-attribute

min_val = min_val

scale instance-attribute

scale = (max_val - min_val) / (old_max - old_min)

shift instance-attribute

shift = min_val - old_min * scale

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/jax_utils.py
def __call__(self, x: jax.Array) -> jax.Array:
    return x * self.scale + self.shift

SoftmaxTransformation

SoftmaxTransformation(cfg)

Softmax 変換クラス.

Source code in src/ml_networks/jax/jax_utils.py
def __init__(
    self,
    cfg: SoftmaxTransConfig,
) -> None:
    super().__init__()
    self.vector = cfg.vector
    self.sigma = cfg.sigma
    self.n_ignore = cfg.n_ignore
    self.max = cfg.max
    self.min = cfg.min
    self.k = jnp.linspace(self.min, self.max, self.vector)

Attributes

k instance-attribute

k = linspace(min, max, vector)

max instance-attribute

max = max

min instance-attribute

min = min

n_ignore instance-attribute

n_ignore = n_ignore

sigma instance-attribute

sigma = sigma

vector instance-attribute

vector = vector

Functions

__call__

__call__(x)
Source code in src/ml_networks/jax/jax_utils.py
def __call__(self, x: jax.Array) -> jax.Array:
    return self.transform(x)

get_transformed_dim

get_transformed_dim(dim)
Source code in src/ml_networks/jax/jax_utils.py
def get_transformed_dim(self, dim: int) -> int:
    return (dim - self.n_ignore) * self.vector + self.n_ignore

inverse

inverse(x)

SoftmaxTransformation の逆変換.

Parameters:

Name Type Description Default
x Array

入力テンソル.

required

Returns:

Type Description
Array

出力テンソル.

Source code in src/ml_networks/jax/jax_utils.py
def inverse(self, x: jax.Array) -> jax.Array:
    """
    SoftmaxTransformation の逆変換.

    Parameters
    ----------
    x : jax.Array
        入力テンソル.

    Returns
    -------
    jax.Array
        出力テンソル.
    """
    *batch, dim = x.shape
    x = x.reshape(-1, dim)
    if self.n_ignore:
        data, ignored = x[:, : -self.n_ignore], x[:, -self.n_ignore :]
    else:
        data = x

    data = data.reshape(len(data), -1, self.vector)

    data = rearrange(data, "b d v -> v b d")

    data = jnp.stack([data[v] * self.k[v] for v in range(self.vector)]).sum(axis=0)

    data = jnp.concatenate([data, ignored], axis=-1) if self.n_ignore else data
    return data.reshape(*batch, -1)

transform

transform(x)

SoftmaxTransformation の実行.

Parameters:

Name Type Description Default
x Array

入力テンソル.

required

Returns:

Type Description
Array

出力テンソル.

Examples:

>>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=16, sigma=0.01, n_ignore=1, min=-1.0, max=1.0))
>>> x = jnp.ones((2, 3, 4))
>>> transformed = trans(x)
>>> transformed.shape
(2, 3, 49)
>>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=11, sigma=0.05, n_ignore=0, min=-1.0, max=1.0))
>>> x = jnp.ones((2, 3, 4))
>>> transformed = trans(x)
>>> transformed.shape
(2, 3, 44)
Source code in src/ml_networks/jax/jax_utils.py
def transform(self, x: jax.Array) -> jax.Array:
    """
    SoftmaxTransformation の実行.

    Parameters
    ----------
    x : jax.Array
        入力テンソル.

    Returns
    -------
    jax.Array
        出力テンソル.

    Examples
    --------
    >>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=16, sigma=0.01, n_ignore=1, min=-1.0, max=1.0))
    >>> x = jnp.ones((2, 3, 4))
    >>> transformed = trans(x)
    >>> transformed.shape
    (2, 3, 49)

    >>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=11, sigma=0.05, n_ignore=0, min=-1.0, max=1.0))
    >>> x = jnp.ones((2, 3, 4))
    >>> transformed = trans(x)
    >>> transformed.shape
    (2, 3, 44)

    """
    *batch, dim = x.shape
    x = x.reshape(-1, dim)
    if self.n_ignore:
        data, ignored = x[:, : -self.n_ignore], x[:, -self.n_ignore :]
    else:
        data = x

    negative = jnp.stack([jnp.exp((-((data - self.k[v]) ** 2)) / self.sigma) for v in range(self.vector)])
    negative_sum = negative.sum(axis=0)

    transformed = negative / (negative_sum + 1e-8)
    transformed = rearrange(transformed, "v b d -> b (d v)")

    transformed = jnp.concatenate([transformed, ignored], axis=-1) if self.n_ignore else transformed
    return transformed.reshape(*batch, self.get_transformed_dim(dim))

その他

HyperNet

HyperNet(input_dim, output_shapes, fc_cfg=None, encoding=None, *, rngs)

Bases: Module, HyperNetMixin

A hypernetwork that generates weights for a target network.

Parameters:

Name Type Description Default
input_dim int

Dimension of the input.

required
output_shapes dict[str, Shape]

Shapes of the primary network weights being predicted.

required
fc_cfg MLPConfig | None

Configuration for the MLP backbone.

None
encoding InputMode

The input encoding mode. Default is None.

None
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/hypernetworks.py
def __init__(
    self,
    input_dim: int,
    output_shapes: dict[str, Shape],
    fc_cfg: MLPConfig | None = None,
    encoding: InputMode = None,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.input_dim = input_dim
    self.output_shapes = output_shapes
    self.encoding = encoding

    self._output_offsets = self.output_offsets()

    self.backbone: nnx.Module
    if fc_cfg is not None:
        self.backbone = MLPLayer(
            self.input_dim,
            self.flat_output_size(),
            fc_cfg,
            rngs=rngs,
        )
    else:
        self.backbone = nnx.Linear(
            self.input_dim,
            self.flat_output_size(),
            rngs=rngs,
        )

Attributes

backbone instance-attribute

backbone

encoding instance-attribute

encoding = encoding

input_dim instance-attribute

input_dim = input_dim

output_shapes instance-attribute

output_shapes = output_shapes

Functions

__call__

__call__(inputs)

Forward pass.

Parameters:

Name Type Description Default
inputs Array

Input tensor.

required

Returns:

Type Description
dict[str, Array]

Dictionary of output tensors.

Source code in src/ml_networks/jax/hypernetworks.py
def __call__(self, inputs: jax.Array) -> dict[str, jax.Array]:
    """Forward pass.

    Parameters
    ----------
    inputs : jax.Array
        Input tensor.

    Returns
    -------
    dict[str, jax.Array]
        Dictionary of output tensors.
    """
    if self.encoding is not None:
        inputs = encode_input(inputs, self.encoding)

    flat_output = self.backbone(inputs)
    return self.unflatten_output(flat_output)

ContrastiveLearningLoss

ContrastiveLearningLoss(dim_input1, dim_input2, cfg, *, rngs)

Bases: Module

Contrastive learning module.

Parameters:

Name Type Description Default
dim_input1 int

Dimension of first input.

required
dim_input2 int

Dimension of second input.

required
cfg ContrastiveLearningConfig

Configuration for contrastive learning.

required
rngs Rngs

Random number generators.

required
Source code in src/ml_networks/jax/contrastive.py
def __init__(
    self,
    dim_input1: int,
    dim_input2: int,
    cfg: ContrastiveLearningConfig,
    *,
    rngs: nnx.Rngs,
) -> None:
    self.cfg = cfg
    self.dim_feature = cfg.dim_feature
    self.dim_input1 = dim_input1
    self.dim_input2 = dim_input2
    self.is_ce_like = cfg.cross_entropy_like

    self.eval_func = MLPLayer(dim_input1, cfg.dim_feature, cfg.eval_func, rngs=rngs)
    if self.dim_input1 != self.dim_input2:
        self.eval_func2 = MLPLayer(dim_input2, cfg.dim_feature, cfg.eval_func, rngs=rngs)
    else:
        self.eval_func2 = self.eval_func

Attributes

cfg instance-attribute

cfg = cfg

dim_feature instance-attribute

dim_feature = dim_feature

dim_input1 instance-attribute

dim_input1 = dim_input1

dim_input2 instance-attribute

dim_input2 = dim_input2

eval_func instance-attribute

eval_func = MLPLayer(dim_input1, dim_feature, eval_func, rngs=rngs)

eval_func2 instance-attribute

eval_func2 = MLPLayer(dim_input2, dim_feature, eval_func, rngs=rngs)

is_ce_like instance-attribute

is_ce_like = cross_entropy_like

Functions

calc_nce

calc_nce(feature1, feature2, return_emb=False)

Calculate the Noise Contrastive Estimation (NCE) loss.

Parameters:

Name Type Description Default
feature1 Array

First input tensor of shape (*, dim_input1)

required
feature2 Array

Second input tensor of shape (*, dim_input2)

required
return_emb bool

Whether to return embeddings. Default is False.

False

Returns:

Type Description
dict or tuple

Loss dictionary, optionally with embeddings.

Source code in src/ml_networks/jax/contrastive.py
def calc_nce(
    self,
    feature1: jax.Array,
    feature2: jax.Array,
    return_emb: bool = False,
) -> dict[str, jax.Array] | tuple[dict[str, jax.Array], tuple[jax.Array, jax.Array]]:
    """
    Calculate the Noise Contrastive Estimation (NCE) loss.

    Parameters
    ----------
    feature1 : jax.Array
        First input tensor of shape (*, dim_input1)
    feature2 : jax.Array
        Second input tensor of shape (*, dim_input2)
    return_emb : bool
        Whether to return embeddings. Default is False.

    Returns
    -------
    dict or tuple
        Loss dictionary, optionally with embeddings.
    """
    loss_dict: dict[str, jax.Array] = {}
    batch_shape = feature1.shape[:-1]
    emb_1 = self.eval_func(feature1.reshape(-1, self.dim_input1))
    emb_2 = self.eval_func2(feature2.reshape(-1, self.dim_input2))

    if self.is_ce_like:
        labels = jnp.arange(len(emb_1))
        sim_matrix = emb_1 @ emb_2.T
        # Cross entropy loss
        log_softmax = jax.nn.log_softmax(sim_matrix, axis=-1)
        nce_loss = -log_softmax[jnp.arange(len(labels)), labels] - np.log(len(sim_matrix))
        loss_dict["nce"] = nce_loss
    else:
        positive = jnp.sum(emb_1 * emb_2, axis=-1)
        loss_dict["positive"] = jax.lax.stop_gradient(positive).mean()

        sim_matrix = emb_1 @ emb_2.T
        negative = jax.nn.logsumexp(sim_matrix, axis=-1) - np.log(len(sim_matrix))
        loss_dict["negative"] = jax.lax.stop_gradient(negative).mean()

        nce_loss = -positive + negative
        loss_dict["nce"] = nce_loss.reshape(batch_shape)

    if return_emb:
        return loss_dict, (emb_1, emb_2)
    return loss_dict

calc_sigmoid

calc_sigmoid(feature1, feature2, return_emb=False, temperature=0.1, bias=0.0)

Calculate the Sigmoid loss for contrastive learning.

Parameters:

Name Type Description Default
feature1 Array

First input tensor of shape (*, dim_input1)

required
feature2 Array

Second input tensor of shape (*, dim_input2)

required
return_emb bool

Whether to return embeddings. Default is False.

False
temperature float

Temperature. Default is 0.1.

0.1
bias float

Bias. Default is 0.0.

0.0

Returns:

Type Description
dict or tuple

Loss dictionary, optionally with embeddings.

Source code in src/ml_networks/jax/contrastive.py
def calc_sigmoid(
    self,
    feature1: jax.Array,
    feature2: jax.Array,
    return_emb: bool = False,
    temperature: float | jax.Array = 0.1,
    bias: float | jax.Array = 0.0,
) -> dict[str, jax.Array] | tuple[dict[str, jax.Array], tuple[jax.Array, jax.Array]]:
    """
    Calculate the Sigmoid loss for contrastive learning.

    Parameters
    ----------
    feature1 : jax.Array
        First input tensor of shape (*, dim_input1)
    feature2 : jax.Array
        Second input tensor of shape (*, dim_input2)
    return_emb : bool
        Whether to return embeddings. Default is False.
    temperature : float
        Temperature. Default is 0.1.
    bias : float
        Bias. Default is 0.0.

    Returns
    -------
    dict or tuple
        Loss dictionary, optionally with embeddings.
    """
    loss_dict: dict[str, jax.Array] = {}
    batch_shape = feature1.shape[:-1]
    emb_1 = self.eval_func(feature1.reshape(-1, self.dim_input1))
    emb_2 = self.eval_func2(feature2.reshape(-1, self.dim_input2))

    logits = emb_1 @ emb_2.T * temperature + bias
    labels = jnp.eye(len(logits)) * 2 - 1
    loss = -jax.nn.log_sigmoid(logits * labels).sum(axis=-1)
    loss_dict["sigmoid"] = loss.reshape(batch_shape)
    if return_emb:
        return loss_dict, (emb_1, emb_2)
    return loss_dict

calc_timeseries_nce

calc_timeseries_nce(feature1, feature2, positive_range_self=0, positive_range_tgt=0, return_emb=False)

Calculate the NCE loss for time series data.

Parameters:

Name Type Description Default
feature1 Array

First input tensor of shape (*batch, length, dim_input1)

required
feature2 Array

Second input tensor of shape (*batch, length, dim_input2)

required
positive_range_self int

Range for self-positive samples. Default is 0.

0
positive_range_tgt int

Range for target-positive samples. Default is 0.

0
return_emb bool

Whether to return embeddings. Default is False.

False

Returns:

Type Description
dict or tuple

Loss dictionary, optionally with embeddings.

Source code in src/ml_networks/jax/contrastive.py
def calc_timeseries_nce(
    self,
    feature1: jax.Array,
    feature2: jax.Array,
    positive_range_self: int = 0,
    positive_range_tgt: int = 0,
    return_emb: bool = False,
) -> dict[str, jax.Array] | tuple[dict[str, jax.Array], tuple[jax.Array, jax.Array]]:
    """
    Calculate the NCE loss for time series data.

    Parameters
    ----------
    feature1 : jax.Array
        First input tensor of shape (*batch, length, dim_input1)
    feature2 : jax.Array
        Second input tensor of shape (*batch, length, dim_input2)
    positive_range_self : int
        Range for self-positive samples. Default is 0.
    positive_range_tgt : int
        Range for target-positive samples. Default is 0.
    return_emb : bool
        Whether to return embeddings. Default is False.

    Returns
    -------
    dict or tuple
        Loss dictionary, optionally with embeddings.
    """
    if not positive_range_self and not positive_range_tgt:
        return self.calc_nce(feature1, feature2, return_emb)

    feature1 = feature1.reshape(-1, feature1.shape[-2], self.dim_input1)
    feature2 = feature2.reshape(-1, feature2.shape[-2], self.dim_input2)
    batch, length, _ = feature1.shape

    emb_1 = self.eval_func(feature1)
    emb_2 = self.eval_func2(feature2)

    loss_dict: dict[str, jax.Array] = {}

    positive = jnp.sum(
        emb_1.reshape(-1, emb_1.shape[-1]) * emb_2.reshape(-1, emb_2.shape[-1]),
        axis=-1,
    )
    loss_dict["positive"] = jax.lax.stop_gradient(positive).mean()

    if positive_range_self > 0:
        self_positive_1, self_positive_2 = self._calculate_self_positive_pairs(
            emb_1,
            emb_2,
            batch,
            length,
            positive_range_self,
        )
        positive = positive + self_positive_1.reshape(-1) + self_positive_2.reshape(-1)
        loss_dict["self_positive_1"] = jax.lax.stop_gradient(self_positive_1).mean()
        loss_dict["self_positive_2"] = jax.lax.stop_gradient(self_positive_2).mean()

    if positive_range_tgt > 0:
        tgt_positive = self._calculate_target_positive_pairs(
            emb_1,
            emb_2,
            batch,
            length,
            positive_range_tgt,
        )
        positive = positive + tgt_positive.reshape(-1)
        loss_dict["tgt_positive"] = jax.lax.stop_gradient(tgt_positive).mean()

    flat_emb1 = emb_1.reshape(-1, emb_1.shape[-1])
    flat_emb2 = emb_2.reshape(-1, emb_2.shape[-1])
    sim_matrix = flat_emb1 @ flat_emb2.T
    negative = jax.nn.logsumexp(sim_matrix, axis=-1) - np.log(len(sim_matrix))
    loss_dict["negative"] = jax.lax.stop_gradient(negative).mean()

    nce_loss = -positive + negative
    nce_loss = nce_loss.mean()
    loss_dict["nce"] = nce_loss

    if return_emb:
        return loss_dict, (emb_1, emb_2)
    return loss_dict

BaseModule

Bases: Module

Base module for JAX/Flax NNX.

Functions

freeze_biases

freeze_biases()

Freeze all bias parameters.

Source code in src/ml_networks/jax/base.py
def freeze_biases(self) -> None:
    """Freeze all bias parameters."""
    _graph_def, state = nnx.split(self)
    flat_state = state.flat_state()
    for key, value in flat_state.items():
        if "bias" in key and isinstance(value, nnx.VariableState):
            value.type = nnx.VariableState  # type: ignore[assignment]
    state = nnx.State.from_flat_path(flat_state)
    nnx.update(self, state)

freeze_weights

freeze_weights()

Freeze all weight parameters (kernel in Flax).

Source code in src/ml_networks/jax/base.py
def freeze_weights(self) -> None:
    """Freeze all weight parameters (kernel in Flax)."""
    _graph_def, state = nnx.split(self)
    flat_state = state.flat_state()
    for key, value in flat_state.items():
        if "kernel" in key and isinstance(value, nnx.VariableState):
            value.type = nnx.VariableState  # type: ignore[assignment]
    state = nnx.State.from_flat_path(flat_state)
    nnx.update(self, state)

unfreeze_biases

unfreeze_biases()

Unfreeze all bias parameters.

Source code in src/ml_networks/jax/base.py
def unfreeze_biases(self) -> None:
    """Unfreeze all bias parameters."""
    _graph_def, state = nnx.split(self)
    flat_state = state.flat_state()
    for key, value in flat_state.items():
        if "bias" in key and isinstance(value, nnx.VariableState):
            value.type = nnx.Param  # type: ignore[assignment]
    state = nnx.State.from_flat_path(flat_state)
    nnx.update(self, state)

unfreeze_weights

unfreeze_weights()

Unfreeze all weight parameters (kernel in Flax).

Source code in src/ml_networks/jax/base.py
def unfreeze_weights(self) -> None:
    """Unfreeze all weight parameters (kernel in Flax)."""
    _graph_def, state = nnx.split(self)
    flat_state = state.flat_state()
    for key, value in flat_state.items():
        if "kernel" in key and isinstance(value, nnx.VariableState):
            value.type = nnx.Param  # type: ignore[assignment]
    state = nnx.State.from_flat_path(flat_state)
    nnx.update(self, state)