Yao Lirong's Blog

Running MobileBert on Android with TensorFlow Lite

2024/09/22
loading

So Google, fxxk you.

Prerequsities

This picture very well explains how TFLite works and also why TensorFlow 2 has both a tf and a keras.

TFLite Workflow

Detours

This section is mostly rant, but it is meaningful in preventing you from taking any of the wrong path. Skip to the next section for a tutorial on what to do.

  1. We first found the Google’s official release http://google-research/mobilebert/, but

    • the tutorial was unclear: Why do I need data_dir and output_dir to export TFLite? How do I even read in the pre-trained weights?
    • the code itself was pretty messy: why did they have export function and training function all at this same file run_squad.py and the only way to tell the program whether to train/export is checking whether export_dir is None rather than passing a flag?

    In figuring out what each part does in this code, I looked up TensorFlow 1’s doc and good lord they were broken. Google doesn’t even host it anywhere: you have to go to a GitHub repo to read them in .md format. At this moment I decided I will not touch anything written by TensorFlow 1’s API. (I actually went through this pain back at my first ML intern in Haier, but not again)

  2. Sidenote before this: I didn’t know you can release model’s on Kaggle (thought everyone releases on Hugging Face) and Google moved their own TensorFlow Hub to Kaggle

    So my supervisor found me a more readable Google release on Kaggle with some high-level API and doesn’t require you to read the painful source code. The above link has a redirect to TensorFlow 2 implementation with an official TFLite release. How neat.

    However, the official TFLite release

    1. doesn’t have signature - TensorFlow’s specification of input and output (remember when you pass inputs to a model you need to give name to them e.g. token_ids = ..., mask = ...) which is required for Xiaomi Service Framework to run a TFLite. P.S. Yes signature is not required to specify when exporting, but for god’s sake all your tutorial teaches people to use it and your own released ditched it? WTF Google.
    2. is broken (as expected?). When I tried to run it on my PC, I got the following error indices_has_only_positive_elements was not true.gather index out of boundsNode number 2 (GATHER) failed to invoke.gather index out of boundsNode number 2 (GATHER) failed to invoke. Someone encountered a similar bug while running the example code provided by TensorFlow and the Google SWE found a bug in their example. At this moment I decided not to trust this TFLite file anymore and just convert it on my own.
  3. So let’s use this official TensorFlow 2 implementation and convert it to TFLite. It was all good and running on my PC, but

    1. Its output format was really weird
      • It output consists of 'mobile_bert_encoder', 'mobile_bert_encoder_1', 'mobile_bert_encoder_2', ..., 'mobile_bert_encoder_51'
      • Each of these has shape (1, 4, 128, 128) for a seq_length = 128, hidden_dim = 512 model. I figured 4 being the number of heads and the other 128 is hidden_dim for each head.
      • They output attention scores, not the final encoded vector: my input was 5 tokens and they output is output[0, 0, 0, :] = array([0.198, 0.138, 0.244, 0.148, 0.270, 0. , 0. , .... They sum to 1 and any other positions at output are 0 , so attention score was my best guess.
    2. It doesn’t run on Android phone: tflite engine load failed due to java.lang.IllegalArgumentException: Internal error: Cannot create interpreter: Op builtin_code out of range: 153. Are you using old TFLite binary with newer model? A Stack Overflow answer suggests the TensorFlow used to export TFLite running on my PC doesn’t match the version of TFLite run time on this Android phone. It can also be caused by me messing up with the whole environment while installing Optimum to export TFLite last night, but I didn’t bother to look because I finally found the solution
  4. And comes the savior, the king, the go-to solution in MLOps - Huggingface. Reminded by a discussion I read by chance, I came to realize TFMobileBertModel.from_pretrained actually returns the Keras model (and the without TF version returns a PyTorch model). That means I can just use Hugging Face API to read it in, then use the native TensorFlow 2 API to export to TFLite. And everything works like a charm now. The final output signature is just Hugging Face’s familiar ['last_hidden_state', 'pooler_output']

Converting TensorFlow Model to TFLite

Conversion is pretty straight forward. You can just follow this official guide: For Mobile & Edge: Convert TensorFlow models. Though I actually followed my predecessor’s note (which actually comes from another TF tutorial). He also told me to caution that calling tf.disable_eager_execution() can lead to absence of signature, so do not call tf.disable_eager_execution() to disable eager mode.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
from transformers import MobileBertTokenizerFast, TFMobileBertModel

# Convert Model
if be_sane:
bert_model = TFMobileBertModel.from_pretrained(kerasH5_model_path) if keras_file else \
TFMobileBertModel.from_pretrained(pytorch_model_path, from_pt = True)
converter = tf.lite.TFLiteConverter.from_keras_model(bert_model)
else: # be crazy or already knows the messy TensorFlow.SavedModel format
converter = tf.lite.TFLiteConverter.from_saved_model(model_path)
tflite_model = converter.convert()

# Save Model
tflite_output_path = '/model.tflite'
with open(tflite_output_path, 'wb') as f:
f.write(tflite_model)

# Check Signature
# Empty signature means error in the export process and the file cannot be used by Xiaomi Service Framework
interpreter = tf.lite.Interpreter(model_path=tflite_output_path)
interpreter = tf.lite.Interpreter(model_content=tflite_model)
interpreter.allocate_tensors()
signatures = interpreter.get_signature_list()
print("tflite model signatures:", signatures)
1
2
3
4
{'serving_default': {'inputs': ['attention_mask',
'input_ids',
'token_type_ids'],
'outputs': ['last_hidden_state', 'pooler_output']}}

In addition, summarizing from the detours I took,

  • Do not use Hugging Face’s Optimum for (at least vanilla) conversion because it just calls the above command (see code)
  • Do not even bother to look at Google’s original code converting MobileBert to TFLite because nobody knows what they’re writing.

Running TFLite (on PC)

Running TFLite on Android phone is the other department’s task. I just want to run the TFLite file on PC to test everything’s good. To do that, I strictly followed TensorFlow’s official guide: TensorFlow Lite inference: Load and run a model in Python.Our converted models have the signatures, you can just follow the “with a defined SignatureDef” guide.

1
2
3
4
5
6
7
8
tokenizer = MobileBertTokenizerFast(f"{model_path}/vocab.txt")
t_output = tokenizer("越过长城,走向世界", return_tensors="tf")
ii, tt, am = t_output['input_ids'], t_output['token_type_ids'], t_output['attention_mask']
# `get_signature_runner()` with empty input gives the "serving_default" runner
# `runner` input parameter is specified by `serving_default['inputs']`
runner = interpreter.get_signature_runner()
output = runner(input_ids = ii, token_type_ids = tt, attention_mask = am)
assert output.keys == ['last_hidden_state', 'pooler_output']

On the other hand, for a model without signatures, you need to use the more primitive API input_details and output_details. They specify the following properties, where index is (probably) the index of this tensor in the compute graph. To pass input values and get output values, you need to access them by this index.

1
2
3
4
5
6
7
8
interpreter.allocate_tensors()
input_details = interpreter.get_input_details()
output_details = interpreter.get_output_details()

interpreter.set_tensor(input_details[0]['index'], input_data)
interpreter.invoke()
output_data = interpreter.get_tensor(output_details[0]['index'])
print(output_data)

The following is the input_details of the non-signature Google packed MobileBert.

1
2
3
4
5
6
7
8
9
10
11
12
interpreter.get_input_details()
[{'name': 'model_attention_mask:0',
'index': 0,
'shape': array([ 1, 512], dtype=int32),
'shape_signature': array([ 1, 512], dtype=int32),
'dtype': numpy.int64,
'quantization': (0.0, 0),
'quantization_parameters': {'scales': array([], dtype=float32),
'zero_points': array([], dtype=int32),
'quantized_dimension': 0},
'sparsity_parameters': {}},
{...}]

Numerical Accuracy

Our original torch/TensorFlow encoder and the converted TFLite encoder, when both running on PC using Python, has a 1.2% difference in their output (last_hidden_state or pooled_output). We do not know where this discrepancy comes from.

Converting Tokenizer to TFLite

We exported and ran the encoder, but that’s not enough. We can’t ask the user to type in token_ids every time. Therefore, we need to integrate the preprocessor (tokenizer) into our TFLite file. To do that, we first tried integrating Google’s official Keras tokenizer implementation into our BERT model and convert them together into a TFLite (yeah I didn’t learn the lesson). This failed in the converting step for reasons that would become clear later. And we switched gears to follow some other guide and first try to convert a standalone tokenizer to TFLite.

Tokenizer is a part of the TensorFlow Text library. I followed the official guide: Converting TensorFlow Text operators to TensorFlow Lite with text.FastBertTokenizer. Note when you follow it, do it carefully and closely. I encountered a few problems along the way:

  1. When you change the text.WhitespaceTokenizer in guide to our text.FastBertTokenizer, remember to specify a text.FastBertTokenizer(vocab=vocab_lst). We need not the path to the vocab but the actual list e.g. [ "[PAD]", "[unused0]", "[unused1]", ...] describes the vocab where [PAD] maps to token id 0, [unused0] to token id 1, and so on.

  2. text.FastBertTokenizer (or the standard version) does not add [CLS] token for you. Google says this is to make sure “you are able to manipulate the tokens and determine how to construct your segments separately” (GitHub issue). How considerate you are, dear Google. I spent one and a half day figuring out how to add these tokens when the model’s input length needs to be fixed, otherwise it triggers TensorFlow’s compute graph to throw “can’t get variable-length input” error. I finally found a solution in Google’s mediapipe’s implementation.

  3. Could not translate MLIR to FlatBuffer when running tflite_model = converter.convert(): as mentioned, you must follow the guide very carefully. The guide specifies a TensorFlow Text version. If not this version, the conversion would fail

    1
    pip install -U "tensorflow-text==2.11.*"
  4. Encountered unresolved custom op: FastBertNormalize when running converted interpreter / signature: as stated in the Inference section of the guide, tokenizers are custom operations and need to be specified when running inference. (I can’t find doc for InterpreterWithCustomOps anywhere but it does have an argument model_path)

    1
    2
    3
    interp = interpreter.InterpreterWithCustomOps(
    model_content=tflite_model,# or model_path=TFLITE_FILE_PATH
    custom_op_registerers=tf_text.tflite_registrar.SELECT_TFTEXT_OPS)
  5. TensorFlow Text custom ops are not found on Android: the above inference guide writes

    while the example below shows inference in Python, the steps are similar in other languages with some minor API translations

    which is a total lie. Android does not support these operations as the custom text op list only mentions python support.

At the end, I did manage to 1 merge the above tokenizer and HuggingFace model, 2 export a TFLite model that reads in a text and outputs the last hidden state. However, I seem to have lost that piece of the code. Don’t worry though. Because thanks to Google’s shitty framework, it only works with very few tokenizer implementations anyway. The work-for-all solution is to build your own tokenizer in Java.

P.S. While debugging the FlatBuffer error, I came across the TensorFlow authoring tool that can explicitly specify a function’s input output format and detect op unsupported by TFLite. However, the tools is pretty broken for me. Debugging this tool would probably take longer than finding the problem yourself online / ask on a forum.

Writing Your Own Tokenizer

What’s weird is TensorFlow does have an official BERT on Android example. Reading it again, I found their tokenizer is actually implemented by C++ (see this example). The repo containing the tokenizer code is called tflite-support. Finding this library’s doc, it becomes clear that the text-related operations are currently not supported.

TFLite-Support Current use-case coverage

Google seems to have used JNI to call the C++ implementation of tokenizer (see code).

Therefore, we’d better write our own tokenizer. Luckily Hugging Face also has a Bert on Android example - tflite-android-transformers and writes more accessible code. We directly copied their tokenizer implementation.

However, when switching to Chinese vocabulary, the tokenizer goes glitchy. See the following example where we tokenize the sentence「越过长城 ,走向世界」

1
2
3
4
5
6
7
8
9
10
11
# Our Java tokenizer gives the following tokens, which detokenizes to the following string
tokenizer.decode([101, 6632, 19871, 20327, 14871, 8024, 6624, 14460, 13743, 17575, 102])
'[CLS] 越过长城 , 走向世界 [SEP]'

# On the other hand, official Hugging Face python BertTokenizer gives
tokenizer.decode([101, 6632, 6814, 7270, 1814, 8024, 6624, 1403, 686, 4518, 102])
'[CLS] 越 过 长 城 , 走 向 世 界 [SEP]'

# Inspecting the first difference, our Java tokenizer seems to have used sentencepiece
tokenizer.decode([19871])
'##过'

It turns out BERT in its original implementation (code) does not use sentence-piece tokenizer on Chinese characters. Instead, it uses character level tokenizer. Therefore, we need to first insert a whitespace to every character to ensure sentence-piece isn’t applied. Note Hugging Face tokenizer follows BERT original python code very closely so you can easily find where to insert that piece of code.

  • Bert original implementation in Python, with Chinese logic

    1
    2
    3
    4
    5
    6
    7
    def tokenize(self, text):
    """Tokenizes a piece of text."""
    text = convert_to_unicode(text)
    text = self._clean_text(text)
    # Chinese Logic
    text = self._tokenize_chinese_chars(text)
    orig_tokens = whitespace_tokenize(text)
  • Hugging Face tokenizer in Java, without Chinese logic

    1
    2
    3
    4
    5
    public final class BasicTokenizer {
    public List<String> tokenize(String text) {
    String cleanedText = cleanText(text);
    // Insert Here
    List<String> origTokens = whitespaceTokenize(cleanedText);

Building a Classifier

The final task is actually to build a classifier of 28 online store commodity classes. As I mentioned in the Detours section, I do not know and don’t wanna bother to know how to define or change a signature. Therefore, I again turn to Hugging Face for its MobileBertForSequenceClassification.

The default classification head only has 1 layer, I changed its structure to give it more expressive power.

1
2
3
4
5
6
7
8
9
10
model = MobileBertForSequenceClassification.from_pretrained(
model_path, num_labels=len(labels), problem_type="multi_label_classification",
id2label=id2label, label2id=label2id)
model.classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(768, 1024)),
('relu1', nn.LeakyReLU()),
('fc2', nn.Linear(1024, num_labels))
]))
# Fine-tune ...
torch.save(model.state_dict(), model_path)

