Hi, I am trying to translate a Transformer Model (from hugging face) to a coreml model and the only piece I have left is to figure out the shape of the input.
#Here is the code
from transformers import BertModel, BertTokenizer, BertConfig
import torch
import coremltools as ct
import torch
import torchvision
enc = BertTokenizer.from_pretrained("bert-base-uncased")
# Tokenizing input text
text = "[CLS] Who was Jim Henson ? [SEP] Jim Henson was a puppeteer [SEP]"
tokenized_text = enc.tokenize(text)
# Masking one of the input tokens
masked_index = 8
tokenized_text[masked_index] = '[MASK]'
indexed_tokens = enc.convert_tokens_to_ids(tokenized_text)
segments_ids = [0, 0, 0, 0, 0, 0, 0, 1, 1, 1, 1, 1, 1, 1]
# Creating a dummy input
tokens_tensor = torch.tensor([indexed_tokens])
segments_tensors = torch.tensor([segments_ids])
dummy_input = [tokens_tensor, segments_tensors]
# Initializing the model with the torchscript flag
# Flag set to True even though it is not necessary as this model does not have an LM Head.
config = BertConfig(vocab_size_or_config_json_file=32000, hidden_size=768,
num_hidden_layers=12, num_attention_heads=12, intermediate_size=3072, torchscript=True)
# Instantiating the model
model = BertModel(config)
# The model needs to be in evaluation mode
model.eval()
# If you are instantiating the model with `from_pretrained` you can also easily set the TorchScript flag
model = BertModel.from_pretrained("bert-base-uncased", torchscript=True)
# Creating the trace
traced_model = torch.jit.trace(model, [tokens_tensor, segments_tensors])
torch.jit.save(traced_model, "traced_bert.pt")
## Get a pytorch model and save it as a *.pt file
#model = torchvision.models.mobilenet_v2()
#model.eval()
#example_input = torch.rand(1, 3, 224, 224)
#traced_model = torch.jit.trace(model, example_input)
#traced_model.save("torchvision_mobilenet_v2.pt")
#from torchsummary import summary
#summary(traced_model, ) #I think torchsummary can tell me this but i'm not sure.
# Convert the saved PyTorch model to Core ML
breakpoint()
mlmodel = ct.convert("traced_bert.pt", inputs=[???]) #Error is here I don't know how to determine the shape of the input.
(gist of above is at https://gist.github.com/zitterbewegung/589a869f59ae54b32c964c08c2f7bb80)