JAX Metal error: failed to legalize operation 'mhlo.scatter'

I only get this error when using the JAX Metal device (CPU is fine). It seems to be a problem whenever I want to modify values of an array in-place using at and set.

note: see current operation: 
%2903 = "mhlo.scatter"(%arg3, %2902, %2893) ({
^bb0(%arg4: tensor<f32>, %arg5: tensor<f32>):
  "mhlo.return"(%arg5) : (tensor<f32>) -> ()
}) {indices_are_sorted = true, scatter_dimension_numbers = #mhlo.scatter<update_window_dims = [0, 1], inserted_window_dims = [1], scatter_dims_to_operand_dims = [1]>, unique_indices = true} : (tensor<10x100x4xf32>, tensor<1xsi32>, tensor<10x4xf32>) -> tensor<10x100x4xf32>
        blocks = blocks.at[i].set(
...

Thx for reporting it. Several bugs of advanced indexing, involving GatherOp and ScatterOp conversion have been fixed at the tip. The example in the post shall be fixed. The fixes will be integrated into next release of jax-metal.

Hi, is there an ETA on the next release with these fixes? Thanks in advance!

I would like to add to this thread that 64bit mode doesn't work for me when gpu is enabled.

Thanks for the great code!

@dingshuhan any ETA on this?

Hi @dingshuhan any update on this release? The bug on scatter means that jacobians in Jax can't be computed on jax-metal.

This seems to be fixed in the latest release (0.0.5)

JAX Metal error: failed to legalize operation 'mhlo.scatter'
 
 
Q