Best way to update values in MPSTensors?

I have few loops that are part of my Deep Learning pipeline i got everything in MPSGraph but graph.for loop with scatter is extremely slow for my use case and there is no sane way to vectorize it.

What is best way to update value in existing tensor/or to help Metal optimizer to vectorize this ops, maybe something like jax.at[idx].set?

(Right now im using graph.for loop with scatter/gather ops)

Best way to update values in MPSTensors?
 
 
Q