MLIRコンパイラインフラストラクチャにおけるONNXモデルの表現と参照の低下
GitHubでプロジェクトを見る onnx/onnx-mlir
このプロジェクトは onnx によってメンテナンスされています
GitHub Pagesでホストされています — テーマは orderedlist です
このドキュメントでは、ONNXダイアレクトにおける演算の定数伝播を行うために使用される`--constprop-onnx`パスについて説明します。
以下のコードが与えられた場合
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[1.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%1 = "onnx.Constant"() {value = dense<[2.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%2 = "onnx.Add"(%0, %1) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
%3 = "onnx.Constant"() {value = dense<[3.0]> : tensor<1xf32>} : () -> tensor<1xf32>
%4 = "onnx.Add"(%2, %3) : (tensor<1xf32> , tensor<1xf32>) -> tensor<1xf32>
"std.return"(%4) : (tensor<1xf32>) -> ()
}
`onnx-mlir-op --constprop-onnx`を呼び出すと、以下のようになります。
func @foo() -> tensor<1xf32> {
%0 = "onnx.Constant"() {value = dense<[6.0]> : tensor<1xf32>} : () -> tensor<1xf32>
"std.return"(%0) : (tensor<1xf32>) -> ()
}
ONNXConstantOpは、定数値を格納するためにMLIR DenseElementsAttrを使用します。重要なのは、DenseElementsAttrが作成されると、コンパイル終了まで有効であり、メモリを消費し続けることです。例では、3つのONNXConstantOpにある3つのDenseElementsAttrはすべて、コンパイル終了まで存在します。特に、2つのONNXAddOpを畳み込むことで生成された2つのONNXConstantOpの中間的な2つのDenseElementsAttrも存在します。実際のモデルでは、中間的なDenseElementsAttrの数は急速に増加し、コンパイル中のメモリフットプリントが大きくなります。
`--constprop-onnx`中に中間的なONNXConstantOpのためにあまりにも多くのDenseElementsAttrを作成することを避けるために、中間的なONNXConstantOpのバッファを動的に割り当ておよび解放し、定数伝播とその他のONNXダイアレクトパスが終了した後、Krnl(または他のターゲットダイアレクト)に下げる直前にのみDenseElementsAttrを作成するメカニズムを設計しました。
これは、複雑でないスカラー要素型(bool、整数、浮動小数点型)の一般的なケースでDenseElementsAttrの代替として機能するカスタム属性DisposableElementsAttrを使用して実現されます。DisposableElementsAttrはDenseElementsAttrと同じElementsAttrインターフェースを実装しており、ほとんどの場合、機能的には同じであり、周囲のコードは区別する必要がありません。メモリフットプリントとパフォーマンスの利点を得るには、OnnxElementsAttrBuilderクラスとElementsAttrHelper関数を使用してElementsAttrインスタンスを構築およびアクセスするだけで済みます。
DisposableElementsAttrバッファの解放は、コンパイラパスの間のDisposableGarbageCollectorで行われます。これは、PassManagerによって「モジュール」パス(他のパスが並行して実行されていないことが保証されている「stop the world」パス)の間で「計測」として実行されます。
DisposableElementsAttrは、クラスのソースファイルのコメントに概要が示されており、会議wikiページにリンクされている2022年11月のプレゼンテーションで説明されている、その他のメモリと速度の利点をもたらします。
定数伝播のパターンを書くために、MLIR宣言型書き換えルール(DRR)を使用します。パターンを定義するために使用されるDRR定義は以下に示されています。
class Pattern<
dag sourcePattern,
list<dag> resultPatterns,
list<dag> additionalConstraints = [],
list<dag> supplementalPatterns = [],
dag benefitsAdded = (addBenefit 0)
>;
DRRに関する詳細情報はこちらにあります。
それでは、ONNXAddOpに定数伝播を追加する簡単な例を見ていきましょう。
最初に、ConstProp.tdにパターンを追加します。
// Constant Propagation for Add
def AddConstProp : Pat<
// source patten: From add(lhs, rhs).
(ONNXAddOp:$addOp (ONNXConstantOp:$lhs $_, $_, $_, $_, $_, $_, $_, $_),
(ONNXConstantOp:$rhs $_, $_, $_, $_, $_, $_, $_, $_)),
// result pattern: To c = lhs + rhs
(CreateAddOfTwoConst $addOp, $lhs, $rhs),
// Additional constraints: if both lhs and rhs are dense constants.
[(IsFromDenseONNXConstantOp:$lhs), (IsFromDenseONNXConstantOp:$rhs)]>;
上記のパターンは、入力が定数であるONNXAddOpを、コンパイル時にそれらの入力を加算することで新しい定数に置き換えます。入力が定数であるかどうかを確認するには、ONNXConstantOpを使用するだけでは不十分です。定数テンソルはスパースになる可能性があり、現在はデンスな定数テンソルのみをサポートしているためです。`IsFromDenseONNXConstantOp`を使用して、デンスな定数テンソルをさらにチェックする必要があります。
結果パターンでは、ONNXConstantOpを生成するために、コンパイル時に`lhs`と`rhs`を加算し、ONNXConstantOpを発行します。メモリフットプリントを最小限に抑えるために、このONNXConstantOpは従来のDenseElementsAttrではなくDisposableElementsAttrを持ちます。
関数`CreateAddOfTwoConst`は、コンパイル時に加算を行い、ONNXConstantOpを返します。
def CreateAddOfTwoConst :
NativeCodeCall<"ConstPropElementwiseBinary<mlir::ONNXAddOp>($_builder, $0, $1, $2)">;
パターン内の関数`CreateAddOfTwoConst`は、ConstProp.cpp内の`ConstPropElementwiseBinary`を呼び出します。その内容は次のとおりです。
template <typename ElementwiseBinaryOp>
Value ConstPropElementwiseBinary(PatternRewriter &rewriter,
Value replacingValue, Value lhsValue, Value rhsValue) {
ConstPropCounters::count("ElementwiseBinary", {lhsValue, rhsValue});
Type replacingType = mlir::cast<ShapedType>(replacingValue.getType());
// Get lhs and rhs ElementsAttr from the values' defining constant ops.
ElementsAttr lhs = getConstValueElements(lhsValue);
ElementsAttr rhs = getConstValueElements(rhsValue);
Type operandsElemType = lhs.getElementType();
assert(operandsElemType == rhs.getElementType() &&
"all element-wise binary ops have matching operands element types");
OnnxElementsAttrBuilder elementsBuilder(rewriter.getContext());
ElementsAttr resultElements = elementsBuilder.combine(lhs, rhs, replacingType,
combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType));
// Construct and return a new ONNXConstantOp with the resultElements attribute.
return createReplacingConstantOp(rewriter, replacingValue, resultElements)
.getResult();
}
ここで、`OnnxElementsAttrBuilder.combine(...)`は必要に応じてlhsとrhsの要素をブロードキャストし、要素ごとのバイナリ関数`combinerOfElementwiseBinaryOp<ElementwiseBinaryOp>(operandsElemType)`(ElementwiseBinaryOp ONNX opをC++演算子にマッピングする)の適用結果を要素とする新しい(Disposable)ElementsAttrを構築します。
定数伝播の詳細については、ConstProp.tdとConstProp.cppを参照してください。