最近訳あってpytorchを入門中です.
自動微分最強と思いつつも,モデルのパラメーターに関する取り扱いについて不明な部分があったので詳しく調べてみました.

model.parameters()

SGDなどを用いて最適化をする際に,勾配の計算対象とするパラメーターを model.parameters() で渡してあげます.

model = MyLinear()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)

これだけでパラメーターの最適化をすることができるわけですが,model.parameters()はどうやってモデル内のパラメーターを取得しているのでしょうか?
これを追ってみます.

pytorchのバージョンはv0.4.1でお送りします.

大枠

  1. 初期化時に,_parameters, _modules という辞書を用意する
  2. インスタンス変数をセットする際にParameterの時は_parametersに,Moduleの時は_modulesに登録
  3. parameters()では,再帰的に,_parameters_modulesの中身を取り出す

故に,

学習の対象パラメーターは,必ずParameterModuleとしてインスタンス変数に登録しなければならない

これらをソースコードを追いつつ解説してみます.

init

torch.nn.Moduleは, __init__ 時に,_parameters_modules というインスタンス変数を OrderedDict で初期化しています.

    self._parameters = OrderedDict()
    self._modules = OrderedDict()

ソースコード(Github)

この辞書に,自身にセットするModuleParameterを保存していくことになります.

setattr

torch.nn.Moduleは,__setattr__を定義しています.__setattr__は,インスタンス変数をセットするたびに呼ばれる特殊メソッドです.

if isinstance(value, Parameter):
    if params is None:
        raise AttributeError(
            "cannot assign parameters before Module.__init__() call")
    remove_from(self.__dict__, self._buffers, self._modules)
    self.register_parameter(name, value)

ソースコード(Github)

まず,インスタンス変数をセットする時に,Parameterであるかを判定して,真であればregister_parameter を呼んでいます.

def register_parameter(self, name, param):
    # 中略
    self._parameters[name] = param

ソースコード(Github)
命名などのバリデーションを経て,_parametersに登録しています.

Module をセットする場合も同様に,

modules = self.__dict__.get('_modules')
if isinstance(value, Module):
    if modules is None:
        raise AttributeError(
            "cannot assign module before Module.__init__() call")
    remove_from(self.__dict__, self._parameters, self._buffers)
    modules[name] = value

ソースコード(Github)

命名などのバリデーションを経て,_modulesに登録しています.

parameters

以上が前準備で,これを前提に parameters(self)を見ていきます.

すぐに実体がnamed_parameters(self) だとわかります.
ソースコード(Github)

まず,_parameters の辞書の中身を全て yieldしています.
なお,重複しないようにmemoをセットしています.

if memo is None:
    memo = set()
for name, p in self._parameters.items():
    if p is not None and p not in memo:
        memo.add(p)
        yield prefix + ('.' if prefix else '') + name, p

次に,named_children()を呼びます.ここで,_modulesの辞書を全てyieldしていることがわかります.

memo = set()
for name, module in self._modules.items():
    if module is not None and module not in memo:
        memo.add(module)
        yield name, module

ソースコード(Github)

各moduleに対して,named_parametersを再帰的に呼び出していることがわかります.

for mname, module in self.named_children():
    submodule_prefix = prefix + ('.' if prefix else '') + mname
    for name, p in module.named_parameters(memo, submodule_prefix):
        yield name, p

このようにして,モデルのパラメーターを全て取り出していることがわかります.

正しい例

自分で新たなモデルを作りたい時は,nn.Moduleを継承したクラスを作る必要があります.

import torch
import torch.nn as nn
from torch.autograd import Variable

class MyLinear(nn.Module):
    def __init__(self, input_size):
        super(MyLinear, self).__init__()
        self.linear = nn.Linear(input_size, 1)
        self.my_bias = nn.Parameter(torch.randn(1))
        self.not_param = Variable(torch.randn(1), requires_grad=True)

    def forward(self, x):
        out = self.linear(x) + self.my_bias
        return out

このようなモデルを自作して,中のパラメーターを出力してみます.

model = MyLinear(5)

