概览
LiteRT 支持将 TensorFlow RNN 模型转换为 LiteRT 的融合 LSTM 操作。融合运算旨在最大限度地提高其底层内核实现的性能,并提供更高级别的接口来定义量化等复杂转换。
由于 TensorFlow 中有许多 RNN API 变体,因此我们的方法分为两部分:
- 提供对标准 TensorFlow RNN API(如 Keras LSTM)的原生支持。 此为推荐选项。
- 提供一个接口,用于将用户定义的 RNN 实现插入到转化基础架构中并将其转化为 LiteRT。我们提供了几个使用 Lingvo 的 LSTMCellSimple 和 LayerNormalizedLSTMCellSimple RNN 接口进行此类转换的现成示例。
转换器 API
该功能是 TensorFlow 2.3 版本的一部分。您也可以通过 tf-nightly pip 或从主分支获取。
通过 SavedModel 或直接从 Keras 模型转换为 LiteRT 时,可以使用此转换功能。请参阅使用示例。
根据已保存的模型
# build a saved model. Here concrete_function is the exported function
# corresponding to the TensorFlow model containing one or more
# Keras LSTM layers.
saved_model, saved_model_dir = build_saved_model_lstm(...)
saved_model.save(saved_model_dir, save_format="tf", signatures=concrete_func)
# Convert the model.
converter = TFLiteConverter.from_saved_model(saved_model_dir)
tflite_model = converter.convert()
从 Keras 模型
# build a Keras model
keras_model = build_keras_lstm(...)
# Convert the model.
converter = TFLiteConverter.from_keras_model(keras_model)
tflite_model = converter.convert()
示例
Keras LSTM 到 LiteRT Colab 演示了 LiteRT 解释器的端到端用法。
支持的 TensorFlow RNN API
Keras LSTM 转换(推荐)
我们支持将 Keras LSTM 直接转换为 LiteRT。如需详细了解此功能的运作方式,请参阅 Keras LSTM 接口 以及此处的转换逻辑。
同样重要的是,要根据 Keras 操作定义突出显示 LiteRT 的 LSTM 合约:
- 输入张量的维度 0 是批次大小。
- recurrent_weight 张量的维度 0 是输出的数量。
- 转置了 weight 和 recurrent_kernel 张量。
- 转置后的权重、转置后的 recurrent_kernel 和 偏差张量沿维度 0 分割为 4 个大小相等的张量。它们分别对应于输入门、遗忘门、细胞和输出门。
Keras LSTM 变体
时间主要
用户可以选择时间优先或不选择时间优先。Keras LSTM 在函数定义属性中添加了 time-major 属性。对于单向序列 LSTM,我们可以简单地映射到 unidirecional_sequence_lstm 的 time major 属性。
双向 LSTM
双向 LSTM 可以通过两个 Keras LSTM 层(一个用于前向,一个用于后向)来实现,请参阅此处的示例。 一旦看到 go_backward 属性,我们就将其识别为反向 LSTM,然后将正向和反向 LSTM 分组在一起。这是未来的工作。目前,这会在 LiteRT 模型中创建两个 UnidirectionalSequenceLSTM 操作。
用户定义的 LSTM 转化示例
LiteRT 还提供了一种转换用户定义的 LSTM 实现的方法。下面以 Lingvo 的 LSTM 为例,说明如何实现这一点。如需了解详情,请参阅 lingvo.LSTMCellSimple 接口和此处的转换逻辑。 我们还在此处提供了 Lingvo 的另一个 LSTM 定义(即 lingvo.LayerNormalizedLSTMCellSimple 接口)及其转换逻辑的示例。
将“自带的 TensorFlow RNN”引入 LiteRT
如果用户的 RNN 接口与标准支持的接口不同,则有以下几种选择:
选项 1:在 TensorFlow Python 中编写适配器代码,以将 RNN 接口适配到 Keras RNN 接口。这意味着,在生成的 RNN 接口的函数上,具有 tf_implements 注释的 tf.function 与由 Keras LSTM 层生成的函数相同。之后,用于 Keras LSTM 的同一转化 API 将正常运行。
方法 2:如果上述方法不可行(例如,Keras LSTM 缺少 LiteRT 的融合 LSTM op 目前公开的某些功能,如层归一化),则通过编写自定义转换代码来扩展 LiteRT 转换器,并将其插入到 prepare-composite-functions MLIR-pass 中(此处)。 该函数的接口应视为 API 合约,并且应包含转换为融合的 LiteRT LSTM 操作所需的实参,即输入、偏差、权重、投影、层归一化等。最好将具有已知秩(即 MLIR 中的 RankedTensorType)的张量作为实参传递给此函数。这样一来,编写可将这些张量视为 RankedTensorType 的转换代码就容易多了,而且有助于将这些张量转换为与融合的 LiteRT 运算符的运算对象对应的秩张量。
此类转换流程的完整示例是 Lingvo 的 LSTMCellSimple 到 LiteRT 转换。
Lingvo 中的 LSTMCellSimple 定义在此处。 使用此 LSTM 单元训练的模型可以按如下方式转换为 LiteRT:
- 将 LSTMCellSimple 的所有用法封装在带有 tf_implements 注解的 tf.function 中,并标记为这样(例如,lingvo.LSTMCellSimple 在这里会是一个不错的注解名称)。确保生成的 tf.function 与转换代码中预期的函数接口相匹配。这是添加注释的模型作者与转化代码之间的合同。
扩展 prepare-composite-functions 传递,以插入自定义复合操作,从而实现 LiteRT 融合 LSTM 操作转换。请参阅 LSTMCellSimple 转换代码。
转化合同:
权重和投影张量已转置。
通过对转置后的权重张量进行切片,提取 {cell, input gate, forget gate, output gate} 的 {input, recurrent}。
通过对偏差张量进行切片,提取 {偏差}到 {单元、输入门、遗忘门、输出门}。
通过对转置后的投影张量进行切片来提取投影。
为 LayerNormalizedLSTMCellSimple 写入了类似的转换。
LiteRT 转换基础架构的其余部分(包括定义的所有 MLIR pass 以及最终导出到 LiteRT flatbuffer)都可以重复使用。
已知问题/限制
- 目前仅支持转换无状态 Keras LSTM(Keras 中的默认行为)。有状态 Keras LSTM 转换是未来的工作。
- 不过,您仍然可以使用底层无状态 Keras LSTM 层对有状态 Keras LSTM 层进行建模,并在用户程序中显式管理状态。使用此处描述的功能,仍可将此类 TensorFlow 程序转换为 LiteRT。
- 在 LiteRT 中,双向 LSTM 目前建模为两个 UnidirectionalSequenceLSTM 操作。此 op 将替换为单个 BidirectionalSequenceLSTM op。