Joeの精進記録

旧:競プロ練習記録

S2V-DQNコードリーディング

https://xuzijian629.hatenadiary.jp/entry/2019/07/27/154356

以前環境構築についてまとめた

今回は実装を読んでいく。

MVCについて読むが多分構造は他も一緒

構造

eps_start = 1.0
eps_end = 0.05
eps_step = 10000.0
for iter in range(int(opt['max_iter'])):
    if iter and iter % 5000 == 0:
        gen_new_graphs(opt)
    eps = eps_end + max(0., (eps_start - eps_end) * (eps_step - iter) / eps_step)
    if iter % 10 == 0:
        api.lib.PlayGame(10, ctypes.c_double(eps))

    if iter % 300 == 0:
        frac = 0.0
        for idx in range(n_valid):
            frac += api.lib.Test(idx)
        print 'iter', iter, 'eps', eps, 'average size of vc: ', frac / n_valid
        sys.stdout.flush()
        model_path = '%s/nrange_%d_%d_iter_%d.model' % (opt['save_dir'], int(opt['min_n']), int(opt['max_n']), iter)
        api.SaveModel(model_path)

    if iter % 1000 == 0:
        api.TakeSnapshot()

    api.lib.Fit()

流れはこんな感じ。もろもろの関数はmvc_lib.cppにある。

  • 5000 iterおきにグラフをgen_new_graphsをしている。内部的にはグラフのプールを更新している。プールには1000個グラフがある(main.py)
  • epsilon greedyのepsをだんだん小さくしているっぽい
  • 10 iterおきにPlayGameしている。10回最初からterminal stateまで実行して、結果の列をNStepReplayMemに格納する。毎回グラフをプールからサンプルするっぽい。
  • 1000 iterおきにSnapShotをとっている。これが実は本質っぽいんだけど、lossの計算ではこの記事の一番下に書いてあるように、Snapshotをとったold_modelと、新しいmodelでの2つの予測結果の二乗誤差を考えている
  • iterでFitしている。これは、batch_sizeサンプルしてきて、nn_api.cppのFitを呼んでいる。
net->SetupTrain(batch_idxes, g_list, covered, actions, target);
net->fg.FeedForward({net->loss}, net->inputs, Phase::TRAIN);
net->fg.BackPropagate({net->loss});
net->learner->Update();

loss += net->loss->AsScalar() * bsize;

みたいなことが行われている。誤差の計算のところはmvc_lib.cppにあって

PredictWithSnapshot(sample.g_list, sample.list_s_primes, list_pred);

からの

for (int i = 0; i < cfg::batch_size; ++i)
{
    double q_rhs = 0;
    if (!sample.list_term[i])
        q_rhs = max(sample.g_list[i]->num_nodes, list_pred[i]->data());
    q_rhs += sample.list_rt[i];
    list_target[i] = q_rhs;
}

が行われている。PredictWithSnapshotは古いモデルでの予測結果っぽい。

のところ。