However, this throws error when you try to read such a fine-tuned model back in. MobileBertForSequenceClassification is set to have one-layer classification head, so it cannot read in your self-defined classifier’s weights.

1
2
3
4
5
torch_model = CustomMobileBertForSequenceClassification.from_pretrained(
model_path, problem_type="multi_label_classification",
num_labels=len(labels), id2label=id2label, label2id=label2id)

> Some weights of MobileBertForSequenceClassification were not initialized from the model checkpoint at ./ckpts/ and are newly initialized: ['classifier.bias', 'classifier.weight']

To solve this, you can

  1. Save encoder weight and classifier weight separately, then load them separately
  2. Create a custom class corresponding to your weights and initialize an instance of that class instead

2 is clearly the more sensible way. You should read the very clearly written MobileBertForSequenceClassification to understand what exactly needs to be changed. It turns out all we have to do is to extend the original class and change its __init__ part, so it has a 2-layer classification head.

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
from transformers import MobileBertForSequenceClassification, TFMobileBertForSequenceClassification

class CustomMobileBertForSequenceClassification(MobileBertForSequenceClassification):
def __init__(self, config):
super().__init__(config)
self.classifier = nn.Sequential(OrderedDict([
('fc1', nn.Linear(768, 1024)),
('relu1', nn.LeakyReLU()),
('fc2', nn.Linear(1024, 28))
]))
self.post_init()

