Models with LSTM and GRU layers don't apply the masking layer during inference

Intro

When training RNN models on time-series datasets (batch_size, timestamps, features) with variable-length timestamps, the official TF documentation suggests padding the examples (so all the examples have the same number of timestamps) and then use the Masking layer before the RNN layers (SimpleRNN, LSTM, GRU) to avoid using the padded timestamps during both training and inference.

Problem

Even though the models implemented with LSTM and GRU layers ignore padded timestamps as defined in the Masking layer during training, it seems that they ignore the mask during inference (model.predict() or model.evaluate()). However, the SimpleRNN layer works as intended (ignores the padded values during both training and inference). Below are the versions of my TF packages:

  • tensorflow-deps=2.10.0
  • tensorflow-estimator=2.9.0
  • tensorflow-macos=2.9.0
  • tensorflow-metal=0.5.0

Code

The code below trains a very simple LSTM network on some random data which are padded, and then I use the model to predict on two examples which should result in the same prediction probabilities if the layers work as intended, but I get different results indicating the masking layer doesn't work during inference. However, the code works as intended in Colab.

import tensorflow as tf
from tensorflow import keras
import numpy as np
import random as python_random

#Make the values reproducible
np.random.seed(1375)
python_random.seed(1375)
tf.random.set_seed(1375)

#The padding value
padding_value = 2.0

#Number of samples and features in the dataset
num_samples = 1000
num_features = 3

#Create a dataset with variable-length timestamps
data = [np.random.randn(np.random.randint(1,10),num_features) for _ in range(num_samples)]

#Create random binary target values
target = np.random.randint(0,2, num_samples)

#Pad the dataset with the padding value so all the examples have the same number of timestamps
data = keras.preprocessing.sequence.pad_sequences(data, dtype="float32", padding="post", value = padding_value)

#Defining and training the model  
input = keras.layers.Input(shape=(None, data.shape[2]))
x = keras.layers.Masking(mask_value = padding_value)(input)
x = keras.layers.LSTM(6)(x)
output = keras.layers.Dense(1, activation="sigmoid")(x)

model = keras.Model(input, output)

model.compile(
    loss= "binary_crossentropy",
    metrics = ["accuracy"],
    optimizer = "RMSprop"
)

history = model.fit(x = data,
                    y= target,
                    validation_split=0.2,
                    batch_size=100,
                    epochs=10)

##############
##Below I am creating two samples (all_ones_no_pad, all_ones_with_pad),
##But I am padding the second one with a lot of padding.
##If the padded values are ignored during inference (which is what is intended), they must result
##in the same prediction probabilities.
##############

all_ones_no_pad = np.ones(shape=(1, 8, num_features))
all_ones_with_pad = np.ones(shape=(1, 150, num_features))
all_ones_with_pad[:,8:,:] = padding_value

no_padded_prediction = model.predict(all_ones_no_pad)
padded_prediction = model.predict(all_ones_with_pad)

if no_padded_prediction == padded_prediction:
  print(f"Both examples result in the same prediction prob {padded_prediction} as intended.")
else:
  print(f"Examples were treated differently. Prediction probs are {no_padded_prediction} and {padded_prediction}")

The code finishes running, but I get the following error in the console:


2023-02-02 12:29:56.458998: W tensorflow/core/common_runtime/forward_type_inference.cc:231] Type inference failed. This indicates an invalid graph that escaped type checking. Error message: INVALID_ARGUMENT: expected compatible input types, but input 1:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_LEGACY_VARIANT
    }
  }
}
 is neither a subtype nor a supertype of the combined inputs preceding it:
type_id: TFT_OPTIONAL
args {
  type_id: TFT_PRODUCT
  args {
    type_id: TFT_TENSOR
    args {
      type_id: TFT_FLOAT
    }
  }
}

        while inferring type of node 'cond_19/output/_23'
2023-02-02 12:29:56.461511: I tensorflow/core/grappler/optimizers/custom_graph_optimizer_registry.cc:113] Plugin optimizer for device_type GPU is enabled.