I wanted to deploy some ViT models on an iPhone. I referred to https://machinelearning.apple.com/research/vision-transformers for deployment and wrote a simple demo based on the code from https://github.com/apple/ml-vision-transformers-ane. However, I found that the uncached load time on the phone is very long. According to the blog, the input is already aligned to 64 bytes, but the speed is still very slow. Is there any way to speed it up? This is my test case:
import torch
import coremltools as ct
import math
from torch import nn
class SelfAttn(torch.nn.Module):
    def __init__(self, window_size, num_heads, dim, dim_out):
        super().__init__()
        self.window_size = window_size
        self.num_heads = num_heads
        self.dim = dim
        self.dim_out = dim_out
        self.q_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
        self.k_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
        self.v_proj = nn.Conv2d(
            in_channels=dim,
            out_channels=dim_out,
            kernel_size=1,
        )
    def forward(self, x):
        B, HW, C = x.shape
        image_shape = (B, C, self.window_size, self.window_size)
        x_2d = x.permute((0, 2, 1)).reshape(image_shape)  # BCHW
        x_flat = torch.unsqueeze(x.permute((0, 2, 1)), 2)  # BC1L
        q, k, v_2d = self.q_proj(x_flat), self.k_proj(x_flat), self.v_proj(x_2d)
        mh_q = torch.split(q, self.dim_out // self.num_heads, dim=1)  # BC1L
        mh_v = torch.split(
            v_2d.reshape(B, -1, x_flat.shape[2], x_flat.shape[3]), self.dim_out // self.num_heads, dim=1
        )
        mh_k = torch.split(
            torch.permute(k, (0, 3, 2, 1)), self.dim_out // self.num_heads, dim=3
        )
        scale_factor = 1 / math.sqrt(mh_q[0].size(1))
        attn_weights = [
            torch.einsum("bchq, bkhc->bkhq", qi, ki) * scale_factor
            for qi, ki in zip(mh_q, mh_k)
        ]
        attn_weights = [
            torch.softmax(aw, dim=1) for aw in attn_weights
        ]  # softmax applied on channel "C"
        mh_x = [torch.einsum("bkhq,bchk->bchq", wi, vi) for wi, vi in zip(attn_weights, mh_v)]
        x = torch.cat(mh_x, dim=1)
        return x
window_size = 8
path_batch = 1024
emb_dim = 96
emb_dim_out = 96
x = torch.rand(path_batch, window_size * window_size, emb_dim)
qkv_layer = SelfAttn(window_size, 1, emb_dim, emb_dim_out)
jit = torch.jit.trace(qkv_layer, (x))
mlmod_fixed_shape = ct.convert(
    jit,
    inputs=[
        ct.TensorType("x", x.shape),
    ],
    convert_to="mlprogram",
)
mlmodel_path = "test_ane.mlpackage"
mlmod_fixed_shape.save(mlmodel_path)
The uncached load took nearly 36 seconds, and it was just a single matrix multiplication.