class TFCustomMobileBertForSequenceClassification(TFMobileBertForSequenceClassification):
def __init__(self, config, *inputs, **kwargs):
super().__init__(config, *inputs, **kwargs)
self.classifier = keras.Sequential([
keras.layers.Dense(1024, input_dim=768, name='fc1'),
keras.layers.LeakyReLU(alpha=0.01, name = 'relu1'), # Keras defaults alpha to 0.3
keras.layers.Dense(28, name='fc2')
])

torch_model = CustomMobileBertForSequenceClassification.from_pretrained(
model_path, problem_type="multi_label_classification",
num_labels=len(labels), id2label=id2label, label2id=label2id)
tf_model = TFCustomMobileBertForSequenceClassification.from_pretrained(
..., from_pt=True)

However, you may find these two models output different values on the same input. A closer look at weights unveil that Hugging Face didn’t convert classifier’s weights from our Torch model to TensorFlow model correctly. We have to set them manually instead.

1
2
tf_model.classifier.get_layer("fc1").set_weights([torch_model.classifier.fc1.weight.transpose(1, 0).detach(), torch_model.classifier.fc1.bias.detach()])
tf_model.classifier.get_layer("fc2").set_weights([torch_model.classifier.fc2.weight.transpose(1, 0).detach(), torch_model.classifier.fc2.bias.detach()])

