Skip to content

Export evaluator type in compare_onnx_execution #93

New issue

Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? No Sign in to your account

Merged
merged 3 commits into from
Jan 6, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion LICENSE.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
Copyright (c) 2023-2024, Xavier Dupré
Copyright (c) 2023-2025, Xavier Dupré

Permission is hereby granted, free of charge, to any person obtaining a copy
of this software and associated documentation files (the "Software"), to deal
Expand Down
21 changes: 15 additions & 6 deletions onnx_array_api/reference/evaluator_yield.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from enum import IntEnum
import numpy as np
from onnx import ModelProto, TensorProto, ValueInfoProto, load
from onnx.reference import ReferenceEvaluator
from onnx.helper import tensor_dtype_to_np_dtype
from onnx.shape_inference import infer_shapes
from . import to_array_extended
Expand Down Expand Up @@ -138,17 +139,23 @@ class YieldEvaluator:

:param onnx_model: model to run
:param recursive: dig into subgraph and functions as well
:param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator
<onnx_array_api.reference.ExtendedReferenceEvaluator>`
"""

def __init__(
self,
onnx_model: ModelProto,
recursive: bool = False,
cls=ExtendedReferenceEvaluator,
cls: Optional[type[ExtendedReferenceEvaluator]] = None,
):
assert not recursive, "recursive=True is not yet implemented"
self.onnx_model = onnx_model
self.evaluator = cls(onnx_model) if cls is not None else None
self.evaluator = (
cls(onnx_model)
if cls is not None
else ExtendedReferenceEvaluator(onnx_model)
)

def enumerate_results(
self,
Expand All @@ -166,9 +173,9 @@ def enumerate_results(
Returns:
iterator on tuple(result kind, name, value, node.op_type or None)
"""
assert isinstance(self.evaluator, ExtendedReferenceEvaluator), (
assert isinstance(self.evaluator, ReferenceEvaluator), (
f"This implementation only works with "
f"ExtendedReferenceEvaluator not {type(self.evaluator)}"
f"ReferenceEvaluator not {type(self.evaluator)}"
)
attributes = {}
if output_names is None:
Expand Down Expand Up @@ -595,6 +602,7 @@ def compare_onnx_execution(
raise_exc: bool = True,
mode: str = "execute",
keep_tensor: bool = False,
cls: Optional[type[ReferenceEvaluator]] = None,
) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]:
"""
Compares the execution of two onnx models.
Expand All @@ -611,6 +619,7 @@ def compare_onnx_execution(
:param mode: the model should be executed but the function can be executed
but the comparison may append on nodes only
:param keep_tensor: keeps the tensor in order to compute a precise distance
:param cls: evaluator class to use
:return: four results, a sequence of results
for the first model and the second model,
the alignment between the two, DistanceExecution
Expand All @@ -634,15 +643,15 @@ def compare_onnx_execution(
print(f"[compare_onnx_execution] execute with {len(inputs)} inputs")
print("[compare_onnx_execution] execute first model")
res1 = list(
YieldEvaluator(model1).enumerate_summarized(
YieldEvaluator(model1, cls=cls).enumerate_summarized(
None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor
)
)
if verbose:
print(f"[compare_onnx_execution] got {len(res1)} results")
print("[compare_onnx_execution] execute second model")
res2 = list(
YieldEvaluator(model2).enumerate_summarized(
YieldEvaluator(model2, cls=cls).enumerate_summarized(
None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor
)
)
Expand Down
Loading