Skip to content

Commit ac25142

Browse files
committed
changed mechnism of generating samples
1 parent 4d12c41 commit ac25142

7 files changed

+528
-74
lines changed

README.md

+22-7
Original file line numberDiff line numberDiff line change
@@ -44,20 +44,35 @@ any image file as an input.
4444
- `python demo_camera.py` to run the web demo.
4545

4646
## Training steps
47+
**UPDATE 10/2017:**
48+
49+
**-Augmented samples are fetched from the [server](https://github.com/michalfaber/rmpe_dataset_server). The network never sees the same image twice
50+
which was a problem in previous approach (tool rmpe_dataset_transformer)
51+
This allows you to run augmentation locally or on separate node.
52+
You can start 2 instances, one serving training set and a second one serving validation set (on different port if locally)**
53+
54+
**-Experimentally I've added image normalization as in vgg paper (images should be zero-centered by mean pixel subtraction)**
55+
4756
- Install gsutil `curl https://sdk.cloud.google.com | bash`. This is a really helpful tool for downloading large datasets.
4857
- Download the data set (~25 GB) `cd dataset; sh get_dataset.sh`,
4958
- Download [COCO official toolbox](https://github.com/pdollar/coco) in `dataset/coco/` .
5059
- `cd coco/PythonAPI; sudo python setup.py install` to install pycocotools.
5160
- Go to the "training" folder `cd ../../../training`.
5261
- Generate masks `python generate_masks.py`. Note: set the parameter "mode" in generate_masks.py (validation or training)
5362
- Create intermediate dataset `python generate_hdf5.py`. This tool creates a dataset in hdf5 format. The structure of this dataset is very similar to the
54-
original lmdb dataset where a sample is represented as an array: 6 x width x height (3 channels for image, 1 channel for metedata, 2 channels for masks)
55-
Note: set the parameters "datasets", "val_size" in generate_hdf5.py
56-
- The resulting intermediate hdf5 dataset has to be transformed to the more keras friendly format with data and labels ready to use in python generator.
57-
Download and compile the tool [dataset_transformer](https://github.com/michalfaber/rmpe_dataset_transformer).
58-
Use this tool to create final datasets `dataset/train_dataset.h5` `dataset/val_dataset.h5`
59-
- You can verify the datasets `inspect_dataset.ipynb`
60-
- Start training `python train_pose.py`
63+
original lmdb dataset where a sample is represented as an array: 5 x width x height (3 channels for image, 1 channel for metedata, 1 channel for miss masks)
64+
For MPI dataset there are 6 channels with additional all masks.
65+
Note: set the parameters `datasets` and `val_size` in `generate_hdf5.py`
66+
- Download and compile the dataset server [rmpe_dataset_server](https://github.com/michalfaber/rmpe_dataset_server).
67+
This server generates augmented samples on the fly. Source samples are retrieved from previously generated hdf5 dataset file.
68+
- Start training data server in the first terminal session.
69+
`./rmpe_dataset_server ../../keras_Realtime_Multi-Person_Pose_Estimation/dataset/train_dataset.h5 5555`
70+
- Start validation data server in a second terminal session.
71+
`./rmpe_dataset_server ../../keras_Realtime_Multi-Person_Pose_Estimation/dataset/val_dataset.h5 5556`
72+
- Optionally you can verify the datasets `inspect_dataset.ipynb`
73+
- Set the correct number of samples within `python train_pose.py` - variables "train_samples = ???" and "val_samples = ???".
74+
This number is used by keras to determine how many samples are in 1 epoch.
75+
- Train the model in a third terminal `python train_pose.py`
6176

6277
NOTE:
6378
I trained the model from scratch for 3,5 days on a single GPU 1070 but did't obtain satisfactory results.

demo_image.py

+6-1
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,10 @@ def process (input_image, params, model_params):
238238
print('start processing...')
239239

240240
# load model
241-
model = get_testing_model()
241+
242+
# authors of original model don't use
243+
# vgg normalization (subtracting mean) on input images
244+
model = get_testing_model(vgg_norm=False)
242245
model.load_weights(keras_weights_file)
243246

244247
# load config
@@ -252,5 +255,7 @@ def process (input_image, params, model_params):
252255

253256
cv2.imwrite(output, canvas)
254257

258+
cv2.destroyAllWindows()
259+
255260

256261

model.py

+13-4
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from keras.layers.merge import Multiply
77
from keras.regularizers import l2
88
from keras.initializers import random_normal,constant
9+
import numpy as np
910

1011
def relu(x): return Activation('relu')(x)
1112

@@ -109,7 +110,7 @@ def apply_mask(x, mask1, mask2, num_p, stage, branch):
109110
return w
110111

111112

112-
def get_training_model(weight_decay):
113+
def get_training_model(weight_decay, vgg_norm):
113114

114115
stages = 6
115116
np_branch1 = 38
@@ -130,7 +131,11 @@ def get_training_model(weight_decay):
130131
inputs.append(vec_weight_input)
131132
inputs.append(heat_weight_input)
132133

133-
img_normalized = Lambda(lambda x: x / 256 - 0.5)(img_input)
134+
if vgg_norm:
135+
vgg_mean = np.array([103.939, 116.779, 123.68]) # BGR
136+
img_normalized = Lambda(lambda x: x - vgg_mean)(img_input)
137+
else:
138+
img_normalized = Lambda(lambda x: x / 256 - 0.5)(img_input) # [-0.5, 0.5]
134139

135140
# VGG
136141
stage0_out = vgg_block(img_normalized, weight_decay)
@@ -169,7 +174,7 @@ def get_training_model(weight_decay):
169174
return model
170175

171176

172-
def get_testing_model():
177+
def get_testing_model(vgg_norm=False):
173178
stages = 6
174179
np_branch1 = 38
175180
np_branch2 = 19
@@ -178,7 +183,11 @@ def get_testing_model():
178183

179184
img_input = Input(shape=img_input_shape)
180185

181-
img_normalized = Lambda(lambda x: x / 256 - 0.5)(img_input) # [-0.5, 0.5]
186+
if vgg_norm:
187+
vgg_mean = np.array([103.939, 116.779, 123.68]) # BGR
188+
img_normalized = Lambda(lambda x: x - vgg_mean)(img_input)
189+
else:
190+
img_normalized = Lambda(lambda x: x / 256 - 0.5)(img_input) # [-0.5, 0.5]
182191

183192
# VGG
184193
stage0_out = vgg_block(img_normalized, None)

training/ds_generator_client.py

+137
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,137 @@
1+
import numpy as np
2+
import zmq
3+
from ast import literal_eval as make_tuple
4+
5+
import six
6+
if six.PY3:
7+
buffer_ = memoryview
8+
else:
9+
buffer_ = buffer # noqa
10+
11+
12+
class DataGeneratorClient(object):
13+
14+
def __init__(self, host, port, hwm=20, batch_size=10):
15+
"""
16+
:param host:
17+
:param port:
18+
:param hwm:, optional
19+
The `ZeroMQ high-water mark (HWM)
20+
<http://zguide.zeromq.org/page:all#High-Water-Marks>`_ on the
21+
sending socket. Increasing this increases the buffer, which can be
22+
useful if your data preprocessing times are very random. However,
23+
it will increase memory usage. There is no easy way to tell how
24+
many batches will actually be queued with a particular HWM.
25+
Defaults to 10. Be sure to set the corresponding HWM on the
26+
receiving end as well.
27+
:param batch_size:
28+
:param shuffle:
29+
:param seed:
30+
"""
31+
self.host = host
32+
self.port = port
33+
self.hwm = hwm
34+
self.socket = None
35+
36+
self.split_point = 38
37+
self.vec_num = 38
38+
self.heat_num = 19
39+
40+
self.batch_size = batch_size
41+
42+
def _recv_arrays(self):
43+
"""Receive a list of NumPy arrays.
44+
Parameters
45+
----------
46+
socket : :class:`zmq.Socket`
47+
The socket to receive the arrays on.
48+
Returns
49+
-------
50+
list
51+
A list of :class:`numpy.ndarray` objects.
52+
Raises
53+
------
54+
StopIteration
55+
If the first JSON object received contains the key `stop`,
56+
signifying that the server has finished a single epoch.
57+
"""
58+
headers = self.socket.recv_json()
59+
if 'stop' in headers:
60+
raise StopIteration
61+
arrays = []
62+
63+
for header in headers:
64+
data = self.socket.recv()
65+
buf = buffer_(data)
66+
array = np.frombuffer(buf, dtype=np.dtype(header['descr']))
67+
array.shape = make_tuple(header['shape'])
68+
69+
if header['fortran_order']:
70+
array.shape = header['shape'][::-1]
71+
array = array.transpose()
72+
arrays.append(array)
73+
74+
return arrays
75+
76+
def gen(self):
77+
batches_x, batches_x1, batches_x2, batches_y1, batches_y2 = \
78+
[None]*self.batch_size, [None]*self.batch_size, [None]*self.batch_size, \
79+
[None]*self.batch_size, [None]*self.batch_size
80+
81+
sample_idx = 0
82+
83+
while True:
84+
data_img, mask_img, label = tuple(self._recv_arrays())
85+
86+
# image
87+
dta_img = np.transpose(data_img, (1, 2, 0))
88+
batches_x[sample_idx]=dta_img[np.newaxis, ...]
89+
90+
# mask - the same for vec_weights, heat_weights
91+
vec_weights = np.repeat(mask_img[:,:,np.newaxis], self.vec_num, axis=2)
92+
heat_weights = np.repeat(mask_img[:,:,np.newaxis], self.heat_num, axis=2)
93+
94+
batches_x1[sample_idx]=vec_weights[np.newaxis, ...]
95+
batches_x2[sample_idx]=heat_weights[np.newaxis, ...]
96+
97+
# label
98+
vec_label = label[:self.split_point, :, :]
99+
vec_label = np.transpose(vec_label, (1, 2, 0))
100+
heat_label = label[self.split_point:, :, :]
101+
heat_label = np.transpose(heat_label, (1, 2, 0))
102+
103+
batches_y1[sample_idx]=vec_label[np.newaxis, ...]
104+
batches_y2[sample_idx]=heat_label[np.newaxis, ...]
105+
106+
sample_idx += 1
107+
108+
if sample_idx == self.batch_size:
109+
sample_idx = 0
110+
111+
batch_x = np.concatenate(batches_x)
112+
batch_x1 = np.concatenate(batches_x1)
113+
batch_x2 = np.concatenate(batches_x2)
114+
batch_y1 = np.concatenate(batches_y1)
115+
batch_y2 = np.concatenate(batches_y2)
116+
117+
yield [batch_x, batch_x1, batch_x2], \
118+
[batch_y1, batch_y2,
119+
batch_y1, batch_y2,
120+
batch_y1, batch_y2,
121+
batch_y1, batch_y2,
122+
batch_y1, batch_y2,
123+
batch_y1, batch_y2]
124+
125+
def start(self):
126+
context = zmq.Context()
127+
self.socket = context.socket(zmq.PULL)
128+
self.socket.set_hwm(self.hwm)
129+
self.socket.connect("tcp://{}:{}".format(self.host, self.port))
130+
131+
def stop(self):
132+
if self.socket:
133+
self.socket.__del__()
134+
135+
def restart(self):
136+
self.stop()
137+
self.start()

training/generate_hdf5.py

+5-6
Original file line numberDiff line numberDiff line change
@@ -22,16 +22,15 @@
2222
(val_anno_path, val_img_dir, val_mask_dir, "COCO")
2323
]
2424

25-
#datasets = [
26-
# (val_anno_path, val_img_dir, val_mask_dir, "COCO")
27-
#]
25+
# datasets = [
26+
# (val_anno_path, val_img_dir, val_mask_dir, "COCO")
27+
# ]
2828

2929
joint_all = []
30-
tr_hdf5_path = os.path.join(dataset_dir, "train_pre_dataset.h5")
31-
val_hdf5_path = os.path.join(dataset_dir, "val_pre_dataset.h5")
30+
tr_hdf5_path = os.path.join(dataset_dir, "train_dataset.h5")
31+
val_hdf5_path = os.path.join(dataset_dir, "val_dataset.h5")
3232

3333
val_size = 2645 # size of validation set
34-
3534
#val_size = 300
3635

3736
def process():

training/inspect_dataset.ipynb

+315-42
Large diffs are not rendered by default.

training/train_pose.py

+30-14
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,10 @@
66
sys.path.append("..")
77
from model import get_training_model
88
from ds_iterator import DataIterator
9+
from ds_generator_client import DataGeneratorClient
910
from optimizers import MultiSGD
1011
from keras.callbacks import LearningRateScheduler, ModelCheckpoint, CSVLogger, TensorBoard
1112
from keras.layers.convolutional import Conv2D
12-
from keras.utils.data_utils import get_file
1313
from keras.applications.vgg19 import VGG19
1414

1515
batch_size = 10
@@ -21,6 +21,9 @@
2121
stepsize = 136106 #68053 // after each stepsize iterations update learning rate: lr=lr*gamma
2222
max_iter = 200000 # 600000
2323

24+
# True = start data generator client, False = use augmented dataset file (deprecated)
25+
use_client_gen = True
26+
2427
WEIGHTS_BEST = "weights.best.h5"
2528
TRAINING_LOG = "training.csv"
2629
LOGS_DIR = "./logs"
@@ -30,7 +33,7 @@ def get_last_epoch():
3033
return max(data['epoch'].values)
3134

3235

33-
model = get_training_model(weight_decay)
36+
model = get_training_model(weight_decay, vgg_norm=True)
3437

3538
from_vgg = dict()
3639
from_vgg['conv1_1'] = 'block1_conv1'
@@ -64,15 +67,28 @@ def get_last_epoch():
6467
last_epoch = 0
6568

6669
# prepare generators
67-
train_di = DataIterator("../dataset/train_dataset.h5", data_shape=(3, 368, 368),
68-
mask_shape=(1, 46, 46),
69-
label_shape=(57, 46, 46),
70-
vec_num=38, heat_num=19, batch_size=batch_size, shuffle=True)
7170

72-
val_di = DataIterator("../dataset/val_dataset.h5", data_shape=(3, 368, 368),
73-
mask_shape=(1, 46, 46),
74-
label_shape=(57, 46, 46),
75-
vec_num=38, heat_num=19, batch_size=batch_size, shuffle=True)
71+
if use_client_gen:
72+
train_client = DataGeneratorClient(port=5555, host="localhost", hwm=160, batch_size=10)
73+
train_client.start()
74+
train_di = train_client.gen()
75+
train_samples = 52597
76+
77+
val_client = DataGeneratorClient(port=5556, host="localhost", hwm=160, batch_size=10)
78+
val_client.start()
79+
val_di = val_client.gen()
80+
val_samples = 2645
81+
else:
82+
train_di = DataIterator("../dataset/train_dataset.h5", data_shape=(3, 368, 368),
83+
mask_shape=(1, 46, 46),
84+
label_shape=(57, 46, 46),
85+
vec_num=38, heat_num=19, batch_size=batch_size, shuffle=True)
86+
train_samples=train_di.N
87+
val_di = DataIterator("../dataset/val_dataset.h5", data_shape=(3, 368, 368),
88+
mask_shape=(1, 46, 46),
89+
label_shape=(57, 46, 46),
90+
vec_num=38, heat_num=19, batch_size=batch_size, shuffle=True)
91+
val_samples=val_di.N
7692

7793
# setup lr multipliers for conv layers
7894
lr_mult=dict()
@@ -131,7 +147,7 @@ def get_last_epoch():
131147
loss_weights["weight_stage6_L2"] = 1
132148

133149
# learning rate schedule - equivalent of caffe lr_policy = "step"
134-
iterations_per_epoch = train_di.N // batch_size
150+
iterations_per_epoch = train_samples // batch_size
135151
def step_decay(epoch):
136152
initial_lrate = base_lr
137153
steps = epoch * iterations_per_epoch
@@ -155,12 +171,12 @@ def step_decay(epoch):
155171
model.compile(loss=losses, loss_weights=loss_weights, optimizer=multisgd, metrics=["accuracy"])
156172

157173
model.fit_generator(train_di,
158-
steps_per_epoch=train_di.N // batch_size,
174+
steps_per_epoch=train_samples // batch_size,
159175
epochs=max_iter,
160176
callbacks=callbacks_list,
161177
validation_data=val_di,
162-
validation_steps=val_di.N // batch_size,
163-
use_multiprocessing=True,
178+
validation_steps=val_samples // batch_size,
179+
use_multiprocessing=False,
164180
initial_epoch=last_epoch
165181
)
166182

0 commit comments

Comments
 (0)