JAX API リファレンス¶
JAX(Flax NNX)バックエンドのAPIリファレンスです。
PyTorchバックエンドと同一のインターフェースを提供しています。詳細は各PyTorch APIリファレンスページを参照してください。
レイヤー (ml_networks.jax.layers)¶
MLPLayer ¶
Bases: Module
Multi-layer perceptron layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension. |
required |
output_dim
|
int
|
Output dimension. |
required |
cfg
|
MLPConfig
|
|
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = MLPConfig(
... hidden_dim=16,
... n_layers=3,
... output_activation="ReLU",
... linear_cfg=LinearConfig(
... activation="ReLU",
... norm="layer",
... norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
... dropout=0.1,
... norm_first=False,
... bias=True
... )
... )
>>> mlp = MLPLayer(32, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32))
>>> output = mlp(x)
>>> output.shape
(1, 16)
Source code in src/ml_networks/jax/layers.py
LinearNormActivation ¶
Bases: Module
Linear layer with normalization and activation, and dropouts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension. |
required |
output_dim
|
int
|
Output dimension. |
required |
cfg
|
LinearConfig
|
Linear layer configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = LinearConfig(
... activation="ReLU",
... norm="layer",
... norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
... dropout=0.1,
... norm_first=False,
... bias=True
... )
>>> linear = LinearNormActivation(32, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32))
>>> output = linear(x)
>>> output.shape
(1, 16)
Source code in src/ml_networks/jax/layers.py
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (*, input_dim) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (*, output_dim) |
Source code in src/ml_networks/jax/layers.py
ConvNormActivation ¶
Bases: Module
Convolutional layer with normalization and activation, and dropouts.
Uses NHWC (channels-last) format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Input channels. |
required |
out_channels
|
int
|
Output channels. |
required |
cfg
|
ConvConfig
|
Convolutional layer configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = ConvConfig(
... activation="ReLU",
... kernel_size=3,
... stride=1,
... padding=1,
... dilation=1,
... groups=1,
... bias=True,
... dropout=0.1,
... norm="batch",
... norm_cfg={"affine": True, "track_running_stats": True},
... scale_factor=0
... )
>>> conv = ConvNormActivation(3, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32, 32, 3))
>>> output = conv(x)
>>> output.shape
(1, 32, 32, 16)
Source code in src/ml_networks/jax/layers.py
Attributes¶
conv
instance-attribute
¶
conv = Conv(in_features=in_channels, out_features=out_channels_, kernel_size=(kernel_size, kernel_size), strides=(stride, stride), padding=conv_padding, kernel_dilation=(dilation, dilation), feature_group_count=groups, use_bias=bias, rngs=rngs)
dropout
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels) |
Source code in src/ml_networks/jax/layers.py
ConvTransposeNormActivation ¶
Bases: Module
Transposed convolutional layer with normalization and activation, and dropouts.
Uses NHWC (channels-last) format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Input channels. |
required |
out_channels
|
int
|
Output channels. |
required |
cfg
|
ConvConfig
|
Convolutional layer configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = ConvConfig(
... activation="ReLU",
... kernel_size=3,
... stride=1,
... padding=1,
... output_padding=0,
... dilation=1,
... groups=1,
... bias=True,
... dropout=0.1,
... norm="batch",
... norm_cfg={"affine": True, "track_running_stats": True}
... )
>>> conv = ConvTransposeNormActivation(3, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32, 32, 3))
>>> output = conv(x)
>>> output.shape
(1, 32, 32, 16)
Source code in src/ml_networks/jax/layers.py
Attributes¶
conv
instance-attribute
¶
conv = ConvTranspose(in_features=in_channels, out_features=out_features, kernel_size=(kernel_size, kernel_size), strides=(stride, stride), padding=((padding, padding + output_padding), (padding, padding + output_padding)), kernel_dilation=(dilation, dilation), use_bias=bias, rngs=rngs)
dropout
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels) |
Source code in src/ml_networks/jax/layers.py
ビジョン (ml_networks.jax.vision)¶
Encoder ¶
Bases: Module
Image encoder module (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int | tuple[int, int, int]
|
Output feature dimension. If int, a fully-connected layer flattens and projects the backbone output. If tuple, the backbone output is returned directly (fc is identity). |
required |
obs_shape
|
tuple[int, int, int]
|
Observation shape in (H, W, C) format. |
required |
backbone_cfg
|
ViTConfig | ConvNetConfig | ResNetConfig
|
Backbone configuration. |
required |
fc_cfg
|
MLPConfig | LinearConfig | SpatialSoftmaxConfig | AdaptiveAveragePoolingConfig | None
|
Fully-connected layer configuration. Required when |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 | |
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (*, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Encoded tensor of shape (*, feature_dim). |
Source code in src/ml_networks/jax/vision.py
Decoder ¶
Bases: Module
Image decoder module (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int | tuple[int, int, int]
|
Input feature dimension. If int, a fully-connected layer projects and reshapes input before the backbone. If tuple, input is passed directly to the backbone. |
required |
obs_shape
|
tuple[int, int, int]
|
Output observation shape in (H, W, C) format. |
required |
backbone_cfg
|
ConvNetConfig | ViTConfig | ResNetConfig
|
Backbone configuration. |
required |
fc_cfg
|
MLPConfig | LinearConfig | None
|
Fully-connected layer configuration. Required when |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 | |
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (*, feature_dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Decoded tensor of shape (*, H, W, C) in NHWC format. |
Source code in src/ml_networks/jax/vision.py
ConvNet ¶
Bases: Module
Convolutional network (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_shape
|
tuple[int, int, int]
|
Observation shape in (H, W, C) format. |
required |
cfg
|
ConvNetConfig
|
Configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
conved_shape
property
¶
Get the spatial shape of the output after convolutional layers.
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Flattened tensor of shape (B, output_dim). |
Source code in src/ml_networks/jax/vision.py
ConvTranspose ¶
Bases: Module
Transposed convolutional network (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_shape
|
tuple[int, int, int]
|
Input shape in (H, W, C) format. |
required |
obs_shape
|
tuple[int, int, int]
|
Output observation shape in (H, W, C) format. |
required |
cfg
|
ConvNetConfig
|
Configuration (channels are in decode order). |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
first_conv
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H', W', out_C) in NHWC format. |
Source code in src/ml_networks/jax/vision.py
get_input_shape
staticmethod
¶
Get the required input shape for a given output shape and config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_shape
|
tuple[int, int, int]
|
Output shape in (H, W, C) format. |
required |
cfg
|
ConvNetConfig
|
Configuration. |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, ...]
|
Required input shape in (H, W, C) format. |
Source code in src/ml_networks/jax/vision.py
ResNetPixUnshuffle ¶
Bases: Module
ResNet with PixelUnshuffle downsampling (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_shape
|
tuple[int, int, int]
|
Input observation shape in (H, W, C) format. |
required |
cfg
|
ResNetConfig
|
Configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
982 983 984 985 986 987 988 989 990 991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 1005 1006 1007 1008 1009 1010 1011 1012 1013 1014 1015 1016 1017 1018 1019 1020 1021 1022 1023 1024 1025 1026 1027 1028 1029 1030 1031 1032 1033 1034 1035 1036 1037 1038 1039 1040 1041 1042 1043 1044 1045 1046 1047 1048 1049 | |
Attributes¶
conv1
instance-attribute
¶
conv2
instance-attribute
¶
conv3
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Downsampled output. |
Source code in src/ml_networks/jax/vision.py
ResNetPixShuffle ¶
Bases: Module
ResNet with PixelShuffle upsampling (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_shape
|
tuple[int, int, int]
|
Input shape in (H, W, C) format. |
required |
obs_shape
|
tuple[int, int, int]
|
Output observation shape in (H, W, C) format. |
required |
cfg
|
ResNetConfig
|
Configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
conv1
instance-attribute
¶
conv2
instance-attribute
¶
conv3
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Upsampled output of shape (B, H', W', C'). |
Source code in src/ml_networks/jax/vision.py
get_input_shape
staticmethod
¶
Get the required input shape for a given output shape and config.
Source code in src/ml_networks/jax/vision.py
ViT ¶
Bases: Module
Vision Transformer for Encoder and Decoder (NHWC format).
The encoder mode (obs_shape is None) follows the DETR convention: a learnable per-patch
positional embedding is added to the query and key tensors of every self-attention layer
(rather than being added once at the input). When cfg.cls_token is True, a CLS token
with its own learnable positional embedding is prepended and the forward pass returns the
CLS token of shape (B, d_model).
The decoder mode (obs_shape is not None) takes a CLS token of shape (B, d_model) or
(B, 1, d_model) and reconstructs an image. The CLS token is projected to a hidden
dimension and used as the key/value of cross-attention. A fixed set of
P = (H // p) * (W // p) learnable query tokens interacts with this representation
through several cross-attention layers with residual MLP blocks. Each query is then
linearly projected to p * p * C pixels and rearranged into a (B, H, W, C) image.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_shape
|
tuple[int, ...]
|
Encoder mode: image shape |
required |
cfg
|
ViTConfig
|
ViT configuration. |
required |
obs_shape
|
tuple[int, int, int] | None
|
Output shape in (H, W, C) format. If |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
obs_shape
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Encoder mode: image tensor of shape |
required |
return_cls_token
|
bool
|
Retained for backward compatibility; the encoder always returns the CLS token. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Encoder mode: CLS token of shape |
Source code in src/ml_networks/jax/vision.py
get_input_shape
staticmethod
¶
Input shape consumed by the ViT decoder: the CLS token has dimension d_model.
Source code in src/ml_networks/jax/vision.py
get_n_patches ¶
Get number of patches for a given shape (NHWC: H, W, C).
get_patch_dim ¶
patchify ¶
unpatchify ¶
Reconstruct images from patches (NHWC).
Source code in src/ml_networks/jax/vision.py
分布 (ml_networks.jax.distributions)¶
Distribution ¶
Bases: Module
A distribution function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_dim
|
int
|
Input dimension. |
required |
dist
|
Literal['normal', 'categorical', 'bernoulli']
|
Distribution type. |
required |
n_groups
|
int
|
Number of groups. Default is 1. |
1
|
spherical
|
bool
|
Whether to project samples to the unit sphere. Default is False. |
False
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/distributions.py
Attributes¶
Functions¶
__call__ ¶
Compute the posterior distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
deterministic
|
bool
|
Whether to use the deterministic mode. Default is False. |
False
|
inv_tmp
|
float
|
Inverse temperature. Default is 1.0. |
1.0
|
key
|
Array | None
|
PRNG key. If provided, |
None
|
Returns:
| Type | Description |
|---|---|
StochState
|
Posterior distribution. |
Source code in src/ml_networks/jax/distributions.py
bernoulli ¶
Source code in src/ml_networks/jax/distributions.py
categorical ¶
Source code in src/ml_networks/jax/distributions.py
deterministic_onehot ¶
Compute the one-hot vector by argmax with straight-through.
Source code in src/ml_networks/jax/distributions.py
normal ¶
Source code in src/ml_networks/jax/distributions.py
NormalStoch
dataclass
¶
Parameters of a normal distribution and its stochastic sample.
Attributes¶
Functions¶
__getitem__ ¶
__len__ ¶
__post_init__ ¶
Source code in src/ml_networks/jax/distributions.py
broadcast ¶
copy ¶
device_put ¶
Place all members on the specified device.
Source code in src/ml_networks/jax/distributions.py
expand_dims ¶
flatten ¶
Flatten along specified axes.
Source code in src/ml_networks/jax/distributions.py
get_distribution ¶
reshape ¶
save ¶
Source code in src/ml_networks/jax/distributions.py
squeeze ¶
stop_gradient ¶
Apply stop_gradient to all members.
swapaxes ¶
to_numpy ¶
CategoricalStoch
dataclass
¶
Parameters of a categorical distribution and its stochastic sample.
Attributes¶
Functions¶
__getitem__ ¶
__len__ ¶
__post_init__ ¶
broadcast ¶
copy ¶
device_put ¶
Place all members on the specified device.
Source code in src/ml_networks/jax/distributions.py
expand_dims ¶
flatten ¶
Flatten along specified axes.
Source code in src/ml_networks/jax/distributions.py
get_distribution ¶
reshape ¶
save ¶
Source code in src/ml_networks/jax/distributions.py
squeeze ¶
stop_gradient ¶
Apply stop_gradient to all members.
swapaxes ¶
to_numpy ¶
Convert all members to NumPy arrays.
BernoulliStoch
dataclass
¶
Parameters of a Bernoulli distribution and its stochastic sample.
Attributes¶
Functions¶
__getitem__ ¶
__len__ ¶
__post_init__ ¶
broadcast ¶
copy ¶
device_put ¶
Place all members on the specified device.
Source code in src/ml_networks/jax/distributions.py
expand_dims ¶
flatten ¶
Flatten along specified axes.
Source code in src/ml_networks/jax/distributions.py
get_distribution ¶
reshape ¶
save ¶
Source code in src/ml_networks/jax/distributions.py
squeeze ¶
stop_gradient ¶
Apply stop_gradient to all members.
swapaxes ¶
to_numpy ¶
Convert all members to NumPy arrays.
損失関数 (ml_networks.jax.loss)¶
focal_loss ¶
Focal loss function. Mainly for multi-class classification.
Reference
Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Array
|
The predicted tensor. This should be before softmax. |
required |
target
|
Array
|
The target tensor (integer class labels). |
required |
gamma
|
float
|
The gamma parameter. Default is 2.0. |
2.0
|
sum_axis
|
int
|
The axis to sum the loss. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
Array
|
The focal loss. |
Source code in src/ml_networks/jax/loss.py
binary_focal_loss ¶
Binary focal loss function. Mainly for binary classification.
Reference
Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Array
|
The predicted tensor. This should be before sigmoid. |
required |
target
|
Array
|
The target tensor. |
required |
gamma
|
float
|
The gamma parameter. Default is 2.0. |
2.0
|
sum_axis
|
int
|
The axis to sum the loss. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
Array
|
The binary focal loss. |
Source code in src/ml_networks/jax/loss.py
charbonnier ¶
Charbonnier loss function.
Reference
A General and Adaptive Robust Loss Function http://arxiv.org/abs/1701.03077
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Array
|
The predicted tensor. |
required |
target
|
Array
|
The target tensor. |
required |
epsilon
|
float
|
A small value to avoid division by zero. Default is 1e-3. |
0.001
|
alpha
|
float
|
The alpha parameter. Default is 1. |
1
|
sum_axes
|
int | list[int] | tuple[int, ...] | None
|
The axes to sum the loss. Default is None (sums over [-1, -2, -3]). |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
The Charbonnier loss. |
Source code in src/ml_networks/jax/loss.py
kl_divergence ¶
KL divergence between two StochState distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
posterior
|
StochState
|
The posterior distribution. |
required |
prior
|
StochState
|
The prior distribution. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The KL divergence. |
Source code in src/ml_networks/jax/loss.py
活性化関数 (ml_networks.jax.activations)¶
Activation ¶
Bases: Module
Generic activation function.
Source code in src/ml_networks/jax/activations.py
REReLU ¶
Bases: Module
Reparametarized ReLU activation function. This backward pass is differentiable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reparametarize_fn
|
str
|
Reparametarization function. Default is GELU. |
'gelu'
|
References
https://openreview.net/forum?id=lNCnZwcH5Z
Examples:
>>> rerelu = REReLU()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = rerelu(x)
>>> output.shape
(1, 3)
Source code in src/ml_networks/jax/activations.py
SiGLU ¶
Bases: Module
SiGLU activation function.
This is equivalent to SwiGLU (Swish variant of Gated Linear Unit) activation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to split the tensor. Default is -1. |
-1
|
References
https://arxiv.org/abs/2102.11972
Examples:
Source code in src/ml_networks/jax/activations.py
CRReLU ¶
Bases: Module
Correction Regularized ReLU activation function. This is a variant of ReLU activation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
float
|
Learning rate. Default is 0.01. |
0.01
|
References
https://openreview.net/forum?id=7TZYM6Hm9p
Examples:
>>> crrelu = CRReLU()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = crrelu(x)
>>> output.shape
(1, 3)
Source code in src/ml_networks/jax/activations.py
TanhExp ¶
UNet (ml_networks.jax.unet)¶
ConditionalUnet2d ¶
Bases: Module
条件付きUNetモデル (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int
|
条件付き特徴量の次元数 |
required |
obs_shape
|
tuple[int, int, int]
|
観測データの形状 (H, W, C) in NHWC format. |
required |
cfg
|
UNetConfig
|
UNetの設定 |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/unet.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | |
Attributes¶
final_conv1
instance-attribute
¶
final_conv2
instance-attribute
¶
mid_modules
instance-attribute
¶
mid_modules = List([ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs), Attention2d(mid_dim, nhead, rngs=rngs) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs)])
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
cond
|
Array
|
Conditional tensor of shape (B, cond_dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H, W, C). |
Source code in src/ml_networks/jax/unet.py
ConditionalUnet1d ¶
Bases: Module
条件付き1D UNetモデル (NLC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int
|
条件付き特徴量の次元数 |
required |
obs_shape
|
tuple[int, int]
|
観測データの形状 (L, C) in NLC format. |
required |
cfg
|
UNetConfig
|
UNetの設定 |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/unet.py
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 | |
Attributes¶
final_conv1
instance-attribute
¶
final_conv2
instance-attribute
¶
mid_modules
instance-attribute
¶
mid_modules = List([ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg, rngs=rngs), Attention1d(mid_dim, nhead, rngs=rngs) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg, rngs=rngs)])
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
Array
|
Input tensor of shape (B, L, C) in NLC format. |
required |
cond
|
Array
|
Conditional tensor of shape (B, cond_dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, L, C). |
Source code in src/ml_networks/jax/unet.py
ユーティリティ (ml_networks.jax.jax_utils)¶
get_optimizer ¶
Get optimizer from optax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Optimizer name (e.g. "adam", "sgd", "adamw", "lamb", "rmsprop"). |
required |
kwargs
|
dict
|
Optimizer arguments (e.g. learning_rate=0.01). |
{}
|
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
|
Examples:
Source code in src/ml_networks/jax/jax_utils.py
jax_fix_seed ¶
乱数を固定する関数.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
Array
|
JAX PRNG key. |
Source code in src/ml_networks/jax/jax_utils.py
MinMaxNormalize ¶
SoftmaxTransformation ¶
Softmax 変換クラス.
Source code in src/ml_networks/jax/jax_utils.py
Attributes¶
Functions¶
__call__ ¶
get_transformed_dim ¶
inverse ¶
SoftmaxTransformation の逆変換.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
入力テンソル. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
出力テンソル. |
Source code in src/ml_networks/jax/jax_utils.py
transform ¶
SoftmaxTransformation の実行.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
入力テンソル. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
出力テンソル. |
Examples:
>>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=16, sigma=0.01, n_ignore=1, min=-1.0, max=1.0))
>>> x = jnp.ones((2, 3, 4))
>>> transformed = trans(x)
>>> transformed.shape
(2, 3, 49)
>>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=11, sigma=0.05, n_ignore=0, min=-1.0, max=1.0))
>>> x = jnp.ones((2, 3, 4))
>>> transformed = trans(x)
>>> transformed.shape
(2, 3, 44)
Source code in src/ml_networks/jax/jax_utils.py
その他¶
HyperNet ¶
Bases: Module, HyperNetMixin
A hypernetwork that generates weights for a target network.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Dimension of the input. |
required |
output_shapes
|
dict[str, Shape]
|
Shapes of the primary network weights being predicted. |
required |
fc_cfg
|
MLPConfig | None
|
Configuration for the MLP backbone. |
None
|
encoding
|
InputMode
|
The input encoding mode. Default is None. |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/hypernetworks.py
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
Array
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary of output tensors. |
Source code in src/ml_networks/jax/hypernetworks.py
ContrastiveLearningLoss ¶
Bases: Module
Contrastive learning module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim_input1
|
int
|
Dimension of first input. |
required |
dim_input2
|
int
|
Dimension of second input. |
required |
cfg
|
ContrastiveLearningConfig
|
Configuration for contrastive learning. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/contrastive.py
Attributes¶
eval_func2
instance-attribute
¶
Functions¶
calc_nce ¶
Calculate the Noise Contrastive Estimation (NCE) loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature1
|
Array
|
First input tensor of shape (*, dim_input1) |
required |
feature2
|
Array
|
Second input tensor of shape (*, dim_input2) |
required |
return_emb
|
bool
|
Whether to return embeddings. Default is False. |
False
|
Returns:
| Type | Description |
|---|---|
dict or tuple
|
Loss dictionary, optionally with embeddings. |
Source code in src/ml_networks/jax/contrastive.py
calc_sigmoid ¶
Calculate the Sigmoid loss for contrastive learning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature1
|
Array
|
First input tensor of shape (*, dim_input1) |
required |
feature2
|
Array
|
Second input tensor of shape (*, dim_input2) |
required |
return_emb
|
bool
|
Whether to return embeddings. Default is False. |
False
|
temperature
|
float
|
Temperature. Default is 0.1. |
0.1
|
bias
|
float
|
Bias. Default is 0.0. |
0.0
|
Returns:
| Type | Description |
|---|---|
dict or tuple
|
Loss dictionary, optionally with embeddings. |
Source code in src/ml_networks/jax/contrastive.py
calc_timeseries_nce ¶
calc_timeseries_nce(feature1, feature2, positive_range_self=0, positive_range_tgt=0, return_emb=False)
Calculate the NCE loss for time series data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature1
|
Array
|
First input tensor of shape (*batch, length, dim_input1) |
required |
feature2
|
Array
|
Second input tensor of shape (*batch, length, dim_input2) |
required |
positive_range_self
|
int
|
Range for self-positive samples. Default is 0. |
0
|
positive_range_tgt
|
int
|
Range for target-positive samples. Default is 0. |
0
|
return_emb
|
bool
|
Whether to return embeddings. Default is False. |
False
|
Returns:
| Type | Description |
|---|---|
dict or tuple
|
Loss dictionary, optionally with embeddings. |
Source code in src/ml_networks/jax/contrastive.py
BaseModule ¶
Bases: Module
Base module for JAX/Flax NNX.
Functions¶
freeze_biases ¶
Freeze all bias parameters.
Source code in src/ml_networks/jax/base.py
freeze_weights ¶
Freeze all weight parameters (kernel in Flax).
Source code in src/ml_networks/jax/base.py
unfreeze_biases ¶
Unfreeze all bias parameters.
Source code in src/ml_networks/jax/base.py
unfreeze_weights ¶
Unfreeze all weight parameters (kernel in Flax).