あかすくぱるふぇ

同人サークル「あかすくぱるふぇ」のブログです。

chainer-goghで画像を更新する仕組みについて解説します。
chainer-goghは沙耶の唄VRの映像を作った時に、肉塊画風の転写に利用させていただきました。
http://akasuku.blog.jp/archives/68121368.html

なお、本記事ではchainer-goghの概要については説明しませんので、概要がわからない方はまず元記事をお読みください。
https://research.preferred.jp/2015/09/chainer-gogh/

chainer-goghと比較するために、まずは通常のニューラルネット学習の例としてtrain_mnistでネットワークを更新する仕組みについて触れます。
train_mnistでは、以下のような手順でネットワークを更新します。
model = L.Classifier(MLP(args.unit, 10)) # モデル(最適化対象パラメータ)
optimizer = chainer.optimizers.Adam() # 最適化モジュール
optimizer.setup(model) # モデルのセット
optimizer.update(loss_func) # 更新
ここで、optimizer.update()は、trainer.run() → updater.update() → optimizer.update()という流れで呼び出されます。
loss_funcにはmodelが代入されています。

optimizer.update()の中では以下のように、順伝搬・逆伝搬・パラメータ更新の処理がなされます。
if lossfun is not None:
loss = lossfun() # 順伝搬
loss.backward() # 逆伝搬
for param in self.target.params():
param.update() # パラメータ更新

以上が、通常のニューラルネット学習の手順です。
続いて、chainer-goghで画像を更新する仕組みについて解説します。

chainer-goghでは、以下のように、モデルの代わりに(更新対象の)画像をoptimizerにセットします。
img_gen = xp.random.uniform(-20,20,(1,3,width,width),dtype=np.float32)
img_gen = chainer.links.Parameter(img_gen) # 画像の更新対象パラメータ化
optimizer.setup(img_gen) # 画像のセット
そして、以下の手順で画像の更新を行います。
nn = VGG() # モデル(chainer-goghではこいつは更新しない)
y = nn.forward(img_gen.W) # 順伝搬
L = ... * F.mean_squared_error(y, Variable(...)) # 順伝搬
L.backward() # 逆伝搬
optimizer.update() # パラメータ更新
ここで重要なのが、train_mnistの場合と異なり、optimizer.update()を引数loss_funcなしで呼び出している点です。
上記したように、optimizer.update()内ではif文でloss_funcが与えられた場合のみ順伝搬・逆伝搬の計算をするようになっています。
このような仕組みによって、chainer-goghでは、モデルに画像を入力して得られる勾配を基に(モデルではなく)画像を更新するという処理を実現しています。

以上が、chainer-goghで画像を更新する仕組みです。

one two three four looking for the beat
one two three four rocking your heart
tight beat 刻んだ dance on a stage
限界無理矢理突破気味
迷える魂 ramble soul
震える瞬間 Possible time
リアルの体感追い越すClick
CoolなStanding 進めよ World

ChainerのTrainer内で、printによるログとファイル出力によるログを出力することができます。
実際に、train_mnist.pyなどでは、この仕組みを用いて、これら2種類のログを出力しています。
これらのうち、printによるログに関わるコードのみを抽出し、それぞれの命令がどのログ出力に関連しているのかを解析してみました。

train_mnist.pyから、printによるログに関わるコードのみを抽出した結果は以下の通りです。
# プログレスバー
trainer.extend(extensions.ProgressBar())

# 訓練ログ
trainer.extend(extensions.LogReport())
trainer.extend(extensions.PrintReport(
['epoch', 'main/loss', 'validation/main/loss',
'main/accuracy', 'validation/main/accuracy', 'elapsed_time']))

# テストデータのログ出力準備
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

コメントにある通り、それぞれ、プログレスバー、訓練ログ、テストデータのログ出力準備のための命令です。
trainer.extend(extensions.LogReport())を呼ばないと、それ以下の2命令がエラーになります。
また、trainer.extend(extensions.Evaluator())を呼ばないと、'validation/main/loss'と'validation/main/accuracy’の結果が出力されなくなります。

↑このページのトップヘ