I'm implementing optimized matmul on metal: https://github.com/crynux-ai/metal-matmul/blob/main/metal/1_shared_mem.metal
I notice that performance is significantly different with different threadgroup memory set in [computeEncoder setThreadgroupMemoryLength] All other lines are exactly same, the only difference is this parameter.
Matmul performance is roughly 250 GFLops if I set 32768 (max bytes allowed on this M1 Max), but 400 GFLops if I set 8192.
Why does this happen? How can I optimize it?