-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathsave_load_util.py
39 lines (33 loc) · 1.33 KB
/
save_load_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
# -*- coding: utf-8 -*-
import torch
import os
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
# Save and Load Functions
def save_model(save_path, model, valid_loss):
if save_path == None:
return
state_dict = {'model_state_dict': model.state_dict(),
'valid_loss': valid_loss}
torch.save(state_dict, save_path)
print(f'Model saved to ==> {save_path}')
def load_model(load_path, model):
if load_path==None:
return
state_dict = torch.load(load_path, map_location=device)
model.load_state_dict(state_dict['model_state_dict'])
print(f'Model loaded from <== {load_path}')
return state_dict['valid_loss']
def save_metrics(save_path, train_loss_list, valid_loss_list, global_steps_list):
if save_path == None:
return
state_dict = {'train_loss_list': train_loss_list,
'valid_loss_list': valid_loss_list,
'global_steps_list': global_steps_list}
torch.save(state_dict, save_path)
print(f'Metrics saved to ==> {save_path}')
def load_metrics(load_path):
if load_path==None:
return
state_dict = torch.load(load_path, map_location=device)
print(f'Metrics loaded from <== {load_path}')
return state_dict['train_loss_list'], state_dict['valid_loss_list'], state_dict['global_steps_list']