1/*
2 * SPDX-License-Identifier: Apache-2.0
3 */
4
5#include "onnx/defs/function.h"
6#include "onnx/defs/schema.h"
7#include "onnx/string_utils.h"
8
9namespace ONNX_NAMESPACE {
10std::string InteralTensorNameGenerator(const std::string& node_name, const std::string& internal_name) {
11 std::string new_name = "Func_" + node_name + internal_name;
12 return new_name;
13}
14
15void FunctionExpandHelper(
16 const NodeProto& node,
17 const FunctionProto& func,
18 GraphProto& g,
19 const std::string& node_prefix) {
20 // Create a temporary unique node prefix for tensor names
21 std::string uniq_prefix = node_prefix;
22 if (uniq_prefix.empty()) {
23 const void* address = static_cast<const void*>(&node);
24 std::stringstream ss;
25 ss << address;
26 uniq_prefix = ss.str();
27 }
28 std::string node_name = node.has_name() ? node.name() : func.name() + uniq_prefix;
29 std::unordered_map<std::string, std::string> io_names_map;
30 std::unordered_map<std::string, AttributeProto> attr_map;
31
32 for (int idx = 0; idx < node.input_size(); ++idx) {
33 if (idx >= func.input_size()) {
34 ONNX_THROW("Input for function node " + node_name + " is out of bounds");
35 }
36 io_names_map[func.input().Get(idx)] = node.input().Get(idx);
37 }
38 for (int idx = 0; idx < node.output_size(); ++idx) {
39 if (idx >= func.output_size()) {
40 ONNX_THROW("Output for function node " + node_name + " is out of bounds");
41 }
42 // If the node output is missing, the corresponding function output should
43 // be treated as an internal value (not as missing) because it could also be
44 // an intermediate value.
45 if (node.output().Get(idx) == "") {
46 continue;
47 }
48 io_names_map[func.output().Get(idx)] = node.output().Get(idx);
49 }
50
51 for (auto& attr : node.attribute()) {
52 attr_map[attr.name()] = attr;
53 }
54
55 // For undefined attributes of the function node
56 // add default values obtained from the function schema.
57 // get the domain version for function schema
58 int domain_version = -1;
59 for (const auto& opset_import : func.opset_import()) {
60 if (opset_import.domain() == node.domain()) {
61 domain_version = static_cast<int>(opset_import.version());
62 }
63 }
64 if (domain_version == -1) {
65 ONNX_THROW("No opset import registered for domain '" + node.domain() + "' in function proto");
66 }
67
68 const OpSchemaRegistry* schema_registry = OpSchemaRegistry::Instance();
69 const auto schema = schema_registry->GetSchema(node.op_type(), domain_version, node.domain());
70 std::map<std::string, OpSchema::Attribute> default_attrs = schema->attributes();
71
72 for (const auto& pair : default_attrs) {
73 const auto& attr_name = pair.first;
74 const auto& attr = pair.second;
75 if (!attr_map.count(attr_name)) {
76 attr_map[attr_name] = attr.default_value;
77 }
78 }
79
80 for (auto& function_node : func.node()) {
81 NodeProto* new_node = g.add_node();
82 new_node->CopyFrom(function_node);
83 new_node->clear_input();
84 new_node->clear_output();
85 new_node->clear_attribute();
86 for (auto& input : function_node.input()) {
87 if (io_names_map.count(input)) {
88 new_node->add_input(io_names_map[input]);
89 } else {
90 new_node->add_input(InteralTensorNameGenerator(node_name, input));
91 }
92 }
93 for (auto& output : function_node.output()) {
94 if (io_names_map.count(output)) {
95 new_node->add_output(io_names_map[output]);
96 } else {
97 new_node->add_output(InteralTensorNameGenerator(node_name, output));
98 }
99 }
100 for (auto& attr : function_node.attribute()) {
101 if (attr.has_ref_attr_name()) {
102 if (attr_map.count(attr.ref_attr_name())) {
103 AttributeProto* new_attr = new_node->add_attribute();
104 new_attr->CopyFrom(attr_map[attr.ref_attr_name()]);
105 new_attr->set_name(attr.name());
106 }
107 } else {
108 AttributeProto* new_attr = new_node->add_attribute();
109 new_attr->CopyFrom(attr);
110 }
111 }
112 }
113}
114
115std::vector<NodeProto> FunctionBodyHelper::BuildNodes(const std::vector<NodeDef>& node_defs) {
116 std::vector<NodeProto> nodes(node_defs.size());
117
118 for (size_t i = 0; i < node_defs.size(); i++) {
119 const NodeDef& node = node_defs[i];
120 NodeProto& n = nodes[i];
121
122 n.set_op_type(node.op_type);
123 n.set_domain(node.domain);
124 for (const auto& i : node.inputs) {
125 n.add_input(i);
126 }
127 for (const auto& o : node.outputs) {
128 n.add_output(o);
129 }
130 for (const auto& attr : node.attributes) {
131 *(n.add_attribute()) = attr.proto;
132 }
133 }
134
135 return nodes;
136}
137
138void FunctionBodyHelper::BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs) {
139 for (size_t i = 0; i < node_defs.size(); i++) {
140 const NodeDef& node = node_defs[i];
141 auto* np = functionProto.add_node();
142
143 np->set_op_type(node.op_type);
144 np->set_domain(node.domain);
145 for (const auto& inp : node.inputs) {
146 np->add_input(inp);
147 }
148 for (const auto& o : node.outputs) {
149 np->add_output(o);
150 }
151 for (const auto& attr : node.attributes) {
152 *(np->add_attribute()) = attr.proto;
153 }
154 }
155}
156
157bool FunctionBodyHelper::BuildFunctionProto(
158 FunctionProto& functionProto,
159 const OpSchema& schema,
160 const std::vector<NodeDef>& node_defs,
161 const std::vector<OperatorSetIdProto>& relied_opsets) {
162 BuildNodes(functionProto, node_defs);
163
164 for (auto& relied_opset : relied_opsets) {
165 *(functionProto.mutable_opset_import()->Add()) = relied_opset;
166 }
167
168 schema.BuildFunction(functionProto);
169 return true;
170}
171
172} // namespace ONNX_NAMESPACE
173