And now we are finally ready to go.

Quantization

I followed this official doc: Post-training quantization. Because of time limit, I didn’t try Quantization Aware Training (QAT).

1
2
3
4
5
6
7
8
9
10
11
vanilla_converter = tf.lite.TFLiteConverter.from_keras_model(bert_model)
tflite_model = vanilla_converter.convert()

quant8_converter = tf.lite.TFLiteConverter.from_keras_model(bert_model)
quant8_converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quant8_model = quant8_converter.convert()

quant16_converter = tf.lite.TFLiteConverter.from_keras_model(bert_model)
quant16_converter.optimizations = [tf.lite.Optimize.DEFAULT]
quant16_converter.target_spec.supported_types = [tf.float16]
tflite_quant16_model = quant16_converter.convert()

Below I report several key metrics for this Chinese-MobileBERT + a 2-layer classification head of [768*1024, 1024*class_num]. This was tested on a Xiaomi 12X with snapdragon 870. The baseline model is my colleague’s BERT-Large implementation with accuracy 88.50% and size 1230MB. My model’s accuracy was bad at first: 75.01% with hyper-parameter weight_decay = 0.01, learning_rate = 1e-4, but we searched out a good hyper-parameter of weight_decay = 2e-4,learning_rate = 2e-5 giving 86.01%. We had 28 classes, 38000 training data in total, and trained for 5 epochs where the validation accuracy roughly flattens.

