Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

conv2d_mod/Conv2D NCHW not implemented #12

Open
xiaoliangbai opened this issue Oct 19, 2020 · 2 comments
Open

conv2d_mod/Conv2D NCHW not implemented #12

xiaoliangbai opened this issue Oct 19, 2020 · 2 comments

Comments

@xiaoliangbai
Copy link

generated_images = self.GAN.GM.predict(n1 + [n2], batch_size = BATCH_SIZE)

File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training.py", line 909, in predict
use_multiprocessing=use_multiprocessing)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 722, in predict
callbacks=callbacks)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/engine/training_arrays.py", line 393, in model_iteration
batch_outs = f(ins_batch)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/keras/backend.py", line 3740, in call
outputs = self._graph_fn(*converted_inputs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1081, in call
return self._call_impl(args, kwargs)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1121, in _call_impl
return self._call_flat(args, self.captured_inputs, cancellation_manager)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 1224, in _call_flat
ctx, args, cancellation_manager=cancellation_manager)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/function.py", line 511, in call
ctx=ctx)
File "/usr/local/lib/python3.6/dist-packages/tensorflow_core/python/eager/execute.py", line 67, in quick_execute
six.raise_from(core._status_to_exception(e.code, message), None)
File "", line 3, in raise_from
tensorflow.python.framework.errors_impl.UnimplementedError: The Conv2D op currently only supports the NHWC tensor format on the CPU. The op was given the format: NCHW
[[node model_1/conv2d_mod/Conv2D
(defined at /usr/local/lib/python3.6/dist-packages/tensorflow_core/python/framework/ops.py:1751) ]] [Op:__inference_keras_scratch_graph_11413]

Function call stack:
keras_scratch_graph

Seems conv2d does not take NCHW data format. I tried to force to run on gpu (with tf.device('/gpu:1'):...), it did not work.
I also tried different tf versions (2.0, 2.3), even with docker image for tf2.0, all got into the same issue.

Anyone knows how to get around this issue?
Thanks

@anthonyivol
Copy link

It is because it runs on CPU, try batch_size = 1, and in conv_mod.py :

# add this
x = tf.transpose(x, [0, 2, 3, 1])

# change NCHW to NHWC
x = tf.nn.conv2d(x, w, strides=self.strides, padding="SAME", data_format="NHWC")

# add this
x = tf.transpose(x, [0, 3, 1, 2])

@xiaoliangbai
Copy link
Author

Thanks Anthony, your solution works.
I thought weights also need to transpose axis in_chan to match with activation data format, turns out it doesn't.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants