Python - Complex-valued linear algebra on GPU

Hi,

I am looking for a routine to perform complex-valued linear algebra on the GPU in python for scientific programming, in particular quantum physics simulations.

At the moment I am looking for a routine for complex-valued matrix multiplication. I found MLX has a routine for float matrix multiplication, but it does not directly work for complex-valued matrices. I figured a work-around by splitting the complex valued matrix into real and imaginary part and working with the pair, but it makes it cumbersome to integrate with the remainder of the code. I was hoping for a library-based implementation similar to cupy.

I also tried out using the tensorflow linear algebra routines, but I couldn't get them to run on the GPU by now. Specifically, a testfile with a tensorflow.keras.applications.ResNet50 routine runs on the GPU, but the routines from tensorflow.linalg and tensorflow.math that I tested (matmul, expm, eigh) were not running on the GPU.

Any advice on how to make linear algebra calculations on mac GPUs work is highly appreciated! For my application the unified memory might be especially beneficial.

Thank you!