Skip to content

Commit 4d12c41

Browse files
committed
updated model
1 parent ee2aef1 commit 4d12c41

8 files changed

+235
-127
lines changed

README.md

+11-1
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,17 @@ any image file as an input.
5757
Download and compile the tool [dataset_transformer](https://github.com/michalfaber/rmpe_dataset_transformer).
5858
Use this tool to create final datasets `dataset/train_dataset.h5` `dataset/val_dataset.h5`
5959
- You can verify the datasets `inspect_dataset.ipynb`
60-
- Start training `python train_pose.py` (TODO)
60+
- Start training `python train_pose.py`
61+
62+
NOTE:
63+
I trained the model from scratch for 3,5 days on a single GPU 1070 but did't obtain satisfactory results.
64+
38 epochs is about 200000 iterations in caffe.
65+
I noticed that reducing learning rate after the step 136106 (as in orginal caffe model) was probably too early
66+
because learning process slowed down.
67+
68+
<div align="center">
69+
<img src="https://github.com/michalfaber/keras_Realtime_Multi-Person_Pose_Estimation/blob/master/readme/tr_results.png", width="450", height="563">
70+
</div>
6171

6272
## Related repository
6373
- CVPR'16, [Convolutional Pose Machines](https://github.com/shihenw/convolutional-pose-machines-release).

demo.ipynb

+31-29
Large diffs are not rendered by default.

demo_image.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import util
77
from config_reader import config_reader
88
from scipy.ndimage.filters import gaussian_filter
9-
from model import get_model
9+
from model import get_testing_model
1010

1111

1212
keras_weights_file = "model/keras/model.h5"
@@ -44,7 +44,7 @@ def process (input_image, params, model_params):
4444
imageToTest_padded, pad = util.padRightDownCorner(imageToTest, model_params['stride'],
4545
model_params['padValue'])
4646

47-
input_img = np.transpose(np.float32(imageToTest_padded[:,:,:,np.newaxis]), (3,0,1,2))/256 - 0.5; # required shape (1, width, height, channels)
47+
input_img = np.transpose(np.float32(imageToTest_padded[:,:,:,np.newaxis]), (3,0,1,2)) # required shape (1, width, height, channels)
4848

4949
output_blobs = model.predict(input_img)
5050

@@ -238,7 +238,7 @@ def process (input_image, params, model_params):
238238
print('start processing...')
239239

240240
# load model
241-
model = get_model()
241+
model = get_testing_model()
242242
model.load_weights(keras_weights_file)
243243

244244
# load config

model.py

+69-31
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from keras.layers.pooling import MaxPooling2D
66
from keras.layers.merge import Multiply
77
from keras.regularizers import l2
8+
from keras.initializers import random_normal,constant
89

910
def relu(x): return Activation('relu')(x)
1011

@@ -14,7 +15,9 @@ def conv(x, nf, ks, name, weight_decay):
1415

1516
x = Conv2D(nf, (ks, ks), padding='same', name=name,
1617
kernel_regularizer=kernel_reg,
17-
bias_regularizer=bias_reg)(x)
18+
bias_regularizer=bias_reg,
19+
kernel_initializer=random_normal(stddev=0.01),
20+
bias_initializer=constant(0.0))(x)
1821
return x
1922

2023
def pooling(x, ks, st, name):
@@ -62,7 +65,7 @@ def vgg_block(x, weight_decay):
6265
return x
6366

6467

65-
def stage1_block(x, x1, x2, num_p, branch, weight_decay):
68+
def stage1_block(x, num_p, branch, weight_decay):
6669
# Block 1
6770
x = conv(x, 128, 3, "Mconv1_stage1_L%d" % branch, (weight_decay, 0))
6871
x = relu(x)
@@ -74,17 +77,10 @@ def stage1_block(x, x1, x2, num_p, branch, weight_decay):
7477
x = relu(x)
7578
x = conv(x, num_p, 1, "Mconv5_stage1_L%d" % branch, (weight_decay, 0))
7679

77-
w_name = "weight_stage1_L%d" % branch
78-
if num_p == 38:
79-
w = Multiply(name=w_name)([x, x1]) # vec_weight
80-
81-
else:
82-
w = Multiply(name=w_name)([x, x2]) # vec_heat
83-
84-
return x, w
80+
return x
8581

8682

87-
def stageT_block(x, x1, x2, num_p, stage, branch, weight_decay):
83+
def stageT_block(x, num_p, stage, branch, weight_decay):
8884
# Block 1
8985
x = conv(x, 128, 7, "Mconv1_stage%d_L%d" % (stage, branch), (weight_decay, 0))
9086
x = relu(x)
@@ -100,17 +96,20 @@ def stageT_block(x, x1, x2, num_p, stage, branch, weight_decay):
10096
x = relu(x)
10197
x = conv(x, num_p, 1, "Mconv7_stage%d_L%d" % (stage, branch), (weight_decay, 0))
10298