for n, p in model.named_parameters():
    print(n)
    print(p)
    print("--------")

出力は次のようになります.

my_bias
Parameter containing:
tensor([-0.8927], requires_grad=True)
--------
linear.weight
Parameter containing:
tensor([[-0.4040, -0.3878,  0.2836, -0.0322, -0.1306]], requires_grad=True)
--------
linear.bias
Parameter containing:
tensor([0.4088], requires_grad=True)
--------

確かに,my_biasパラメーターと,linearモジュールのパラメーターが列挙されていることがわかります.
一方で,インスタンス変数not_paramはパラメーターとして登録していないので,列挙されていないということがわかります.

やりそうな間違い

モデルの層を可変にしたいと思い,ちょっと凝ったをやろうとすると次のようにやってしまいがちです.

import torch
import torch.nn as nn

class NGMultiLinear(nn.Module):
    def __init__(self, input_size, layer_num):
        super(NGMultiLinear, self).__init__()
        self.linears = [nn.Linear(input_size, input_size) for _ in range(layer_num)]

    def forward(self, x):
        for linear in self.linears:
            x = torch.relu(linear(x))
        return x

こうしてしまった場合,forwardはうまくいきます.

model = NGMultiLinear(5, 5)
print(model(torch.randn(5)))

しかし,model.parameters() でパラメーターを取り出すことはできません.

for n, p in model.named_parameters():
    print(n)
    print(p)
    print("--------")

# => 出力なし

これは,インスタンス変数self.linearsをセットするときにlistとしてセットしていて,ParameterでもModuleでもないので,無視されてしまうという訳ですね.

逆に言えば,パラメーター学習の対象としたい場合は,必ずParameterModuleとしてインスタンス変数に登録しなければならないという訳です.

上の例のような場合では,nn.ModuleListを用います.

class MultiLinear(nn.Module):
    def __init__(self, input_size, layer_num):
        super(MultiLinear, self).__init__()
        self.linears = nn.ModuleList([nn.Linear(input_size, input_size) for _ in range(layer_num)])

    def forward(self, x):
        for linear in self.linears:
            x = torch.relu(linear(x))
        return x

listで代入していた部分を,nn.ModuleListの引数に渡してあげるだけで済みます.

model = MultiLinear(5, 2)

for n, p in model.named_parameters():
    print(n)
    print(p)
    print("--------")

各層のパラメーターが出力されることがわかります.

linears.0.weight
Parameter containing:
tensor([[-0.1373, -0.1944, -0.3631, -0.0286, -0.3060],
        [ 0.0248, -0.0721,  0.2569, -0.0984,  0.1653],
        [ 0.0190,  0.3214,  0.3876,  0.0626,  0.1884],
        [ 0.3648, -0.1109,  0.0741, -0.0240,  0.1980],
        [ 0.0940, -0.4317,  0.0085,  0.1093, -0.3265]], requires_grad=True)
--------
linears.0.bias
Parameter containing:
tensor([-0.2768, -0.2334, -0.3220, -0.2035,  0.4458], requires_grad=True)
--------
linears.1.weight
Parameter containing:
tensor([[ 0.4232,  0.4181,  0.0354,  0.2646, -0.4470],
        [-0.4198, -0.3613, -0.3049, -0.3048,  0.2506],
        [ 0.3719, -0.3787,  0.0460,  0.3418,  0.0614],
        [ 0.2464,  0.1336,  0.3618, -0.0931,  0.3158],
        [ 0.3431,  0.2322,  0.4372, -0.3667,  0.2039]], requires_grad=True)
--------
linears.1.bias
Parameter containing:
tensor([-0.0262,  0.0904, -0.0037, -0.1591, -0.2842], requires_grad=True)
--------

なお,nn.ParameterList や,Dict版のnn.ModuleDictもあるようです.

まとめ

実用的には,

学習の対象パラメーターは,必ずParameterModuleとしてインスタンス変数に登録しなければならない

ということに尽きます.これを確かめるためにソースコードを読んでみました.

まあ自分が上記のやりがちな間違いをやってしまったので,これを機に理解を深めていたわけです.