|
1 | 1 | import sys
|
2 | 2 | from functools import partial
|
3 |
| -from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union |
| 3 | +from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union |
4 | 4 | import numpy as np
|
5 | 5 | from onnx.defs import onnx_opset_version
|
6 | 6 | import onnx.helper as oh
|
@@ -30,6 +30,51 @@ def __init__(
|
30 | 30 | self.constant_size = constant_size
|
31 | 31 |
|
32 | 32 |
|
| 33 | +class NodePattern: |
| 34 | + """ |
| 35 | + Class defining a matching pattern able to find nodes in a set of nodes. |
| 36 | + """ |
| 37 | + |
| 38 | + def __init__( |
| 39 | + self, |
| 40 | + index: Optional[int] = None, |
| 41 | + op_type: Optional[str] = None, |
| 42 | + name: Optional[None] = None, |
| 43 | + ): |
| 44 | + self.index = index |
| 45 | + self.op_type = op_type |
| 46 | + self.name = name |
| 47 | + |
| 48 | + def __repr__(self): |
| 49 | + "usual" |
| 50 | + args = ["index", "op_type", "name"] |
| 51 | + sargs = [] |
| 52 | + for a in args: |
| 53 | + if a: |
| 54 | + sargs.append(f"{a}={getattr(self, a)!r}") |
| 55 | + return f"{self.__class__.__name__}({', '.join(sargs)})" |
| 56 | + |
| 57 | + def find(self, graph: "GraphBuilder") -> Iterator: |
| 58 | + """ |
| 59 | + Iterates on nodes matching the pattern. |
| 60 | + """ |
| 61 | + for index, node in enumerate(graph.nodes): |
| 62 | + if self.match(index, node): |
| 63 | + yield node |
| 64 | + |
| 65 | + def match(self, index, node: NodeProto) -> bool: |
| 66 | + """ |
| 67 | + Tells if a node is matching this pattern. |
| 68 | + """ |
| 69 | + if self.index is not None and self.index != index: |
| 70 | + return False |
| 71 | + if self.op_type is not None and self.op_type != node.op_type: |
| 72 | + return False |
| 73 | + if self.name is not None and self.name != node.name: |
| 74 | + return False |
| 75 | + return True |
| 76 | + |
| 77 | + |
33 | 78 | class Opset:
|
34 | 79 | # defined for opset >= 18
|
35 | 80 | # name: number of expected outputs
|
@@ -168,7 +213,7 @@ def __init__(
|
168 | 213 | f"{type(target_opset_or_existing_proto)} is not supported."
|
169 | 214 | )
|
170 | 215 |
|
171 |
| - self.op = Opset(self, self.opsets[""]) |
| 216 | + self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None |
172 | 217 | self._cache_array = []
|
173 | 218 |
|
174 | 219 | def _get_tensor_shape(
|
@@ -749,7 +794,6 @@ def constant_folding(self):
|
749 | 794 | Folds all constants. Constants are marked during the creation of the graph.
|
750 | 795 | There is no need to propagate this information.
|
751 | 796 | """
|
752 |
| - |
753 | 797 | updates = {}
|
754 | 798 | node_to_remove = set()
|
755 | 799 | for k, v in self.constants_.items():
|
@@ -840,3 +884,74 @@ def remove_identity_nodes(self):
|
840 | 884 | self.nodes.append(new_node)
|
841 | 885 | else:
|
842 | 886 | self.nodes.append(node)
|
| 887 | + |
| 888 | + def np( |
| 889 | + self, |
| 890 | + index: Optional[int] = None, |
| 891 | + op_type: Optional[str] = None, |
| 892 | + name: Optional[str] = None, |
| 893 | + ) -> NodePattern: |
| 894 | + """ |
| 895 | + Returns an instance of :class:`NodePattern |
| 896 | + <onnx_array_api.graph_api.graph_builder.NodePattern>`. |
| 897 | + """ |
| 898 | + return NodePattern(index=index, op_type=op_type, name=name) |
| 899 | + |
| 900 | + def update_attribute( |
| 901 | + self, |
| 902 | + pat: NodePattern, |
| 903 | + recursive: bool = False, |
| 904 | + **kwargs: Dict[str, Any], |
| 905 | + ) -> int: |
| 906 | + """ |
| 907 | + Udates attributes for nodes matching the |
| 908 | +
|
| 909 | + :param pat: returned by method :meth:`GraphBuilder.np` |
| 910 | + :param recursive: walk through subgraph |
| 911 | + :param kwargs: attributes to modify |
| 912 | + :return: number of modified nodes |
| 913 | + """ |
| 914 | + assert not recursive, "recursive=True is not implemented." |
| 915 | + modified = 0 |
| 916 | + for node in pat.find(self): |
| 917 | + up = self.update_node(node, **kwargs) |
| 918 | + if up: |
| 919 | + modified += 1 |
| 920 | + return modified |
| 921 | + |
| 922 | + DELETE = object() |
| 923 | + |
| 924 | + def update_node(self, node: NodeProto, **kwargs) -> bool: |
| 925 | + """ |
| 926 | + Updates attributes of a node proto. |
| 927 | + Returns True if the node was updated. |
| 928 | + """ |
| 929 | + processed = set() |
| 930 | + modified = True |
| 931 | + atts = [] |
| 932 | + for att in node.attribute: |
| 933 | + if att.name in kwargs: |
| 934 | + processed.add(att.name) |
| 935 | + if kwargs[att.name] is GraphBuilder.DELETE: |
| 936 | + continue |
| 937 | + new_att = oh.make_attribute(att.name, kwargs[att.name]) |
| 938 | + assert new_att.type == att.type, ( |
| 939 | + f"Mismatch value for attribute {att.name!r} has type " |
| 940 | + f"{att.type} but the new value leads to " |
| 941 | + f"type={new_att.type}." |
| 942 | + ) |
| 943 | + atts.append(new_att) |
| 944 | + modified = True |
| 945 | + continue |
| 946 | + atts.append(att) |
| 947 | + for k, v in kwargs.items(): |
| 948 | + if k in processed or v is GraphBuilder.DELETE: |
| 949 | + continue |
| 950 | + modified = True |
| 951 | + new_att = oh.make_attribute(k, v) |
| 952 | + atts.append(new_att) |
| 953 | + |
| 954 | + if modified: |
| 955 | + del node.attribute[:] |
| 956 | + node.attribute.extend(atts) |
| 957 | + return modified |
0 commit comments