99+
return x
100+
101+
102+
def apply_mask(x, mask1, mask2, num_p, stage, branch):
103103
w_name = "weight_stage%d_L%d" % (stage, branch)
104104
if num_p == 38:
105-
w = Multiply(name=w_name)([x, x1]) # vec_weight
105+
w = Multiply(name=w_name)([x, mask1]) # vec_weight
106106

107107
else:
108-
w = Multiply(name=w_name)([x, x2]) # vec_heat
108+
w = Multiply(name=w_name)([x, mask2]) # vec_heat
109+
return w
109110

110-
return x, w
111111

112-
113-
def get_model(training=True, weight_decay=None):
112+
def get_training_model(weight_decay):
114113

115114
stages = 6
116115
np_branch1 = 38
@@ -131,38 +130,77 @@ def get_model(training=True, weight_decay=None):
131130
inputs.append(vec_weight_input)
132131
inputs.append(heat_weight_input)
133132

134-
img_normalized = Lambda(lambda x: x / 127.5 - 1.0)(img_input)
133+
img_normalized = Lambda(lambda x: x / 256 - 0.5)(img_input)
135134

136135
# VGG
137136
stage0_out = vgg_block(img_normalized, weight_decay)
138137

139-
# stage 1
140-
stage1_branch1_out,w1 = stage1_block(stage0_out, vec_weight_input,
141-
heat_weight_input, np_branch1, 1, weight_decay)
142-
stage1_branch2_out,w2 = stage1_block(stage0_out, vec_weight_input,
143-
heat_weight_input, np_branch2, 2, weight_decay)
138+
# stage 1 - branch 1 (PAF)
139+
stage1_branch1_out = stage1_block(stage0_out, np_branch1, 1, weight_decay)
140+
w1 = apply_mask(stage1_branch1_out, vec_weight_input, heat_weight_input, np_branch1, 1, 1)
141+
142+
# stage 1 - branch 2 (confidence maps)
143+
stage1_branch2_out = stage1_block(stage0_out, np_branch2, 2, weight_decay)
144+
w2 = apply_mask(stage1_branch2_out, vec_weight_input, heat_weight_input, np_branch2, 1, 2)
145+
144146
x = Concatenate()([stage1_branch1_out, stage1_branch2_out, stage0_out])
145147

146148
outputs.append(w1)
147149
outputs.append(w2)
148150

149-
# stage t >= 2
150-
#stageT_branch1_out = None
151-
#stageT_branch2_out = None
151+
# stage sn >= 2
152152
for sn in range(2, stages + 1):
153-
stageT_branch1_out, w1 = stageT_block(x, vec_weight_input,
154-
heat_weight_input, np_branch1, sn, 1, weight_decay)
155-
stageT_branch2_out, w2 = stageT_block(x, vec_weight_input,
156-
heat_weight_input, np_branch2, sn, 2, weight_decay)
153+
# stage SN - branch 1 (PAF)
154+
stageT_branch1_out = stageT_block(x, np_branch1, sn, 1, weight_decay)
155+
w1 = apply_mask(stageT_branch1_out, vec_weight_input, heat_weight_input, np_branch1, sn, 1)
156+
157+
# stage SN - branch 2 (confidence maps)
158+
stageT_branch2_out = stageT_block(x, np_branch2, sn, 2, weight_decay)
159+
w2 = apply_mask(stageT_branch2_out, vec_weight_input, heat_weight_input, np_branch2, sn, 2)
157160

158161
outputs.append(w1)
159162
outputs.append(w2)
160163

161164
if (sn < stages):
162165
x = Concatenate()([stageT_branch1_out, stageT_branch2_out, stage0_out])
163166

164-
#outputs.insert(0, stageT_branch1_out)
165-
#outputs.insert(1, stageT_branch2_out)
166167
model = Model(inputs=inputs, outputs=outputs)
167168

