端の知識の備忘録

技術メモになりきれない、なにものか達の供養先

【PyTorch小ネタ】 複数モデルを組み合わせたモデルにおいて別々のPretrained Weightを読み込む方法

まえがき

時の流れは早いもので、もう2024年になってしまったようです。

最近はKaggleやら実践的な個人活動は一先ずお休みし、社会人5年目になる前に一旦基礎固めし直そうと線形代数/解析学/統計の勉強をしたり新しく出たBishop本を読んだりしており、あまりアウトプットするものがありませんでした。

bishopbook.com

物買ったネタもEPYC買ったりオールフラッシュNASを作ったり好き勝手した結果欲しい物が無くなってしまったので、本当にブログに書く内容がない状態。

とはいえたまには書いとかないとアレなので、ちょっとした小ネタで更新しておきます。

最近ちょっとあったケースで、既存の2つのモデルのEmbeddingモデルの出力をくっつけて、FC層に流して何らかの出力を得るようなモデルを作りたいという相談を受けました。

例えば、画像とキャプションテキストから記事のジャンルを分類する課題を考えたとき、画像に関してはEfficientNetで、テキストはBERTで埋め込みを取得して、その出力をconcatして分類するみたいな。

こういうとき、既存のPretrained Weightを正しく読み込むのにどうすればいいのか、PyTorchのStatedictをいじりながら見ていきます。

たまにモデルのパラメータを直接さわりたいときがあると思いますが、意外と柔軟に弄れるので経験しておくと役に立つかもしれません。この辺頑張ると下の記事みたいな応用例もあるかも。

logmi.jp

まとめ

下ではうだうだと例を以て説明を書いておりますが、モデル内の一部でウェイトを読み込みたい場合、モデルのインスタンス変数として対象のレイヤーを呼び出し、load_state_dictメソッドを利用します。

model.embedding1.load_state_dict(state_dict_a)

メソッドの公式ソースは下

pytorch.org

モデル定義

デモ用に簡単な2つのモデルを作ります。若干出力サイズだけ変えてありますがほぼ同じモデルです。

実際には片方が画像用、もう片方がテキスト用のモデルみたいなケースを想像しておいてください

import torch
import torch.nn as nn
import torch.nn.functional as F

class ModelA(nn.Module):
    def __init__(self):
        super(ModelA, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(6 * 13 * 13, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 10)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 6 * 13 * 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    
class ModelB(nn.Module):
    def __init__(self):
        super(ModelB, self).__init__()
        self.conv1 = nn.Conv2d(3, 6, 5)
        self.pool = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(6 * 13 * 13, 120)
        self.fc2 = nn.Linear(120, 84)
        self.fc3 = nn.Linear(84, 20)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = x.view(-1, 6 * 13 * 13)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        x = self.fc3(x)
        return x
    

で、この2つのモデルを結合して利用するモデルを次のように定義します。

ここで設定するインスタンス変数名のembedding1embedding2が後々大事になります。

class CombinedModel(nn.Module):
    def __init__(self):
        super(CombinedModel, self).__init__()
        self.embedding1 = ModelA()
        self.embedding2 = ModelB()

        self.activation = nn.ReLU()
        self.fc1 = nn.Linear(30, 10)
        self.fc2 = nn.Linear(10, 1)

    def forward(self, x, y):
        x = self.embedding1(x) # [bs, 10]
        y = self.embedding2(y) # [bs, 20]

        z = torch.cat((x, y), 1) # [bs, 30]

        z = self.activation(self.fc1(z))
        z = self.fc2(z)
        return z

パラメータの操作

ModelAとModelBに関して、Weightを手動で割り当てるためにstate dictを作っていきます。

名前の通り、state dictはdict型でモデルの重みやOptimizerなどの状態を保存します。

今回は重みのみを保存しますので、単にLayer名をキーとして重みをバリューに持つDictを用意します。

本来は学習によって有意義な重みを学習するところですが、今回はシンプルにmodel_aは全部0,model_bは全部1の重みを設定します。

model_a = ModelA()
model_b = ModelB()

state_dict_a = {}
for param_tensor in model_a.state_dict():
    print(param_tensor, "\t", model_a.state_dict()[param_tensor].size())
    state_dict_a[param_tensor] = torch.zeros(model_a.state_dict()[param_tensor].size())

state_dict_b = {}
for param_tensor in model_b.state_dict():
    print(param_tensor, "\t", model_b.state_dict()[param_tensor].size())
    state_dict_b[param_tensor] = torch.ones(model_b.state_dict()[param_tensor].size())

printの出力は次のようになります。

conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
fc1.weight   torch.Size([120, 1014])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([10, 84])
fc3.bias     torch.Size([10])
conv1.weight     torch.Size([6, 3, 5, 5])
conv1.bias   torch.Size([6])
fc1.weight   torch.Size([120, 1014])
fc1.bias     torch.Size([120])
fc2.weight   torch.Size([84, 120])
fc2.bias     torch.Size([84])
fc3.weight   torch.Size([20, 84])
fc3.bias     torch.Size([20])

作成したstate dictをモデルに読み込ませるには、load_state_dictメソッドを利用します。自前で作ったDictでも、ちゃんとキーとバリューのShapeが合っていれば読み込んでくれます。

また、保存の際にはtorch.saveを利用します。

model_a.load_state_dict(state_dict_a)
model_b.load_state_dict(state_dict_b)

torch.save(model_a.state_dict(), 'model_a.pth')
torch.save(model_b.state_dict(), 'model_b.pth')

CombinedModelにそれぞれの子モデルの重みを読み込む

CombinedModelのインスタンスを作成し、先程保存した.pthファイルを改めて読み込みます。

model = CombinedModel()
state_dict_a = torch.load('model_a.pth')
state_dict_b = torch.load('model_b.pth')

で、本題である子モデルの重みを読み込む方法ですが、単にモデルのインスタンス変数を読み出し、load_state_dictメソッドを利用します。

model.embedding1.load_state_dict(state_dict_a)
model.embedding2.load_state_dict(state_dict_b)

ちゃんと重みが更新されたかどうかを見てみます。embedding1には0、embedding1には1が格納されており、特に重みをロードしていないfc1などではランダムな値が入っており、意図通りの操作ができたことが確認できます。

print(next(model.embedding1.parameters())[0][:1])
print(next(model.embedding2.parameters())[0][:1])
print(next(model.fc1.parameters())[0])
tensor([[[0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.],
         [0., 0., 0., 0., 0.]]], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([[[1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.],
         [1., 1., 1., 1., 1.]]], device='cuda:0', grad_fn=<SliceBackward0>)
tensor([-0.0896, -0.0968, -0.1777, -0.1720, -0.1266, -0.0351,  0.0431, -0.1412,
         0.1106,  0.0788, -0.0594,  0.1563,  0.0192, -0.1271,  0.0526, -0.1182,
        -0.0506, -0.1199, -0.0888,  0.0479, -0.0823,  0.1511,  0.1532,  0.0241,
         0.0064,  0.1728,  0.0547,  0.0753, -0.0233,  0.0545], device='cuda:0',
       grad_fn=<SelectBackward0>)

公式の説明は下。nn.Moduleのパラメータ以外、Optimizerなどのstate dict読み込みなどに関しても言及があります。

pytorch.org