Skip to content

分布

分布関連のクラスと関数を提供します。

ml_networks.torch.distributions(PyTorch)とml_networks.jax.distributions(JAX)の両方で提供されています。

Distribution

Distribution

Distribution(in_dim, dist, n_groups=1, spherical=False)

Bases: Module

A distribution function.

Parameters:

Name Type Description Default
in_dim int

Input dimension.

required
dist Literal['normal', 'categorical', 'bernoulli']

Distribution type.

required
n_groups int

Number of groups. Default is 1. This is used for the categorical and Bernoulli distributions.

1
spherical bool

Whether to project samples to the unit sphere. Default is False. This is used for the categorical and Bernoulli distributions. If True and dist=="categorical", the samples are projected from {0, 1} to {-1, 1}. If True and dist=="bernoulli", the samples are projected from {0, 1} to the unit sphere.

refer to https://arxiv.org/abs/2406.07548

False

Examples:

>>> dist = Distribution(10, "normal")
>>> data = torch.randn(2, 20)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'NormalStoch'
>>> posterior.shape
NormalShape(mean=torch.Size([2, 10]), std=torch.Size([2, 10]), stoch=torch.Size([2, 10]))
>>> dist = Distribution(10, "categorical", n_groups=2)
>>> data = torch.randn(2, 10)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'CategoricalStoch'
>>> posterior.shape
CategoricalShape(logits=torch.Size([2, 2, 5]), probs=torch.Size([2, 2, 5]), stoch=torch.Size([2, 10]))
>>> dist = Distribution(10, "bernoulli", n_groups=2)
>>> data = torch.randn(2, 10)
>>> posterior = dist(data)
>>> posterior.__class__.__name__
'BernoulliStoch'
>>> posterior.shape
BernoulliShape(logits=torch.Size([2, 2, 5]), probs=torch.Size([2, 2, 5]), stoch=torch.Size([2, 10]))
Source code in src/ml_networks/torch/distributions.py
def __init__(
    self,
    in_dim: int,
    dist: Literal["normal", "categorical", "bernoulli"],
    n_groups: int = 1,
    spherical: bool = False,
) -> None:
    super().__init__()

    self.dist = dist
    self.spherical = spherical
    self.n_class = in_dim // n_groups
    self.in_dim = in_dim
    self.n_groups = n_groups

    if dist == "normal":
        self.posterior = self.normal  # type: ignore[assignment]
    elif dist == "categorical":
        self.posterior = self.categorical  # type: ignore[assignment]
    elif dist == "bernoulli":
        self.posterior = self.bernoulli  # type: ignore[assignment]
    else:
        raise NotImplementedError

    if spherical:
        self.codebook = BSQCodebook(self.n_class)

Attributes

codebook instance-attribute

codebook = BSQCodebook(n_class)

dist instance-attribute

dist = dist

in_dim instance-attribute

in_dim = in_dim

n_class instance-attribute

n_class = in_dim // n_groups

n_groups instance-attribute

n_groups = n_groups

posterior instance-attribute

posterior = normal

spherical instance-attribute

spherical = spherical

Functions

bernoulli

