jax-metal gather operation failing for 2+D inputs with striding

The lax.gather operation fails on inputs with dimension greater than one. First an example working on a 1D input:

>>> import jax
>>> jax.devices()
Metal device set to: AMD Radeon Pro 555X

systemMemory: 32.00 GB
maxCacheSize: 2.00 GB

[MetalDevice(id=0, process_index=0)]
>>> jax.devices()[0].platform
'METAL'
>>> jax.devices()[0].device_kind
'Metal'
>>> jax.devices()[0].client.platform
'METAL'
>>> jax.devices()[0].client.runtime_type
'tfrt'
>>> import jax.numpy as jnp
>>> from jax import lax
>>> x = jnp.arange(1000)
>>> idx = (10 * jnp.arange(50) + 5).astype(int)
>>> idx = idx.reshape(-1,1)
>>> dnums = lax.GatherDimensionNumbers((1,), (), (0,))
>>> out = lax.gather(x, idx, dnums, (5,))
>>> out.shape
(50, 5)

Given 2D or greater the gather operation fails:

>>> x = jnp.arange(1000).reshape(10,100)
>>> idx = jnp.stack(jnp.meshgrid(jnp.arange(3)*2, jnp.arange(20)*3), axis=-1)
>>> dnums = lax.GatherDimensionNumbers((1,2), (), (0,1))
>>> out = lax.gather(x, idx, dnums, (2,5))
LLVM ERROR: Failed to infer result type(s).
[1]    96135 abort      python

This is a problem for array slicing in particular and seems to be a mismatch in tensor sizing as shown in this error:

>>> import jax.numpy as jnp
>>> jnp.arange(10).at[::2].get()
Metal device set to: AMD Radeon Pro 555X

systemMemory: 32.00 GB
maxCacheSize: 2.00 GB

Array([0, 2, 4, 6, 8], dtype=int32)
>>> jnp.arange(20).reshape(4,5).at[::2,::2].get()
"builtin.module"() ({
  "func.func"() ({
  ^bb0(%arg0: tensor<2xsi32>):
    %0 = "mps.constant"() {value = dense<[2, 3, 1]> : tensor<3xsi64>} : () -> tensor<3xsi64>
    %1 = "mps.broadcast_to"(%arg0, %0) : (tensor<2xsi32>, tensor<3xsi64>) -> tensor<2x3x2xsi32>
    "func.return"(%1) : (tensor<2x3x2xsi32>) -> ()
  }) {arg_attrs = [{mhlo.sharding = "{replicated}"}], function_type = (tensor<2xsi32>) -> tensor<2x3x1xsi32>, res_attrs = [{}], sym_name = "main", sym_visibility = "public"} : () -> ()
}) {mhlo.num_partitions = 1 : i32, mhlo.num_replicas = 1 : i32, sym_name = "jit_broadcast_in_dim"} : () -> ()
Traceback (most recent call last):
  File "<stdin>", line 1, in <module>

.....

  File "/Users/will/jax-jax-v0.4.11/jax/_src/dispatch.py", line 465, in backend_compile
    return backend.compile(built_c, compile_options=options)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <unknown>:0: error: type of return operand 0 ('tensor<2x3x2xsi32>') doesn't match function result type ('tensor<2x3x1xsi32>') in function @main
<unknown>:0: note: see current operation: "func.return"(%1) : (tensor<2x3x2xsi32>) -> ()

Some gather operations will work but others fail:

>>> x.at[1:5,1:3].get()
Array([[ 37,  38],
       [ 73,  74],
       [109, 110],
       [145, 146]], dtype=int32)
>>> x.at[1:5,1:3:-1].get()
[1]    16271 segmentation fault  python
jax-metal gather operation failing for 2&#43;D inputs with striding
 
 
Q