JAX API リファレンス¶
JAX(Flax NNX)バックエンドのAPIリファレンスです。
PyTorchバックエンドと同一のインターフェースを提供しています。詳細は各PyTorch APIリファレンスページを参照してください。
レイヤー (ml_networks.jax.layers)¶
MLPLayer ¶
Bases: Module
Multi-layer perceptron layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension. |
required |
output_dim
|
int
|
Output dimension. |
required |
cfg
|
MLPConfig
|
|
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = MLPConfig(
... hidden_dim=16,
... n_layers=3,
... output_activation="ReLU",
... linear_cfg=LinearConfig(
... activation="ReLU",
... norm="layer",
... norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
... dropout=0.1,
... norm_first=False,
... bias=True
... )
... )
>>> mlp = MLPLayer(32, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32))
>>> output = mlp(x)
>>> output.shape
(1, 16)
Source code in src/ml_networks/jax/layers.py
LinearNormActivation ¶
Bases: Module
Linear layer with normalization and activation, and dropouts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Input dimension. |
required |
output_dim
|
int
|
Output dimension. |
required |
cfg
|
LinearConfig
|
Linear layer configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = LinearConfig(
... activation="ReLU",
... norm="layer",
... norm_cfg={"eps": 1e-05, "elementwise_affine": True, "bias": True},
... dropout=0.1,
... norm_first=False,
... bias=True
... )
>>> linear = LinearNormActivation(32, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32))
>>> output = linear(x)
>>> output.shape
(1, 16)
Source code in src/ml_networks/jax/layers.py
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (*, input_dim) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (*, output_dim) |
Source code in src/ml_networks/jax/layers.py
ConvNormActivation ¶
Bases: Module
Convolutional layer with normalization and activation, and dropouts.
Uses NHWC (channels-last) format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Input channels. |
required |
out_channels
|
int
|
Output channels. |
required |
cfg
|
ConvConfig
|
Convolutional layer configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = ConvConfig(
... activation="ReLU",
... kernel_size=3,
... stride=1,
... padding=1,
... dilation=1,
... groups=1,
... bias=True,
... dropout=0.1,
... norm="batch",
... norm_cfg={"affine": True, "track_running_stats": True},
... scale_factor=0
... )
>>> conv = ConvNormActivation(3, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32, 32, 3))
>>> output = conv(x)
>>> output.shape
(1, 32, 32, 16)
Source code in src/ml_networks/jax/layers.py
Attributes¶
conv
instance-attribute
¶
conv = Conv(in_features=in_channels, out_features=out_channels_, kernel_size=(kernel_size, kernel_size), strides=(stride, stride), padding=conv_padding, kernel_dilation=(dilation, dilation), feature_group_count=groups, use_bias=bias, rngs=rngs)
dropout
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels) |
Source code in src/ml_networks/jax/layers.py
ConvTransposeNormActivation ¶
Bases: Module
Transposed convolutional layer with normalization and activation, and dropouts.
Uses NHWC (channels-last) format.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Input channels. |
required |
out_channels
|
int
|
Output channels. |
required |
cfg
|
ConvConfig
|
Convolutional layer configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Examples:
>>> rngs = nnx.Rngs(0)
>>> cfg = ConvConfig(
... activation="ReLU",
... kernel_size=3,
... stride=1,
... padding=1,
... output_padding=0,
... dilation=1,
... groups=1,
... bias=True,
... dropout=0.1,
... norm="batch",
... norm_cfg={"affine": True, "track_running_stats": True}
... )
>>> conv = ConvTransposeNormActivation(3, 16, cfg, rngs=rngs)
>>> x = jnp.ones((1, 32, 32, 3))
>>> output = conv(x)
>>> output.shape
(1, 32, 32, 16)
Source code in src/ml_networks/jax/layers.py
Attributes¶
conv
instance-attribute
¶
conv = ConvTranspose(in_features=in_channels, out_features=out_features, kernel_size=(kernel_size, kernel_size), strides=(stride, stride), padding=((padding, padding + output_padding), (padding, padding + output_padding)), kernel_dilation=(dilation, dilation), use_bias=bias, rngs=rngs)
dropout
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, in_channels) or (H, W, in_channels) |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H', W', out_channels) or (H', W', out_channels) |
Source code in src/ml_networks/jax/layers.py
ビジョン (ml_networks.jax.vision)¶
Encoder ¶
Bases: Module
Image encoder module (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int | tuple[int, int, int]
|
Output feature dimension. If int, a fully-connected layer flattens and projects the backbone output. If tuple, the backbone output is returned directly (fc is identity). |
required |
obs_shape
|
tuple[int, int, int]
|
Observation shape in (H, W, C) format. |
required |
backbone_cfg
|
ViTConfig | ConvNetConfig | ResNetConfig
|
Backbone configuration. |
required |
fc_cfg
|
MLPConfig | LinearConfig | SpatialSoftmaxConfig | AdaptiveAveragePoolingConfig | None
|
Fully-connected layer configuration. Required when |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (*, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Encoded tensor of shape (*, feature_dim). |
Source code in src/ml_networks/jax/vision.py
Decoder ¶
Bases: Module
Image decoder module (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int | tuple[int, int, int]
|
Input feature dimension. If int, a fully-connected layer projects and reshapes input before the backbone. If tuple, input is passed directly to the backbone. |
required |
obs_shape
|
tuple[int, int, int]
|
Output observation shape in (H, W, C) format. |
required |
backbone_cfg
|
ConvNetConfig | ViTConfig | ResNetConfig
|
Backbone configuration. |
required |
fc_cfg
|
MLPConfig | LinearConfig | None
|
Fully-connected layer configuration. Required when |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
decoder
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (*, feature_dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Decoded tensor of shape (*, H, W, C) in NHWC format. |
Source code in src/ml_networks/jax/vision.py
ConvNet ¶
Bases: Module
Convolutional network (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_shape
|
tuple[int, int, int]
|
Observation shape in (H, W, C) format. |
required |
cfg
|
ConvNetConfig
|
Configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
conved_shape
property
¶
Get the spatial shape of the output after convolutional layers.
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Flattened tensor of shape (B, output_dim). |
Source code in src/ml_networks/jax/vision.py
ConvTranspose ¶
Bases: Module
Transposed convolutional network (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_shape
|
tuple[int, int, int]
|
Input shape in (H, W, C) format. |
required |
obs_shape
|
tuple[int, int, int]
|
Output observation shape in (H, W, C) format. |
required |
cfg
|
ConvNetConfig
|
Configuration (channels are in decode order). |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
first_conv
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H', W', out_C) in NHWC format. |
Source code in src/ml_networks/jax/vision.py
get_input_shape
staticmethod
¶
Get the required input shape for a given output shape and config.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_shape
|
tuple[int, int, int]
|
Output shape in (H, W, C) format. |
required |
cfg
|
ConvNetConfig
|
Configuration. |
required |
Returns:
| Type | Description |
|---|---|
tuple[int, ...]
|
Required input shape in (H, W, C) format. |
Source code in src/ml_networks/jax/vision.py
ResNetPixUnshuffle ¶
Bases: Module
ResNet with PixelUnshuffle downsampling (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
obs_shape
|
tuple[int, int, int]
|
Input observation shape in (H, W, C) format. |
required |
cfg
|
ResNetConfig
|
Configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
conv1
instance-attribute
¶
conv2
instance-attribute
¶
conv3
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Downsampled output. |
Source code in src/ml_networks/jax/vision.py
ResNetPixShuffle ¶
Bases: Module
ResNet with PixelShuffle upsampling (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_shape
|
tuple[int, int, int]
|
Input shape in (H, W, C) format. |
required |
obs_shape
|
tuple[int, int, int]
|
Output observation shape in (H, W, C) format. |
required |
cfg
|
ResNetConfig
|
Configuration. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
conv1
instance-attribute
¶
conv2
instance-attribute
¶
conv3
instance-attribute
¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Upsampled output of shape (B, H', W', C'). |
Source code in src/ml_networks/jax/vision.py
get_input_shape
staticmethod
¶
Get the required input shape for a given output shape and config.
Source code in src/ml_networks/jax/vision.py
ViT ¶
Bases: Module
Vision Transformer for Encoder and Decoder (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_shape
|
tuple[int, int, int]
|
Input shape in (H, W, C) format. |
required |
cfg
|
ViTConfig
|
ViT configuration. |
required |
obs_shape
|
tuple[int, int, int] | None
|
Output shape in (H, W, C) format. If None, acts as encoder. |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/vision.py
Attributes¶
out_patch_dim
instance-attribute
¶
patch_embed
instance-attribute
¶
patch_embed = PatchEmbed(emb_dim=in_patch_dim, patch_size=patch_size, obs_shape=in_shape, rngs=rngs)
positional_encoding
instance-attribute
¶
positional_encoding = PositionalEncoding(in_patch_dim, dropout, max_len=get_n_patches(in_shape), rngs=rngs)
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
return_cls_token
|
bool
|
Whether to return CLS token only. Default is False. |
False
|
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor. |
Source code in src/ml_networks/jax/vision.py
get_input_shape
staticmethod
¶
Get the required input shape (NHWC: H, W, C).
get_n_patches ¶
Get number of patches for a given shape (NHWC: H, W, C).
get_patch_dim ¶
patchify ¶
Split images into patches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
imgs
|
Array
|
Input images of shape (N, H, W, C) in NHWC format. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Patchified images of shape (N, L, patch_size**2 * C). |
Source code in src/ml_networks/jax/vision.py
unpatchify ¶
Reconstruct images from patches.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input of shape (N, L, patch_size**2 * C). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Images of shape (N, H, W, C) in NHWC format. |
Source code in src/ml_networks/jax/vision.py
分布 (ml_networks.jax.distributions)¶
Distribution ¶
Bases: Module
A distribution function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_dim
|
int
|
Input dimension. |
required |
dist
|
Literal['normal', 'categorical', 'bernoulli']
|
Distribution type. |
required |
n_groups
|
int
|
Number of groups. Default is 1. |
1
|
spherical
|
bool
|
Whether to project samples to the unit sphere. Default is False. |
False
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/distributions.py
Attributes¶
Functions¶
__call__ ¶
Compute the posterior distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
Input tensor. |
required |
deterministic
|
bool
|
Whether to use the deterministic mode. Default is False. |
False
|
inv_tmp
|
float
|
Inverse temperature. Default is 1.0. |
1.0
|
Returns:
| Type | Description |
|---|---|
StochState
|
Posterior distribution. |
Source code in src/ml_networks/jax/distributions.py
bernoulli ¶
Source code in src/ml_networks/jax/distributions.py
categorical ¶
Source code in src/ml_networks/jax/distributions.py
deterministic_onehot ¶
Compute the one-hot vector by argmax with straight-through.
Source code in src/ml_networks/jax/distributions.py
normal ¶
Source code in src/ml_networks/jax/distributions.py
NormalStoch
dataclass
¶
Parameters of a normal distribution and its stochastic sample.
Attributes¶
Functions¶
__getitem__ ¶
__len__ ¶
__post_init__ ¶
Source code in src/ml_networks/jax/distributions.py
detach ¶
flatten ¶
Flatten along specified dimensions.
Source code in src/ml_networks/jax/distributions.py
get_distribution ¶
reshape ¶
save ¶
Source code in src/ml_networks/jax/distributions.py
CategoricalStoch
dataclass
¶
BernoulliStoch
dataclass
¶
損失関数 (ml_networks.jax.loss)¶
focal_loss ¶
Focal loss function. Mainly for multi-class classification.
Reference
Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Array
|
The predicted tensor. This should be before softmax. |
required |
target
|
Array
|
The target tensor (integer class labels). |
required |
gamma
|
float
|
The gamma parameter. Default is 2.0. |
2.0
|
sum_axis
|
int
|
The axis to sum the loss. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
Array
|
The focal loss. |
Source code in src/ml_networks/jax/loss.py
binary_focal_loss ¶
Binary focal loss function. Mainly for binary classification.
Reference
Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Array
|
The predicted tensor. This should be before sigmoid. |
required |
target
|
Array
|
The target tensor. |
required |
gamma
|
float
|
The gamma parameter. Default is 2.0. |
2.0
|
sum_axis
|
int
|
The axis to sum the loss. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
Array
|
The binary focal loss. |
Source code in src/ml_networks/jax/loss.py
charbonnier ¶
Charbonnier loss function.
Reference
A General and Adaptive Robust Loss Function http://arxiv.org/abs/1701.03077
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Array
|
The predicted tensor. |
required |
target
|
Array
|
The target tensor. |
required |
epsilon
|
float
|
A small value to avoid division by zero. Default is 1e-3. |
0.001
|
alpha
|
float
|
The alpha parameter. Default is 1. |
1
|
sum_axes
|
int | list[int] | tuple[int, ...] | None
|
The axes to sum the loss. Default is None (sums over [-1, -2, -3]). |
None
|
Returns:
| Type | Description |
|---|---|
Array
|
The Charbonnier loss. |
Source code in src/ml_networks/jax/loss.py
kl_divergence ¶
KL divergence between two StochState distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
posterior
|
StochState
|
The posterior distribution. |
required |
prior
|
StochState
|
The prior distribution. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
The KL divergence. |
Source code in src/ml_networks/jax/loss.py
活性化関数 (ml_networks.jax.activations)¶
Activation ¶
Bases: Module
Generic activation function.
Source code in src/ml_networks/jax/activations.py
REReLU ¶
Bases: Module
Reparametarized ReLU activation function. This backward pass is differentiable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
reparametarize_fn
|
str
|
Reparametarization function. Default is GELU. |
'gelu'
|
References
https://openreview.net/forum?id=lNCnZwcH5Z
Examples:
>>> rerelu = REReLU()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = rerelu(x)
>>> output.shape
(1, 3)
Source code in src/ml_networks/jax/activations.py
SiGLU ¶
Bases: Module
SiGLU activation function.
This is equivalent to SwiGLU (Swish variant of Gated Linear Unit) activation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to split the tensor. Default is -1. |
-1
|
References
https://arxiv.org/abs/2102.11972
Examples:
Source code in src/ml_networks/jax/activations.py
CRReLU ¶
Bases: Module
Correction Regularized ReLU activation function. This is a variant of ReLU activation function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
lr
|
float
|
Learning rate. Default is 0.01. |
0.01
|
References
https://openreview.net/forum?id=7TZYM6Hm9p
Examples:
>>> crrelu = CRReLU()
>>> x = jnp.array([[1.0, -1.0, 0.5]])
>>> output = crrelu(x)
>>> output.shape
(1, 3)
Source code in src/ml_networks/jax/activations.py
TanhExp ¶
UNet (ml_networks.jax.unet)¶
ConditionalUnet2d ¶
Bases: Module
条件付きUNetモデル (NHWC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int
|
条件付き特徴量の次元数 |
required |
obs_shape
|
tuple[int, int, int]
|
観測データの形状 (H, W, C) in NHWC format. |
required |
cfg
|
UNetConfig
|
UNetの設定 |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/unet.py
41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 | |
Attributes¶
final_conv1
instance-attribute
¶
final_conv2
instance-attribute
¶
mid_modules
instance-attribute
¶
mid_modules = List([ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs), Attention2d(mid_dim, nhead, rngs=rngs) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock2d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs)])
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
Array
|
Input tensor of shape (B, H, W, C) in NHWC format. |
required |
cond
|
Array
|
Conditional tensor of shape (B, cond_dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, H, W, C). |
Source code in src/ml_networks/jax/unet.py
ConditionalUnet1d ¶
Bases: Module
条件付き1D UNetモデル (NLC format).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature_dim
|
int
|
条件付き特徴量の次元数 |
required |
obs_shape
|
tuple[int, int]
|
観測データの形状 (L, C) in NLC format. |
required |
cfg
|
UNetConfig
|
UNetの設定 |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/unet.py
205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 | |
Attributes¶
final_conv1
instance-attribute
¶
final_conv2
instance-attribute
¶
mid_modules
instance-attribute
¶
mid_modules = List([ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg, rngs=rngs), Attention1d(mid_dim, nhead, rngs=rngs) if has_attn and nhead is not None else Identity(), ConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, cond_predict_scale=cond_pred_scale, rngs=rngs) if not use_hypernet else HyperConditionalResidualBlock1d(mid_dim, mid_dim, cond_dim=feature_dim, conv_cfg=conv_cfg, hyper_mlp_cfg=hyper_mlp_cfg, rngs=rngs)])
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
base
|
Array
|
Input tensor of shape (B, L, C) in NLC format. |
required |
cond
|
Array
|
Conditional tensor of shape (B, cond_dim). |
required |
Returns:
| Type | Description |
|---|---|
Array
|
Output tensor of shape (B, L, C). |
Source code in src/ml_networks/jax/unet.py
ユーティリティ (ml_networks.jax.jax_utils)¶
get_optimizer ¶
Get optimizer from optax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Optimizer name (e.g. "adam", "sgd", "adamw", "lamb", "rmsprop"). |
required |
kwargs
|
dict
|
Optimizer arguments (e.g. learning_rate=0.01). |
{}
|
Returns:
| Type | Description |
|---|---|
GradientTransformation
|
|
Examples:
Source code in src/ml_networks/jax/jax_utils.py
jax_fix_seed ¶
乱数を固定する関数.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
seed
|
int
|
Random seed. |
42
|
Returns:
| Type | Description |
|---|---|
Array
|
JAX PRNG key. |
Source code in src/ml_networks/jax/jax_utils.py
MinMaxNormalize ¶
SoftmaxTransformation ¶
Softmax 変換クラス.
Source code in src/ml_networks/jax/jax_utils.py
Attributes¶
Functions¶
__call__ ¶
get_transformed_dim ¶
inverse ¶
SoftmaxTransformation の逆変換.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
入力テンソル. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
出力テンソル. |
Source code in src/ml_networks/jax/jax_utils.py
transform ¶
SoftmaxTransformation の実行.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array
|
入力テンソル. |
required |
Returns:
| Type | Description |
|---|---|
Array
|
出力テンソル. |
Examples:
>>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=16, sigma=0.01, n_ignore=1, min=-1.0, max=1.0))
>>> x = jnp.ones((2, 3, 4))
>>> transformed = trans(x)
>>> transformed.shape
(2, 3, 49)
>>> trans = SoftmaxTransformation(SoftmaxTransConfig(vector=11, sigma=0.05, n_ignore=0, min=-1.0, max=1.0))
>>> x = jnp.ones((2, 3, 4))
>>> transformed = trans(x)
>>> transformed.shape
(2, 3, 44)
Source code in src/ml_networks/jax/jax_utils.py
その他¶
HyperNet ¶
Bases: Module, HyperNetMixin
A hypernetwork that generates weights for a target network.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dim
|
int
|
Dimension of the input. |
required |
output_shapes
|
dict[str, Shape]
|
Shapes of the primary network weights being predicted. |
required |
fc_cfg
|
MLPConfig | None
|
Configuration for the MLP backbone. |
None
|
encoding
|
InputMode
|
The input encoding mode. Default is None. |
None
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/hypernetworks.py
Attributes¶
Functions¶
__call__ ¶
Forward pass.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
inputs
|
Array
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Array]
|
Dictionary of output tensors. |
Source code in src/ml_networks/jax/hypernetworks.py
ContrastiveLearningLoss ¶
Bases: Module
Contrastive learning module.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim_input1
|
int
|
Dimension of first input. |
required |
dim_input2
|
int
|
Dimension of second input. |
required |
cfg
|
ContrastiveLearningConfig
|
Configuration for contrastive learning. |
required |
rngs
|
Rngs
|
Random number generators. |
required |
Source code in src/ml_networks/jax/contrastive.py
Attributes¶
eval_func2
instance-attribute
¶
Functions¶
calc_nce ¶
Calculate the Noise Contrastive Estimation (NCE) loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature1
|
Array
|
First input tensor of shape (*, dim_input1) |
required |
feature2
|
Array
|
Second input tensor of shape (*, dim_input2) |
required |
return_emb
|
bool
|
Whether to return embeddings. Default is False. |
False
|
Returns:
| Type | Description |
|---|---|
dict or tuple
|
Loss dictionary, optionally with embeddings. |
Source code in src/ml_networks/jax/contrastive.py
calc_sigmoid ¶
Calculate the Sigmoid loss for contrastive learning.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature1
|
Array
|
First input tensor of shape (*, dim_input1) |
required |
feature2
|
Array
|
Second input tensor of shape (*, dim_input2) |
required |
return_emb
|
bool
|
Whether to return embeddings. Default is False. |
False
|
temperature
|
float
|
Temperature. Default is 0.1. |
0.1
|
bias
|
float
|
Bias. Default is 0.0. |
0.0
|
Returns:
| Type | Description |
|---|---|
dict or tuple
|
Loss dictionary, optionally with embeddings. |
Source code in src/ml_networks/jax/contrastive.py
calc_timeseries_nce ¶
calc_timeseries_nce(feature1, feature2, positive_range_self=0, positive_range_tgt=0, return_emb=False)
Calculate the NCE loss for time series data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
feature1
|
Array
|
First input tensor of shape (*batch, length, dim_input1) |
required |
feature2
|
Array
|
Second input tensor of shape (*batch, length, dim_input2) |
required |
positive_range_self
|
int
|
Range for self-positive samples. Default is 0. |
0
|
positive_range_tgt
|
int
|
Range for target-positive samples. Default is 0. |
0
|
return_emb
|
bool
|
Whether to return embeddings. Default is False. |
False
|
Returns:
| Type | Description |
|---|---|
dict or tuple
|
Loss dictionary, optionally with embeddings. |
Source code in src/ml_networks/jax/contrastive.py
BaseModule ¶
Bases: Module
Base module for JAX/Flax NNX.
Functions¶
freeze_biases ¶
Freeze all bias parameters.
Source code in src/ml_networks/jax/base.py
freeze_weights ¶
Freeze all weight parameters (kernel in Flax).
Source code in src/ml_networks/jax/base.py
unfreeze_biases ¶
Unfreeze all bias parameters.
Source code in src/ml_networks/jax/base.py
unfreeze_weights ¶
Unfreeze all weight parameters (kernel in Flax).