S2V-DQNコードリーディング
https://xuzijian629.hatenadiary.jp/entry/2019/07/27/154356
以前環境構築についてまとめた
今回は実装を読んでいく。
MVCについて読むが多分構造は他も一緒
構造
eps_start = 1.0 eps_end = 0.05 eps_step = 10000.0 for iter in range(int(opt['max_iter'])): if iter and iter % 5000 == 0: gen_new_graphs(opt) eps = eps_end + max(0., (eps_start - eps_end) * (eps_step - iter) / eps_step) if iter % 10 == 0: api.lib.PlayGame(10, ctypes.c_double(eps)) if iter % 300 == 0: frac = 0.0 for idx in range(n_valid): frac += api.lib.Test(idx) print 'iter', iter, 'eps', eps, 'average size of vc: ', frac / n_valid sys.stdout.flush() model_path = '%s/nrange_%d_%d_iter_%d.model' % (opt['save_dir'], int(opt['min_n']), int(opt['max_n']), iter) api.SaveModel(model_path) if iter % 1000 == 0: api.TakeSnapshot() api.lib.Fit()
流れはこんな感じ。もろもろの関数はmvc_lib.cpp
にある。
- 5000 iterおきにグラフをgen_new_graphsをしている。内部的にはグラフのプールを更新している。プールには1000個グラフがある(main.py)
- epsilon greedyのepsをだんだん小さくしているっぽい
- 10 iterおきにPlayGameしている。10回最初からterminal stateまで実行して、結果の列をNStepReplayMemに格納する。毎回グラフをプールからサンプルするっぽい。
- 1000 iterおきにSnapShotをとっている。これが実は本質っぽいんだけど、lossの計算ではこの記事の一番下に書いてあるように、Snapshotをとったold_modelと、新しいmodelでの2つの予測結果の二乗誤差を考えている
- 毎iterでFitしている。これは、batch_sizeサンプルしてきて、
nn_api.cpp
のFitを呼んでいる。
net->SetupTrain(batch_idxes, g_list, covered, actions, target); net->fg.FeedForward({net->loss}, net->inputs, Phase::TRAIN); net->fg.BackPropagate({net->loss}); net->learner->Update(); loss += net->loss->AsScalar() * bsize;
みたいなことが行われている。誤差の計算のところはmvc_lib.cpp
にあって
PredictWithSnapshot(sample.g_list, sample.list_s_primes, list_pred);
からの
for (int i = 0; i < cfg::batch_size; ++i) { double q_rhs = 0; if (!sample.list_term[i]) q_rhs = max(sample.g_list[i]->num_nodes, list_pred[i]->data()); q_rhs += sample.list_rt[i]; list_target[i] = q_rhs; }
が行われている。PredictWithSnapshot
は古いモデルでの予測結果っぽい。
のところ。