Tensor indexing

I have probably found a bug when indexing tensors with tensorflow-metal. It is best demonstrated by the following minimal example:

import tensorflow as tf

print(tf.constant([[1, 2], [3, 4]], dtype=tf.float32)[..., :2, 1])

The expected result is [2, 4] (i.e. the second column of the matrix) which is what I get when tensorflow-metal is not installed (and on other non-Apple machines), but using tensorflow-metal I get [2, 2] (i.e. the first element of the column is repeated - this also happens if there are more than two rows).

The following conditions seem to be necessary in order to trigger this behavior:

  • dtype must be float32; it works correctly with float64, int32 and int64.
  • the sequence of ellipsis (for batch axes), stride (for row), index (for column) is critical; i.e. it does work correctly when the column is also a stride, and it does work if the row is a single number or the "full" slice :.
  • the indexed tensor does not actually have batch axes (the ellipsis is there because it could have)

The original context is: I have function that gets a tensor with 0 or more batch axes containing 4x4 homogenous matrices from which I want to extract the translation, i.e. the first three rows of the last column, which leads to [..., :3, 3].

Versions:

  • python 3.9.6 (system)
  • tensorflow-macos 2.13.0
  • tensorflow-metal 1.0.1
Tensor indexing
 
 
Q