-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathdraw_chart.py
38 lines (32 loc) · 1.39 KB
/
draw_chart.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import argparse
from save_load_util import load_metrics
def get_parser():
parser = argparse.ArgumentParser(description='draw the trainning loss by metrics file')
parser.add_argument('metrics_path', type=str,
help='specify the path of the metrics file')
parser.add_argument('-x', '--x-axis-str', type=str, default='Epochs',
help='specify string show at x-axis')
parser.add_argument('-y', '--y-axis-str', type=str, default='Loss',
help='specify string show at y-axis')
parser.add_argument('-s', '--saving-path', type=str, default=None,
help='specify saving path of loss chart')
return parser
def draw_loss_chart(metrics_file_path, x_axis_str, y_axis_str, saving_path):
train_loss_list, valid_loss_list, global_steps_list = load_metrics(metrics_file_path)
plt.plot(global_steps_list, train_loss_list, label='Train')
plt.plot(global_steps_list, valid_loss_list, label='Valid')
plt.xlabel(x_axis_str)
plt.ylabel(y_axis_str)
plt.legend()
if saving_path:
plt.savefig(saving_path)
else:
plt.show()
def main(args):
draw_loss_chart(args.metrics_path, args.x_axis_str, args.y_axis_str, args.saving_path)
if __name__ == '__main__':
parser = get_parser()
args = parser.parse_args()
main(args)