1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #pragma once |
6 | |
7 | #include <mutex> |
8 | #include <string> |
9 | #include <unordered_map> |
10 | #include <vector> |
11 | |
12 | #include "attr_proto_util.h" |
13 | #include "onnx/common/constants.h" |
14 | #include "onnx/common/status.h" |
15 | #include "onnx/defs/parser.h" |
16 | #include "onnx/defs/schema.h" |
17 | #include "tensor_proto_util.h" |
18 | |
19 | namespace ONNX_NAMESPACE { |
20 | // Helper function to expand a function node given the function proto |
21 | void FunctionExpandHelper( |
22 | const NodeProto& node, |
23 | const FunctionProto& func, |
24 | GraphProto& g, |
25 | const std::string& node_prefix = "" ); |
26 | |
27 | class FunctionBodyHelper { |
28 | public: |
29 | struct AttributeProtoWrapper { |
30 | AttributeProto proto; |
31 | |
32 | AttributeProtoWrapper() {} |
33 | |
34 | AttributeProtoWrapper(const AttributeProto& attr_prot) { |
35 | proto = attr_prot; |
36 | } |
37 | |
38 | template <typename T> |
39 | AttributeProtoWrapper(const std::string& attr_name, const T& value) { |
40 | proto = MakeAttribute(attr_name, value); |
41 | } |
42 | }; |
43 | |
44 | struct NodeDef { |
45 | NodeDef( |
46 | std::vector<std::string> outputs, |
47 | std::string op_type, |
48 | std::vector<std::string> inputs, |
49 | std::vector<AttributeProtoWrapper> attributes = {}, |
50 | std::string domain = "" ) |
51 | : outputs(std::move(outputs)), |
52 | op_type(std::move(op_type)), |
53 | inputs(std::move(inputs)), |
54 | attributes(std::move(attributes)), |
55 | domain(std::move(domain)) {} |
56 | |
57 | std::vector<std::string> outputs; |
58 | std::string op_type; |
59 | std::vector<std::string> inputs; |
60 | std::vector<AttributeProtoWrapper> attributes; |
61 | std::string domain; |
62 | }; |
63 | |
64 | /* |
65 | BuildNodes() is an utility function for easily define a Function Body. |
66 | |
67 | To build a simple node: |
68 | {{"Z"}, "Add", {"X", "Y"}} represents Z = Add(X,Y) |
69 | |
70 | To build a node with attribute: |
71 | {{"Y"}, "Concat", {"X1", "X2", "X3"}, {{"axis", 1}}} |
72 | represents Y = Concat(X1,X2,X3) with axis = 1 |
73 | The attribute type are infered from the attribute value's c++ type |
74 | Supported value types are |
75 | int64_t -> int, vector<int64_t> -> ints |
76 | float -> float, vector<float> -> floats |
77 | string -> string, vector<string> ->strings |
78 | For refering an attribute from parent, use: |
79 | {MakeRefAttribute("axes", AttributeProto::INTS)}} |
80 | |
81 | To build a node which belongs to a domain other than onnx standard domain: |
82 | {{"Z"}, "Foo", {"X", "Y"}, "customdomain"} represents Z = customdomain.Foo(X,Y) |
83 | or |
84 | {{"Y"}, "Bar", {"X1", "X2", "X3"}, {{"axis", 1}}, "customdomain"} |
85 | represents Y = customdomain.Bar(X1,X2,X3) with axis = 1 |
86 | |
87 | For more examples, please find the references of this function |
88 | */ |
89 | static std::vector<NodeProto> BuildNodes(const std::vector<NodeDef>& node_defs); |
90 | |
91 | static void BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs); |
92 | |
93 | static bool BuildFunctionProto( |
94 | FunctionProto& functionProto, |
95 | const OpSchema& schema, |
96 | const std::vector<NodeDef>& node_defs, |
97 | const std::vector<OperatorSetIdProto>& relied_opsets); |
98 | |
99 | template <typename T> |
100 | static NodeDef Const(const std::string& name, const T& value) { |
101 | return NodeDef{{name}, "Constant" , {}, {{"value" , ToTensor<T>(value)}}}; |
102 | } |
103 | |
104 | template <typename T> |
105 | static NodeDef Const(const std::string& name, const std::vector<T>& values) { |
106 | return NodeDef{{name}, "Constant" , {}, {{"value" , ToTensor<T>(values)}}}; |
107 | } |
108 | }; |
109 | |
110 | class FunctionBuilder { |
111 | public: |
112 | FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {} |
113 | |
114 | FunctionBuilder& Add(const char* nodes_txt) { |
115 | OnnxParser parser(nodes_txt); |
116 | auto& nodes = *funProto.mutable_node(); |
117 | |
118 | while (!parser.EndOfInput()) { |
119 | auto status = parser.Parse(*nodes.Add()); |
120 | if (!status.IsOK()) |
121 | ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage())); |
122 | } |
123 | |
124 | return *this; |
125 | } |
126 | |
127 | FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) { |
128 | OnnxParser parser(node_txt); |
129 | auto& node = *funProto.add_node(); |
130 | auto status = parser.Parse(node); |
131 | if (!status.IsOK()) { |
132 | ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage())); |
133 | } |
134 | |
135 | if (!parser.EndOfInput()) { |
136 | ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage())); |
137 | } |
138 | |
139 | *node.add_attribute() = attr; |
140 | |
141 | return *this; |
142 | } |
143 | |
144 | template <typename T> |
145 | FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, const T& attr_value) { |
146 | return Add(node_txt, MakeAttribute(attr_name, attr_value)); |
147 | } |
148 | |
149 | FunctionBuilder& Const(const std::string& name, const TensorProto& tensor) { |
150 | std::string constant_op(name); |
151 | constant_op += " = Constant()" ; |
152 | return Add(constant_op.c_str(), MakeAttribute("value" , tensor)); |
153 | } |
154 | |
155 | // Creates a scalar constant (a tensor of rank zero). |
156 | template <typename T> |
157 | FunctionBuilder& Const(const std::string& name, T const_value) { |
158 | std::string constant_op(name); |
159 | constant_op += " = Constant()" ; |
160 | return Add(constant_op.c_str(), MakeAttribute("value" , ToTensor(const_value))); |
161 | } |
162 | |
163 | // Creates a 1D tensor constant consisting of a single value. |
164 | template <typename T> |
165 | FunctionBuilder& Const1D(const std::string& name, T const_value) { |
166 | std::string constant_op(name); |
167 | constant_op += " = Constant()" ; |
168 | auto tensor = ToTensor(const_value); |
169 | tensor.add_dims(1); |
170 | return Add(constant_op.c_str(), MakeAttribute("value" , tensor)); |
171 | } |
172 | |
173 | // Creates a 1D tensor constant consisting of zero or more values. |
174 | template <typename T> |
175 | FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) { |
176 | std::string constant_op(name); |
177 | constant_op += " = Constant()" ; |
178 | auto tensor = ToTensor(values); |
179 | tensor.add_dims(values.size()); // Treat as 1D tensor. |
180 | |
181 | return Add(constant_op.c_str(), MakeAttribute("value" , tensor)); |
182 | } |
183 | |
184 | FunctionBuilder& AddOpset(const char* domain, int version) { |
185 | auto* opset = funProto.add_opset_import(); |
186 | opset->set_domain(domain); |
187 | opset->set_version(version); |
188 | return *this; |
189 | } |
190 | |
191 | private: |
192 | FunctionProto& funProto; |
193 | }; |
194 | |
195 | } // namespace ONNX_NAMESPACE |
196 | |