|
| 1 | +# ---------------------------------------------------------------------------- |
| 2 | +# |
| 3 | +# Phiflow Karman vortex solver framework |
| 4 | +# Copyright 2020-2021 Kiwon Um, Nils Thuerey |
| 5 | +# |
| 6 | +# This program is free software, distributed under the terms of the |
| 7 | +# Apache License, Version 2.0 |
| 8 | +# http://www.apache.org/licenses/LICENSE-2.0 |
| 9 | +# |
| 10 | +# Training |
| 11 | +# |
| 12 | +# ---------------------------------------------------------------------------- |
| 13 | + |
| 14 | +import os, sys, logging, argparse, pickle, glob, random, distutils.dir_util |
| 15 | + |
| 16 | +log = logging.getLogger() |
| 17 | +log.addHandler(logging.StreamHandler()) |
| 18 | +log.setLevel(logging.INFO) |
| 19 | +# log.setLevel(logging.DEBUG) |
| 20 | + |
| 21 | +params = {} |
| 22 | +parser = argparse.ArgumentParser(description='Parameter Parser', formatter_class=argparse.ArgumentDefaultsHelpFormatter) |
| 23 | +parser.add_argument('--gpu', default='0', help='visible GPUs') |
| 24 | +parser.add_argument('--train', default=None, help='training; will load data from this simulation folder (set) and save down-sampled files') |
| 25 | +parser.add_argument('--skip-ds', action='store_true', help='skip down-scaling; assume you have already saved') |
| 26 | +parser.add_argument('--only-ds', action='store_true', help='exit after down-scaling and saving; use only for data pre-processing') |
| 27 | +parser.add_argument('--log', default=None, help='path to a log file') |
| 28 | +parser.add_argument('-s', '--scale', default=4, type=int, help='simulation scale for high-res') |
| 29 | +parser.add_argument('-n', '--nsims', default=1, type=int, help='number of simulations') |
| 30 | +parser.add_argument('-b', '--sbatch', default=1, type=int, help='size of a batch; when 10 simulations with the size of 5, 5 simulations are into two batches') |
| 31 | +parser.add_argument('-t', '--simsteps', default=1500, type=int, help='simulation steps; # of data samples (i.e. frames) per simulation') |
| 32 | +parser.add_argument('-m', '--msteps', default=2, type=int, help='multi steps in training loss') |
| 33 | +parser.add_argument('-e', '--epochs', default=10, type=int, help='training epochs') |
| 34 | +parser.add_argument('--seed', default=None, type=int, help='seed for random number generator') |
| 35 | +parser.add_argument('-r', '--res', default=32, type=int, help='target (i.e., low-res) resolution') # FIXME: save and restore from the data |
| 36 | +parser.add_argument('-l', '--len', default=100, type=int, help='length of the reference axis') # FIXME: save and restore from the data |
| 37 | +parser.add_argument('--model', default='mars_moon', help='(predefined) network model') |
| 38 | +parser.add_argument('--reg-loss', action='store_true', help='turn on regularization loss') |
| 39 | +parser.add_argument('--lr', default=1e-3, type=float, help='start learning rate') |
| 40 | +parser.add_argument('--adplr', action='store_true', help='turn on adaptive learning rate') |
| 41 | +parser.add_argument('--clip-grad', action='store_true', help='turn on clip gradients') |
| 42 | +parser.add_argument('--resume', default=-1, type=int, help='resume training epochs') |
| 43 | +parser.add_argument('--inittf', default=None, help='load initial model weights (warm start)') |
| 44 | +parser.add_argument('--pretf', default=None, help='load pre-trained weights (only for testing pre-trained supervised model; do not use for a warm start!)') |
| 45 | +parser.add_argument('--tf', default='/tmp/phiflow/tf', help='path to a tensorflow output dir (model, logs, etc.)') |
| 46 | +sys.argv += ['--' + p for p in params if isinstance(params[p], bool) and params[p]] |
| 47 | +pargs = parser.parse_args() |
| 48 | +params.update(vars(pargs)) |
| 49 | + |
| 50 | +os.environ['CUDA_VISIBLE_DEVICES'] = params['gpu'] |
| 51 | + |
| 52 | +from phi.physics._boundaries import Domain, OPEN, STICKY as CLOSED |
| 53 | +from phi.tf.flow import * |
| 54 | + |
| 55 | +gpus = tf.config.list_physical_devices('GPU') |
| 56 | +if gpus: |
| 57 | + for gpu in gpus: tf.config.experimental.set_memory_growth(gpu, True) |
| 58 | + logical_gpus = tf.config.experimental.list_logical_devices('GPU') |
| 59 | + log.info('{} Physical GPUs {} Logical GPUs'.format(len(gpus), len(logical_gpus))) |
| 60 | + |
| 61 | +from tensorflow import keras |
| 62 | + |
| 63 | +random.seed(params['seed']) |
| 64 | +np.random.seed(params['seed']) |
| 65 | +tf.random.set_seed(params['seed']) |
| 66 | + |
| 67 | +if params['resume']>0 and params['log']: |
| 68 | + params['log'] = os.path.splitext(params['log'])[0] + '_resume{:04d}'.format(params['resume']) + os.path.splitext(params['log'])[1] |
| 69 | + |
| 70 | +if params['log']: |
| 71 | + distutils.dir_util.mkpath(os.path.dirname(params['log'])) |
| 72 | + log.addHandler(logging.FileHandler(params['log'])) |
| 73 | + |
| 74 | +if (params['nsims'] % params['sbatch']) != 0: |
| 75 | + params['nsims'] = (params['nsims']//params['sbatch'])*params['sbatch'] |
| 76 | + log.info('Number of simulations is not divided by the batch size thus adjusted to {}'.format(params['nsims'])) |
| 77 | + |
| 78 | +log.info(params) |
| 79 | +log.info('tensorflow-{} ({}, {}); keras-{} ({})'.format(tf.__version__, tf.sysconfig.get_include(), tf.sysconfig.get_lib(), keras.__version__, keras.__path__)) |
| 80 | + |
| 81 | +def model_mercury(inputs_dict): |
| 82 | + with tf.name_scope('model_mercury') as scope: |
| 83 | + return keras.Sequential([ |
| 84 | + keras.layers.Input(**inputs_dict), |
| 85 | + keras.layers.Conv2D(filters=32, kernel_size=5, padding='same', activation=tf.nn.relu), |
| 86 | + keras.layers.Conv2D(filters=64, kernel_size=5, padding='same', activation=tf.nn.relu), |
| 87 | + keras.layers.Conv2D(filters=2, kernel_size=5, padding='same', activation=None), # u, v |
| 88 | + ], name='mercury') |
| 89 | + |
| 90 | +def model_mars_moon(inputs_dict): |
| 91 | + with tf.name_scope('model_mars_moon') as scope: |
| 92 | + l_input = keras.layers.Input(**inputs_dict) |
| 93 | + block_0 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(l_input) |
| 94 | + block_0 = keras.layers.LeakyReLU()(block_0) |
| 95 | + |
| 96 | + l_conv1 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(block_0) |
| 97 | + l_conv1 = keras.layers.LeakyReLU()(l_conv1) |
| 98 | + l_conv2 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(l_conv1) |
| 99 | + l_skip1 = keras.layers.add([block_0, l_conv2]) |
| 100 | + block_1 = keras.layers.LeakyReLU()(l_skip1) |
| 101 | + |
| 102 | + l_conv3 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(block_1) |
| 103 | + l_conv3 = keras.layers.LeakyReLU()(l_conv3) |
| 104 | + l_conv4 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(l_conv3) |
| 105 | + l_skip2 = keras.layers.add([block_1, l_conv4]) |
| 106 | + block_2 = keras.layers.LeakyReLU()(l_skip2) |
| 107 | + |
| 108 | + l_conv5 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(block_2) |
| 109 | + l_conv5 = keras.layers.LeakyReLU()(l_conv5) |
| 110 | + l_conv6 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(l_conv5) |
| 111 | + l_skip3 = keras.layers.add([block_2, l_conv6]) |
| 112 | + block_3 = keras.layers.LeakyReLU()(l_skip3) |
| 113 | + |
| 114 | + l_conv7 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(block_3) |
| 115 | + l_conv7 = keras.layers.LeakyReLU()(l_conv7) |
| 116 | + l_conv8 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(l_conv7) |
| 117 | + l_skip4 = keras.layers.add([block_3, l_conv8]) |
| 118 | + block_4 = keras.layers.LeakyReLU()(l_skip4) |
| 119 | + |
| 120 | + l_conv9 = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(block_4) |
| 121 | + l_conv9 = keras.layers.LeakyReLU()(l_conv9) |
| 122 | + l_convA = keras.layers.Conv2D(filters=32, kernel_size=5, padding='same')(l_conv9) |
| 123 | + l_skip5 = keras.layers.add([block_4, l_convA]) |
| 124 | + block_5 = keras.layers.LeakyReLU()(l_skip5) |
| 125 | + |
| 126 | + l_output = keras.layers.Conv2D(filters=2, kernel_size=5, padding='same')(block_5) |
| 127 | + return keras.models.Model(inputs=l_input, outputs=l_output, name='mars_moon') |
| 128 | + |
| 129 | +def lr_schedule(epoch, current_lr): |
| 130 | + """Learning Rate Schedule |
| 131 | +
|
| 132 | + Learning rate is scheduled to be reduced after 10, 15, 20, 22 epochs. |
| 133 | + Called automatically every epoch as part of callbacks during training. |
| 134 | +
|
| 135 | + # Arguments |
| 136 | + epoch (int): The number of epochs |
| 137 | +
|
| 138 | + # Returns |
| 139 | + lr (float32): learning rate |
| 140 | + """ |
| 141 | + lr = current_lr |
| 142 | + if epoch == 23: lr *= 0.5 |
| 143 | + elif epoch == 21: lr *= 1e-1 |
| 144 | + elif epoch == 16: lr *= 1e-1 |
| 145 | + elif epoch == 11: lr *= 1e-1 |
| 146 | + return lr |
| 147 | + |
| 148 | + |
| 149 | +class KarmanFlow(): |
| 150 | + def __init__(self, domain): |
| 151 | + self.domain = domain |
| 152 | + |
| 153 | + shape_v = self.domain.staggered_grid(0).vector['y'].shape |
| 154 | + vel_yBc = np.zeros(shape_v.sizes) |
| 155 | + vel_yBc[0:2, 0:vel_yBc.shape[1]-1] = 1.0 |
| 156 | + vel_yBc[0:vel_yBc.shape[0], 0:1] = 1.0 |
| 157 | + vel_yBc[0:vel_yBc.shape[0], -1:] = 1.0 |
| 158 | + self.vel_yBc = math.tensor(vel_yBc, shape_v) |
| 159 | + self.vel_yBcMask = math.tensor(np.copy(vel_yBc), shape_v) # warning, only works for 1s, otherwise setup/scale |
| 160 | + |
| 161 | + self.inflow = self.domain.scalar_grid(Box[5:10, 25:75]) # TODO: scale with domain if necessary! |
| 162 | + self.obstacles = [Obstacle(Sphere(center=[50, 50], radius=10))] # TODO: scale with domain if necessary! |
| 163 | + |
| 164 | + def step(self, density_in, velocity_in, re, res, buoyancy_factor=0, dt=1.0, make_input_divfree=False, make_output_divfree=True): #, conserve_density=True): |
| 165 | + velocity = velocity_in |
| 166 | + density = density_in |
| 167 | + |
| 168 | + # apply viscosity |
| 169 | + velocity = phi.flow.diffuse.explicit(field=velocity, diffusivity=1.0/re*dt*res*res, dt=dt) |
| 170 | + vel_x = velocity.vector['x'] |
| 171 | + vel_y = velocity.vector['y'] |
| 172 | + |
| 173 | + # apply velocity BCs, only y for now; velBCy should be pre-multiplied |
| 174 | + vel_y = vel_y*(1.0 - self.vel_yBcMask) + self.vel_yBc |
| 175 | + velocity = self.domain.staggered_grid(phi.math.stack([vel_y.data, vel_x.data], channel('vector'))) |
| 176 | + |
| 177 | + pressure = None |
| 178 | + if make_input_divfree: |
| 179 | + velocity, pressure = fluid.make_incompressible(velocity, self.obstacles) |
| 180 | + |
| 181 | + # --- Advection --- |
| 182 | + density = advect.semi_lagrangian(density+self.inflow, velocity, dt=dt) |
| 183 | + velocity = advected_velocity = advect.semi_lagrangian(velocity, velocity, dt=dt) |
| 184 | + # if conserve_density and self.domain.boundaries['accessible_extrapolation'] == math.extrapolation.ZERO: # solid boundary |
| 185 | + # density = field.normalize(density, self.density) |
| 186 | + |
| 187 | + # --- Pressure solve --- |
| 188 | + if make_output_divfree: |
| 189 | + velocity, pressure = fluid.make_incompressible(velocity, self.obstacles) |
| 190 | + |
| 191 | + self.solve_info = { |
| 192 | + 'pressure': pressure, |
| 193 | + 'advected_velocity': advected_velocity, |
| 194 | + } |
| 195 | + |
| 196 | + return [density, velocity] |
| 197 | + |
| 198 | +class PhifDataset(): |
| 199 | + def __init__(self, domain, dirpath, num_frames, num_sims=None, batch_size=1, print_fn=print, skip_preprocessing=False): |
| 200 | + self.dataSims = sorted(glob.glob(dirpath + '/sim_0*'))[0:num_sims] |
| 201 | + self.pathsDen = [ sorted(glob.glob(asim + '/dens_0*.npz')) for asim in self.dataSims ] |
| 202 | + self.pathsVel = [ sorted(glob.glob(asim + '/velo_0*.npz')) for asim in self.dataSims ] |
| 203 | + self.dataFrms = [ np.arange(num_frames) for _ in self.dataSims ] # NOTE: may contain different numbers of frames |
| 204 | + self.batchSize = batch_size |
| 205 | + self.epoch = None |
| 206 | + self.epochIdx = 0 |
| 207 | + self.batch = None |
| 208 | + self.batchIdx = 0 |
| 209 | + self.step = None |
| 210 | + self.stepIdx = 0 |
| 211 | + self.dataPreloaded = None |
| 212 | + self.printFn = print_fn |
| 213 | + self.domain = domain # phiflow: target domain (i.e., low-res.) |
| 214 | + |
| 215 | + self.numOfSims = num_sims |
| 216 | + self.numOfBatchs = self.numOfSims//self.batchSize |
| 217 | + self.numOfFrames = num_frames |
| 218 | + self.numOfSteps = num_frames |
| 219 | + |
| 220 | + if not skip_preprocessing: |
| 221 | + self.printFn('Pre-processing: Loading data from {} = {} and save down-scaled data'.format(dirpath, self.dataSims)) |
| 222 | + for j,asim in enumerate(self.dataSims): |
| 223 | + for i in range(num_frames): |
| 224 | + if not os.path.isfile(self.filenameToDownscaled(self.pathsDen[j][i])): |
| 225 | + d = phi.field.read(file=self.pathsDen[j][i]).at(self.domain.scalar_grid()) |
| 226 | + phi.field.write(field=d, file=self.filenameToDownscaled(self.pathsDen[j][i])) |
| 227 | + self.printFn('Wrote {}'.format(self.filenameToDownscaled(self.pathsDen[j][i]))) |
| 228 | + if not os.path.isfile(self.filenameToDownscaled(self.pathsVel[j][i])): |
| 229 | + v = phi.field.read(file=self.pathsVel[j][i]).at(self.domain.staggered_grid()) |
| 230 | + phi.field.write(field=v, file=self.filenameToDownscaled(self.pathsVel[j][i])) |
| 231 | + self.printFn('Wrote {}'.format(self.filenameToDownscaled(self.pathsVel[j][i]))) |
| 232 | + |
| 233 | + self.printFn('Preload: Loading data from {} = {}'.format(dirpath, self.dataSims)) |
| 234 | + self.dataPreloaded = { # dataPreloaded['sim_key'][frame number][0=density, 1=x-velocity, 2=y-velocity] |
| 235 | + asim: [ |
| 236 | + ( |
| 237 | + np.expand_dims(phi.field.read(file=self.filenameToDownscaled(self.pathsDen[j][i])).values.numpy(('y', 'x')), axis=0), # density |
| 238 | + np.expand_dims(phi.field.read(file=self.filenameToDownscaled(self.pathsVel[j][i])).vector['x'].values.numpy(('y', 'x')), axis=0), # x-velocity |
| 239 | + np.expand_dims(phi.field.read(file=self.filenameToDownscaled(self.pathsVel[j][i])).vector['y'].values.numpy(('y', 'x')), axis=0), # y-velocity |
| 240 | + ) for i in range(num_frames) |
| 241 | + ] for j,asim in enumerate(self.dataSims) |
| 242 | + } # for each, keep shape=[batch-size, res-y, res-x] |
| 243 | + assert len(self.dataPreloaded[self.dataSims[0]][0][0].shape)==3, 'Data shape is wrong.' |
| 244 | + assert len(self.dataPreloaded[self.dataSims[0]][0][1].shape)==3, 'Data shape is wrong.' |
| 245 | + assert len(self.dataPreloaded[self.dataSims[0]][0][2].shape)==3, 'Data shape is wrong.' |
| 246 | + |
| 247 | + self.dataStats = { |
| 248 | + 'std': ( |
| 249 | + np.std(np.concatenate([np.absolute(self.dataPreloaded[asim][i][0].reshape(-1)) for asim in self.dataSims for i in range(num_frames)], axis=-1)), # density |
| 250 | + np.std(np.concatenate([np.absolute(self.dataPreloaded[asim][i][1].reshape(-1)) for asim in self.dataSims for i in range(num_frames)], axis=-1)), # x-velocity |
| 251 | + np.std(np.concatenate([np.absolute(self.dataPreloaded[asim][i][2].reshape(-1)) for asim in self.dataSims for i in range(num_frames)], axis=-1)), # y-velocity |
| 252 | + ) |
| 253 | + } |
| 254 | + |
| 255 | + self.extConstChannelPerSim = {} # extConstChannelPerSim['sim_key'][0=first channel, ...]; for now, only Reynolds Nr. |
| 256 | + num_of_ext_channel = 1 |
| 257 | + for asim in self.dataSims: |
| 258 | + with open(asim+'/params.pickle', 'rb') as f: |
| 259 | + sim_params = pickle.load(f) |
| 260 | + self.extConstChannelPerSim[asim] = [ sim_params['re'] ] # Reynolds Nr. |
| 261 | + |
| 262 | + self.dataStats.update({ |
| 263 | + 'ext.std': [ |
| 264 | + np.std([np.absolute(self.extConstChannelPerSim[asim][i]) for asim in self.dataSims]) for i in range(num_of_ext_channel) # Reynolds Nr |
| 265 | + ] |
| 266 | + }) |
| 267 | + self.printFn(self.dataStats) |
| 268 | + |
| 269 | + def filenameToDownscaled(self, fname): |
| 270 | + return os.path.dirname(fname) + '/ds_' + os.path.basename(fname) |
| 271 | + |
| 272 | + def getInstance(self, sim_idx=0, frame=0): |
| 273 | + d0_hi = math.concat([self.dataPreloaded[self.dataSims[sim_idx+i]][frame][0] for i in range(self.batchSize)], axis=0) |
| 274 | + u0_hi = math.concat([self.dataPreloaded[self.dataSims[sim_idx+i]][frame][1] for i in range(self.batchSize)], axis=0) |
| 275 | + v0_hi = math.concat([self.dataPreloaded[self.dataSims[sim_idx+i]][frame][2] for i in range(self.batchSize)], axis=0) |
| 276 | + return [d0_hi, u0_hi, v0_hi] # TODO: additional channels |
| 277 | + |
| 278 | + def newEpoch(self, exclude_tail=0, shuffle_data=True): |
| 279 | + self.numOfSteps = self.numOfFrames - exclude_tail |
| 280 | + sim_frames = [ (asim, self.dataFrms[i][0:(len(self.dataFrms[i])-exclude_tail)]) for i,asim in enumerate(self.dataSims) ] |
| 281 | + sim_frame_pairs = [] |
| 282 | + for i,_ in enumerate(sim_frames): |
| 283 | + sim_frame_pairs += [ (i, aframe) for aframe in sim_frames[i][1] ] # [(sim_idx, frame_number), ...] |
| 284 | + |
| 285 | + if shuffle_data: random.shuffle(sim_frame_pairs) |
| 286 | + self.epoch = [ list(sim_frame_pairs[i*self.numOfSteps:(i+1)*self.numOfSteps]) for i in range(self.batchSize*self.numOfBatchs) ] |
| 287 | + self.epochIdx += 1 |
| 288 | + self.batchIdx = 0 |
| 289 | + self.stepIdx = 0 |
| 290 | + |
| 291 | + def nextBatch(self): # batch size may be the number of simulations in a batch |
| 292 | + self.batchIdx += self.batchSize |
| 293 | + self.stepIdx = 0 |
| 294 | + |
| 295 | + def nextStep(self): |
| 296 | + self.stepIdx += 1 |
| 297 | + |
| 298 | + def getData(self, consecutive_frames, with_skip=1): |
| 299 | + d_hi = [ |
| 300 | + np.concatenate([ |
| 301 | + self.dataPreloaded[ |
| 302 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] # sim_key |
| 303 | + ][ |
| 304 | + self.epoch[self.batchIdx+i][self.stepIdx][1]+j*with_skip # frames |
| 305 | + ][0] |
| 306 | + for i in range(self.batchSize) |
| 307 | + ], axis=0) for j in range(consecutive_frames+1) |
| 308 | + ] |
| 309 | + u_hi = [ |
| 310 | + np.concatenate([ |
| 311 | + self.dataPreloaded[ |
| 312 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] # sim_key |
| 313 | + ][ |
| 314 | + self.epoch[self.batchIdx+i][self.stepIdx][1]+j*with_skip # frames |
| 315 | + ][1] |
| 316 | + for i in range(self.batchSize) |
| 317 | + ], axis=0) for j in range(consecutive_frames+1) |
| 318 | + ] |
| 319 | + v_hi = [ |
| 320 | + np.concatenate([ |
| 321 | + self.dataPreloaded[ |
| 322 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] # sim_key |
| 323 | + ][ |
| 324 | + self.epoch[self.batchIdx+i][self.stepIdx][1]+j*with_skip # frames |
| 325 | + ][2] |
| 326 | + for i in range(self.batchSize) |
| 327 | + ], axis=0) for j in range(consecutive_frames+1) |
| 328 | + ] |
| 329 | + ext = [ |
| 330 | + self.extConstChannelPerSim[ |
| 331 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] |
| 332 | + ][0] for i in range(self.batchSize) |
| 333 | + ] |
| 334 | + return [d_hi, u_hi, v_hi, ext] |
| 335 | + |
| 336 | + def getPrevData(self, previous_frames, with_skip=1): # NOTE: not in use; need to test |
| 337 | + d_hi = [ |
| 338 | + math.concat([ |
| 339 | + self.dataPreloaded[ |
| 340 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] |
| 341 | + ][ |
| 342 | + max([0, self.epoch[self.batchIdx+i][self.stepIdx][1]-j*with_skip]) |
| 343 | + ][0] |
| 344 | + for i in range(self.batchSize) |
| 345 | + ], axis=0) for j in range(previous_frames) |
| 346 | + ] |
| 347 | + u_hi = [ |
| 348 | + math.concat([ |
| 349 | + self.dataPreloaded[ |
| 350 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] |
| 351 | + ][ |
| 352 | + max([0, self.epoch[self.batchIdx+i][self.stepIdx][1]-j*with_skip]) |
| 353 | + ][1] |
| 354 | + for i in range(self.batchSize) |
| 355 | + ], axis=0) for j in range(previous_frames) |
| 356 | + ] |
| 357 | + v_hi = [ |
| 358 | + math.concat([ |
| 359 | + self.dataPreloaded[ |
| 360 | + self.dataSims[self.epoch[self.batchIdx+i][self.stepIdx][0]] |
| 361 | + ][ |
| 362 | + max([0, self.epoch[self.batchIdx+i][self.stepIdx][1]-j*with_skip]) |
| 363 | + ][2] |
| 364 | + for i in range(self.batchSize) |
| 365 | + ], axis=0) for j in range(previous_frames) |
| 366 | + ] |
| 367 | + # TODO: additional channels |
| 368 | + return [d_hi, v_hi] |
| 369 | + |
| 370 | + |
| 371 | +domain = Domain(y=params['res']*2, x=params['res'], bounds=Box[0:params['len']*2, 0:params['len']], boundaries=OPEN) |
| 372 | +simulator_lo = KarmanFlow(domain=domain) |
| 373 | + |
| 374 | +dataset = PhifDataset( |
| 375 | + domain=domain, |
| 376 | + dirpath=params['train'], |
| 377 | + num_frames=params['simsteps'], num_sims=params['nsims'], batch_size=params['sbatch'], |
| 378 | + print_fn=log.info, |
| 379 | + skip_preprocessing=params['skip_ds'] |
| 380 | +) |
| 381 | +if params['only_ds']: exit(0) |
| 382 | + |
| 383 | +if params['pretf']: |
| 384 | + with open(os.path.dirname(params['pretf'])+'/stats.pickle', 'rb') as f: ld_stats = pickle.load(f) |
| 385 | + dataset.dataStats['in.std'] = (ld_stats['in.std'][0], (ld_stats['in.std'][1], ld_stats['in.std'][2])) |
| 386 | + dataset.dataStats['out.std'] = ld_stats['out.std'] |
| 387 | + log.info(dataset.dataStats) |
| 388 | + |
| 389 | +if params['resume']>0: |
| 390 | + with open(params['tf']+'/dataStats.pickle', 'rb') as f: dataset.dataStats = pickle.load(f) |
| 391 | + |
| 392 | +if (params['train'] is None): |
| 393 | + log.info(params['train']) |
| 394 | + log.info('No pre-loadable training data path is given.') |
| 395 | + exit(0) |
| 396 | + |
| 397 | +tf_tb_writer = tf.summary.create_file_writer(params['tf']+'/summary/training') |
| 398 | + |
| 399 | +# model |
| 400 | +netModel = eval('model_{}'.format(params['model'])) |
| 401 | +model = netModel(dict(shape=(params['res']*2, params['res'], 3))) |
| 402 | +model.summary(print_fn=log.info) |
| 403 | + |
| 404 | +if params['pretf']: |
| 405 | + log.info('load a pre-trained model: {}'.format(params['pretf'])) |
| 406 | + ld_model = keras.models.load_model(params['pretf'], compile=False) |
| 407 | + model.set_weights(ld_model.get_weights()) |
| 408 | + |
| 409 | +if params['inittf']: |
| 410 | + log.info('load an initial model (warm start): {}'.format(params['inittf'])) |
| 411 | + ld_model = keras.models.load_model(params['inittf'], compile=False) |
| 412 | + model.set_weights(ld_model.get_weights()) |
| 413 | + |
| 414 | +if params['resume']<1: |
| 415 | + [ params['tf'] and distutils.dir_util.mkpath(params['tf']) ] |
| 416 | + with open(params['tf']+'/dataStats.pickle', 'wb') as f: pickle.dump(dataset.dataStats, f) |
| 417 | + |
| 418 | +else: |
| 419 | + ld_model = keras.models.load_model(params['tf']+'/model_epoch{:04d}.h5'.format(params['resume'])) |
| 420 | + model.set_weights(ld_model.get_weights()) |
| 421 | + |
| 422 | +opt = tf.keras.optimizers.Adam(learning_rate=params['lr']) |
| 423 | + |
| 424 | +def to_feature(dens_vel_grid_array, ext_const_channel): |
| 425 | + # drop the unused edges of the staggered velocity grid making its dim same to the centered grid's |
| 426 | + with tf.name_scope('to_feature') as scope: |
| 427 | + return math.stack( |
| 428 | + [ |
| 429 | + dens_vel_grid_array[1].vector['x'].x[:-1].values, # u |
| 430 | + dens_vel_grid_array[1].vector['y'].y[:-1].values, # v |
| 431 | + math.ones(dens_vel_grid_array[0].shape)*ext_const_channel # Re |
| 432 | + ], |
| 433 | + math.channel('channels') |
| 434 | + ) |
| 435 | + |
| 436 | +def to_staggered(tf_tensor, domain): |
| 437 | + with tf.name_scope('to_staggered') as scope: |
| 438 | + return domain.staggered_grid( |
| 439 | + math.stack( |
| 440 | + [ |
| 441 | + math.tensor(tf.pad(tf_tensor[..., 1], [(0,0), (0,1), (0,0)]), math.batch('batch'), math.spatial('y, x')), # v |
| 442 | + math.tensor(tf.pad(tf_tensor[..., 0], [(0,0), (0,0), (0,1)]), math.batch('batch'), math.spatial('y, x')), # u |
| 443 | + ], math.channel('vector') |
| 444 | + ) |
| 445 | + ) |
| 446 | + |
| 447 | +def train_step(pf_in_dens_gt, pf_in_velo_gt, pf_in_Re, i_step): |
| 448 | + with tf.name_scope('train_step'), tf.GradientTape() as tape: |
| 449 | + with tf.name_scope('sol') as scope: |
| 450 | + pf_co_prd, pf_cv_md = [], [] # predicted states with correction, inferred velocity corrections |
| 451 | + for i in range(params['msteps']): |
| 452 | + with tf.name_scope('solve_and_correct') as scope: |
| 453 | + with tf.name_scope('solver_step') as scope: |
| 454 | + pf_co_prd += [ |
| 455 | + simulator_lo.step( |
| 456 | + density_in=pf_in_dens_gt[0] if i==0 else pf_co_prd[-1][0], |
| 457 | + velocity_in=pf_in_velo_gt[0] if i==0 else pf_co_prd[-1][1], |
| 458 | + re=pf_in_Re, |
| 459 | + res=params['res'], |
| 460 | + ) |
| 461 | + ] # pf_co_prd: [[density1, velocity1], [density2, velocity2], ...] |
| 462 | + |
| 463 | + with tf.name_scope('pred') as scope: |
| 464 | + model_input = to_feature(pf_co_prd[-1], pf_in_Re) |
| 465 | + model_input /= math.tensor([dataset.dataStats['std'][1], dataset.dataStats['std'][2], dataset.dataStats['ext.std'][0]], channel('channels')) # [u, v, Re] |
| 466 | + model_out = model(model_input.native(['batch', 'y', 'x', 'channels']), training=True) |
| 467 | + model_out *= [dataset.dataStats['std'][1], dataset.dataStats['std'][2]] # [u, v] |
| 468 | + pf_cv_md += [ to_staggered(model_out, domain) ] # pf_cv_md: [velocity_correction1, velocity_correction2, ...] |
| 469 | + |
| 470 | + pf_co_prd[-1][1] = pf_co_prd[-1][1] + pf_cv_md[-1] |
| 471 | + |
| 472 | + with tf.name_scope('loss') as scope, tf_tb_writer.as_default(): |
| 473 | + with tf.name_scope('steps_x') as scope: |
| 474 | + loss_steps_x = [ |
| 475 | + tf.nn.l2_loss( |
| 476 | + ( |
| 477 | + pf_in_velo_gt[i+1].vector['x'].values.native(('batch', 'y', 'x')) |
| 478 | + - pf_co_prd[i][1].vector['x'].values.native(('batch', 'y', 'x')) |
| 479 | + )/dataset.dataStats['std'][1] |
| 480 | + ) |
| 481 | + for i in range(params['msteps']) |
| 482 | + ] |
| 483 | + loss_steps_x_sum = tf.math.reduce_sum(loss_steps_x) |
| 484 | + |
| 485 | + with tf.name_scope('steps_y') as scope: |
| 486 | + loss_steps_y = [ |
| 487 | + tf.nn.l2_loss( |
| 488 | + ( |
| 489 | + pf_in_velo_gt[i+1].vector['y'].values.native(('batch', 'y', 'x')) |
| 490 | + - pf_co_prd[i][1].vector['y'].values.native(('batch', 'y', 'x')) |
| 491 | + )/dataset.dataStats['std'][2] |
| 492 | + ) |
| 493 | + for i in range(params['msteps']) |
| 494 | + ] |
| 495 | + loss_steps_y_sum = tf.math.reduce_sum(loss_steps_y) |
| 496 | + |
| 497 | + loss = (loss_steps_x_sum + loss_steps_y_sum)/params['msteps'] |
| 498 | + |
| 499 | + for i,a_step_loss in enumerate(loss_steps_x): tf.summary.scalar(name='loss_each_step_vel_x{:02d}'.format(i+1), data=a_step_loss, step=math.to_int64(i_step).native()) |
| 500 | + for i,a_step_loss in enumerate(loss_steps_y): tf.summary.scalar(name='loss_each_step_vel_y{:02d}'.format(i+1), data=a_step_loss, step=math.to_int64(i_step).native()) |
| 501 | + tf.summary.scalar(name='sum_steps_loss', data=loss, step=math.to_int64(i_step).native()) |
| 502 | + |
| 503 | + total_loss = loss |
| 504 | + if params['reg_loss']: |
| 505 | + reg_loss = tf.math.add_n(model.losses) |
| 506 | + total_loss += reg_loss |
| 507 | + tf.summary.scalar(name='loss_regularization', data=reg_loss, step=math.to_int64(i_step).native()) |
| 508 | + |
| 509 | + tf.summary.scalar(name='loss', data=total_loss, step=math.to_int64(i_step).native()) |
| 510 | + |
| 511 | + with tf.name_scope('apply_gradients') as scope: |
| 512 | + gradients = tape.gradient(total_loss, model.trainable_variables) |
| 513 | + opt.apply_gradients(zip(gradients, model.trainable_variables)) |
| 514 | + |
| 515 | + return math.tensor(total_loss) |
| 516 | + |
| 517 | +jit_step = math.jit_compile(train_step) |
| 518 | + |
| 519 | +i_st = 0 |
| 520 | +for j in range(params['epochs']): # training |
| 521 | + dataset.newEpoch(exclude_tail=params['msteps']) |
| 522 | + if j<params['resume']: |
| 523 | + log.info('resume: skipping {} epoch'.format(j+1)) |
| 524 | + i_st += dataset.numOfSteps*dataset.numOfBatchs |
| 525 | + continue |
| 526 | + |
| 527 | + for ib in range(dataset.numOfBatchs): # for each batch |
| 528 | + for i in range(dataset.numOfSteps): # for each step |
| 529 | + # adata: [[dens0, dens1, ...], [x-velo0, x-velo1, ...], [y-velo0, y-velo1, ...], [ReynoldsNr(s)]] |
| 530 | + adata = dataset.getData(consecutive_frames=params['msteps'], with_skip=1) |
| 531 | + dens_gt = [ # [density0:CenteredGrid, density1, ...] |
| 532 | + domain.scalar_grid( |
| 533 | + math.tensor(adata[0][k], math.batch('batch'), math.spatial('y, x')) |
| 534 | + ) for k in range(params['msteps']+1) |
| 535 | + ] |
| 536 | + velo_gt = [ # [velocity0:StaggeredGrid, velocity1, ...] |
| 537 | + domain.staggered_grid( |
| 538 | + math.stack( |
| 539 | + [ |
| 540 | + math.tensor(adata[2][k], math.batch('batch'), math.spatial('y, x')), |
| 541 | + math.tensor(adata[1][k], math.batch('batch'), math.spatial('y, x')), |
| 542 | + ], math.channel('vector') |
| 543 | + ) |
| 544 | + ) for k in range(params['msteps']+1) |
| 545 | + ] |
| 546 | + re_nr = math.tensor(adata[3], math.batch('batch')) |
| 547 | + |
| 548 | + if i_st==0: tf.summary.trace_on(graph=True, profiler=True) |
| 549 | + |
| 550 | + l2 = jit_step(dens_gt, velo_gt, re_nr, math.tensor(i_st)) |
| 551 | + |
| 552 | + if i_st==0: |
| 553 | + with tf_tb_writer.as_default(): |
| 554 | + tf.summary.trace_export(name="trace_train_step", step=i_st, profiler_outdir=params['tf']+'/summary/training') |
| 555 | + |
| 556 | + i_st += 1 |
| 557 | + |
| 558 | + log.info('epoch {:03d}/{:03d}, batch {:03d}/{:03d}, step {:04d}/{:04d}: loss={}'.format( |
| 559 | + j+1, params['epochs'], ib+1, dataset.numOfBatchs, i+1, dataset.numOfSteps, l2 |
| 560 | + )) |
| 561 | + dataset.nextStep() |
| 562 | + |
| 563 | + dataset.nextBatch() |
| 564 | + |
| 565 | + if j%10==9: model.save(params['tf']+'/model_epoch{:04d}.h5'.format(j+1)) |
| 566 | + |
| 567 | +tf_tb_writer.close() |
| 568 | +model.save(params['tf']+'/model.h5') |
0 commit comments