169+
return model
170+
171+
172+
def get_testing_model():
173+
stages = 6
174+
np_branch1 = 38
175+
np_branch2 = 19
176+
177+
img_input_shape = (None, None, 3)
178+
179+
img_input = Input(shape=img_input_shape)
180+
181+
img_normalized = Lambda(lambda x: x / 256 - 0.5)(img_input) # [-0.5, 0.5]
182+
183+
# VGG
184+
stage0_out = vgg_block(img_normalized, None)
185+
186+
# stage 1 - branch 1 (PAF)
187+
stage1_branch1_out = stage1_block(stage0_out, np_branch1, 1, None)
188+
189+
# stage 1 - branch 2 (confidence maps)
190+
stage1_branch2_out = stage1_block(stage0_out, np_branch2, 2, None)
191+
192+
x = Concatenate()([stage1_branch1_out, stage1_branch2_out, stage0_out])
193+
194+
# stage t >= 2
195+
stageT_branch1_out = None
196+
stageT_branch2_out = None
197+
for sn in range(2, stages + 1):
198+
stageT_branch1_out = stageT_block(x, np_branch1, sn, 1, None)
199+
stageT_branch2_out = stageT_block(x, np_branch2, sn, 2, None)
200+
201+
if (sn < stages):
202+
x = Concatenate()([stageT_branch1_out, stageT_branch2_out, stage0_out])
203+
204+
model = Model(inputs=[img_input], outputs=[stageT_branch1_out, stageT_branch2_out])
205+
168206
return model

readme/tr_results.png

57.2 KB
Loading

training/generate_hdf5.py

+27-23
Original file line numberDiff line numberDiff line change
@@ -17,21 +17,23 @@
1717
val_img_dir = os.path.join(dataset_dir, "val2017")
1818
val_mask_dir = os.path.join(dataset_dir, "valmask2017")
1919

20-
# datasets = [
21-
# (tr_anno_path, tr_img_dir, tr_mask_dir, "COCO"),
22-
# (val_anno_path, val_img_dir, val_mask_dir, "COCO")
23-
# ]
24-
2520
datasets = [
21+
(tr_anno_path, tr_img_dir, tr_mask_dir, "COCO"),
2622
(val_anno_path, val_img_dir, val_mask_dir, "COCO")
2723
]
2824

25+
#datasets = [
26+
# (val_anno_path, val_img_dir, val_mask_dir, "COCO")
27+
#]
28+
2929
joint_all = []
3030
tr_hdf5_path = os.path.join(dataset_dir, "train_pre_dataset.h5")
3131
val_hdf5_path = os.path.join(dataset_dir, "val_pre_dataset.h5")
3232

3333
val_size = 2645 # size of validation set
3434

35+
#val_size = 300
36+
3537
def process():
3638
count = 0
3739
for _, ds in enumerate(datasets):
@@ -55,12 +57,8 @@ def process():
5557

5658
print("Image ID ", img_id)
5759

58-
if i < val_size:
59-
isValidation = 1
60-
else:
61-
isValidation = 0
62-
6360
persons = []
61+
prev_center = []
6462

6563
for p in range(numPeople):
6664

@@ -76,19 +74,18 @@ def process():
7674
person_center = [img_anns[p]["bbox"][0] + img_anns[p]["bbox"][2] / 2,
7775
img_anns[p]["bbox"][1] + img_anns[p]["bbox"][3] / 2]
7876

79-
# # skip this person if the distance to exiting person is too small
80-
# person_center = np.array((img_anns[p]["bbox"][0] + img_anns[p]["bbox"][2] / 2,
81-
# img_anns[p]["bbox"][1] + img_anns[p]["bbox"][3] / 2))
82-
# flag = 0
83-
#
84-
# for pc in prev_center:
85-
# dist = cdist(np.expand_dims(pc[:2], axis=0), np.expand_dims(person_center, axis=0))[0]
86-
# if dist < pc[2]*0.3:
87-
# flag = 1
88-
# continue
89-
#
90-
# if flag == 1:
91-
# continue
77+
# skip this person if the distance to exiting person is too small
78+
flag = 0
79+
for pc in prev_center:
80+
a = np.expand_dims(pc[:2], axis=0)
81+
b = np.expand_dims(person_center, axis=0)
82+
dist = cdist(a, b)[0]
83+
if dist < pc[2]*0.3:
84+
flag = 1
85+
continue
86+
87+
if flag == 1:
88+
continue
9289

9390
pers["objpos"] = person_center
9491
pers["bbox"] = img_anns[p]["bbox"]
@@ -110,13 +107,20 @@ def process():
110107
pers["scale_provided"] = img_anns[p]["bbox"][3] / 368
111108

112109
persons.append(pers)
110+
prev_center.append(np.append(person_center, max(img_anns[p]["bbox"][2], img_anns[p]["bbox"][3])))
111+
113112

114113
if len(persons) > 0:
115114

116115
joint_all.append(dict())
117116

118117
joint_all[count]["dataset"] = dataset_type
119118

119+
if count < val_size:
120+
isValidation = 1
121+
else:
122+
isValidation = 0
123+
120124
joint_all[count]["isValidation"] = isValidation
121125

122126
joint_all[count]["img_width"] = w

training/inspect_dataset.ipynb

+16-16
Large diffs are not rendered by default.

0 commit comments

Comments
 (0)