-
Notifications
You must be signed in to change notification settings - Fork 61
/
Copy pathtrain.py
129 lines (97 loc) · 4.99 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import numpy as np
import random
import json
from glob import glob
from keras.models import model_from_json,load_model
from keras.preprocessing.image import ImageDataGenerator
from keras.callbacks import ModelCheckpoint,Callback,LearningRateScheduler
import keras.backend as K
from model import Unet_model
from losses import *
#from keras.utils.visualize_util import plot
class SGDLearningRateTracker(Callback):
def on_epoch_begin(self, epoch, logs={}):
optimizer = self.model.optimizer
lr = K.get_value(optimizer.lr)
decay = K.get_value(optimizer.decay)
lr=lr/10
decay=decay*10
K.set_value(optimizer.lr, lr)
K.set_value(optimizer.decay, decay)
print('LR changed to:',lr)
print('Decay changed to:',decay)
class Training(object):
def __init__(self, batch_size,nb_epoch,load_model_resume_training=None):
self.batch_size = batch_size
self.nb_epoch = nb_epoch
#loading model from path to resume previous training without recompiling the whole model
if load_model_resume_training is not None:
self.model =load_model(load_model_resume_training,custom_objects={'gen_dice_loss': gen_dice_loss,'dice_whole_metric':dice_whole_metric,'dice_core_metric':dice_core_metric,'dice_en_metric':dice_en_metric})
print("pre-trained model loaded!")
else:
unet =Unet_model(img_shape=(128,128,4))
self.model=unet.model
print("U-net CNN compiled!")
def fit_unet(self,X33_train,Y_train,X_patches_valid=None,Y_labels_valid=None):
train_generator=self.img_msk_gen(X33_train,Y_train,9999)
checkpointer = ModelCheckpoint(filepath='brain_segmentation/ResUnet.{epoch:02d}_{val_loss:.3f}.hdf5', verbose=1)
self.model.fit_generator(train_generator,steps_per_epoch=len(X33_train)//self.batch_size,epochs=self.nb_epoch, validation_data=(X_patches_valid,Y_labels_valid),verbose=1, callbacks = [checkpointer,SGDLearningRateTracker()])
#self.model.fit(X33_train,Y_train, epochs=self.nb_epoch,batch_size=self.batch_size,validation_data=(X_patches_valid,Y_labels_valid),verbose=1, callbacks = [checkpointer,SGDLearningRateTracker()])
def img_msk_gen(self,X33_train,Y_train,seed):
'''
a custom generator that performs data augmentation on both patches and their corresponding targets (masks)
'''
datagen = ImageDataGenerator(horizontal_flip=True,data_format="channels_last")
datagen_msk = ImageDataGenerator(horizontal_flip=True,data_format="channels_last")
image_generator = datagen.flow(X33_train,batch_size=4,seed=seed)
y_generator = datagen_msk.flow(Y_train,batch_size=4,seed=seed)
while True:
yield(image_generator.next(), y_generator.next())
def save_model(self, model_name):
'''
INPUT string 'model_name': path where to save model and weights, without extension
Saves current model as json and weights as h5df file
'''
model_tosave = '{}.json'.format(model_name)
weights = '{}.hdf5'.format(model_name)
json_string = self.model.to_json()
self.model.save_weights(weights)
with open(model_tosave, 'w') as f:
json.dump(json_string, f)
print ('Model saved.')
def load_model(self, model_name):
'''
Load a model
INPUT (1) string 'model_name': filepath to model and weights, not including extension
OUTPUT: Model with loaded weights. can fit on model using loaded_model=True in fit_model method
'''
print ('Loading model {}'.format(model_name))
model_toload = '{}.json'.format(model_name)
weights = '{}.hdf5'.format(model_name)
with open(model_toload) as f:
m = next(f)
model_comp = model_from_json(json.loads(m))
model_comp.load_weights(weights)
print ('Model loaded.')
self.model = model_comp
return model_comp
if __name__ == "__main__":
#set arguments
#reload already trained model to resume training
model_to_load="Models/ResUnet.04_0.646.hdf5"
#save=None
#compile the model
brain_seg = Training(batch_size=4,nb_epoch=3,load_model_resume_training=model_to_load)
print("number of trainabale parameters:",brain_seg.model.count_params())
#print(brain_seg.model.summary())
#plot(brain_seg.model, to_file='model_architecture.png', show_shapes=True)
#load data from disk
Y_labels=np.load("y_training.npy").astype(np.uint8)
X_patches=np.load("x_training.npy").astype(np.float32)
Y_labels_valid=np.load("y_valid.npy").astype(np.uint8)
X_patches_valid=np.load("x_valid.npy").astype(np.float32)
print("loading patches done\n")
# fit model
brain_seg.fit_unet(X_patches,Y_labels,X_patches_valid,Y_labels_valid)#*
#if save is not None:
# brain_seg.save_model('models/' + save)