텐서플로우가 업데이트가 되면서 많은것들이 바뀌었습니다.
기본 메서드들의 이름도 바뀌고, 다른 용법들도 많이 바뀌고 있습니다.
너무나 빠른 변화가 느껴집니다.
이번에는 학습 파라미터를 저장하고 그것을 불러오는 코드를 정리해보겠습니다.
1. 먼저 저장할 파라미터를 정합니다.
2. 그것들을 list 형식으로 만들어 줍니다.
3. tf.train.saver(list) 로 텐서플로우에게 이것이 저장할 변수라는 것을 알려줍니다.
>>코드 :
#Create a saver
param_list = [W_h1, b_h1, W_h2, b_h2, W_o, b_o]
saver = tf.train.Saver(param_list)
4. 학습할때마다 저장을 하도록 합니다.
for i in range(1000):
_,loss_, acc = sess.run([train,cost, accuracy],feed_dict = tensor_map)
if i % 100 == 0:
saver.save(sess,'./tensorflow_live.ckpt')
print("step: ",i)
print("loss_: ",loss_)
print("accuracy: ", acc)
print("============")
이렇게 해서 텐서플로우의 파라미터들을 저장해서 다음에 불러올 수 있습니다.
그러면 나중에 불러올때는 어떻게 해야될지 알아보겠습니다.
<Parameter restore>
1. 세션을 먼저 열어줍니다. tf.Session()
2. 변수들을 초기화 시켜줍니다. tf.global_variables_initializer()
3. 가져올 파일을 import를 시켜줍니다. tf.train.import.meta_graph('/파일이름.meta')
텐서플로우가 업데이트 되면서 파라미터들이 meta파일에 저장이 됩니다.
그것을 불러주면 해결이 됩니다.
4. 그리고 import한 파일을 restore시킵니다. new_saver.restore(sess, tf.train.latest_checkpoint('./'))
sess = tf.Session()
sess.run(tf.global_variables_initializer())
new_saver = tf.train.import_meta_graph('tensorflow_live.ckpt.meta')
new_saver.restore(sess, tf.train.latest_checkpoint('./'))
이렇게 하면 파라미터들을 저장하고 다시 읽어 올수가 있습니다.
시간을 이렇게 아껴요!
'딥러닝' 카테고리의 다른 글
[딥러닝]Wasserstein distance 에 관하여 (2) | 2017.08.28 |
---|---|
[딥러닝] 우분투 16.04에서 pyCuda 설치 (0) | 2017.08.09 |
[딥러닝] Tensorflow 윈도우10에 설치하기 (0) | 2017.01.26 |
[딥러닝] 2007년 인공지능에 관한 테드영상 (0) | 2017.01.23 |
[이찬우님 Tensorflow] 1강. Tensorflow의 자료형 (0) | 2016.10.31 |