分布¶
分布関連のクラスと関数を提供します。
ml_networks.torch.distributions(PyTorch)とml_networks.jax.distributions(JAX)の両方で提供されています。
Distribution¶
Distribution ¶
Bases: Module
A distribution function.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_dim
|
int
|
Input dimension. |
required |
dist
|
Literal['normal', 'categorical', 'bernoulli']
|
Distribution type. |
required |
n_groups
|
int
|
Number of groups. Default is 1. This is used for the categorical and Bernoulli distributions. |
1
|
spherical
|
bool
|
Whether to project samples to the unit sphere. Default is False. This is used for the categorical and Bernoulli distributions. If True and dist=="categorical", the samples are projected from {0, 1} to {-1, 1}. If True and dist=="bernoulli", the samples are projected from {0, 1} to the unit sphere. refer to https://arxiv.org/abs/2406.07548 |
False
|
Examples:
>>> dist = Distribution(10, "normal")
>>> data = torch.randn(2, 20)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'NormalStoch'
>>> posterior.shape
NormalShape(mean=torch.Size([2, 10]), std=torch.Size([2, 10]), stoch=torch.Size([2, 10]))
>>> dist = Distribution(10, "categorical", n_groups=2)
>>> data = torch.randn(2, 10)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'CategoricalStoch'
>>> posterior.shape
CategoricalShape(logits=torch.Size([2, 2, 5]), probs=torch.Size([2, 2, 5]), stoch=torch.Size([2, 10]))
>>> dist = Distribution(10, "bernoulli", n_groups=2)
>>> data = torch.randn(2, 10)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'BernoulliStoch'
>>> posterior.shape
BernoulliShape(logits=torch.Size([2, 2, 5]), probs=torch.Size([2, 2, 5]), stoch=torch.Size([2, 10]))
Source code in src/ml_networks/torch/distributions.py
Attributes¶
Functions¶
bernoulli ¶
Source code in src/ml_networks/torch/distributions.py
categorical ¶
Source code in src/ml_networks/torch/distributions.py
deterministic_onehot ¶
Compute the one-hot vector by argmax.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input
|
Tensor
|
Input tensor. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
One-hot vector. |
Examples:
>>> input = torch.arange(6).reshape(2, 3) / 5.0
>>> dist = Distribution(3, "categorical")
>>> onehot = dist.deterministic_onehot(input)
>>> onehot
tensor([[0., 0., 1.],
[0., 0., 1.]])
Source code in src/ml_networks/torch/distributions.py
forward ¶
Compute the posterior distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Tensor
|
Input tensor. |
required |
deterministic
|
bool
|
Whether to use the deterministic mode. Default is False. if True and dist=="normal", the mean is returned. if True and dist=="categorical", the one-hot vector computed by argmax is returned. if True and dist=="bernoulli", 1 is returned if x > 0.5 or 0 is returned if x <= 0.5. |
False
|
inv_tmp
|
float
|
Inverse temperature. Default is 1.0. This is used for the categorical and Bernoulli distributions. |
1.0
|
Returns:
| Type | Description |
|---|---|
StochState
|
Posterior distribution. |
Source code in src/ml_networks/torch/distributions.py
normal ¶
Source code in src/ml_networks/torch/distributions.py
NormalStoch¶
NormalStoch
dataclass
¶
Parameters of a normal distribution and its stochastic sample.
Attributes:
| Name | Type | Description |
|---|---|---|
mean |
Tensor
|
Mean of the normal distribution. |
std |
Tensor
|
Standard deviation of the normal distribution. |
stoch |
Tensor
|
sample from the normal distribution with reparametrization trick. |
Attributes¶
Functions¶
__getattr__ ¶
torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.
例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
メソッド名。 |
required |
Returns:
| Type | Description |
|---|---|
callable
|
torch.Tensorのメソッドを各メンバに適用する関数。 |
Raises:
| Type | Description |
|---|---|
AttributeError
|
指定された名前がtorch.Tensorのメソッドでない場合。 |
Source code in src/ml_networks/torch/distributions.py
__getitem__ ¶
インデックスアクセス.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int or slice or tuple
|
インデックス指定。 |
required |
Returns:
| Type | Description |
|---|---|
NormalStoch
|
指定されたインデックスに対応する |
Source code in src/ml_networks/torch/distributions.py
__len__ ¶
__post_init__ ¶
初期化後の処理.
Raises:
| Type | Description |
|---|---|
ValueError
|
|
Source code in src/ml_networks/torch/distributions.py
get_distribution ¶
save ¶
Save the parameters of the normal distribution to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Path to save the parameters. |
required |
Source code in src/ml_networks/torch/distributions.py
squeeze ¶
Squeeze the parameters of the normal distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to squeeze. |
required |
Returns:
| Type | Description |
|---|---|
NormalStoch
|
Squeezed normal distribution. |
Source code in src/ml_networks/torch/distributions.py
unsqueeze ¶
Unsqueeze the parameters of the normal distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to unsqueeze. |
required |
Returns:
| Type | Description |
|---|---|
NormalStoch
|
Unsqueezed normal distribution. |
Source code in src/ml_networks/torch/distributions.py
CategoricalStoch¶
CategoricalStoch
dataclass
¶
Parameters of a categorical distribution and its stochastic sample.
Attributes:
| Name | Type | Description |
|---|---|---|
logits |
Tensor
|
Logits of the categorical distribution. |
probs |
Tensor
|
Probabilities of the categorical distribution. |
stoch |
Tensor
|
sample from the categorical distribution with Straight-Through Estimator. |
Attributes¶
Functions¶
__getattr__ ¶
torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.
例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
メソッド名。 |
required |
Returns:
| Type | Description |
|---|---|
callable
|
torch.Tensorのメソッドを各メンバに適用する関数。 |
Raises:
| Type | Description |
|---|---|
AttributeError
|
指定された名前がtorch.Tensorのメソッドでない場合。 |
Source code in src/ml_networks/torch/distributions.py
__getitem__ ¶
インデックスアクセス.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int or slice or tuple
|
インデックス指定。 |
required |
Returns:
| Type | Description |
|---|---|
CategoricalStoch
|
指定されたインデックスに対応する |
Source code in src/ml_networks/torch/distributions.py
__len__ ¶
__post_init__ ¶
初期化後の処理.
Raises:
| Type | Description |
|---|---|
ValueError
|
|
Source code in src/ml_networks/torch/distributions.py
get_distribution ¶
save ¶
Save the parameters of the categorical distribution to the specified path.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str
|
Path to save the parameters. |
required |
Source code in src/ml_networks/torch/distributions.py
squeeze ¶
Squeeze the parameters of the categorical distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to squeeze. |
required |
Returns:
| Type | Description |
|---|---|
CategoricalStoch
|
Squeezed categorical distribution. |
Source code in src/ml_networks/torch/distributions.py
unsqueeze ¶
Unsqueeze the parameters of the categorical distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to unsqueeze. |
required |
Returns:
| Type | Description |
|---|---|
CategoricalStoch
|
Unsqueezed categorical distribution. |
Source code in src/ml_networks/torch/distributions.py
BernoulliStoch¶
BernoulliStoch
dataclass
¶
Parameters of a Bernoulli distribution and its stochastic sample.
Attributes:
| Name | Type | Description |
|---|---|---|
logits |
Tensor
|
Logits of the Bernoulli distribution. |
probs |
Tensor
|
Probabilities of the Bernoulli distribution. |
stoch |
Tensor
|
sample from the Bernoulli distribution with Straight-Through Estimator. |
Attributes¶
Functions¶
__getattr__ ¶
torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.
例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
メソッド名。 |
required |
Returns:
| Type | Description |
|---|---|
callable
|
torch.Tensorのメソッドを各メンバに適用する関数。 |
Raises:
| Type | Description |
|---|---|
AttributeError
|
指定された名前がtorch.Tensorのメソッドでない場合。 |
Source code in src/ml_networks/torch/distributions.py
__getitem__ ¶
インデックスアクセス.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
idx
|
int or slice or tuple
|
インデックス指定。 |
required |
Returns:
| Type | Description |
|---|---|
BernoulliStoch
|
指定されたインデックスに対応する |
Source code in src/ml_networks/torch/distributions.py
__len__ ¶
__post_init__ ¶
初期化後の処理.
Raises:
| Type | Description |
|---|---|
ValueError
|
|
Source code in src/ml_networks/torch/distributions.py
get_distribution ¶
save ¶
Source code in src/ml_networks/torch/distributions.py
squeeze ¶
Squeeze the parameters of the Bernoulli distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to squeeze. |
required |
Returns:
| Type | Description |
|---|---|
BernoulliStoch
|
Squeezed Bernoulli distribution. |
Source code in src/ml_networks/torch/distributions.py
unsqueeze ¶
Unsqueeze the parameters of the Bernoulli distribution.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Dimension to unsqueeze. |
required |
Returns:
| Type | Description |
|---|---|
BernoulliStoch
|
Unsqueezed Bernoulli distribution. |
Source code in src/ml_networks/torch/distributions.py
StochState¶
BSQCodebook¶
BSQCodebook ¶
Bases: Module
Binary Spherical Quantization codebook.
Reference
https://arxiv.org/abs/2406.07548
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
codebook_dim
|
int
|
Dimension of the codebook. |
required |
Attributes:
| Name | Type | Description |
|---|---|---|
codebook_dim |
int
|
Dimension of the codebook. |
codebook_size |
int
|
Size of the codebook. This is equal to 2 ** codebook_dim. |
codebook |
Tensor
|
Codebook. |
Source code in src/ml_networks/torch/distributions.py
Attributes¶
Functions¶
bits_to_codes
staticmethod
¶
Convert bits to codes, which are bits of either 0 or 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
bits
|
Tensor
|
Bits of either 0 or 1. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Codes, which are bits depending on codebook_dim(dimension of the sphery) |
Examples:
>>> bits = torch.tensor([[0., 1.], [1., 0.]])
>>> BSQCodebook.bits_to_codes(bits)
tensor([[-0.7071, 0.7071],
[ 0.7071, -0.7071]])
>>> bits = torch.tensor([[0., 0., 1.], [1., 0., 0.]])
>>> BSQCodebook.bits_to_codes(bits)
tensor([[-0.5774, -0.5774, 0.5774],
[ 0.5774, -0.5774, -0.5774]])
Source code in src/ml_networks/torch/distributions.py
indices_to_codes ¶
Convert indices to codes, which are bits of either -1 or 1.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
indices
|
Tensor
|
Indices. |
required |
Returns:
| Type | Description |
|---|---|
Tensor
|
Codes, which are bits depending on codebook_dim(dimension of the sphery) |
Examples:
>>> indices = torch.tensor([[31], [19]])
>>> codebook = BSQCodebook(5)
>>> codebook.indices_to_codes(indices)
tensor([[ 0.4472, 0.4472, 0.4472, 0.4472, 0.4472],
[ 0.4472, -0.4472, -0.4472, 0.4472, 0.4472]])
Source code in src/ml_networks/torch/distributions.py
stack_dist¶
stack_dist ¶
Stack the parameters of the distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stochs
|
Tuple[StochState, ...]
|
Tuple of the distributions. |
required |
dim
|
int
|
Dimension to stack the parameters of the distributions. Default is 0. |
0
|
Returns:
| Type | Description |
|---|---|
StochState
|
Stacked distribution. |
Examples:
>>> dist1 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> dist2 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> stack_dist = stack_dist((dist1, dist2))
>>> stack_dist.shape
NormalShape(mean=torch.Size([2, 2, 3]), std=torch.Size([2, 2, 3]), stoch=torch.Size([2, 2, 3]))
Source code in src/ml_networks/torch/distributions.py
cat_dist¶
cat_dist ¶
Concatenate the parameters of the distributions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
stochs
|
Tuple[StochState, ...]
|
Tuple of the distributions. |
required |
dim
|
int
|
Dimension to concatenate the parameters of the distributions. Default is -1. |
-1
|
Returns:
| Type | Description |
|---|---|
StochState
|
Concatenated distribution. |
Examples:
>>> dist1 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> dist2 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> cat_dist = cat_dist((dist1, dist2))
>>> cat_dist.shape
NormalShape(mean=torch.Size([2, 6]), std=torch.Size([2, 6]), stoch=torch.Size([2, 6]))
Source code in src/ml_networks/torch/distributions.py
get_dist¶
get_dist ¶
確率的状態から分布を取得する.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
state
|
StochState
|
正規分布・カテゴリカル分布・ベルヌーイ分布のいずれかの確率的状態。 |
required |
Returns:
| Type | Description |
|---|---|
Independent
|
与えられた状態に対応する |