diff --git a/LICENSE.txt b/LICENSE.txt index e027853..1a46a8e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2023-2024, Xavier Dupré +Copyright (c) 2023-2025, Xavier Dupré Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal diff --git a/onnx_array_api/reference/evaluator_yield.py b/onnx_array_api/reference/evaluator_yield.py index 5b77e8b..6ae005c 100644 --- a/onnx_array_api/reference/evaluator_yield.py +++ b/onnx_array_api/reference/evaluator_yield.py @@ -3,6 +3,7 @@ from enum import IntEnum import numpy as np from onnx import ModelProto, TensorProto, ValueInfoProto, load +from onnx.reference import ReferenceEvaluator from onnx.helper import tensor_dtype_to_np_dtype from onnx.shape_inference import infer_shapes from . import to_array_extended @@ -138,17 +139,23 @@ class YieldEvaluator: :param onnx_model: model to run :param recursive: dig into subgraph and functions as well + :param cls: evaluator to use, default value is :class:`ExtendedReferenceEvaluator + ` """ def __init__( self, onnx_model: ModelProto, recursive: bool = False, - cls=ExtendedReferenceEvaluator, + cls: Optional[type[ExtendedReferenceEvaluator]] = None, ): assert not recursive, "recursive=True is not yet implemented" self.onnx_model = onnx_model - self.evaluator = cls(onnx_model) if cls is not None else None + self.evaluator = ( + cls(onnx_model) + if cls is not None + else ExtendedReferenceEvaluator(onnx_model) + ) def enumerate_results( self, @@ -166,9 +173,9 @@ def enumerate_results( Returns: iterator on tuple(result kind, name, value, node.op_type or None) """ - assert isinstance(self.evaluator, ExtendedReferenceEvaluator), ( + assert isinstance(self.evaluator, ReferenceEvaluator), ( f"This implementation only works with " - f"ExtendedReferenceEvaluator not {type(self.evaluator)}" + f"ReferenceEvaluator not {type(self.evaluator)}" ) attributes = {} if output_names is None: @@ -595,6 +602,7 @@ def compare_onnx_execution( raise_exc: bool = True, mode: str = "execute", keep_tensor: bool = False, + cls: Optional[type[ReferenceEvaluator]] = None, ) -> Tuple[List[ResultExecution], List[ResultExecution], List[Tuple[int, int]]]: """ Compares the execution of two onnx models. @@ -611,6 +619,7 @@ def compare_onnx_execution( :param mode: the model should be executed but the function can be executed but the comparison may append on nodes only :param keep_tensor: keeps the tensor in order to compute a precise distance + :param cls: evaluator class to use :return: four results, a sequence of results for the first model and the second model, the alignment between the two, DistanceExecution @@ -634,7 +643,7 @@ def compare_onnx_execution( print(f"[compare_onnx_execution] execute with {len(inputs)} inputs") print("[compare_onnx_execution] execute first model") res1 = list( - YieldEvaluator(model1).enumerate_summarized( + YieldEvaluator(model1, cls=cls).enumerate_summarized( None, feeds1, raise_exc=raise_exc, keep_tensor=keep_tensor ) ) @@ -642,7 +651,7 @@ def compare_onnx_execution( print(f"[compare_onnx_execution] got {len(res1)} results") print("[compare_onnx_execution] execute second model") res2 = list( - YieldEvaluator(model2).enumerate_summarized( + YieldEvaluator(model2, cls=cls).enumerate_summarized( None, feeds2, raise_exc=raise_exc, keep_tensor=keep_tensor ) )