diff --git a/CHANGELOGS.rst b/CHANGELOGS.rst index dad0930..d0b6445 100644 --- a/CHANGELOGS.rst +++ b/CHANGELOGS.rst @@ -4,6 +4,7 @@ Change Logs 0.2.0 +++++ +* :pr:`71`: adds tools to compare two onnx graphs * :pr:`61`: adds function to plot onnx model as graphs * :pr:`60`: supports translation of local functions * :pr:`59`: add methods to update nodes in GraphAPI diff --git a/_doc/api/reference.rst b/_doc/api/reference.rst index acbf90a..3b4ae7d 100644 --- a/_doc/api/reference.rst +++ b/_doc/api/reference.rst @@ -5,3 +5,33 @@ ExtendedReferenceEvaluator ++++++++++++++++++++++++++ .. autoclass:: onnx_array_api.reference.ExtendedReferenceEvaluator + :members: + +ResultType +++++++++++ + +.. autoclass:: onnx_array_api.reference.ResultType + :members: + +ResultExecution ++++++++++++++++ + +.. autoclass:: onnx_array_api.reference.ResultExecution + :members: + +YieldEvaluator +++++++++++++++ + +.. autoclass:: onnx_array_api.reference.YieldEvaluator + :members: + +DistanceExecution ++++++++++++++++++ + +.. autoclass:: onnx_array_api.reference.DistanceExecution + :members: + +compare_onnx_execution +++++++++++++++++++++++ + +.. autofunction:: onnx_array_api.reference.compare_onnx_execution diff --git a/_doc/command_lines.rst b/_doc/command_lines.rst new file mode 100644 index 0000000..38ca5f2 --- /dev/null +++ b/_doc/command_lines.rst @@ -0,0 +1,52 @@ +============= +command lines +============= + +compare +======= + +The function convers an onnx file into some code. + +:: + + python -m compare -m1 model1.onnx -m2 model2.onnx -v 1 + +Output example:: + + [compare_onnx_execution] got 2 inputs + [compare_onnx_execution] execute first model + [compare_onnx_execution] got 5 results + [compare_onnx_execution] execute second model + [compare_onnx_execution] got 5 results + [compare_onnx_execution] compute edit distance + [compare_onnx_execution] got 4 pairs + [compare_onnx_execution] done + = | INPUT float32 5x6 AAAA X | INPUT float32 5x6 AAAA X + = | INPUT float32 5x6 AAAA Y | INPUT float32 5x6 AAAA Y + = | RESULT float32 5x6 AABB Add res | RESULT float32 5x6 AABB Add res + = | RESULT float32 5x6 AAAA Cos Z | RESULT float32 5x6 AAAA Cos Z + +.. runpython:: + + from onnx_array_api._command_lines_parser import get_parser_compare + get_parser_compare().print_help() + +See function :func:`onnx_array_api.reference.compare_onnx_execution`. + +translate +========= + +The function convers an onnx file into some code. + +:: + + python -m translate ... + +Output example:: + + not yet ready + +.. runpython:: + + from onnx_array_api._command_lines_parser import get_parser_translate + get_parser_translate().print_help() diff --git a/_doc/examples/plot_onnx_diff.py b/_doc/examples/plot_onnx_diff.py new file mode 100644 index 0000000..7a5f1d3 --- /dev/null +++ b/_doc/examples/plot_onnx_diff.py @@ -0,0 +1,68 @@ +""" + +.. _l-onnx-diff-example: + +Compares the conversions of the same model with different options +================================================================= + +The script compares two onnx models obtained with the same trained +scikit-learn models but converted with different options. + +A model ++++++++ +""" + +from sklearn.mixture import GaussianMixture +from sklearn.datasets import load_iris +from sklearn.model_selection import train_test_split +from skl2onnx import to_onnx +from onnx_array_api.reference import compare_onnx_execution +from onnx_array_api.plotting.text_plot import onnx_simple_text_plot + + +data = load_iris() +X_train, X_test = train_test_split(data.data) +model = GaussianMixture() +model.fit(X_train) + +################################# +# Conversion to onnx +# ++++++++++++++++++ + +onx = to_onnx( + model, X_train[:1], options={id(model): {"score_samples": True}}, target_opset=12 +) + +print(onnx_simple_text_plot(onx)) + +################################## +# Conversion to onnx without ReduceLogSumExp +# ++++++++++++++++++++++++++++++++++++++++++ + +onx2 = to_onnx( + model, + X_train[:1], + options={id(model): {"score_samples": True}}, + black_op={"ReduceLogSumExp"}, + target_opset=12, +) + +print(onnx_simple_text_plot(onx2)) + + +############################################# +# Differences +# +++++++++++ +# +# Function :func:`onnx_array_api.reference.compare_onnx_execution` +# compares the intermediate results of two onnx models. Then it finds +# the best alignmet between the two models using an edit distance. + +res1, res2, align, dc = compare_onnx_execution(onx, onx2, verbose=1) +print("------------") +text = dc.to_str(res1, res2, align) +print(text) + +############################### +# The display shows that ReduceSumSquare was replaced by Mul + ReduceSum, +# and ReduceLogSumExp by ReduceMax + Sub + Exp + Log + Add. diff --git a/_doc/index.rst b/_doc/index.rst index 02c4eed..b81be4f 100644 --- a/_doc/index.rst +++ b/_doc/index.rst @@ -36,6 +36,7 @@ The objective is to speed up the implementation of converter libraries. tutorial/index api/index tech/index + command_lines auto_examples/index .. toctree:: diff --git a/_doc/tutorial/index.rst b/_doc/tutorial/index.rst index f4cce00..9fcc557 100644 --- a/_doc/tutorial/index.rst +++ b/_doc/tutorial/index.rst @@ -10,4 +10,5 @@ Tutorial graph_api light_api numpy_api + tools benchmarks diff --git a/_doc/tutorial/tools.rst b/_doc/tutorial/tools.rst new file mode 100644 index 0000000..fe673f7 --- /dev/null +++ b/_doc/tutorial/tools.rst @@ -0,0 +1,20 @@ +===== +Tools +===== + +Some of useful tools. + +Text representation +=================== + +Plotting a graph is great but difficult to read when +the graph is big and it is slow. +:func:`onnx_array_api.plotting.text_plot.onnx_simple_text_plot` +prints out a text representation. + +Differences between two models +============================== + +How to understand the differences between two models +assuming they are producing the same outputs? +Example :ref:`l-onnx-diff-example` shows how to do it. diff --git a/_unittests/ut_reference/test_array_tensor.py b/_unittests/ut_reference/test_array_tensor.py index 59fe5f1..f13c3e5 100644 --- a/_unittests/ut_reference/test_array_tensor.py +++ b/_unittests/ut_reference/test_array_tensor.py @@ -1,7 +1,13 @@ import unittest import numpy as np from onnx import TensorProto -from onnx.helper import make_graph, make_model, make_node, make_tensor_value_info +from onnx.helper import ( + make_graph, + make_model, + make_node, + make_tensor_value_info, + make_opsetid, +) from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api.reference import ( to_array_extended, @@ -51,6 +57,24 @@ def make_model_f8(fr, to): back = from_array_extended(got, "a") self.assertEqual(to, back.data_type) + def test_fused_matmul(self): + model = make_model( + make_graph( + [make_node("FusedMatMul", ["X", "Y"], ["Z"], domain="com.microsoft")], + "name", + [ + make_tensor_value_info("X", TensorProto.FLOAT, None), + make_tensor_value_info("Y", TensorProto.FLOAT, None), + ], + [make_tensor_value_info("Z", TensorProto.FLOAT, None)], + ), + opset_imports=[make_opsetid("", 18), make_opsetid("com.microsoft", 1)], + ) + ref = ExtendedReferenceEvaluator(model) + a = np.arange(4).reshape(-1, 2) + got = ref.run(None, {"X": a, "Y": a}) + self.assertEqualArray(a @ a, got[0]) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/_unittests/ut_reference/test_evaluator_yield.py b/_unittests/ut_reference/test_evaluator_yield.py new file mode 100644 index 0000000..7181456 --- /dev/null +++ b/_unittests/ut_reference/test_evaluator_yield.py @@ -0,0 +1,464 @@ +import unittest +import numpy as np +from onnx import TensorProto +from onnx.helper import ( + make_function, + make_graph, + make_model, + make_node, + make_opsetid, + make_tensor_value_info, +) +from onnx.parser import parse_model +from onnx_array_api.ext_test_case import ExtTestCase +from onnx_array_api.reference import ( + YieldEvaluator, + ResultType, + DistanceExecution, + ResultExecution, + compare_onnx_execution, +) +from onnx_array_api.reference.evaluator_yield import make_summary + + +class TestArrayTensor(ExtTestCase): + def test_make_summary(self): + a = np.arange(12).reshape(3, 4) + v = make_summary(a) + self.assertEqual(v, "DMVE") + a = np.arange(12) + v = make_summary(a) + self.assertEqual(v, "DMVE") + a = np.arange(12).astype(np.float32) + v = make_summary(a) + self.assertEqual(v, "DMVE") + a = np.arange(13) + a[-1] = 0 + v = make_summary(a) + self.assertEqual(v, "GWMA") + + def test_evaluator_yield(self): + new_domain = "custom_domain" + opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)] + + node1 = make_node("MatMul", ["X", "A"], ["XA"]) + node2 = make_node("Add", ["XA", "B"], ["Y"]) + + linear_regression = make_function( + new_domain, + "LinearRegression", + ["X", "A", "B"], + ["Y"], + [node1, node2], + opset_imports, + [], + ) + + X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) + B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, None) + + graph = make_graph( + [ + make_node( + "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain + ), + make_node("Abs", ["Y1"], ["Y"]), + ], + "example", + [X, A, B], + [Y], + ) + + onnx_model = make_model( + graph, opset_imports=opset_imports, functions=[linear_regression] + ) + + cst = np.arange(4).reshape((-1, 2)).astype(np.float32) + yield_eval = YieldEvaluator(onnx_model) + results = list( + yield_eval.enumerate_results(None, {"A": cst, "B": cst, "X": cst}) + ) + expected = [ + ( + ResultType.INPUT, + "A", + np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32), + None, + ), + ( + ResultType.INPUT, + "B", + np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32), + None, + ), + ( + ResultType.INPUT, + "X", + np.array([[0.0, 1.0], [2.0, 3.0]], dtype=np.float32), + None, + ), + ( + ResultType.RESULT, + "Y1", + np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32), + "LinearRegression", + ), + ( + ResultType.RESULT, + "Y", + np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32), + "Abs", + ), + ( + ResultType.OUTPUT, + "Y", + np.array([[2.0, 4.0], [8.0, 14.0]], dtype=np.float32), + None, + ), + ] + self.assertEqual(len(expected), len(results)) + for a, b in zip(expected, results): + self.assertEqual(len(a), len(b)) + self.assertEqual(a[0], b[0]) + self.assertEqual(a[1], b[1]) + self.assertEqual(a[2].tolist(), b[2].tolist()) + self.assertEqual(a[3], b[3]) + + def test_evaluator_yield_summary(self): + new_domain = "custom_domain" + opset_imports = [make_opsetid("", 14), make_opsetid(new_domain, 1)] + + node1 = make_node("MatMul", ["X", "A"], ["XA"]) + node2 = make_node("Add", ["XA", "B"], ["Y"]) + + linear_regression = make_function( + new_domain, + "LinearRegression", + ["X", "A", "B"], + ["Y"], + [node1, node2], + opset_imports, + [], + ) + + X = make_tensor_value_info("X", TensorProto.FLOAT, [None, None]) + A = make_tensor_value_info("A", TensorProto.FLOAT, [None, None]) + B = make_tensor_value_info("B", TensorProto.FLOAT, [None, None]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, None) + + graph = make_graph( + [ + make_node( + "LinearRegression", ["X", "A", "B"], ["Y1"], domain=new_domain + ), + make_node("Abs", ["Y1"], ["Y"]), + ], + "example", + [X, A, B], + [Y], + ) + + onnx_model = make_model( + graph, opset_imports=opset_imports, functions=[linear_regression] + ) + + cst = np.arange(4).reshape((-1, 2)).astype(np.float32) + yield_eval = YieldEvaluator(onnx_model) + results = list( + yield_eval.enumerate_summarized(None, {"A": cst, "B": cst, "X": cst}) + ) + expected = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + self.assertEqual(len(expected), len(results)) + for a, b in zip(expected, results): + self.assertEqual(len(a), len(b)) + self.assertEqual(a[0], b[0]) + self.assertEqual(a[1], b[1]) + self.assertEqual(a[2], b[2]) + self.assertEqual(a[3], b[3]) + self.assertEqual(a[4], b[4]) + self.assertEqual(a[5], b[5]) + + def test_distance_pair(self): + el1 = (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None) + el2 = el1 + dc = DistanceExecution() + self.assertEqual(dc.distance_pair(el1, el2), 0) + el2 = (ResultType.INPUT, np.dtype("float16"), (2, 2), "ABCD", None) + self.assertEqual(dc.distance_pair(el1, el2), 2) + el2 = (ResultType.OUTPUT, np.dtype("float16"), (2, 2, 4), "GBCD", "Abs") + self.assertEqual(dc.distance_pair(el1, el2), 1130) + el2 = (ResultType.OUTPUT, np.dtype("float16"), (2, 3), "GBCD", "Abs") + self.assertEqual(dc.distance_pair(el1, el2), 1021) + + def test_distance_sequence_0(self): + expected = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + + dc = DistanceExecution() + d, align = dc.distance_sequence(expected, expected) + self.assertEqual(d, 0) + self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]) + + def test_distance_sequence_ins(self): + s1 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + s2 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + + dc = DistanceExecution() + d, align = dc.distance_sequence(s1, s2) + self.assertEqual(d, dc.insert_cost) + self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 3), (5, 4)]) + d, align = dc.distance_sequence(s2, s1) + self.assertEqual(d, dc.insert_cost) + self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (3, 4), (4, 5)]) + + def test_distance_sequence_equal(self): + s1 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + s2 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Z"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + + dc = DistanceExecution() + d, align = dc.distance_sequence(s1, s2) + self.assertEqual(d, 0) + self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]) + + def test_distance_sequence_diff(self): + s1 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + s2 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIP", "Abs", "Z"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + + dc = DistanceExecution() + d, align = dc.distance_sequence(s1, s2) + self.assertEqual(d, 1) + self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]) + + def test_distance_sequence_diff2(self): + s1 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + s2 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 3), "CEIP", "Abs", "Z"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIP", None, "Y"), + ] + + dc = DistanceExecution() + d, align = dc.distance_sequence(s1, s2) + self.assertEqual(d, 5) + self.assertEqual(align, [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4), (5, 5)]) + + def test_distance_sequence_str(self): + s1 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 3), "ABCD", None, "X"), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Exp", "H"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs", "Y"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIO", None, "Y"), + ] + s2 = [ + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "A"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "B"), + (ResultType.INPUT, np.dtype("float32"), (2, 2), "ABCD", None, "X"), + ( + ResultType.RESULT, + np.dtype("float32"), + (2, 2), + "CEIO", + "LinearRegression", + "Y1", + ), + (ResultType.RESULT, np.dtype("float32"), (2, 3), "CEIP", "Abs", "Z"), + (ResultType.OUTPUT, np.dtype("float32"), (2, 2), "CEIP", None, "Y"), + ] + s1 = [ResultExecution(*s) for s in s1] + s2 = [ResultExecution(*s) for s in s2] + + dc = DistanceExecution() + d, align = dc.distance_sequence(s1, s2) + self.assertEqual(d, 1008) + self.assertEqual( + align, [(0, 0), (1, 1), (2, 2), (3, 2), (4, 3), (5, 4), (6, 5)] + ) + text = dc.to_str(s1, s2, align) + self.assertIn("OUTPUT", text) + expected = """ + =|INPUTfloat322x2ABCDA|INPUTfloat322x2ABCDA + =|INPUTfloat322x2ABCDB|INPUTfloat322x2ABCDB + ~|INPUTfloat322x3ABCDX|INPUTfloat322x2ABCDX + -|RESULTfloat322x2CEIOExpH| + =|RESULTfloat322x2CEIOLinearRegrY1|RESULTfloat322x2CEIOLinearRegrY1 + ~|RESULTfloat322x2CEIOAbsY|RESULTfloat322x3CEIPAbsZ + ~|OUTPUTfloat322x2CEIOY|OUTPUTfloat322x2CEIPY + """.replace( + " ", "" + ).strip( + "\n " + ) + self.assertEqual(expected, text.replace(" ", "").strip("\n")) + + def test_compare_execution(self): + m1 = parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + four = Add(two, two) + z = Mul(x, x) + }""" + ) + m2 = parse_model( + """ + + agraph (float[N] x) => (float[N] z) { + two = Constant () + z = Mul(x, x) + }""" + ) + res1, res2, align, dc = compare_onnx_execution(m1, m2) + text = dc.to_str(res1, res2, align) + self.assertIn("CAAA Constant", text) + self.assertEqual(len(align), 5) + + +if __name__ == "__main__": + unittest.main(verbosity=2) diff --git a/_unittests/ut_xrun_doc/test_command_lines1.py b/_unittests/ut_xrun_doc/test_command_lines1.py index 8aa17ee..02f84bd 100644 --- a/_unittests/ut_xrun_doc/test_command_lines1.py +++ b/_unittests/ut_xrun_doc/test_command_lines1.py @@ -14,6 +14,7 @@ from onnx_array_api.ext_test_case import ExtTestCase from onnx_array_api._command_lines_parser import ( get_main_parser, + get_parser_compare, get_parser_translate, main, ) @@ -70,6 +71,42 @@ def test_command_translate(self): code = st.getvalue() self.assertIn("start(opset=", code) + def test_parser_compare(self): + st = StringIO() + with redirect_stdout(st): + get_parser_compare().print_help() + text = st.getvalue() + self.assertIn("model1", text) + + def test_command_compare(self): + X = make_tensor_value_info("X", TensorProto.FLOAT, [5, 6]) + Y = make_tensor_value_info("Y", TensorProto.FLOAT, [5, 6]) + Z = make_tensor_value_info("Z", TensorProto.FLOAT, [5, 6]) + graph = make_graph( + [ + make_node("Add", ["X", "Y"], ["res"]), + make_node("Cos", ["res"], ["Z"]), + ], + "g", + [X, Y], + [Z], + ) + onnx_model = make_model(graph, opset_imports=[make_opsetid("", 18)]) + + with tempfile.TemporaryDirectory() as root: + model_file = os.path.join(root, "model.onnx") + with open(model_file, "wb") as f: + f.write(onnx_model.SerializeToString()) + + args = ["compare", "-m1", model_file, "-m2", model_file, "-v", "1"] + st = StringIO() + with redirect_stdout(st): + main(args) + + code = st.getvalue() + self.assertIn("[compare_onnx_execution]", code) + self.assertIn("ADFF", code) + if __name__ == "__main__": unittest.main(verbosity=2) diff --git a/onnx_array_api/_command_lines_parser.py b/onnx_array_api/_command_lines_parser.py index 71f5a35..a180deb 100644 --- a/onnx_array_api/_command_lines_parser.py +++ b/onnx_array_api/_command_lines_parser.py @@ -14,12 +14,13 @@ def get_main_parser() -> ArgumentParser: ) parser.add_argument( "cmd", - choices=["translate"], + choices=["translate", "compare"], help=dedent( """ Selects a command. - 'translate' exports an onnx graph into a piece of code replicating it. + 'translate' exports an onnx graph into a piece of code replicating it, + 'compares' compares the execution of two onnx models """ ), ) @@ -65,8 +66,59 @@ def _cmd_translate(argv: List[Any]): print(code) +def get_parser_compare() -> ArgumentParser: + parser = ArgumentParser( + prog="compare", + description=dedent( + """ + Compares the execution of two onnx models. + """ + ), + epilog="This is used when two models are different but should produce the same results.", + ) + parser.add_argument( + "-m1", + "--model1", + type=str, + required=True, + help="first onnx model", + ) + parser.add_argument( + "-m2", + "--model2", + type=str, + required=True, + help="second onnx model", + ) + parser.add_argument( + "-v", + "--verbose", + default=0, + help="verbosity", + ) + parser.add_argument( + "-c", + "--column-size", + default=50, + help="column size when displaying the results", + ) + return parser + + +def _cmd_compare(argv: List[Any]): + from .reference import compare_onnx_execution + + parser = get_parser_compare() + args = parser.parse_args(argv[1:]) + onx1 = onnx.load(args.model1) + onx2 = onnx.load(args.model2) + res1, res2, align, dc = compare_onnx_execution(onx1, onx2, verbose=args.verbose) + text = dc.to_str(res1, res2, align, column_size=args.column_size) + print(text) + + def main(argv: Optional[List[Any]] = None): - fcts = dict(translate=_cmd_translate) + fcts = dict(translate=_cmd_translate, compare=_cmd_compare) if argv is None: argv = sys.argv[1:] diff --git a/onnx_array_api/reference/__init__.py b/onnx_array_api/reference/__init__.py index d8c5aa5..fd1d27c 100644 --- a/onnx_array_api/reference/__init__.py +++ b/onnx_array_api/reference/__init__.py @@ -11,6 +11,13 @@ ) from onnx.reference.op_run import to_array_extended from .evaluator import ExtendedReferenceEvaluator +from .evaluator_yield import ( + DistanceExecution, + ResultExecution, + ResultType, + YieldEvaluator, + compare_onnx_execution, +) def from_array_extended(tensor: np.array, name: Optional[str] = None) -> TensorProto: diff --git a/onnx_array_api/reference/evaluator.py b/onnx_array_api/reference/evaluator.py index e20be76..54f0c26 100644 --- a/onnx_array_api/reference/evaluator.py +++ b/onnx_array_api/reference/evaluator.py @@ -7,6 +7,7 @@ from .ops.op_cast_like import CastLike_15, CastLike_19 from .ops.op_concat import Concat from .ops.op_constant_of_shape import ConstantOfShape +from .ops.op_fused_matmul import FusedMatMul logger = getLogger("onnx-array-api-eval") @@ -32,6 +33,7 @@ class ExtendedReferenceEvaluator(ReferenceEvaluator): CastLike_15, CastLike_19, ConstantOfShape, + FusedMatMul, ] @staticmethod diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py new file mode 100644 index 0000000..3935913 --- /dev/null +++ b/onnx_array_api/reference/evaluator_yield.py @@ -0,0 +1,449 @@ +from dataclasses import dataclass +from typing import Any, Dict, List, Iterator, Optional, Tuple +from enum import IntEnum +import numpy as np +from onnx import ModelProto, TensorProto, ValueInfoProto +from .evaluator import ExtendedReferenceEvaluator + + +def _align(res: str, limit: int) -> str: + if len(res) == limit: + return res + if len(res) > limit: + return res[:limit] + return res + " " * (limit - len(res)) + + +class ResultType(IntEnum): + RESULT = 1 + INITIALIZER = 2 + SPARSE_INITIALIZER = 4 + INPUT = 8 + OUTPUT = 16 + + def __repr__(self): + return f"{self.__class__.__name__}.{self._name_}" + + +@dataclass +class ResultExecution: + """ + The description of a result. + """ + + kind: ResultType + dtype: object + shape: tuple + summary: str + op_type: str + name: str + + def __len__(self) -> int: + return 6 + + def __getitem__(self, i: int) -> Any: + if i == 0: + return self.kind + if i == 1: + return self.dtype + if i == 2: + return self.shape + if i == 3: + return self.summary + if i == 4: + return self.op_type + if i == 5: + return self.name + raise IndexError(f"i={i} out of boundary") + + def __str__(self): + els = [ + _align(self.kind._name_, 6), + _align(str(self.dtype).replace("dtype(", "").replace(")", ""), 8), + _align("x".join(map(str, self.shape)), 15), + self.summary, + _align(self.op_type or "", 10), + self.name or "", + ] + return " ".join(els) + + +def make_summary(value: Any, length: int = 4, modulo: int = 26) -> str: + """ + Create a short string summarizing the value (discretization). + + :param value: array + :param length: number of value to produce + :param module: discretization parameter + :return: short string + """ + value4 = np.zeros(length, dtype=np.float64) + if value.size <= length: + value4[: value.size] = value.flatten().astype(np.float64) + else: + if value.size % length != 0: + value2 = np.zeros( + value.size + length - value.size % length, dtype=np.float64 + ) + value2[: value.size] = value.flatten().astype(np.float64) + else: + value2 = value.flatten().astype(np.float64) + value4 = value2.reshape((4, -1)).sum(axis=1) + value4i = value4.astype(np.int64) % modulo + s = "".join([chr(65 + i) for i in value4i]) + return s + + +class YieldEvaluator: + """ + This class implements method `enumerate_results` which iterates on + intermediates results. By default, it uses + :class:`onnx_array_api.reference.ExtendedReferenceEvaluator`. + + :param onnx_model: model to run + :param recursive: dig into subgraph and functions as well + """ + + def __init__( + self, + onnx_model: ModelProto, + recursive: bool = False, + cls=ExtendedReferenceEvaluator, + ): + assert not recursive, "recursive=True is not yet implemented" + self.onnx_model = onnx_model + self.evaluator = cls(onnx_model) if cls is not None else None + + def enumerate_results( + self, + output_names: Optional[List[str]] = None, + feed_inputs: Optional[Dict[str, Any]] = None, + ) -> Iterator[Tuple[ResultType, str, Any]]: + """ + Executes the onnx model and enumerate all the intermediate results. + + Args: + output_names: requested outputs by names, None for all + feed_inputs: dictionary `{ input name: input value }` + + Returns: + iterator on tuple(result kind, name, value, node.op_type or None) + """ + assert isinstance(self.evaluator, ExtendedReferenceEvaluator), ( + f"This implementation only works with " + f"ExtendedReferenceEvaluator not {type(self.evaluator)}" + ) + attributes = {} + if output_names is None: + output_names = self.evaluator.output_names + + results = {"": None} + results.update(self.evaluator.rt_inits_) + results.update(feed_inputs) + # step 0: initializer + for k, v in self.evaluator.rt_inits_.items(): + yield ResultType.INITIALIZER, k, v, None + # step 1: inputs + for k, v in feed_inputs.items(): + yield ResultType.INPUT, k, v, None + + # step 2: execute nodes + for node in self.evaluator.rt_nodes_: + for i in node.input: + if i not in results: + raise RuntimeError( + f"Unable to find input {i!r} in known results {sorted(results)}, " + f"self.rt_inits_ has {sorted(self.evaluator.rt_inits_)}, " + f"feed_inputs has {sorted(feed_inputs)}." + ) + inputs = [results[i] for i in node.input] + linked_attributes = {} + if node.has_linked_attribute and attributes: + linked_attributes["linked_attributes"] = attributes + if node.need_context(): + outputs = node.run(*inputs, context=results, **linked_attributes) + else: + outputs = node.run(*inputs, **linked_attributes) + for name, value in zip(node.output, outputs): + yield ResultType.RESULT, name, value, node.op_type + results[name] = value + + # step 3: outputs + for name in output_names: + if name not in results: + raise RuntimeError( + f"Unable to find output name {name!r} in {sorted(results)}, proto is\n{self.proto_}" + ) + yield ResultType.OUTPUT, name, results[name], None + + def enumerate_summarized( + self, + output_names: Optional[List[str]] = None, + feed_inputs: Optional[Dict[str, Any]] = None, + ) -> Iterator[ResultExecution]: + """ + Executes the onnx model and enumerate intermediate results without their names. + + Args: + output_names: requested outputs by names, None for all + feed_inputs: dictionary `{ input name: input value }` + + Returns: + iterator on tuple(result kind, node.type, dtype, shape, value, result name) + """ + for kind, name, value, op_type in self.enumerate_results( + output_names, feed_inputs + ): + summary = make_summary(value) + yield ResultExecution( + kind, value.dtype, value.shape, summary, op_type, name + ) + + +class DistanceExecution: + """ + Computes a distance between two results. + """ + + float_types = { + np.float16, + np.float32, + np.float64, + np.dtype("float16"), + np.dtype("float32"), + np.dtype("float64"), + } + + def __init__(self, max_lag: int = 50): + self.kind_cost = 1000 + self.type_cost = 10 + self.rank_cost = 100 + self.op_type_cost = 10 + self.max_lag = max_lag + self.insert_cost = 1000 + + def distance_pair(self, r1: ResultExecution, r2: ResultExecution) -> float: + """ + (ResultType.RESULT, np.dtype("float32"), (2, 2), "CEIO", "Abs"), + + :param r1: first result + :param r2: second result + :return: distance + """ + d = 0 + if r1[0] != r2[0]: + # difference type + d += self.kind_cost + if r1[1] != r2[1]: + d += self._cost_type(r1[1], r2[1]) * self.type_cost + if r1[2] != r2[2]: + d += self._cost_shape(r1[2], r2[2]) + if r1[3] != r2[3]: + d += self._cost_summary(r1[3], r2[3]) + if r1[4] != r2[4]: + d += self.op_type_cost + return d + + def _cost_type(self, t1: "np.dtype", t2: "np.dtype") -> float: + if t1 in self.float_types and t2 in self.float_types: + return 0.2 + return 1 + + def _cost_shape(self, s1: Tuple[int, ...], s2: Tuple[int, ...]) -> float: + d = abs(np.prod(s1) - np.prod(s2)) + if len(s1) != len(s2): + return self.rank_cost + d + for i, j in zip(s1, s2): + d += abs(i - j) + return d + + def _cost_summary(self, s1: str, s2: str) -> float: + if len(s1) != len(s2): + return 1e6 + d = 0 + for a, b in zip(s1, s2): + d += abs(ord(a) - ord(b)) + return d + + def distance_sequence( + self, s1: List[ResultExecution], s2: List[ResultExecution] + ) -> Tuple[float, List[Tuple[int, int]]]: + """ + Computes the distance between two sequences of results. + + :param s1: first sequence + :param s2: second sequence + :return: distance and alignment + """ + delay = self.max_lag + distance = {(-1, -1): 0} + predecessor = {(-1, -1): None} + for i in range(len(s1)): + for j in range(max(0, i - delay), min(len(s2), i + delay)): + best = 1e100 + pred = None + ki, kj = i - 1, j - 1 + if (ki, kj) in distance: + d = distance[ki, kj] + self.distance_pair(s1[i], s2[j]) + if d < best: + best = d + pred = (ki, kj) + ki, kj = i - 1, j + if (ki, kj) in distance: + d = distance[ki, kj] + self.insert_cost + if d < best: + best = d + pred = (ki, kj) + ki, kj = i, j - 1 + if (ki, kj) in distance: + d = distance[ki, kj] + self.insert_cost + if d < best: + best = d + pred = (ki, kj) + distance[i, j] = best + predecessor[i, j] = pred + + # reverse + way = [] + last = len(s1) - 1, len(s2) - 1 + while last is not None: + way.append(last) + last = predecessor[last] + return distance[len(s1) - 1, len(s2) - 1], list(reversed(way))[1:] + + def to_str( + self, + s1: List[ResultExecution], + s2: List[ResultExecution], + alignment: List[Tuple[int, int]], + column_size: int = 60, + ) -> str: + """ + Prints out the alignment between two sequences into a string. + :param s1: first sequence + :param s2: second sequence + :param alignment: alignment + :param column_size: column size + :return: test + """ + rows = [] + last = -1, -1 + for i, j in alignment: + assert i < len(s1), f"Unexpected value i={i} >= len(s1)={len(s1)}" + assert j < len(s2), f"Unexpected value i={j} >= len(s2)={len(s2)}" + expected = last[0] + 1, last[1] + 1 + + if expected == (i, j): + d1 = s1[i] + d2 = s2[j] + d = self.distance_pair(d1, d2) + symbol = "=" if d == 0 else "~" + rows.append( + f"{symbol} | {_align(str(d1), column_size)} | {_align(str(d2), column_size)}" + ) + elif i == last[0]: + d2 = s2[j] + rows.append( + f"+ | {_align('', column_size)} | {_align(str(d2), column_size)} " + ) + else: + d1 = s1[i] + rows.append( + f"- | {_align(str(d1), column_size)} | {_align('', column_size)}" + ) + last = i, j + return "\n".join(rows) + + +def generate_input(info: ValueInfoProto) -> np.ndarray: + """ + Generates one input. + """ + elem_type = info.type.tensor_type.elem_type + shape = [ + (getattr(d, "dim_value", None) or getattr(d, "dim_param")) + for d in info.type.tensor_type.shape.dim + ] + new_shape = [] + for sh in shape: + if isinstance(sh, str): + if len(new_shape) == 0: + new_shape.append(1) + else: + new_shape.append(16) + else: + new_shape.append(sh) + new_shape = tuple(new_shape) + p = np.prod(new_shape) + value = np.arange(p) + if elem_type == TensorProto.INT32: + return value.astype(np.int32).reshape(new_shape) + if elem_type == TensorProto.INT64: + return value.astype(np.int64).reshape(new_shape) + if elem_type == TensorProto.FLOAT: + return (value.astype(np.float32) / p).astype(np.float32).reshape(new_shape) + if elem_type == TensorProto.FLOAT16: + return (value.astype(np.float16) / p).astype(np.float16).reshape(new_shape) + if elem_type == TensorProto.DOUBLE: + return (value.astype(np.float64) / p).astype(np.float64).reshape(new_shape) + raise RuntimeError(f"Unexpected element_type {elem_type} for info={info}") + + +def generate_inputs(model: ModelProto) -> List[np.ndarray]: + """ + Generates inputs for a specific model. + + :param model: ModelProto + :return: list of inputs + """ + inputs = [] + inits = set(i.name for i in model.graph.initializer) + for inp in model.graph.input: + if inp.name in inits: + break + inputs.append(generate_input(inp)) + return inputs + + +def compare_onnx_execution( + model1: ModelProto, + model2: ModelProto, + inputs: Optional[List[Any]] = None, + verbose: int = 0, +) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]: + """ + Compares the execution of two onnx models. + The function assumes both models takes the same inputs. + See :ref:`l-onnx-diff-example` to see a full example using + this function. + + :param model1: first model + :param model2: second model + :param inputs: inputs to use + :param verbose: verbosity + :return: four results, a sequence of results for the first model and the second model, + the alignment between the two, DistanceExecution + """ + if verbose: + print("[compare_onnx_execution] generate inputs") + if inputs is None: + inputs = generate_inputs(model1) + feeds1 = {i.name: v for i, v in zip(model1.graph.input, inputs)} + feeds2 = {i.name: v for i, v in zip(model2.graph.input, inputs)} + if verbose: + print(f"[compare_onnx_execution] got {len(inputs)} inputs") + print("[compare_onnx_execution] execute first model") + res1 = list(YieldEvaluator(model1).enumerate_summarized(None, feeds1)) + if verbose: + print(f"[compare_onnx_execution] got {len(res1)} results") + print("[compare_onnx_execution] execute second model") + res2 = list(YieldEvaluator(model2).enumerate_summarized(None, feeds2)) + if verbose: + print(f"[compare_onnx_execution] got {len(res2)} results") + print("[compare_onnx_execution] compute edit distance") + dc = DistanceExecution() + _, align = dc.distance_sequence(res1, res2) + if verbose: + print(f"[compare_onnx_execution] got {len(align)} pairs") + print("[compare_onnx_execution] done") + return res1, res2, align, dc diff --git a/onnx_array_api/reference/ops/op_fused_matmul.py b/onnx_array_api/reference/ops/op_fused_matmul.py new file mode 100644 index 0000000..0f738c7 --- /dev/null +++ b/onnx_array_api/reference/ops/op_fused_matmul.py @@ -0,0 +1,31 @@ +import numpy as np +from onnx.reference.op_run import OpRun + + +class FusedMatMul(OpRun): + op_domain = "com.microsoft" + + def _run( + self, + A, + B, + alpha: float = 1, + transA: int = 0, + transB: int = 0, + transBatchA: int = 0, + transBatchB: int = 0, + ): + assert ( + transBatchA == 0 + ), f"Not implemented for transBatchA==1 and {A.shape}x{B.shape}" + assert ( + transBatchB == 0 + ), f"Not implemented for transBatchB==1 and {A.shape}x{B.shape}" + if transA: + dim = len(A.shape) + A = A.transpose(axes=(dim - 2, dim - 1)) + if transB: + dim = len(B.shape) + B = B.transpose(axes=(dim - 2, dim - 1)) + a = np.array(alpha, dtype=A.dtype) + return (A @ B * a,)