ONNXRuntime生态
ONNXRuntime

provider

注册算子的几种方法

ONNX_OPERATOR_KERNEL_EX

算子按开发模式分类

Python运行

import onnxruntime as ort

# Load custom op library
so = ort.SessionOptions()
so.register_custom_ops_library('my_custom_op.so')

# Create session
session = ort.InferenceSession('model.onnx', so)

C++运行

onnxruntime代码库内自带的算子融合

看看 EmbedLayerNormalization 是如何融合成的?

enable_embed_layer_norm -> fuse_embed_layer -> FusionEmbedLayerNormalization

BERT Embedding Layer Pattern:
                            (input_ids)
                           /         \
                          /          Shape
                         /              |
                        /              Gather (indices=1)
                       /                  |
                      /                 Add (optional, B=0)
                     /                    |
                Gather   segment_ids  Unsqueeze (axes=0)
                    \        |           |
                     \     Gather      Slice (data[1,512], starts=0, ends=*, axes=1, steps=1)
                      \    /            |
                        Add          Gather
                         \        /
                             Add
                              |
                        LayerNormalization

输入: input_ids, segment_ids (也就是 模型的输入token_type_ids),
word_embedding, (内置向量词表) position_embedding, (内置向量词表) segment_embedding, (内置向量词表) gamma, (模型 Mul算子里的 LayerNorm.weight ) beta, (模型 Add算子里的 LayerNorm.bias) mask (也就是 模型的输入att_mask), position_ids (input_ids经过位置编码过程得到的)

mask是怎么使用的? mask -> mask_index 输出将来给attention使用

UT: embed_layer_norm_op_test.cc

看看 SkipLayerNormalization 是如何融合成的?

输入: input skip (前一个SkipLayerNormalization的输出) gamma (也就是 attention.output.LayerNorm.weight) beta (也就是 attention.output.LayerNorm.bias) bias (也就是 attention.output.dense.bias) 先做Add

skip 是怎么使用的? 先把input和skip做Add

    input                      skip
      \                        /
     Add(dense.bias)        Add
          \                 /
        LayerNormalization(gamma, beta)

UT: skiplayernorm_op_test.cc

onxxruntime 线程模型

NNXRuntime的线程池接口在Eigen线程池接口基础之上扩展而来(题外话:TensorFlow中的线程池同样是建立在Eigen线程池基础上)

*****
Written by Lu.dev on 01 September 2024