Tensorflow 모델 저장하기
Tensorflow를 활용한 모델 저장 방법
데이터 과학자나 개발자라면, 모델을 훈련시키는 것만큼이나 중요한 것이 훈련된 모델을 저장하고 불러오는 것입니다. 이 기능을 활용하면, 한 번 훈련된 모델을 다시 사용하거나 다른 사람과 공유할 수 있습니다. 이번 포스트에서는 TensorFlow를 활용해 모델을 저장하고 불러오는 방법에 대해 알아보겠습니다.
TensorFlow에서 모델 저장하기: Checkpoint 사용
TensorFlow에서는 Checkpoint라는 기능을 이용해 모델을 저장하고 불러올 수 있습니다. Checkpoint는 모델의 변수들을 특정 파일에 저장하는 방식으로 동작합니다.
1
saver = tf.train.Saver()
모델을 저장하는 방법은 다음과 같습니다.
1
2
3
4
5
6
checkpoint_dir = "YOUR_DIRECTORY"
if not os.path.exists(checkpoint_dir):
os.makedirs(checkpoint_dir)
saver.save(sess, os.path.join(checkpoint_dir, "YOUR_FILENAME.ckpt"), global_step=0)
TensorFlow에서 모델 불러오기: Checkpoint 사용
1
2
3
4
5
6
7
8
9
10
ckpt = tf.train.get_checkpoint_state(checkpoint_dir)
if ckpt and ckpt.model_checkpoint_path:
ckpt_name = os.path.basename(ckpt.model_checkpoint_path)
saver.restore(sess, os.path.join(checkpoint_dir, ckpt_name))
counter = int(next(re.finditer("(\d+)(?!.*\d)", ckpt_name)).group(0))
print("Success")
else:
print("Failed")
Tensorflow에서 그래프 저장하기
모델이 아닌 그래프를 저장하고 싶다면, 다음과 같이 할 수 있습니다.
1
tf.train.write_graph(sess.graph_def, '.', 'graph.pbtxt')
TensorFlow에서 .pb 파일 저장하기
마지막으로, TensorFlow에서는 .pb (protobuf) 형식으로 그래프를 저장할 수 있습니다.
1
2
3
4
from tensorflow.python.framework import graph_io
frozen = tf.graph_util.convert_variables_to_constants(sess, sess.graph_def, ["output_node_name"])
graph_io.write_graph(frozen, './', 'graph.pb', as_text=False)
- output node name을 아는 것이 중요하다.
- tensorboard에서 찾아보기
TensorFlow에서 TensorBoard 사용하기
TensorBoard는 TensorFlow에서 제공하는 시각화 도구입니다. 훈련 과정에서 중요한 지표들을 시각화해줍니다.
1
2
tf.summary.scalar("loss",loss)
merge = tf.summary.merge_all()
TensorBoard에 로그를 저장하려면 다음과 같이 합니다.
1
2
with tf.Session() as sess:
writer = tf.summary.FileWriter('./log/', sess.graph)
그리고 이렇게 저장된 로그를 TensorBoard에서 확인할 수 있습니다.
1
tensorboard --logdir=./logs/
This post is licensed under CC BY 4.0 by the author.