I have probably found a bug when indexing tensors with tensorflow-metal. It is best demonstrated by the following minimal example:
import tensorflow as tf
print(tf.constant([[1, 2], [3, 4]], dtype=tf.float32)[..., :2, 1])
The expected result is [2, 4]
(i.e. the second column of the matrix) which is what I get when tensorflow-metal is not installed (and on other non-Apple machines), but using tensorflow-metal I get [2, 2]
(i.e. the first element of the column is repeated - this also happens if there are more than two rows).
The following conditions seem to be necessary in order to trigger this behavior:
- dtype must be float32; it works correctly with float64, int32 and int64.
- the sequence of ellipsis (for batch axes), stride (for row), index (for column) is critical; i.e. it does work correctly when the column is also a stride, and it does work if the row is a single number or the "full" slice
:
. - the indexed tensor does not actually have batch axes (the ellipsis is there because it could have)
The original context is: I have function that gets a tensor with 0 or more batch axes containing 4x4 homogenous matrices from which I want to extract the translation, i.e. the first three rows of the last column, which leads to [..., :3, 3]
.
Versions:
- python 3.9.6 (system)
- tensorflow-macos 2.13.0
- tensorflow-metal 1.0.1