bernoulli(logits, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/torch/distributions.py
def bernoulli(self, logits: torch.Tensor, deterministic: bool = False, inv_tmp: float = 1.0) -> BernoulliStoch:
    batch_shape = logits.shape[:-1]
    chunked_logits = torch.chunk(logits, self.n_groups, dim=-1)
    logits = torch.stack(chunked_logits, dim=-2)
    logits = logits * inv_tmp
    probs = torch.sigmoid(logits)

    dist = BernoulliStraightThrough(probs=probs)
    posterior_dist = D.Independent(dist, 1)

    sample = posterior_dist.rsample()

    if self.spherical:
        sample = self.codebook.bits_to_codes(sample)

    if deterministic:
        sample = (
            torch.where(sample > 0.5, torch.ones_like(sample), torch.zeros_like(sample)) + probs - probs.detach()
        )

    return BernoulliStoch(
        logits,
        probs,
        sample.reshape([*batch_shape, -1]),
    )

categorical

categorical(logits, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/torch/distributions.py
def categorical(self, logits: torch.Tensor, deterministic: bool = False, inv_tmp: float = 1.0) -> CategoricalStoch:
    batch_shape = logits.shape[:-1]
    logits_chunk = torch.chunk(logits, self.n_groups, dim=-1)
    logits = torch.stack(logits_chunk, dim=-2)
    logits = logits
    probs = softmax(logits, dim=-1, temperature=1 / inv_tmp)
    dist = D.OneHotCategoricalStraightThrough(probs=probs)
    posterior_dist = D.Independent(dist, 1)

    sample = posterior_dist.rsample()

    if self.spherical:
        sample = sample * 2 - 1

    return CategoricalStoch(
        logits,
        probs,
        sample.reshape([*batch_shape, -1])
        if not deterministic
        else self.deterministic_onehot(probs).reshape([*batch_shape, -1]),
    )

deterministic_onehot

deterministic_onehot(input)

Compute the one-hot vector by argmax.

Parameters:

Name Type Description Default
input Tensor

Input tensor.

required

Returns:

Type Description
Tensor

One-hot vector.

Examples:

>>> input = torch.arange(6).reshape(2, 3) / 5.0
>>> dist = Distribution(3, "categorical")
>>> onehot = dist.deterministic_onehot(input)
>>> onehot
tensor([[0., 0., 1.],
        [0., 0., 1.]])
Source code in src/ml_networks/torch/distributions.py
def deterministic_onehot(self, input: torch.Tensor) -> torch.Tensor:
    """
    Compute the one-hot vector by argmax.

    Parameters
    ----------
    input : torch.Tensor
        Input tensor.

    Returns
    -------
    torch.Tensor
        One-hot vector.

    Examples
    --------
    >>> input = torch.arange(6).reshape(2, 3) / 5.0
    >>> dist = Distribution(3, "categorical")
    >>> onehot = dist.deterministic_onehot(input)
    >>> onehot
    tensor([[0., 0., 1.],
            [0., 0., 1.]])
    """
    return F.one_hot(input.argmax(dim=-1), num_classes=self.n_class) + input - input.detach()

forward

forward(x, deterministic=False, inv_tmp=1.0)

Compute the posterior distribution.

Parameters:

Name Type Description Default
x Tensor

Input tensor.

required
deterministic bool

Whether to use the deterministic mode. Default is False. if True and dist=="normal", the mean is returned. if True and dist=="categorical", the one-hot vector computed by argmax is returned. if True and dist=="bernoulli", 1 is returned if x > 0.5 or 0 is returned if x <= 0.5.

False
inv_tmp float

Inverse temperature. Default is 1.0. This is used for the categorical and Bernoulli distributions.

1.0

Returns:

Type Description
StochState

Posterior distribution.

Source code in src/ml_networks/torch/distributions.py
def forward(
    self,
    x: torch.Tensor,
    deterministic: bool = False,
    inv_tmp: float = 1.0,
) -> StochState:
    """
    Compute the posterior distribution.

    Parameters
    ----------
    x : torch.Tensor
        Input tensor.
    deterministic : bool, optional
        Whether to use the deterministic mode. Default is False.
        if True and dist=="normal", the mean is returned.
        if True and dist=="categorical", the one-hot vector computed by argmax is returned.
        if True and dist=="bernoulli", 1 is returned if x > 0.5 or 0 is returned if x <= 0.5.

    inv_tmp : float, optional
        Inverse temperature. Default is 1.0.
        This is used for the categorical and Bernoulli distributions.

    Returns
    -------
    StochState
        Posterior distribution.


    """
    return self.posterior(x, deterministic=deterministic, inv_tmp=inv_tmp)

normal

normal(mu_std, deterministic=False, inv_tmp=1.0)
Source code in src/ml_networks/torch/distributions.py
def normal(self, mu_std: torch.Tensor, deterministic: bool = False, inv_tmp: float = 1.0) -> NormalStoch:
    assert mu_std.shape[-1] == self.in_dim * 2, (
        f"mu_std.shape[-1] {mu_std.shape[-1]} and in_dim {self.in_dim} must be the same."
    )

    mu, std = torch.chunk(mu_std, 2, dim=-1)
    std = F.softplus(std) + 1e-6

    normal_dist = D.Normal(mu, std)
    posterior_dist = D.Independent(normal_dist, 1)

    sample = posterior_dist.rsample() if not deterministic else mu

    return NormalStoch(mu, std, sample if not deterministic else mu)

NormalStoch

NormalStoch dataclass

NormalStoch(mean, std, stoch)

Parameters of a normal distribution and its stochastic sample.

Attributes:

Name Type Description
mean Tensor

Mean of the normal distribution.

std Tensor

Standard deviation of the normal distribution.

stoch Tensor

sample from the normal distribution with reparametrization trick.

Attributes

mean instance-attribute

mean

shape property

shape

mean, std, stoch の shape をタプルで返す.

std instance-attribute

std

stoch instance-attribute

stoch

Functions

__getattr__

__getattr__(name)

torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

Parameters:

Name Type Description Default
name str

メソッド名。

required

Returns:

Type Description
callable

torch.Tensorのメソッドを各メンバに適用する関数。

Raises:

Type Description
AttributeError

指定された名前がtorch.Tensorのメソッドでない場合。

Source code in src/ml_networks/torch/distributions.py
def __getattr__(self, name: str) -> Any:
    """torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

    例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

    Parameters
    ----------
    name : str
        メソッド名。

    Returns
    -------
    callable
        torch.Tensorのメソッドを各メンバに適用する関数。

    Raises
    ------
    AttributeError
        指定された名前がtorch.Tensorのメソッドでない場合。
    """
    if hasattr(torch.Tensor, name):  # torch.Tensor のメソッドか確認

        def method(*args: Any, **kwargs: Any) -> NormalStoch:
            return NormalStoch(
                getattr(self.mean, name)(*args, **kwargs),
                getattr(self.std, name)(*args, **kwargs),
                getattr(self.stoch, name)(*args, **kwargs),
            )

        return method
    msg = f"'{self.__class__.__name__}' object has no attribute '{name}'"
    raise AttributeError(msg)

__getitem__

__getitem__(idx)

インデックスアクセス.

Parameters:

Name Type Description Default
idx int or slice or tuple

インデックス指定。

required

Returns:

Type Description
NormalStoch

指定されたインデックスに対応するNormalStoch

Source code in src/ml_networks/torch/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> NormalStoch:
    """インデックスアクセス.

    Parameters
    ----------
    idx : int or slice or tuple
        インデックス指定。

    Returns
    -------
    NormalStoch
        指定されたインデックスに対応する`NormalStoch`。
    """
    return NormalStoch(self.mean[idx], self.std[idx], self.stoch[idx])

__len__

__len__()

長さを返す.

Returns:

Type Description
int

バッチ次元の長さ。

Source code in src/ml_networks/torch/distributions.py
def __len__(self) -> int:
    """長さを返す.

    Returns
    -------
    int
        バッチ次元の長さ。
    """
    return self.stoch.shape[0]

__post_init__

__post_init__()

初期化後の処理.

Raises:

Type Description
ValueError

meanstd のshapeが異なる場合、またはstdに負の値が含まれる場合。

Source code in src/ml_networks/torch/distributions.py
def __post_init__(self) -> None:
    """初期化後の処理.

    Raises
    ------
    ValueError
        `mean` と `std` のshapeが異なる場合、または`std`に負の値が含まれる場合。
    """
    if self.mean.shape != self.std.shape:
        msg = f"mean.shape {self.mean.shape} and std.shape {self.std.shape} must be the same."
        raise ValueError(msg)
    if (self.std < 0).any():
        msg = "std must be non-negative."
        raise ValueError(msg)

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/torch/distributions.py
def get_distribution(self, independent: int = 1) -> D.Independent:
    return D.Independent(D.Normal(self.mean, self.std), independent)

save

save(path)

Save the parameters of the normal distribution to the specified path.

Parameters:

Name Type Description Default
path str

Path to save the parameters.

required
Source code in src/ml_networks/torch/distributions.py
def save(self, path: str) -> None:
    """
    Save the parameters of the normal distribution to the specified path.

    Parameters
    ----------
    path : str
        Path to save the parameters.

    """
    os.makedirs(path, exist_ok=True)

    save_blosc2(f"{path}/mean.blosc2", self.mean.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/std.blosc2", self.std.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/stoch.blosc2", self.stoch.detach().clone().cpu().numpy())

squeeze

squeeze(dim)

Squeeze the parameters of the normal distribution.

Parameters:

Name Type Description Default
dim int

Dimension to squeeze.

required

Returns:

Type Description
NormalStoch

Squeezed normal distribution.

Source code in src/ml_networks/torch/distributions.py
def squeeze(self, dim: int) -> NormalStoch:
    """
    Squeeze the parameters of the normal distribution.

    Parameters
    ----------
    dim : int
        Dimension to squeeze.

    Returns
    -------
    NormalStoch
        Squeezed normal distribution.

    """
    return NormalStoch(
        self.mean.squeeze(dim),
        self.std.squeeze(dim),
        self.stoch.squeeze(dim),
    )

unsqueeze

unsqueeze(dim)

Unsqueeze the parameters of the normal distribution.

Parameters:

Name Type Description Default
dim int

Dimension to unsqueeze.

required

Returns:

Type Description
NormalStoch

Unsqueezed normal distribution.

Source code in src/ml_networks/torch/distributions.py
def unsqueeze(self, dim: int) -> NormalStoch:
    """
    Unsqueeze the parameters of the normal distribution.

    Parameters
    ----------
    dim : int
        Dimension to unsqueeze.

    Returns
    -------
    NormalStoch
        Unsqueezed normal distribution.

    """
    return NormalStoch(
        self.mean.unsqueeze(dim),
        self.std.unsqueeze(dim),
        self.stoch.unsqueeze(dim),
    )

CategoricalStoch

CategoricalStoch dataclass

CategoricalStoch(logits, probs, stoch)

Parameters of a categorical distribution and its stochastic sample.

Attributes:

Name Type Description
logits Tensor

Logits of the categorical distribution.

probs Tensor

Probabilities of the categorical distribution.

stoch Tensor

sample from the categorical distribution with Straight-Through Estimator.

Attributes

logits instance-attribute

logits

probs instance-attribute

probs

shape property

shape

mean, std, stoch の shape をタプルで返す.

stoch instance-attribute

stoch

Functions

__getattr__

__getattr__(name)

torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

Parameters:

Name Type Description Default
name str

メソッド名。

required

Returns:

Type Description
callable

torch.Tensorのメソッドを各メンバに適用する関数。

Raises:

Type Description
AttributeError

指定された名前がtorch.Tensorのメソッドでない場合。

Source code in src/ml_networks/torch/distributions.py
def __getattr__(self, name: str) -> Any:
    """torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

    例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

    Parameters
    ----------
    name : str
        メソッド名。

    Returns
    -------
    callable
        torch.Tensorのメソッドを各メンバに適用する関数。

    Raises
    ------
    AttributeError
        指定された名前がtorch.Tensorのメソッドでない場合。
    """
    if hasattr(torch.Tensor, name):  # torch.Tensor のメソッドか確認

        def method(*args: Any, **kwargs: Any) -> CategoricalStoch:
            return CategoricalStoch(
                getattr(self.logits, name)(*args, **kwargs),
                getattr(self.probs, name)(*args, **kwargs),
                getattr(self.stoch, name)(*args, **kwargs),
            )

        return method
    msg = f"'{self.__class__.__name__}' object has no attribute '{name}'"
    raise AttributeError(msg)

__getitem__

__getitem__(idx)

インデックスアクセス.

Parameters:

Name Type Description Default
idx int or slice or tuple

インデックス指定。

required

Returns:

Type Description
CategoricalStoch

指定されたインデックスに対応するCategoricalStoch

Source code in src/ml_networks/torch/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> CategoricalStoch:
    """インデックスアクセス.

    Parameters
    ----------
    idx : int or slice or tuple
        インデックス指定。

    Returns
    -------
    CategoricalStoch
        指定されたインデックスに対応する`CategoricalStoch`。
    """
    return CategoricalStoch(self.logits[idx], self.probs[idx], self.stoch[idx])

__len__

__len__()

長さを返す.

Returns:

Type Description
int

バッチ次元の長さ。

Source code in src/ml_networks/torch/distributions.py
def __len__(self) -> int:
    """長さを返す.

    Returns
    -------
    int
        バッチ次元の長さ。
    """
    return self.stoch.shape[0]

__post_init__

__post_init__()

初期化後の処理.

Raises:

Type Description
ValueError

logitsprobs のshapeが異なる場合、 あるいはprobsが[0, 1]の範囲外、または和が1から大きくずれている場合。

Source code in src/ml_networks/torch/distributions.py
def __post_init__(self) -> None:
    """初期化後の処理.

    Raises
    ------
    ValueError
        `logits` と `probs` のshapeが異なる場合、
        あるいは`probs`が[0, 1]の範囲外、または和が1から大きくずれている場合。
    """
    if self.logits.shape != self.probs.shape:
        msg = f"logits.shape {self.logits.shape} and probs.shape {self.probs.shape} must be the same."
        raise ValueError(msg)
    if (self.probs < 0).any() or (self.probs > 1).any():
        msg = "probs must be in the range [0, 1]."
        raise ValueError(msg)
    if (self.probs.sum(dim=-1) - 1).abs().max() > 1e-6:
        msg = "probs must sum to 1."
        raise ValueError(msg)

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/torch/distributions.py
def get_distribution(self, independent: int = 1) -> D.Independent:
    return D.Independent(D.OneHotCategoricalStraightThrough(self.probs), independent)

save

save(path)

Save the parameters of the categorical distribution to the specified path.

Parameters:

Name Type Description Default
path str

Path to save the parameters.

required
Source code in src/ml_networks/torch/distributions.py
def save(self, path: str) -> None:
    """
    Save the parameters of the categorical distribution to the specified path.

    Parameters
    ----------
    path : str
        Path to save the parameters.
    """
    os.makedirs(path, exist_ok=True)

    save_blosc2(f"{path}/logits.blosc2", self.logits.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/probs.blosc2", self.probs.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/stoch.blosc2", self.stoch.detach().clone().cpu().numpy())

squeeze

squeeze(dim)

Squeeze the parameters of the categorical distribution.

Parameters:

Name Type Description Default
dim int

Dimension to squeeze.

required

Returns:

Type Description
CategoricalStoch

Squeezed categorical distribution.

Source code in src/ml_networks/torch/distributions.py
def squeeze(self, dim: int) -> CategoricalStoch:
    """
    Squeeze the parameters of the categorical distribution.

    Parameters
    ----------
    dim : int
        Dimension to squeeze.

    Returns
    -------
    CategoricalStoch
        Squeezed categorical distribution.

    """
    return CategoricalStoch(
        self.logits.squeeze(dim),
        self.probs.squeeze(dim),
        self.stoch.squeeze(dim),
    )

unsqueeze

unsqueeze(dim)

Unsqueeze the parameters of the categorical distribution.

Parameters:

Name Type Description Default
dim int

Dimension to unsqueeze.

required

Returns:

Type Description
CategoricalStoch

Unsqueezed categorical distribution.

Source code in src/ml_networks/torch/distributions.py
def unsqueeze(self, dim: int) -> CategoricalStoch:
    """
    Unsqueeze the parameters of the categorical distribution.

    Parameters
    ----------
    dim : int
        Dimension to unsqueeze.

    Returns
    -------
    CategoricalStoch
        Unsqueezed categorical distribution.

    """
    return CategoricalStoch(
        self.logits.unsqueeze(dim),
        self.probs.unsqueeze(dim),
        self.stoch.unsqueeze(dim),
    )

BernoulliStoch

BernoulliStoch dataclass

BernoulliStoch(logits, probs, stoch)

Parameters of a Bernoulli distribution and its stochastic sample.

Attributes:

Name Type Description
logits Tensor

Logits of the Bernoulli distribution.

probs Tensor

Probabilities of the Bernoulli distribution.

stoch Tensor

sample from the Bernoulli distribution with Straight-Through Estimator.

Attributes

logits instance-attribute

logits

probs instance-attribute

probs

shape property

shape

mean, std, stoch の shape をタプルで返す.

stoch instance-attribute

stoch

Functions

__getattr__

__getattr__(name)

torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

Parameters:

Name Type Description Default
name str

メソッド名。

required

Returns:

Type Description
callable

torch.Tensorのメソッドを各メンバに適用する関数。

Raises:

Type Description
AttributeError

指定された名前がtorch.Tensorのメソッドでない場合。

Source code in src/ml_networks/torch/distributions.py
def __getattr__(self, name: str) -> Any:
    """torch.Tensor に含まれるメソッドを呼び出したら、各メンバに適用する.

    例: normal.flatten() → NormalStoch(mean.flatten(), std.flatten(), stoch.flatten()).

    Parameters
    ----------
    name : str
        メソッド名。

    Returns
    -------
    callable
        torch.Tensorのメソッドを各メンバに適用する関数。

    Raises
    ------
    AttributeError
        指定された名前がtorch.Tensorのメソッドでない場合。
    """
    if hasattr(torch.Tensor, name):  # torch.Tensor のメソッドか確認

        def method(*args: Any, **kwargs: Any) -> BernoulliStoch:
            return BernoulliStoch(
                getattr(self.logits, name)(*args, **kwargs),
                getattr(self.probs, name)(*args, **kwargs),
                getattr(self.stoch, name)(*args, **kwargs),
            )

        return method
    msg = f"'{self.__class__.__name__}' object has no attribute '{name}'"
    raise AttributeError(msg)

__getitem__

__getitem__(idx)

インデックスアクセス.

Parameters:

Name Type Description Default
idx int or slice or tuple

インデックス指定。

required

Returns:

Type Description
BernoulliStoch

指定されたインデックスに対応するBernoulliStoch

Source code in src/ml_networks/torch/distributions.py
def __getitem__(self, idx: int | slice | tuple) -> BernoulliStoch:
    """インデックスアクセス.

    Parameters
    ----------
    idx : int or slice or tuple
        インデックス指定。

    Returns
    -------
    BernoulliStoch
        指定されたインデックスに対応する`BernoulliStoch`。
    """
    return BernoulliStoch(self.logits[idx], self.probs[idx], self.stoch[idx])

__len__

__len__()

長さを返す.

Returns:

Type Description
int

バッチ次元の長さ。

Source code in src/ml_networks/torch/distributions.py
def __len__(self) -> int:
    """長さを返す.

    Returns
    -------
    int
        バッチ次元の長さ。
    """
    return self.stoch.shape[0]

__post_init__

__post_init__()

初期化後の処理.

Raises:

Type Description
ValueError

logitsprobs のshapeが異なる場合、あるいはprobsが[0, 1]の範囲外の場合。

Source code in src/ml_networks/torch/distributions.py
def __post_init__(self) -> None:
    """初期化後の処理.

    Raises
    ------
    ValueError
        `logits` と `probs` のshapeが異なる場合、あるいは`probs`が[0, 1]の範囲外の場合。
    """
    if self.logits.shape != self.probs.shape:
        msg = f"logits.shape {self.logits.shape} and probs.shape {self.probs.shape} must be the same."
        raise ValueError(msg)
    if (self.probs < 0).any() or (self.probs > 1).any():
        msg = "probs must be in the range [0, 1]."
        raise ValueError(msg)

get_distribution

get_distribution(independent=1)
Source code in src/ml_networks/torch/distributions.py
def get_distribution(self, independent: int = 1) -> D.Independent:
    return D.Independent(BernoulliStraightThrough(self.probs), independent)

save

save(path)
Source code in src/ml_networks/torch/distributions.py
def save(self, path: str) -> None:
    os.makedirs(path, exist_ok=True)

    save_blosc2(f"{path}/logits.blosc2", self.logits.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/probs.blosc2", self.probs.detach().clone().cpu().numpy())
    save_blosc2(f"{path}/stoch.blosc2", self.stoch.detach().clone().cpu().numpy())

squeeze

squeeze(dim)

Squeeze the parameters of the Bernoulli distribution.

Parameters:

Name Type Description Default
dim int

Dimension to squeeze.

required

Returns:

Type Description
BernoulliStoch

Squeezed Bernoulli distribution.

Source code in src/ml_networks/torch/distributions.py
def squeeze(self, dim: int) -> BernoulliStoch:
    """
    Squeeze the parameters of the Bernoulli distribution.

    Parameters
    ----------
    dim : int
        Dimension to squeeze.

    Returns
    -------
    BernoulliStoch
        Squeezed Bernoulli distribution.

    """
    return BernoulliStoch(
        self.logits.squeeze(dim),
        self.probs.squeeze(dim),
        self.stoch.squeeze(dim),
    )

unsqueeze

unsqueeze(dim)

Unsqueeze the parameters of the Bernoulli distribution.

Parameters:

Name Type Description Default
dim int

Dimension to unsqueeze.

required

Returns:

Type Description
BernoulliStoch

Unsqueezed Bernoulli distribution.

Source code in src/ml_networks/torch/distributions.py
def unsqueeze(self, dim: int) -> BernoulliStoch:
    """
    Unsqueeze the parameters of the Bernoulli distribution.

    Parameters
    ----------
    dim : int
        Dimension to unsqueeze.

    Returns
    -------
    BernoulliStoch
        Unsqueezed Bernoulli distribution.

    """
    return BernoulliStoch(
        self.logits.unsqueeze(dim),
        self.probs.unsqueeze(dim),
        self.stoch.unsqueeze(dim),
    )

StochState

StochState module-attribute

StochState = NormalStoch | CategoricalStoch | BernoulliStoch

BSQCodebook

BSQCodebook

BSQCodebook(codebook_dim)

Bases: Module

Binary Spherical Quantization codebook.

Reference

https://arxiv.org/abs/2406.07548

Parameters:

Name Type Description Default
codebook_dim int

Dimension of the codebook.

required

Attributes:

Name Type Description
codebook_dim int

Dimension of the codebook.

codebook_size int

Size of the codebook. This is equal to 2 ** codebook_dim.

codebook Tensor

Codebook.

Source code in src/ml_networks/torch/distributions.py
def __init__(
    self,
    codebook_dim: int,
) -> None:
    super().__init__()
    self.codebook_dim = codebook_dim
    self.codebook_size = 2**codebook_dim
    mask = 2 ** torch.arange(codebook_dim - 1, -1, -1)
    self.mask: torch.Tensor
    self.register_buffer("mask", mask)
    self.mask = mask
    all_codes = torch.arange(self.codebook_size)
    bits = ((all_codes[..., None].int() & self.mask) != 0).float()
    codebook = self.bits_to_codes(bits)
    self.register_buffer("codebook", codebook.float(), persistent=False)

Attributes

codebook_dim instance-attribute

codebook_dim = codebook_dim

codebook_size instance-attribute

codebook_size = 2 ** codebook_dim

mask instance-attribute

mask = mask

Functions

bits_to_codes staticmethod

bits_to_codes(bits)

Convert bits to codes, which are bits of either 0 or 1.

Parameters:

Name Type Description Default
bits Tensor

Bits of either 0 or 1.

required

Returns:

Type Description
Tensor

Codes, which are bits depending on codebook_dim(dimension of the sphery)

Examples:

>>> bits = torch.tensor([[0., 1.], [1., 0.]])
>>> BSQCodebook.bits_to_codes(bits)
tensor([[-0.7071,  0.7071],
        [ 0.7071, -0.7071]])
>>> bits = torch.tensor([[0., 0., 1.], [1., 0., 0.]])
>>> BSQCodebook.bits_to_codes(bits)
tensor([[-0.5774, -0.5774,  0.5774],
        [ 0.5774, -0.5774, -0.5774]])
Source code in src/ml_networks/torch/distributions.py
@staticmethod
def bits_to_codes(bits: torch.Tensor) -> torch.Tensor:
    """Convert bits to codes, which are bits of either 0 or 1.

    Parameters
    ----------
    bits : torch.Tensor
        Bits of either 0 or 1.

    Returns
    -------
    torch.Tensor
        Codes, which are bits depending on codebook_dim(dimension of the sphery)

    Examples
    --------
    >>> bits = torch.tensor([[0., 1.], [1., 0.]])
    >>> BSQCodebook.bits_to_codes(bits)
    tensor([[-0.7071,  0.7071],
            [ 0.7071, -0.7071]])

    >>> bits = torch.tensor([[0., 0., 1.], [1., 0., 0.]])
    >>> BSQCodebook.bits_to_codes(bits)
    tensor([[-0.5774, -0.5774,  0.5774],
            [ 0.5774, -0.5774, -0.5774]])


    """
    bits = bits * 2 - 1
    return F.normalize(bits, dim=-1)

indices_to_codes

indices_to_codes(indices)

Convert indices to codes, which are bits of either -1 or 1.

Parameters:

Name Type Description Default
indices Tensor

Indices.

required

Returns:

Type Description
Tensor

Codes, which are bits depending on codebook_dim(dimension of the sphery)

Examples:

>>> indices = torch.tensor([[31], [19]])
>>> codebook = BSQCodebook(5)
>>> codebook.indices_to_codes(indices)
tensor([[ 0.4472,  0.4472,  0.4472,  0.4472,  0.4472],
        [ 0.4472, -0.4472, -0.4472,  0.4472,  0.4472]])
Source code in src/ml_networks/torch/distributions.py
def indices_to_codes(
    self,
    indices: torch.Tensor,
) -> torch.Tensor:
    """
    Convert indices to codes, which are bits of either -1 or 1.

    Parameters
    ----------
    indices : torch.Tensor
        Indices.

    Returns
    -------
    torch.Tensor
        Codes, which are bits depending on codebook_dim(dimension of the sphery)

    Examples
    --------
    >>> indices = torch.tensor([[31], [19]])
    >>> codebook = BSQCodebook(5)
    >>> codebook.indices_to_codes(indices)
    tensor([[ 0.4472,  0.4472,  0.4472,  0.4472,  0.4472],
            [ 0.4472, -0.4472, -0.4472,  0.4472,  0.4472]])

    """
    indices = indices.squeeze(-1)

    # indices to codes, which are bits of either -1 or 1

    bits = ((indices[..., None].int() & self.mask) != 0).float()

    return self.bits_to_codes(bits)

stack_dist

stack_dist

stack_dist(stochs, dim=0)

Stack the parameters of the distributions.

Parameters:

Name Type Description Default
stochs Tuple[StochState, ...]

Tuple of the distributions.

required
dim int

Dimension to stack the parameters of the distributions. Default is 0.

0

Returns:

Type Description
StochState

Stacked distribution.

Examples:

>>> dist1 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> dist2 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> stack_dist = stack_dist((dist1, dist2))
>>> stack_dist.shape
NormalShape(mean=torch.Size([2, 2, 3]), std=torch.Size([2, 2, 3]), stoch=torch.Size([2, 2, 3]))
Source code in src/ml_networks/torch/distributions.py
def stack_dist(stochs: tuple[StochState, ...], dim: int = 0) -> StochState | None:
    """
    Stack the parameters of the distributions.

    Parameters
    ----------
    stochs : Tuple[StochState, ...]
        Tuple of the distributions.
    dim : int, optional
        Dimension to stack the parameters of the distributions. Default is 0.

    Returns
    -------
    StochState
        Stacked distribution.

    Examples
    --------
    >>> dist1 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
    >>> dist2 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
    >>> stack_dist = stack_dist((dist1, dist2))
    >>> stack_dist.shape
    NormalShape(mean=torch.Size([2, 2, 3]), std=torch.Size([2, 2, 3]), stoch=torch.Size([2, 2, 3]))

    """
    if isinstance(stochs[0], NormalStoch):
        return NormalStoch(
            torch.stack([stoch.mean for stoch in stochs], dim=dim),
            torch.stack([stoch.std for stoch in stochs], dim=dim),
            torch.stack([stoch.stoch for stoch in stochs], dim=dim),
        )
    if isinstance(stochs[0], CategoricalStoch):
        return CategoricalStoch(
            torch.stack([stoch.logits for stoch in stochs], dim=dim),
            torch.stack([stoch.probs for stoch in stochs], dim=dim),
            torch.stack([stoch.stoch for stoch in stochs], dim=dim),
        )
    if isinstance(stochs[0], BernoulliStoch):
        return BernoulliStoch(
            torch.stack([stoch.logits for stoch in stochs], dim=dim),
            torch.stack([stoch.probs for stoch in stochs], dim=dim),
            torch.stack([stoch.stoch for stoch in stochs], dim=dim),
        )
    return None

cat_dist

cat_dist

cat_dist(stochs, dim=-1)

Concatenate the parameters of the distributions.

Parameters:

Name Type Description Default
stochs Tuple[StochState, ...]

Tuple of the distributions.

required
dim int

Dimension to concatenate the parameters of the distributions. Default is -1.

-1

Returns:

Type Description
StochState

Concatenated distribution.

Examples:

>>> dist1 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> dist2 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
>>> cat_dist = cat_dist((dist1, dist2))
>>> cat_dist.shape
NormalShape(mean=torch.Size([2, 6]), std=torch.Size([2, 6]), stoch=torch.Size([2, 6]))
Source code in src/ml_networks/torch/distributions.py
def cat_dist(stochs: tuple[StochState, ...], dim: int = -1) -> StochState | None:
    """
    Concatenate the parameters of the distributions.

    Parameters
    ----------
    stochs : Tuple[StochState, ...]
        Tuple of the distributions.
    dim : int, optional
        Dimension to concatenate the parameters of the distributions.
        Default is -1.

    Returns
    -------
    StochState
        Concatenated distribution.

    Examples
    --------
    >>> dist1 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
    >>> dist2 = NormalStoch(torch.randn(2, 3), torch.rand(2, 3), torch.randn(2, 3))
    >>> cat_dist = cat_dist((dist1, dist2))
    >>> cat_dist.shape
    NormalShape(mean=torch.Size([2, 6]), std=torch.Size([2, 6]), stoch=torch.Size([2, 6]))

    """
    if isinstance(stochs[0], NormalStoch):
        return NormalStoch(
            torch.cat([stoch.mean for stoch in stochs], dim=dim),
            torch.cat([stoch.std for stoch in stochs], dim=dim),
            torch.cat([stoch.stoch for stoch in stochs], dim=dim),
        )
    if isinstance(stochs[0], CategoricalStoch):
        return CategoricalStoch(
            torch.cat([stoch.logits for stoch in stochs], dim=dim),
            torch.cat([stoch.probs for stoch in stochs], dim=dim),
            torch.cat([stoch.stoch for stoch in stochs], dim=dim),
        )
    if isinstance(stochs[0], BernoulliStoch):
        return BernoulliStoch(
            torch.cat([stoch.logits for stoch in stochs], dim=dim),
            torch.cat([stoch.probs for stoch in stochs], dim=dim),
            torch.cat([stoch.stoch for stoch in stochs], dim=dim),
        )
    return None

get_dist

get_dist

get_dist(state)

確率的状態から分布を取得する.

Parameters:

Name Type Description Default
state StochState

正規分布・カテゴリカル分布・ベルヌーイ分布のいずれかの確率的状態。

required

Returns:

Type Description
Independent

与えられた状態に対応するtorch.distributions.Independentオブジェクト。

Source code in src/ml_networks/torch/distributions.py
def get_dist(state: StochState) -> D.Independent:
    """確率的状態から分布を取得する.

    Parameters
    ----------
    state : StochState
        正規分布・カテゴリカル分布・ベルヌーイ分布のいずれかの確率的状態。

    Returns
    -------
    D.Independent
        与えられた状態に対応する`torch.distributions.Independent`オブジェクト。
    """
    if isinstance(state, NormalStoch):
        normal = D.Normal(state.mean, state.std)
        return D.Independent(normal, 1)
    if isinstance(state, CategoricalStoch):
        categorical = D.OneHotCategoricalStraightThrough(probs=state.probs)
        return D.Independent(categorical, 1)
    if isinstance(state, BernoulliStoch):
        bernoulli = BernoulliStraightThrough(probs=state.probs)
        return D.Independent(bernoulli, 2)
    raise NotImplementedError