1 | /* |
2 | * SPDX-License-Identifier: Apache-2.0 |
3 | */ |
4 | |
5 | #include "onnx/defs/function.h" |
6 | #include "onnx/defs/schema.h" |
7 | #include "onnx/string_utils.h" |
8 | |
9 | namespace ONNX_NAMESPACE { |
10 | std::string InteralTensorNameGenerator(const std::string& node_name, const std::string& internal_name) { |
11 | std::string new_name = "Func_" + node_name + internal_name; |
12 | return new_name; |
13 | } |
14 | |
15 | void FunctionExpandHelper( |
16 | const NodeProto& node, |
17 | const FunctionProto& func, |
18 | GraphProto& g, |
19 | const std::string& node_prefix) { |
20 | // Create a temporary unique node prefix for tensor names |
21 | std::string uniq_prefix = node_prefix; |
22 | if (uniq_prefix.empty()) { |
23 | const void* address = static_cast<const void*>(&node); |
24 | std::stringstream ss; |
25 | ss << address; |
26 | uniq_prefix = ss.str(); |
27 | } |
28 | std::string node_name = node.has_name() ? node.name() : func.name() + uniq_prefix; |
29 | std::unordered_map<std::string, std::string> io_names_map; |
30 | std::unordered_map<std::string, AttributeProto> attr_map; |
31 | |
32 | for (int idx = 0; idx < node.input_size(); ++idx) { |
33 | if (idx >= func.input_size()) { |
34 | ONNX_THROW("Input for function node " + node_name + " is out of bounds" ); |
35 | } |
36 | io_names_map[func.input().Get(idx)] = node.input().Get(idx); |
37 | } |
38 | for (int idx = 0; idx < node.output_size(); ++idx) { |
39 | if (idx >= func.output_size()) { |
40 | ONNX_THROW("Output for function node " + node_name + " is out of bounds" ); |
41 | } |
42 | // If the node output is missing, the corresponding function output should |
43 | // be treated as an internal value (not as missing) because it could also be |
44 | // an intermediate value. |
45 | if (node.output().Get(idx) == "" ) { |
46 | continue; |
47 | } |
48 | io_names_map[func.output().Get(idx)] = node.output().Get(idx); |
49 | } |
50 | |
51 | for (auto& attr : node.attribute()) { |
52 | attr_map[attr.name()] = attr; |
53 | } |
54 | |
55 | // For undefined attributes of the function node |
56 | // add default values obtained from the function schema. |
57 | // get the domain version for function schema |
58 | int domain_version = -1; |
59 | for (const auto& opset_import : func.opset_import()) { |
60 | if (opset_import.domain() == node.domain()) { |
61 | domain_version = static_cast<int>(opset_import.version()); |
62 | } |
63 | } |
64 | if (domain_version == -1) { |
65 | ONNX_THROW("No opset import registered for domain '" + node.domain() + "' in function proto" ); |
66 | } |
67 | |
68 | const OpSchemaRegistry* schema_registry = OpSchemaRegistry::Instance(); |
69 | const auto schema = schema_registry->GetSchema(node.op_type(), domain_version, node.domain()); |
70 | std::map<std::string, OpSchema::Attribute> default_attrs = schema->attributes(); |
71 | |
72 | for (const auto& pair : default_attrs) { |
73 | const auto& attr_name = pair.first; |
74 | const auto& attr = pair.second; |
75 | if (!attr_map.count(attr_name)) { |
76 | attr_map[attr_name] = attr.default_value; |
77 | } |
78 | } |
79 | |
80 | for (auto& function_node : func.node()) { |
81 | NodeProto* new_node = g.add_node(); |
82 | new_node->CopyFrom(function_node); |
83 | new_node->clear_input(); |
84 | new_node->clear_output(); |
85 | new_node->clear_attribute(); |
86 | for (auto& input : function_node.input()) { |
87 | if (io_names_map.count(input)) { |
88 | new_node->add_input(io_names_map[input]); |
89 | } else { |
90 | new_node->add_input(InteralTensorNameGenerator(node_name, input)); |
91 | } |
92 | } |
93 | for (auto& output : function_node.output()) { |
94 | if (io_names_map.count(output)) { |
95 | new_node->add_output(io_names_map[output]); |
96 | } else { |
97 | new_node->add_output(InteralTensorNameGenerator(node_name, output)); |
98 | } |
99 | } |
100 | for (auto& attr : function_node.attribute()) { |
101 | if (attr.has_ref_attr_name()) { |
102 | if (attr_map.count(attr.ref_attr_name())) { |
103 | AttributeProto* new_attr = new_node->add_attribute(); |
104 | new_attr->CopyFrom(attr_map[attr.ref_attr_name()]); |
105 | new_attr->set_name(attr.name()); |
106 | } |
107 | } else { |
108 | AttributeProto* new_attr = new_node->add_attribute(); |
109 | new_attr->CopyFrom(attr); |
110 | } |
111 | } |
112 | } |
113 | } |
114 | |
115 | std::vector<NodeProto> FunctionBodyHelper::BuildNodes(const std::vector<NodeDef>& node_defs) { |
116 | std::vector<NodeProto> nodes(node_defs.size()); |
117 | |
118 | for (size_t i = 0; i < node_defs.size(); i++) { |
119 | const NodeDef& node = node_defs[i]; |
120 | NodeProto& n = nodes[i]; |
121 | |
122 | n.set_op_type(node.op_type); |
123 | n.set_domain(node.domain); |
124 | for (const auto& i : node.inputs) { |
125 | n.add_input(i); |
126 | } |
127 | for (const auto& o : node.outputs) { |
128 | n.add_output(o); |
129 | } |
130 | for (const auto& attr : node.attributes) { |
131 | *(n.add_attribute()) = attr.proto; |
132 | } |
133 | } |
134 | |
135 | return nodes; |
136 | } |
137 | |
138 | void FunctionBodyHelper::BuildNodes(FunctionProto& functionProto, const std::vector<NodeDef>& node_defs) { |
139 | for (size_t i = 0; i < node_defs.size(); i++) { |
140 | const NodeDef& node = node_defs[i]; |
141 | auto* np = functionProto.add_node(); |
142 | |
143 | np->set_op_type(node.op_type); |
144 | np->set_domain(node.domain); |
145 | for (const auto& inp : node.inputs) { |
146 | np->add_input(inp); |
147 | } |
148 | for (const auto& o : node.outputs) { |
149 | np->add_output(o); |
150 | } |
151 | for (const auto& attr : node.attributes) { |
152 | *(np->add_attribute()) = attr.proto; |
153 | } |
154 | } |
155 | } |
156 | |
157 | bool FunctionBodyHelper::BuildFunctionProto( |
158 | FunctionProto& functionProto, |
159 | const OpSchema& schema, |
160 | const std::vector<NodeDef>& node_defs, |
161 | const std::vector<OperatorSetIdProto>& relied_opsets) { |
162 | BuildNodes(functionProto, node_defs); |
163 | |
164 | for (auto& relied_opset : relied_opsets) { |
165 | *(functionProto.mutable_opset_import()->Add()) = relied_opset; |
166 | } |
167 | |
168 | schema.BuildFunction(functionProto); |
169 | return true; |
170 | } |
171 | |
172 | } // namespace ONNX_NAMESPACE |
173 | |