tf.function decorator with tensorflow-metal breaks tf.signal.fft3d()

I consistently receive corrupted results from tf.signal.fft3d() when it is within a function that has a @tf.function decorator. The results are all zero (0.) for entries after a certain x index (see image). Surprisingly, the issue depends on the matrix size. For example, (1023, 1023, 287) works but (1023, 1023, 575) does not. The issue is problematic because it occurs silently and not for all matrix sizes, i.e. can easily slip through tests.

The error occurs only when tensorflow-metal is installed. The Tensorflow version is 2.16.1. My hardware is a Macbook Pro M3 Max with 40 GPU cores, 128 GB RAM running MacOS Sonoma version 14.5 (23F79). A Python environment to reproduce the bug can be created as follows:

conda create --name tfmetalbug python=3.11.9
conda activate tfmetalbug
pip install tensorflow tensorflow-metal
conda install matplotlib

The following code reproduces the issue:

import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt

# Wrap fft3d with tf.function
@tf.function
def fft3d_wrapper_function(x):
    return tf.signal.fft3d(x)

# Generate a 3D image
img = tf.random.normal(shape=(1023, 1023, 575), stddev=1., dtype=float) # generate random 3d image
img = tf.dtypes.cast(img, tf.complex64)  # convert to complex values

# Compute the 3D FFT
img_fft = fft3d_wrapper_function(img)

# Visualize the 3D FFT
plt.imshow(np.real(img_fft)[:, img_fft.shape[1]//2+10, :], cmap="gray", vmin=-0.001, vmax=0.001)
plt.savefig("fft3d_wrapper_function.png")

For me, removing the @tf.function decorator has resolved the issue.

tf.function decorator with tensorflow-metal breaks tf.signal.fft3d()
 
 
Q