ml-networks¶
村田研共通フレームワーク化計画の一環として、基本的な深層学習モデルのアーキテクチャを提供するPythonパッケージです。
概要¶
ml-networksは、PyTorchおよびJAX(Flax NNX)ベースの深層学習モデル構築を支援するライブラリです。以下の機能を提供します:
- 基本的なニューラルネットワークアーキテクチャ: MLP、Encoder、Decoder、UNet、Vision Transformer(ViT)など
- 分布のサポート: 正規分布、カテゴリカル分布、ベルヌーイ分布、BSQコードブック
- 損失関数: Focal Loss、Charbonnier Loss、Focal Frequency Loss、KLダイバージェンスなど
- 便利なユーティリティ: 活性化関数、最適化手法、データ保存・読み込み機能
- 高度な機能: HyperNetwork、対照学習(Contrastive Learning)、条件付きUNet
特徴¶
- マルチフレームワーク対応: PyTorchとJAX(Flax NNX)の両方をサポート。同一のConfig体系で切り替え可能
- 使いやすい: 直感的なAPI設計。YAMLファイルから設定を読み込み、
hydra.utils.instantiateでインスタンス化 - 柔軟性: 豊富な設定オプション。バックボーン、正規化、活性化関数を自由に組み合わせ
- 包括的: 深層学習に必要な主要コンポーネントを網羅
- 実用的: 型チェック(mypy)対応、CI/CD整備済み
パッケージ構成¶
ml_networks/
├── config.py # 共通設定クラス(PyTorch/JAX共通)
├── utils.py # 共通ユーティリティ
├── callbacks.py # PyTorch Lightning コールバック
├── torch/ # PyTorch実装
│ ├── layers.py # MLP, Conv, Attention, Transformerなど
│ ├── vision.py # Encoder, Decoder, ConvNet, ResNet, ViT
│ ├── unet.py # ConditionalUnet1d, ConditionalUnet2d
│ ├── distributions.py # 確率分布
│ ├── loss.py # 損失関数
│ ├── activations.py # カスタム活性化関数
│ ├── hypernetworks.py # HyperNetwork
│ ├── contrastive.py # 対照学習
│ └── torch_utils.py # PyTorch固有ユーティリティ
└── jax/ # JAX (Flax NNX) 実装
├── layers.py # MLP, Conv, Attention, Transformerなど
├── vision.py # Encoder, Decoder, ConvNet, ResNet, ViT
├── unet.py # ConditionalUnet1d, ConditionalUnet2d
├── distributions.py # 確率分布
├── loss.py # 損失関数
├── activations.py # カスタム活性化関数
├── hypernetworks.py # HyperNetwork
├── contrastive.py # 対照学習
└── jax_utils.py # JAX固有ユーティリティ
インストール¶
要件¶
- Python >= 3.10
- PyTorch >= 2.0(PyTorchバックエンドを使用する場合)
- JAX >= 0.4.30 + Flax >= 0.12.0(JAXバックエンドを使用する場合)
インストール方法¶
以下のいずれかの方法でインストールできます:
pipを使用する場合¶
ryeを使用する場合¶
uvを使用する場合¶
JAXサポートを追加する場合¶
注意: uvを使用する場合は、<access token>をGitHubのPersonal Access Tokenに置き換えてください。
クイックスタート¶
from ml_networks.torch import MLPLayer
from ml_networks import MLPConfig, LinearConfig
import torch
# MLPの設定
mlp_config = MLPConfig(
hidden_dim=128,
n_layers=2,
output_activation="Tanh",
linear_cfg=LinearConfig(activation="ReLU", bias=True)
)
# MLPの作成と推論
mlp = MLPLayer(input_dim=16, output_dim=8, mlp_config=mlp_config)
x = torch.randn(32, 16)
y = mlp(x)
print(y.shape) # torch.Size([32, 8])
from ml_networks.jax import MLPLayer
from ml_networks import MLPConfig, LinearConfig
import jax
import jax.numpy as jnp
# MLPの設定(PyTorchと同じConfigを使用)
mlp_config = MLPConfig(
hidden_dim=128,
n_layers=2,
output_activation="Tanh",
linear_cfg=LinearConfig(activation="ReLU", bias=True)
)
# MLPの作成と推論
mlp = MLPLayer(input_dim=16, output_dim=8, mlp_config=mlp_config, rngs=jax.random.PRNGKey(0))
x = jnp.ones((32, 16))
y = mlp(x)
print(y.shape) # (32, 8)
詳細はクイックスタートガイドを参照してください。
ドキュメント¶
- クイックスタート - 基本的な使用方法
- 設定管理ガイド - YAMLファイルから設定を読み込む方法(推奨)
- ガイド:
- MLP - 多層パーセプトロン
- Encoder - 画像エンコーダ
- Decoder - 画像デコーダ
- UNet - 条件付きUNet
- Distributions - 確率分布
- 損失関数 - 各種損失関数
- データの保存と読み込み - blosc2形式のデータI/O
- その他の便利な機能 - 活性化関数、最適化、seed固定など
- 高度な機能 - HyperNetwork、対照学習、Attention
- JAXバックエンド - JAX (Flax NNX) での使用方法
- API リファレンス - 完全なAPIドキュメント
ライセンス¶
このプロジェクトのライセンス情報については、リポジトリのLICENSEファイルを参照してください。
作者¶
- oakwood-fujiken (oakwood.n14.4sp@keio.jp)
- nomutin (nomura0508@icloud.com)