I only get this error when using the JAX Metal device (CPU is fine). It seems to be a problem whenever I want to modify values of an array in-place using at
and set
.
note: see current operation:
%2903 = "mhlo.scatter"(%arg3, %2902, %2893) ({
^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
"mhlo.return"(%arg5) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<10x100x4xf32>, tensor<1xsi32>, tensor<10x4xf32>) -> tensor<10x100x4xf32>
blocks = blocks.at[i].set(
...