1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <unordered_map> |
18 | #include <unordered_set> |
19 | |
20 | #include "absl/strings/match.h" |
21 | #include "tensorflow/c/c_api_internal.h" |
22 | #include "tensorflow/c/tf_buffer_internal.h" |
23 | #include "tensorflow/core/framework/attr_value_util.h" |
24 | #include "tensorflow/core/framework/function.pb.h" |
25 | #include "tensorflow/core/framework/graph_to_functiondef.h" |
26 | #include "tensorflow/core/framework/node_def.pb.h" |
27 | #include "tensorflow/core/framework/node_def_util.h" |
28 | #include "tensorflow/core/framework/tensor.pb.h" // NOLINT |
29 | #include "tensorflow/core/framework/types.h" |
30 | #include "tensorflow/core/graph/graph.h" |
31 | #include "tensorflow/core/platform/base64.h" |
32 | #include "tensorflow/core/platform/strcat.h" |
33 | |
34 | using tensorflow::errors::InvalidArgument; |
35 | |
36 | namespace tensorflow { |
37 | namespace { |
38 | |
39 | Status ValidateNonRefOutput(const Node* node, int idx) { |
40 | const DataType& dt = node->output_type(idx); |
41 | return IsRefType(dt) |
42 | ? InvalidArgument("Output " , idx, " of node '" , node->name(), |
43 | "' has a reference type " , DataTypeString(dt)) |
44 | : OkStatus(); |
45 | } |
46 | |
47 | // Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and |
48 | // does various checks while doing so. `input_nodes` will contain the same |
49 | // information as input_tensors just in a different structure to make |
50 | // following processing easier. TODO(iga): Simplify this nested structure. |
51 | Status ProcessInputs( |
52 | const TF_Graph* fn_body, const char* fn_name, int ninputs, |
53 | const TF_Output* inputs, std::vector<OutputTensor>* input_tensors, |
54 | std::unordered_map<const Node*, std::vector<int>>* input_nodes) |
55 | TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { |
56 | input_tensors->reserve(ninputs); |
57 | for (int i = 0; i < ninputs; ++i) { |
58 | Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr; |
59 | int idx = inputs[i].index; |
60 | |
61 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
62 | fn_body->graph.IsValidOutputTensor(node, idx), |
63 | "Encountered while processing input " , i, " into function '" , fn_name, |
64 | "'" ); |
65 | TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), |
66 | "Encountered while processing input " , i, |
67 | " into function '" , fn_name, "'" ); |
68 | |
69 | input_tensors->emplace_back(node, idx); |
70 | |
71 | const auto& iter = input_nodes->find(node); |
72 | if (iter == input_nodes->end()) { |
73 | input_nodes->insert({node, {idx}}); |
74 | } else { |
75 | auto& indices = iter->second; |
76 | if (std::find(indices.begin(), indices.end(), idx) != indices.end()) { |
77 | return InvalidArgument("TF_Output " , node->name(), ":" , idx, |
78 | " appears more than once in the input list" ); |
79 | } |
80 | indices.push_back(idx); |
81 | } |
82 | } |
83 | return OkStatus(); |
84 | } |
85 | |
86 | // Converts `noutputs` and `outputs` into `outputs_tensors` and does various |
87 | // checks while doing so. |
88 | Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name, |
89 | int noutputs, const TF_Output* outputs, |
90 | std::vector<OutputTensor>* output_tensors) |
91 | TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { |
92 | output_tensors->reserve(noutputs); |
93 | for (int i = 0; i < noutputs; ++i) { |
94 | Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr; |
95 | int idx = outputs[i].index; |
96 | TF_RETURN_WITH_CONTEXT_IF_ERROR( |
97 | fn_body->graph.IsValidOutputTensor(node, idx), |
98 | "Encountered while processing output " , i, " from function '" , fn_name, |
99 | "'" ); |
100 | TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx), |
101 | "Encountered while creating function '" , |
102 | fn_name, "'" ); |
103 | output_tensors->emplace_back(node, idx); |
104 | } |
105 | return OkStatus(); |
106 | } |
107 | |
108 | // Populates `body_nodes` with the nodes that will become function's body. |
109 | // Performs various checks. |
110 | Status ComputeBodyNodes( |
111 | const TF_Graph* fn_body, const char* fn_name, int num_opers, |
112 | const TF_Operation* const* opers, |
113 | const std::unordered_map<const Node*, std::vector<int>>& input_nodes, |
114 | std::vector<const Node*>* body_nodes) |
115 | TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { |
116 | if (num_opers == -1) { |
117 | for (const Node* node : fn_body->graph.op_nodes()) { |
118 | const auto& iter = input_nodes.find(node); |
119 | if (iter == input_nodes.end()) { |
120 | // This node is not referenced in inputs. Add it to the body. |
121 | body_nodes->push_back(node); |
122 | } else { |
123 | // This node is referenced in inputs. Currently, we place an |
124 | // artificial restriction and require that when num_opers=-1, such |
125 | // nodes must have a single output. |
126 | if (node->num_outputs() != 1) { |
127 | return InvalidArgument( |
128 | "When `num_opers` is set to -1, nodes referenced in `inputs` " |
129 | "must have a single output. Node " , |
130 | node->name(), " has " , node->num_outputs(), |
131 | " outputs. Encountered while creating function '" , fn_name, "'" ); |
132 | } |
133 | } |
134 | } |
135 | } else { |
136 | body_nodes->reserve(num_opers); |
137 | for (int i = 0; i < num_opers; ++i) { |
138 | const Node* node = &opers[i]->node; |
139 | body_nodes->push_back(node); |
140 | } |
141 | } |
142 | return OkStatus(); |
143 | } |
144 | |
145 | } // namespace |
146 | } // namespace tensorflow |
147 | |
148 | using tensorflow::Node; |
149 | using tensorflow::string; |
150 | |
151 | TF_Function* TF_GraphToFunctionWithControlOutputs( |
152 | const TF_Graph* fn_body, const char* fn_name, |
153 | unsigned char append_hash_to_fn_name, int num_opers, |
154 | const TF_Operation* const* opers, int ninputs, const TF_Output* inputs, |
155 | int noutputs, const TF_Output* outputs, const char* const* output_names, |
156 | int ncontrol_outputs, const TF_Operation* const* control_outputs, |
157 | const char* const* control_output_names, const TF_FunctionOptions* opts, |
158 | const char* description, TF_Status* status) { |
159 | tensorflow::mutex_lock l(fn_body->mu); |
160 | |
161 | // Process inputs. |
162 | std::vector<tensorflow::OutputTensor> input_tensors; |
163 | std::unordered_map<const Node*, std::vector<int>> input_nodes; |
164 | status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs, |
165 | &input_tensors, &input_nodes); |
166 | if (TF_GetCode(status) != TF_OK) return nullptr; |
167 | |
168 | // Process outputs. |
169 | std::vector<tensorflow::OutputTensor> output_tensors; |
170 | status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs, |
171 | outputs, &output_tensors); |
172 | if (TF_GetCode(status) != TF_OK) return nullptr; |
173 | |
174 | // Process output names. |
175 | std::vector<string> output_names_vec; |
176 | if (output_names) { |
177 | output_names_vec.reserve(noutputs); |
178 | for (int i = 0; i < noutputs; ++i) { |
179 | output_names_vec.push_back(string(output_names[i])); |
180 | } |
181 | } |
182 | |
183 | // Process control output names. |
184 | std::vector<string> control_output_names_vec; |
185 | if (control_output_names) { |
186 | control_output_names_vec.reserve(ncontrol_outputs); |
187 | for (int i = 0; i < ncontrol_outputs; ++i) { |
188 | control_output_names_vec.push_back(string(output_names[i])); |
189 | } |
190 | } |
191 | |
192 | // Compute body nodes. |
193 | std::vector<const Node*> body_nodes; |
194 | status->status = tensorflow::ComputeBodyNodes( |
195 | fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes); |
196 | if (TF_GetCode(status) != TF_OK) return nullptr; |
197 | |
198 | // Compute body nodes. |
199 | std::vector<const Node*> control_output_nodes; |
200 | control_output_nodes.reserve(ncontrol_outputs); |
201 | for (int i = 0; i < ncontrol_outputs; ++i) { |
202 | control_output_nodes.push_back(&control_outputs[i]->node); |
203 | } |
204 | |
205 | // Do the actual function creation. |
206 | TF_Function* tf_function = new TF_Function(); |
207 | DCHECK(append_hash_to_fn_name <= 1); |
208 | status->status = tensorflow::GraphToFunctionDef( |
209 | fn_body->graph, fn_name, append_hash_to_fn_name != 0, |
210 | /*set_stateful_from_nodes=*/true, |
211 | /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors, |
212 | output_tensors, output_names_vec, control_output_nodes, |
213 | control_output_names_vec, description, &tf_function->fdef); |
214 | if (TF_GetCode(status) != TF_OK) { |
215 | TF_DeleteFunction(tf_function); |
216 | return nullptr; |
217 | } |
218 | |
219 | for (const Node* n : fn_body->graph.nodes()) { |
220 | tf_function->stack_traces[n->name()] = n->GetStackTrace(); |
221 | } |
222 | |
223 | return tf_function; |
224 | } |
225 | |
226 | TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name, |
227 | unsigned char append_hash_to_fn_name, |
228 | int num_opers, const TF_Operation* const* opers, |
229 | int ninputs, const TF_Output* inputs, |
230 | int noutputs, const TF_Output* outputs, |
231 | const char* const* output_names, |
232 | const TF_FunctionOptions* opts, |
233 | const char* description, TF_Status* status) { |
234 | return TF_GraphToFunctionWithControlOutputs( |
235 | fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs, |
236 | inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts, |
237 | description, status); |
238 | } |
239 | |
240 | const char* TF_FunctionName(TF_Function* func) { |
241 | return func->fdef.signature().name().c_str(); |
242 | } |
243 | |
244 | void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func, |
245 | const TF_Function* grad, TF_Status* status) { |
246 | if (func == nullptr) { |
247 | status->status = InvalidArgument( |
248 | "'func' argument to TF_GraphCopyFunction cannot be null" ); |
249 | return; |
250 | } |
251 | |
252 | // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph |
253 | // to avoid the extra copy here. |
254 | tensorflow::FunctionDefLibrary fdef_lib; |
255 | *fdef_lib.add_function() = func->fdef; |
256 | if (grad) { |
257 | *fdef_lib.add_function() = grad->fdef; |
258 | tensorflow::GradientDef* gdef = fdef_lib.add_gradient(); |
259 | gdef->set_function_name(func->fdef.signature().name()); |
260 | gdef->set_gradient_func(grad->fdef.signature().name()); |
261 | } |
262 | |
263 | tensorflow::mutex_lock l(g->mu); |
264 | status->status = g->graph.AddFunctionLibrary(fdef_lib); |
265 | } |
266 | |
267 | int TF_GraphNumFunctions(TF_Graph* g) { |
268 | tensorflow::mutex_lock l(g->mu); |
269 | return g->graph.flib_def().num_functions(); |
270 | } |
271 | |
272 | int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func, |
273 | TF_Status* status) { |
274 | tensorflow::FunctionDefLibrary lib; |
275 | { |
276 | tensorflow::mutex_lock l(g->mu); |
277 | lib = g->graph.flib_def().ToProto(); |
278 | } |
279 | const auto len = std::min(max_func, static_cast<int>(lib.function_size())); |
280 | for (int i = 0; i < len; ++i) { |
281 | TF_Function* func = new TF_Function(); |
282 | func->fdef = lib.function(i); |
283 | funcs[i] = func; |
284 | } |
285 | status->status = ::tensorflow::OkStatus(); |
286 | return len; |
287 | } |
288 | |
289 | void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def, |
290 | TF_Status* status) { |
291 | status->status = MessageToBuffer(func->fdef, output_func_def); |
292 | } |
293 | |
294 | TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len, |
295 | TF_Status* status) { |
296 | TF_Function* func = new TF_Function(); |
297 | if (!func->fdef.ParseFromArray(proto, proto_len)) { |
298 | status->status = InvalidArgument( |
299 | "Invalid FunctionDef given to TF_FunctionImportFunctionDef" ); |
300 | TF_DeleteFunction(func); |
301 | return nullptr; |
302 | } |
303 | status->status = ::tensorflow::OkStatus(); |
304 | return func; |
305 | } |
306 | |
307 | void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name, |
308 | const void* proto, size_t proto_len, |
309 | TF_Status* status) { |
310 | tensorflow::AttrValue attr_value; |
311 | if (!attr_value.ParseFromArray(proto, proto_len)) { |
312 | status->status = InvalidArgument( |
313 | "Unparseable AttrValue proto passed to " |
314 | "TF_FunctionSetAttrValueProto" ); |
315 | return; |
316 | } |
317 | (*func->fdef.mutable_attr())[string(attr_name)] = attr_value; |
318 | status->status = ::tensorflow::OkStatus(); |
319 | } |
320 | |
321 | void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name, |
322 | TF_Buffer* output_attr_value, |
323 | TF_Status* status) { |
324 | const auto& it = func->fdef.attr().find(attr_name); |
325 | if (it == func->fdef.attr().end()) { |
326 | status->status = |
327 | InvalidArgument("Function '" , func->fdef.signature().name(), |
328 | "' has no attr named '" , attr_name, "'." ); |
329 | return; |
330 | } |
331 | status->status = MessageToBuffer(it->second, output_attr_value); |
332 | } |
333 | |
334 | void TF_DeleteFunction(TF_Function* func) { delete func; } |
335 | |