Skip to content

Commit 71aa3a0

Browse files
authored
Add methods to update nodes in GraphAPI (#59)
* Add methods to update nodes * update doc
1 parent 6718ee8 commit 71aa3a0

File tree

5 files changed

+188
-4
lines changed

5 files changed

+188
-4
lines changed

CHANGELOGS.rst

+5
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,11 @@
11
Change Logs
22
===========
33

4+
0.2.0
5+
+++++
6+
7+
* :pr:`59`: add methods to update nodes in GraphAPI
8+
49
0.1.3
510
+++++
611

_doc/api/graph_api.rst

+6
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,12 @@ GraphBuilder
99
.. autoclass:: onnx_array_api.graph_api.GraphBuilder
1010
:members:
1111

12+
NodePattern
13+
===========
14+
15+
.. autoclass:: onnx_array_api.graph_api.NodePattern
16+
:members:
17+
1218
OptimizationOptions
1319
===================
1420

_unittests/ut_graph_api/test_graph_builder.py

+58
Original file line numberDiff line numberDiff line change
@@ -376,6 +376,64 @@ def test_make_nodes_noprefix(self):
376376
got = ref.run(None, feeds)
377377
self.assertEqualArray(expected, got[0])
378378

379+
def test_node_pattern(self):
380+
model = onnx.parser.parse_model(
381+
"""
382+
<ir_version: 8, opset_import: [ "": 18]>
383+
agraph (float[N] x) => (float[N] z) {
384+
two = Constant <value_float=2.0> ()
385+
four = Add(two, two)
386+
z = Mul(x, four)
387+
}"""
388+
)
389+
gr = GraphBuilder(model)
390+
p = gr.np(index=0)
391+
r = repr(p)
392+
self.assertEqual("NodePattern(index=0, op_type=None, name=None)", r)
393+
394+
def test_update_node_attribute(self):
395+
model = onnx.parser.parse_model(
396+
"""
397+
<ir_version: 8, opset_import: [ "": 18]>
398+
agraph (float[N] x) => (float[N] z) {
399+
two = Constant <value_float=2.0> ()
400+
four = Add(two, two)
401+
z = Mul(x, four)
402+
}"""
403+
)
404+
gr = GraphBuilder(model)
405+
self.assertEqual(len(gr.nodes), 3)
406+
m = gr.update_attribute(gr.np(op_type="Constant"), value_float=float(1))
407+
self.assertEqual(m, 1)
408+
self.assertEqual(len(gr.nodes), 3)
409+
onx = gr.to_onnx()
410+
self.assertEqual(len(onx.graph.node), 3)
411+
node = onx.graph.node[0]
412+
self.assertIn("f: 1", str(node))
413+
414+
def test_delete_node_attribute(self):
415+
model = onnx.parser.parse_model(
416+
"""
417+
<ir_version: 8, opset_import: [ "": 18]>
418+
agraph (float[N] x) => (float[N] z) {
419+
two = Constant <value_float=2.0> ()
420+
four = Add(two, two)
421+
z = Mul(x, four)
422+
}"""
423+
)
424+
gr = GraphBuilder(model)
425+
self.assertEqual(len(gr.nodes), 3)
426+
m = gr.update_attribute(
427+
gr.np(op_type="Constant"), value_float=gr.DELETE, value_int=1
428+
)
429+
self.assertEqual(m, 1)
430+
self.assertEqual(len(gr.nodes), 3)
431+
onx = gr.to_onnx()
432+
self.assertEqual(len(onx.graph.node), 3)
433+
node = onx.graph.node[0]
434+
self.assertNotIn('name: "value_float"', str(node))
435+
self.assertIn("i: 1", str(node))
436+
379437

380438
if __name__ == "__main__":
381439
unittest.main(verbosity=2)

onnx_array_api/graph_api/__init__.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -1 +1 @@
1-
from .graph_builder import GraphBuilder
1+
from .graph_builder import GraphBuilder, NodePattern

onnx_array_api/graph_api/graph_builder.py

+118-3
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
import sys
22
from functools import partial
3-
from typing import Any, Dict, List, Optional, Sequence, Set, Tuple, Union
3+
from typing import Any, Dict, Iterator, List, Optional, Sequence, Set, Tuple, Union
44
import numpy as np
55
from onnx.defs import onnx_opset_version
66
import onnx.helper as oh
@@ -30,6 +30,51 @@ def __init__(
3030
self.constant_size = constant_size
3131

3232

33+
class NodePattern:
34+
"""
35+
Class defining a matching pattern able to find nodes in a set of nodes.
36+
"""
37+
38+
def __init__(
39+
self,
40+
index: Optional[int] = None,
41+
op_type: Optional[str] = None,
42+
name: Optional[None] = None,
43+
):
44+
self.index = index
45+
self.op_type = op_type
46+
self.name = name
47+
48+
def __repr__(self):
49+
"usual"
50+
args = ["index", "op_type", "name"]
51+
sargs = []
52+
for a in args:
53+
if a:
54+
sargs.append(f"{a}={getattr(self, a)!r}")
55+
return f"{self.__class__.__name__}({', '.join(sargs)})"
56+
57+
def find(self, graph: "GraphBuilder") -> Iterator:
58+
"""
59+
Iterates on nodes matching the pattern.
60+
"""
61+
for index, node in enumerate(graph.nodes):
62+
if self.match(index, node):
63+
yield node
64+
65+
def match(self, index, node: NodeProto) -> bool:
66+
"""
67+
Tells if a node is matching this pattern.
68+
"""
69+
if self.index is not None and self.index != index:
70+
return False
71+
if self.op_type is not None and self.op_type != node.op_type:
72+
return False
73+
if self.name is not None and self.name != node.name:
74+
return False
75+
return True
76+
77+
3378
class Opset:
3479
# defined for opset >= 18
3580
# name: number of expected outputs
@@ -168,7 +213,7 @@ def __init__(
168213
f"{type(target_opset_or_existing_proto)} is not supported."
169214
)
170215

171-
self.op = Opset(self, self.opsets[""])
216+
self.op = Opset(self, self.opsets[""]) if "" in self.opsets else None
172217
self._cache_array = []
173218

174219
def _get_tensor_shape(
@@ -749,7 +794,6 @@ def constant_folding(self):
749794
Folds all constants. Constants are marked during the creation of the graph.
750795
There is no need to propagate this information.
751796
"""
752-
753797
updates = {}
754798
node_to_remove = set()
755799
for k, v in self.constants_.items():
@@ -840,3 +884,74 @@ def remove_identity_nodes(self):
840884
self.nodes.append(new_node)
841885
else:
842886
self.nodes.append(node)
887+
888+
def np(
889+
self,
890+
index: Optional[int] = None,
891+
op_type: Optional[str] = None,
892+
name: Optional[str] = None,
893+
) -> NodePattern:
894+
"""
895+
Returns an instance of :class:`NodePattern
896+
<onnx_array_api.graph_api.graph_builder.NodePattern>`.
897+
"""
898+
return NodePattern(index=index, op_type=op_type, name=name)
899+
900+
def update_attribute(
901+
self,
902+
pat: NodePattern,
903+
recursive: bool = False,
904+
**kwargs: Dict[str, Any],
905+
) -> int:
906+
"""
907+
Udates attributes for nodes matching the
908+
909+
:param pat: returned by method :meth:`GraphBuilder.np`
910+
:param recursive: walk through subgraph
911+
:param kwargs: attributes to modify
912+
:return: number of modified nodes
913+
"""
914+
assert not recursive, "recursive=True is not implemented."
915+
modified = 0
916+
for node in pat.find(self):
917+
up = self.update_node(node, **kwargs)
918+
if up:
919+
modified += 1
920+
return modified
921+
922+
DELETE = object()
923+
924+
def update_node(self, node: NodeProto, **kwargs) -> bool:
925+
"""
926+
Updates attributes of a node proto.
927+
Returns True if the node was updated.
928+
"""
929+
processed = set()
930+
modified = True
931+
atts = []
932+
for att in node.attribute:
933+
if att.name in kwargs:
934+
processed.add(att.name)
935+
if kwargs[att.name] is GraphBuilder.DELETE:
936+
continue
937+
new_att = oh.make_attribute(att.name, kwargs[att.name])
938+
assert new_att.type == att.type, (
939+
f"Mismatch value for attribute {att.name!r} has type "
940+
f"{att.type} but the new value leads to "
941+
f"type={new_att.type}."
942+
)
943+
atts.append(new_att)
944+
modified = True
945+
continue
946+
atts.append(att)
947+
for k, v in kwargs.items():
948+
if k in processed or v is GraphBuilder.DELETE:
949+
continue
950+
modified = True
951+
new_att = oh.make_attribute(k, v)
952+
atts.append(new_att)
953+
954+
if modified:
955+
del node.attribute[:]
956+
node.attribute.extend(atts)
957+
return modified

0 commit comments

Comments
 (0)