1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#pragma once
6
7#include <mutex>
8#include <string>
9#include <unordered_map>
10#include <vector>
11
12#include "attr_proto_util.h"
13#include "onnx/common/constants.h"
14#include "onnx/common/status.h"
15#include "onnx/defs/parser.h"
16#include "onnx/defs/schema.h"
17#include "tensor_proto_util.h"
18
19namespace ONNX_NAMESPACE {
20// Helper function to expand a function node given the function proto
21void FunctionExpandHelper(
22 const NodeProto& node,
23 const FunctionProto& func,
24 GraphProto& g,
25 const std::string& node_prefix = "");
26
27class FunctionBodyHelper {
28 public:
29 struct AttributeProtoWrapper {
30 AttributeProto proto;
31
32 AttributeProtoWrapper() {}
33
34 AttributeProtoWrapper(const AttributeProto& attr_prot) {
35 proto = attr_prot;
36 }
37
38 template <typename T>
39 AttributeProtoWrapper(const std::string& attr_name, const T& value) {
40 proto = MakeAttribute(attr_name, value);
41 }
42 };
43
44 struct NodeDef {
45 NodeDef(
46 std::vector<std::string> outputs,
47 std::string op_type,
48 std::vector<std::string> inputs,
49 std::vector<AttributeProtoWrapper> attributes = {},
50 std::string domain = "")
51 : outputs(std::move(outputs)),
52 op_type(std::move(op_type)),
53 inputs(std::move(inputs)),
54 attributes(std::move(attributes)),
55 domain(std::move(domain)) {}
56
57 std::vector<std::string> outputs;
58 std::string op_type;
59 std::vector<std::string> inputs;
60 std::vector<AttributeProtoWrapper> attributes;
61 std::string domain;
62 };
63
64 /*
65 BuildNodes() is an utility function for easily define a Function Body.
66
67 To build a simple node:
68 {{"Z"}, "Add", {"X", "Y"}} represents Z = Add(X,Y)
69
70 To build a node with attribute:
71 {{"Y"}, "Concat", {"X1", "X2", "X3"}, {{"axis", 1}}}
72 represents Y = Concat(X1,X2,X3) with axis = 1
73 The attribute type are infered from the attribute value's c++ type
74 Supported value types are
75 int64_t -> int, vector<int64_t> -> ints
76 float -> float, vector<float> -> floats
77 string -> string, vector<string> ->strings
78 For refering an attribute from parent, use:
79 {MakeRefAttribute("axes", AttributeProto::INTS)}}
80
81 To build a node which belongs to a domain other than onnx standard domain:
82 {{"Z"}, "Foo", {"X", "Y"}, "customdomain"} represents Z = customdomain.Foo(X,Y)
83 or
84 {{"Y"}, "Bar", {"X1", "X2", "X3"}, {{"axis", 1}}, "customdomain"}
85 represents Y = customdomain.Bar(X1,X2,X3) with axis = 1
86
87 For more examples, please find the references of this function
88 */
89 static std::vector<NodeProto> BuildNodes(const std::vector<NodeDef>& node_defs);
90
91 static void BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs);
92
93 static bool BuildFunctionProto(
94 FunctionProto& functionProto,
95 const OpSchema& schema,
96 const std::vector<NodeDef>& node_defs,
97 const std::vector<OperatorSetIdProto>& relied_opsets);
98
99 template <typename T>
100 static NodeDef Const(const std::string& name, const T& value) {
101 return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(value)}}};
102 }
103
104 template <typename T>
105 static NodeDef Const(const std::string& name, const std::vector<T>& values) {
106 return NodeDef{{name}, "Constant", {}, {{"value", ToTensor<T>(values)}}};
107 }
108};
109
110class FunctionBuilder {
111 public:
112 FunctionBuilder(FunctionProto& funProto_) : funProto(funProto_) {}
113
114 FunctionBuilder& Add(const char* nodes_txt) {
115 OnnxParser parser(nodes_txt);
116 auto& nodes = *funProto.mutable_node();
117
118 while (!parser.EndOfInput()) {
119 auto status = parser.Parse(*nodes.Add());
120 if (!status.IsOK())
121 ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
122 }
123
124 return *this;
125 }
126
127 FunctionBuilder& Add(const char* node_txt, const AttributeProto& attr) {
128 OnnxParser parser(node_txt);
129 auto& node = *funProto.add_node();
130 auto status = parser.Parse(node);
131 if (!status.IsOK()) {
132 ONNX_THROW_EX(std::logic_error("Error parsing node:" + status.ErrorMessage()));
133 }
134
135 if (!parser.EndOfInput()) {
136 ONNX_THROW_EX(std::logic_error("Error unexpected extra input in node:" + status.ErrorMessage()));
137 }
138
139 *node.add_attribute() = attr;
140
141 return *this;
142 }
143
144 template <typename T>
145 FunctionBuilder& Add(const char* node_txt, const std::string& attr_name, const T& attr_value) {
146 return Add(node_txt, MakeAttribute(attr_name, attr_value));
147 }
148
149 FunctionBuilder& Const(const std::string& name, const TensorProto& tensor) {
150 std::string constant_op(name);
151 constant_op += " = Constant()";
152 return Add(constant_op.c_str(), MakeAttribute("value", tensor));
153 }
154
155 // Creates a scalar constant (a tensor of rank zero).
156 template <typename T>
157 FunctionBuilder& Const(const std::string& name, T const_value) {
158 std::string constant_op(name);
159 constant_op += " = Constant()";
160 return Add(constant_op.c_str(), MakeAttribute("value", ToTensor(const_value)));
161 }
162
163 // Creates a 1D tensor constant consisting of a single value.
164 template <typename T>
165 FunctionBuilder& Const1D(const std::string& name, T const_value) {
166 std::string constant_op(name);
167 constant_op += " = Constant()";
168 auto tensor = ToTensor(const_value);
169 tensor.add_dims(1);
170 return Add(constant_op.c_str(), MakeAttribute("value", tensor));
171 }
172
173 // Creates a 1D tensor constant consisting of zero or more values.
174 template <typename T>
175 FunctionBuilder& Const(const std::string& name, const std::vector<T>& values) {
176 std::string constant_op(name);
177 constant_op += " = Constant()";
178 auto tensor = ToTensor(values);
179 tensor.add_dims(values.size()); // Treat as 1D tensor.
180
181 return Add(constant_op.c_str(), MakeAttribute("value", tensor));
182 }
183
184 FunctionBuilder& AddOpset(const char* domain, int version) {
185 auto* opset = funProto.add_opset_import();
186 opset->set_domain(domain);
187 opset->set_version(version);
188 return *this;
189 }
190
191 private:
192 FunctionProto& funProto;
193};
194
195} // namespace ONNX_NAMESPACE
196