tf.random not working inside tf.function

I've stumbled into issues using tensorflow-metal on the M1.

When generating random tensors, the tf.random function will produce the same tensor after calling it once inside a tf.function, while it is fine outside of a tf.function.

Example:

import tensorflow as tf

tf.random.set_seed(1)

def func():
    return tf.random.uniform(())

Calling from a regular function works:

Edit: Only works when setting the seed beforehand. If not, same issue as below happens, where the random tensor is only generated 'once'.

print("Normal function -> seed works")
print(func())
print(func())
print(func())
print(func())
> Normal function -> seed works
> tf.Tensor(0.16513085, shape=(), dtype=float32)
> tf.Tensor(0.51010704, shape=(), dtype=float32)
> tf.Tensor(0.8292774, shape=(), dtype=float32)
> tf.Tensor(0.2364521, shape=(), dtype=float32)

While when using it inside a tf.function:

tf_func = tf.function(func)
print("tf.function -> seed stops working after 1 generation")
print(tf_func())
print(tf_func())
print(tf_func())
print(tf_func())
> tf.function -> seed stops working after 1 generation
> tf.Tensor(0.81269646, shape=(), dtype=float32)
> tf.Tensor(0.31179297, shape=(), dtype=float32)
> tf.Tensor(0.31179297, shape=(), dtype=float32)
> tf.Tensor(0.31179297, shape=(), dtype=float32)

I've seen that the team is working on a similar issue atm, but maybe this is something your are not aware of right now.

System Info

system: Darwin
release: 21.4.0
version: Darwin Kernel Version 21.4.0: Fri Mar 18 00:47:26 PDT 2022; root:xnu-8020.101.4~15/RELEASE_ARM64_T8101
machine: arm64
processor: arm
python_implementation: CPython
python_version: 3.9.7
python_version_tuple: ('3', '9', '7')
python_build: ('default', 'Sep 29 2021 19:24:02')
python_compiler: Clang 11.1.0 
platform: macOS-12.3.1-arm64-arm-64bit

Packages:
tensorflow-macos         2.9.2
tensorflow-metal         0.5.0
Post not yet marked as solved Up vote post of fstermann Down vote post of fstermann
813 views

Replies

Hi @fstermann

Thanks for reporting the issue! We have a fix for this issue that is included the next tensorflow-metal version we will release.

Unfortunately the issue was not fixed with the changes in tensorflow-metal==0.5.1 as I suspected. I will need to investigate more since this has something to do with the tf.function interfering with progressing the counter in the RNG.