Pytorch Training Fails on Apple M2 - device = "mps"

I'm trying to train a simple GNN on Cora dataset which is really small. I use Jupyter notebook. I get dead kernel every time. When I switch to device ="cpu" it runs quickly. Am I missing something while using MPS

Hi @narenq7 - could you please share the PyTorch version you are using (from pip list)? Also, if you could share a small sample reproducing the issue, it would be very useful.

Pytorch Training Fails on Apple M2 - device = "mps"
 
 
Q