tensorflow-macos gives shape errors only if tensorflow-metal is installed

Greetings!

I am working on a project using Magenta (https://github.com/magenta/magenta/) which has tools to train neural networks and generate music using them.

Everything works if I just have tensorflow-macos, but calculations are done on the CPU. If I install tensorflow-metal I get shape errors such as:

tensorflow.python.framework.errors_impl.InvalidArgumentError: Graph execution error:

Detected at node 'RNN/while/AttentionCellWrapper/Linear/concat' defined at (most recent call last):
Node: 'RNN/while/AttentionCellWrapper/Linear/concat'
Detected at node 'RNN/while/AttentionCellWrapper/Linear/concat' defined at (most recent call last):
Node: 'RNN/while/AttentionCellWrapper/Linear/concat'
2 root error(s) found.
  (0) INVALID_ARGUMENT: concat_dim tensor should be a scalar integer, but got shape [1,256]
	 [[{{node RNN/while/AttentionCellWrapper/Linear/concat}}]]
	 [[RNN/while/AttentionCellWrapper/AttnOutputProjection/add/_69]]
  (1) INVALID_ARGUMENT: concat_dim tensor should be a scalar integer, but got shape [1,256]
	 [[{{node RNN/while/AttentionCellWrapper/Linear/concat}}]]

I woult think of it as a problem with Apple's GPU acceleration implementation, as I only need to uninstall tensorflow-metal for the code to work again.

Steps to reproduce:

  1. Create a new virtualenv using the attached requirements file (pip install -r magenta-requirements.txt --no-deps - no-deps is necessary because magenta depends on "regular" tensorflow and does not know about tensorflow-macos)
  2. Download example drum bundle: curl --output drum_kit_rnn.mag download.magenta.tensorflow.org/models/drum_kit_rnn.mag
  3. Generate drums: drums_rnn_generate --bundle_file="drum_kit_rnn.mag"
  4. See that the program worked, some midi files are generated in /tmp/drums_rnn
  5. Install GPU acceleration pip install tensorflow-metal
  6. Retry step 3 and get the error

Thanks in advance for any help!

Hi @aggro,

Thanks for reporting this and providing the steps to reproduce. I'll be taking a look at this since it looks like there is something going wrong in our concat implementation. I'll let you know if I have any additional questions or when I've figured out a fix for the issue.

Thanks! Looking forward to it, and I'm available for any followup questions/talk :)

tensorflow-macos gives shape errors only if tensorflow-metal is installed
 
 
Q