損失関数¶
損失関数を提供します。
ml_networks.torch.loss(PyTorch)とml_networks.jax.loss(JAX)の両方で提供されています。
focal_loss¶
focal_loss ¶
Focal loss function. Mainly for multi-class classification.
Reference
Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Tensor
|
The predicted tensor. This should be before softmax. |
required |
target
|
Tensor
|
The target tensor. |
required |
gamma
|
float
|
The gamma parameter. Default is 2.0. |
2.0
|
sum_dim
|
int
|
The dimension to sum the loss. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The focal loss. |
Source code in src/ml_networks/torch/loss.py
binary_focal_loss¶
binary_focal_loss ¶
Binary focal loss function. Mainly for binary classification.
Reference
Focal Loss for Dense Object Detection https://arxiv.org/abs/1708.02002
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Tensor
|
The predicted tensor. This should be before sigmoid. |
required |
target
|
Tensor
|
The target tensor. |
required |
gamma
|
float
|
The gamma parameter. Default is 2.0. |
2.0
|
sum_dim
|
int
|
The dimension to sum the loss. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The binary focal loss. |
Source code in src/ml_networks/torch/loss.py
charbonnier¶
charbonnier ¶
Charbonnier loss function.
Reference
A General and Adaptive Robust Loss Function http://arxiv.org/abs/1701.03077
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
prediction
|
Tensor
|
The predicted tensor. |
required |
target
|
Tensor
|
The target tensor. |
required |
epsilon
|
float
|
A small value to avoid division by zero. Default is 1e-3. |
0.001
|
alpha
|
float
|
The alpha parameter. Default is 1. |
1
|
sum_dim
|
int | list[int] | tuple[int, ...] | None
|
The dimension to sum the loss. Default is None (sums over [-1, -2, -3]). |
None
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The Charbonnier loss. |
Source code in src/ml_networks/torch/loss.py
FocalFrequencyLoss¶
FocalFrequencyLoss ¶
FocalFrequencyLoss(loss_weight=1.0, alpha=1.0, patch_factor=1, ave_spectrum=False, log_matrix=False, batch_matrix=False)
The torch.nn.Module class that implements focal frequency loss.
A frequency domain loss function for optimizing generative models.
Reference
Focal Frequency Loss for Image Reconstruction and Synthesis. In ICCV 2021. https://arxiv.org/pdf/2012.12821.pdf
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
loss_weight
|
float
|
weight for focal frequency loss. Default: 1.0 |
1.0
|
alpha
|
float
|
the scaling factor alpha of the spectrum weight matrix for flexibility. Default: 1.0 |
1.0
|
patch_factor
|
int
|
the factor to crop image patches for patch-based focal frequency loss. Default: 1 |
1
|
ave_spectrum
|
bool
|
whether to use minibatch average spectrum. Default: False |
False
|
log_matrix
|
bool
|
whether to adjust the spectrum weight matrix by logarithm. Default: False |
False
|
batch_matrix
|
bool
|
whether to calculate the spectrum weight matrix using batch-based statistics. Default: False |
False
|
Source code in src/ml_networks/torch/loss.py
Attributes¶
Functions¶
__call__ ¶
Forward function to calculate focal frequency loss.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
pred
|
Tensor
|
of shape (N, C, H, W). Predicted tensor. |
required |
target
|
Tensor
|
of shape (N, C, H, W). Target tensor. |
required |
matrix
|
Tensor | None
|
Default: None (If set to None: calculated online, dynamic). |
None
|
mean_batch
|
bool
|
Whether to average over batch dimension. |
True
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The focal frequency loss. |
Source code in src/ml_networks/torch/loss.py
loss_formulation ¶
Source code in src/ml_networks/torch/loss.py
tensor2freq ¶
Source code in src/ml_networks/torch/loss.py
kl_divergence¶
kl_divergence ¶
KL divergence between two distributions for StochState in ml-networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
posterior
|
StochState
|
The posterior distribution. |
required |
prior
|
StochState
|
The prior distribution. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
The KL divergence between the two distributions. |
Source code in src/ml_networks/torch/loss.py
kl_balancing¶
kl_balancing ¶
KL balancing loss function for StochState in ml-networks.
Reference
Mastering Atari with Discrete World Models. In NeurIPS 2020. https://arxiv.org/abs/2010.02193
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
posterior
|
StochState
|
The posterior distribution. |
required |
prior
|
StochState
|
The prior distribution. |
required |
weight
|
float
|
The weight of prior gradient for the balancing. Default is 0.8. |
0.8
|
Returns:
| Type | Description |
|---|---|
Tensor
|
The KL balancing loss. |