TensorFlow pre-training with GPU acceleration on M1 Mac with Metal fails

I installed TensorFlow 2.6.0 on macOS 12.0.1 on a MacBook Air (2020) with the M1 chip, with tensorflow-metal Pluggable Device for GPU acceleration.

Normally everything works if I disable eager execution.

Now I am testing a pre-training procedure: instantiate a model with a certain loss function -> train it on data -> save the model weights -> instantiate a new model with a different loss function (same architecture) -> set its weights to saved weights from pre-training -> train on data.

Specifically, I am doing this with a variational autoencoder (VAE).

The first training (pre-training) works fine and is done on the GPU (I have tested it on larger datasets than the one used in the reproducible example below and looked at GPU usage). The second training does not work: calling .fit gives an error (Cannot assign a device for operation...). The desired behaviour is to be able to train the model on these data, using GPU.

Below is a reproducible example:

import tensorflow as tf
import tensorflow.keras.layers as layers
import tensorflow.keras.backend as K

from tensorflow.python.framework.ops import disable_eager_execution
disable_eager_execution()

import numpy as np

import sklearn
from sklearn import datasets

eps_std = tf.constant(1e-2, dtype=tf.float64)
eps_sq = eps_std ** 2
K.set_floatx('float64')

## Define encoder, sampler and decoder

class Encoder(layers.Layer):
  def __init__(self, latent_dim=2, name='Encoder', **kwargs):
    super(Encoder, self).__init__(name=name, **kwargs)
    self.l = layers.Dense(32, activation='selu')
    self.mu = layers.Dense(latent_dim)
    self.sigma = layers.Dense(latent_dim)
  def call(self, x):
    x = self.l(x)
    return self.mu(x), self.sigma(x)

class Sampler(layers.Layer):
  def call(self, z_mu, z_sigma):
      eps = tf.keras.backend.random_normal(shape=tf.shape(z_mu))
      return z_mu + eps_std * tf.exp(0.5*z_sigma) * eps

class Decoder(layers.Layer):
  def __init__(self, full_dim, name='Decoder', **kwargs):
      super(Decoder, self).__init__(name=name, **kwargs)
      self.l = layers.Dense(32, activation='selu')
      self.recon = layers.Dense(full_dim)
  def call(self, x):
      x = self.l(x)
      return self.recon(x)

## Define Variational Autoencoder

class vae(tf.keras.Model):
  def __init__(self, full_dim, latent_dim=2, loss_weights=None, name='vae', **kwargs):
      super(vae, self).__init__(name=name, **kwargs)
      self.full_dim = full_dim
      self.latent_dim = latent_dim
      self.encoder = Encoder(latent_dim=self.latent_dim)
      self.decoder = Decoder(full_dim=self.full_dim)
      self.sampler = Sampler()
      self.w = loss_weights
  def loss_distortion(self, y_true, y_pred):
      return tf.reduce_sum((y_true - y_pred) ** 2)
  def loss_rate(self, z_mu, z_sigma):
      return -0.5 * tf.reduce_sum(z_sigma + tf.math.log(eps_sq) - tf.square(z_mu) - eps_sq * tf.exp(z_sigma))
  def call(self, x):
      z_mu, z_sigma = self.encoder(x)
      z = self.sampler(z_mu, z_sigma)
      recon = self.decoder(z)
      if self.w['recon'] > 0.:
          self.add_loss(self.w['recon'] * self.loss_distortion(x, recon))
      if self.w['kldiv'] > 0.:
          self.add_loss(self.w['kldiv'] * self.loss_rate(z_mu, z_sigma))
      return recon

## Load data for training

iris = datasets.load_iris()
data = iris['data']

## Pre-training model

m = vae(full_dim=data.shape[1], 
        name='vae',
        loss_weights={'recon':1.0, 'kldiv':0.5} # reconstruction error, KL-divergence vs isotropic Gaussian prior
)
m.compile(optimizer=tf.keras.optimizers.Adam())
m.fit(x=data, y=None, batch_size=16, epochs=2, shuffle=True)

## Save weights from pre-training

w = m.get_weights()

## Create new model (with a different loss function)

m = vae(full_dim=data.shape[1], loss_weights={'recon':1.0, 'kldiv':0.5}, name='vae')
m.compile(optimizer=tf.keras.optimizers.Adam())

## Instantiate the model (otherwise model weights couldn't be set)

dummy = m.predict(data[0:5])

## Set model weights to the ones we obtained via pre-training

m.set_weights(w)

## Attempt training (gives error)

m.fit(x=data, y=None, batch_size=16, epochs=2, shuffle=True)

See console output of the first .fit call and the error message generated by the second .fit call below in attachment:

Does anybody know how to address this issue?

Many thanks.

TensorFlow pre-training with GPU acceleration on M1 Mac with Metal fails
 
 
Q