Quantization Logit Difference Accuracy Accuracy (after hyper-param search) Model Size (MB) Inference Time(ms) Power Usage(ma) CPU(%) Memory(MB)
float32 (No quant) 0 75.01% 86.094% 101.4 1003.3 89.98 108.02 267.11
float16 0.015% 75.01% 86.073% 51 838 64.15 108.77 377.11
int8 4.251% 63.49% 85.947% 25.9 573.8 60.09 110.83 233.19

If look at the not fine-tuned, vanilla transformer encoder only, the last_hidden_state has a difference:

Quantization Logit Difference Model Size (MB)
float32 (No quant) 0 97
float16 0.1% 48.1
int8 19.8% 24.9

Small Language Models

BERT is the go-to option for classification task. But when it comes to small BERT, we had several options:

  • mobileBERT

  • distilledBERT

  • tinyBERT

As the post is about, we used mobileBERT at last because it’s by Google Brain and Google probably knows their thing best.

On the other hand, if you’re looking for small generative model, which people mostly call SLM (Small Language Model) as opposed to LLM, I found these options but didn’t try them myself.

  • openELM: Apple, 1.1B
  • Phi-2: Microsoft, 2.7B

Post Script

If you want to build an app utilizing edge transformer, I would recommend to read the source code of Hugging Face’s toy app. It doesn’t have a README or tutorial, nor have I gone through it personally, but everything from TensorFlow sucks (including MediaPipe unfortunately)

When checking back on this tutorial at date 2024/12/28, I found Google released AI Edge Torch, the official tool converting PyTorch models into a .tflite format. So you may probably want to try this first, but again, don’t trust anything from TensorFlow team.

CATALOG
  1. 1. Prerequsities
  2. 2. Detours
  3. 3. Converting TensorFlow Model to TFLite
  4. 4. Running TFLite (on PC)
  5. 5. Numerical Accuracy
  6. 6. Converting Tokenizer to TFLite
  7. 7. Writing Your Own Tokenizer
  8. 8. Building a Classifier
  9. 9. Quantization
  10. 10. Small Language Models
  11. 11. Post Script