1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <unordered_map>
18#include <unordered_set>
19
20#include "absl/strings/match.h"
21#include "tensorflow/c/c_api_internal.h"
22#include "tensorflow/c/tf_buffer_internal.h"
23#include "tensorflow/core/framework/attr_value_util.h"
24#include "tensorflow/core/framework/function.pb.h"
25#include "tensorflow/core/framework/graph_to_functiondef.h"
26#include "tensorflow/core/framework/node_def.pb.h"
27#include "tensorflow/core/framework/node_def_util.h"
28#include "tensorflow/core/framework/tensor.pb.h" // NOLINT
29#include "tensorflow/core/framework/types.h"
30#include "tensorflow/core/graph/graph.h"
31#include "tensorflow/core/platform/base64.h"
32#include "tensorflow/core/platform/strcat.h"
33
34using tensorflow::errors::InvalidArgument;
35
36namespace tensorflow {
37namespace {
38
39Status ValidateNonRefOutput(const Node* node, int idx) {
40 const DataType& dt = node->output_type(idx);
41 return IsRefType(dt)
42 ? InvalidArgument("Output ", idx, " of node '", node->name(),
43 "' has a reference type ", DataTypeString(dt))
44 : OkStatus();
45}
46
47// Converts `ninputs` and `inputs` into `inputs_tensors` and `input_nodes` and
48// does various checks while doing so. `input_nodes` will contain the same
49// information as input_tensors just in a different structure to make
50// following processing easier. TODO(iga): Simplify this nested structure.
51Status ProcessInputs(
52 const TF_Graph* fn_body, const char* fn_name, int ninputs,
53 const TF_Output* inputs, std::vector<OutputTensor>* input_tensors,
54 std::unordered_map<const Node*, std::vector<int>>* input_nodes)
55 TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
56 input_tensors->reserve(ninputs);
57 for (int i = 0; i < ninputs; ++i) {
58 Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
59 int idx = inputs[i].index;
60
61 TF_RETURN_WITH_CONTEXT_IF_ERROR(
62 fn_body->graph.IsValidOutputTensor(node, idx),
63 "Encountered while processing input ", i, " into function '", fn_name,
64 "'");
65 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
66 "Encountered while processing input ", i,
67 " into function '", fn_name, "'");
68
69 input_tensors->emplace_back(node, idx);
70
71 const auto& iter = input_nodes->find(node);
72 if (iter == input_nodes->end()) {
73 input_nodes->insert({node, {idx}});
74 } else {
75 auto& indices = iter->second;
76 if (std::find(indices.begin(), indices.end(), idx) != indices.end()) {
77 return InvalidArgument("TF_Output ", node->name(), ":", idx,
78 " appears more than once in the input list");
79 }
80 indices.push_back(idx);
81 }
82 }
83 return OkStatus();
84}
85
86// Converts `noutputs` and `outputs` into `outputs_tensors` and does various
87// checks while doing so.
88Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
89 int noutputs, const TF_Output* outputs,
90 std::vector<OutputTensor>* output_tensors)
91 TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
92 output_tensors->reserve(noutputs);
93 for (int i = 0; i < noutputs; ++i) {
94 Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
95 int idx = outputs[i].index;
96 TF_RETURN_WITH_CONTEXT_IF_ERROR(
97 fn_body->graph.IsValidOutputTensor(node, idx),
98 "Encountered while processing output ", i, " from function '", fn_name,
99 "'");
100 TF_RETURN_WITH_CONTEXT_IF_ERROR(ValidateNonRefOutput(node, idx),
101 "Encountered while creating function '",
102 fn_name, "'");
103 output_tensors->emplace_back(node, idx);
104 }
105 return OkStatus();
106}
107
108// Populates `body_nodes` with the nodes that will become function's body.
109// Performs various checks.
110Status ComputeBodyNodes(
111 const TF_Graph* fn_body, const char* fn_name, int num_opers,
112 const TF_Operation* const* opers,
113 const std::unordered_map<const Node*, std::vector<int>>& input_nodes,
114 std::vector<const Node*>* body_nodes)
115 TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
116 if (num_opers == -1) {
117 for (const Node* node : fn_body->graph.op_nodes()) {
118 const auto& iter = input_nodes.find(node);
119 if (iter == input_nodes.end()) {
120 // This node is not referenced in inputs. Add it to the body.
121 body_nodes->push_back(node);
122 } else {
123 // This node is referenced in inputs. Currently, we place an
124 // artificial restriction and require that when num_opers=-1, such
125 // nodes must have a single output.
126 if (node->num_outputs() != 1) {
127 return InvalidArgument(
128 "When `num_opers` is set to -1, nodes referenced in `inputs` "
129 "must have a single output. Node ",
130 node->name(), " has ", node->num_outputs(),
131 " outputs. Encountered while creating function '", fn_name, "'");
132 }
133 }
134 }
135 } else {
136 body_nodes->reserve(num_opers);
137 for (int i = 0; i < num_opers; ++i) {
138 const Node* node = &opers[i]->node;
139 body_nodes->push_back(node);
140 }
141 }
142 return OkStatus();
143}
144
145} // namespace
146} // namespace tensorflow
147
148using tensorflow::Node;
149using tensorflow::string;
150
151TF_Function* TF_GraphToFunctionWithControlOutputs(
152 const TF_Graph* fn_body, const char* fn_name,
153 unsigned char append_hash_to_fn_name, int num_opers,
154 const TF_Operation* const* opers, int ninputs, const TF_Output* inputs,
155 int noutputs, const TF_Output* outputs, const char* const* output_names,
156 int ncontrol_outputs, const TF_Operation* const* control_outputs,
157 const char* const* control_output_names, const TF_FunctionOptions* opts,
158 const char* description, TF_Status* status) {
159 tensorflow::mutex_lock l(fn_body->mu);
160
161 // Process inputs.
162 std::vector<tensorflow::OutputTensor> input_tensors;
163 std::unordered_map<const Node*, std::vector<int>> input_nodes;
164 status->status = tensorflow::ProcessInputs(fn_body, fn_name, ninputs, inputs,
165 &input_tensors, &input_nodes);
166 if (TF_GetCode(status) != TF_OK) return nullptr;
167
168 // Process outputs.
169 std::vector<tensorflow::OutputTensor> output_tensors;
170 status->status = tensorflow::ProcessOutputs(fn_body, fn_name, noutputs,
171 outputs, &output_tensors);
172 if (TF_GetCode(status) != TF_OK) return nullptr;
173
174 // Process output names.
175 std::vector<string> output_names_vec;
176 if (output_names) {
177 output_names_vec.reserve(noutputs);
178 for (int i = 0; i < noutputs; ++i) {
179 output_names_vec.push_back(string(output_names[i]));
180 }
181 }
182
183 // Process control output names.
184 std::vector<string> control_output_names_vec;
185 if (control_output_names) {
186 control_output_names_vec.reserve(ncontrol_outputs);
187 for (int i = 0; i < ncontrol_outputs; ++i) {
188 control_output_names_vec.push_back(string(output_names[i]));
189 }
190 }
191
192 // Compute body nodes.
193 std::vector<const Node*> body_nodes;
194 status->status = tensorflow::ComputeBodyNodes(
195 fn_body, fn_name, num_opers, opers, input_nodes, &body_nodes);
196 if (TF_GetCode(status) != TF_OK) return nullptr;
197
198 // Compute body nodes.
199 std::vector<const Node*> control_output_nodes;
200 control_output_nodes.reserve(ncontrol_outputs);
201 for (int i = 0; i < ncontrol_outputs; ++i) {
202 control_output_nodes.push_back(&control_outputs[i]->node);
203 }
204
205 // Do the actual function creation.
206 TF_Function* tf_function = new TF_Function();
207 DCHECK(append_hash_to_fn_name <= 1);
208 status->status = tensorflow::GraphToFunctionDef(
209 fn_body->graph, fn_name, append_hash_to_fn_name != 0,
210 /*set_stateful_from_nodes=*/true,
211 /*copy_placeholder_attrs_from_nodes=*/true, body_nodes, input_tensors,
212 output_tensors, output_names_vec, control_output_nodes,
213 control_output_names_vec, description, &tf_function->fdef);
214 if (TF_GetCode(status) != TF_OK) {
215 TF_DeleteFunction(tf_function);
216 return nullptr;
217 }
218
219 for (const Node* n : fn_body->graph.nodes()) {
220 tf_function->stack_traces[n->name()] = n->GetStackTrace();
221 }
222
223 return tf_function;
224}
225
226TF_Function* TF_GraphToFunction(const TF_Graph* fn_body, const char* fn_name,
227 unsigned char append_hash_to_fn_name,
228 int num_opers, const TF_Operation* const* opers,
229 int ninputs, const TF_Output* inputs,
230 int noutputs, const TF_Output* outputs,
231 const char* const* output_names,
232 const TF_FunctionOptions* opts,
233 const char* description, TF_Status* status) {
234 return TF_GraphToFunctionWithControlOutputs(
235 fn_body, fn_name, append_hash_to_fn_name, num_opers, opers, ninputs,
236 inputs, noutputs, outputs, output_names, 0, nullptr, nullptr, opts,
237 description, status);
238}
239
240const char* TF_FunctionName(TF_Function* func) {
241 return func->fdef.signature().name().c_str();
242}
243
244void TF_GraphCopyFunction(TF_Graph* g, const TF_Function* func,
245 const TF_Function* grad, TF_Status* status) {
246 if (func == nullptr) {
247 status->status = InvalidArgument(
248 "'func' argument to TF_GraphCopyFunction cannot be null");
249 return;
250 }
251
252 // TODO(iga): Add AddFunctionDef() and AddGradientDef() methods to graph
253 // to avoid the extra copy here.
254 tensorflow::FunctionDefLibrary fdef_lib;
255 *fdef_lib.add_function() = func->fdef;
256 if (grad) {
257 *fdef_lib.add_function() = grad->fdef;
258 tensorflow::GradientDef* gdef = fdef_lib.add_gradient();
259 gdef->set_function_name(func->fdef.signature().name());
260 gdef->set_gradient_func(grad->fdef.signature().name());
261 }
262
263 tensorflow::mutex_lock l(g->mu);
264 status->status = g->graph.AddFunctionLibrary(fdef_lib);
265}
266
267int TF_GraphNumFunctions(TF_Graph* g) {
268 tensorflow::mutex_lock l(g->mu);
269 return g->graph.flib_def().num_functions();
270}
271
272int TF_GraphGetFunctions(TF_Graph* g, TF_Function** funcs, int max_func,
273 TF_Status* status) {
274 tensorflow::FunctionDefLibrary lib;
275 {
276 tensorflow::mutex_lock l(g->mu);
277 lib = g->graph.flib_def().ToProto();
278 }
279 const auto len = std::min(max_func, static_cast<int>(lib.function_size()));
280 for (int i = 0; i < len; ++i) {
281 TF_Function* func = new TF_Function();
282 func->fdef = lib.function(i);
283 funcs[i] = func;
284 }
285 status->status = ::tensorflow::OkStatus();
286 return len;
287}
288
289void TF_FunctionToFunctionDef(TF_Function* func, TF_Buffer* output_func_def,
290 TF_Status* status) {
291 status->status = MessageToBuffer(func->fdef, output_func_def);
292}
293
294TF_Function* TF_FunctionImportFunctionDef(const void* proto, size_t proto_len,
295 TF_Status* status) {
296 TF_Function* func = new TF_Function();
297 if (!func->fdef.ParseFromArray(proto, proto_len)) {
298 status->status = InvalidArgument(
299 "Invalid FunctionDef given to TF_FunctionImportFunctionDef");
300 TF_DeleteFunction(func);
301 return nullptr;
302 }
303 status->status = ::tensorflow::OkStatus();
304 return func;
305}
306
307void TF_FunctionSetAttrValueProto(TF_Function* func, const char* attr_name,
308 const void* proto, size_t proto_len,
309 TF_Status* status) {
310 tensorflow::AttrValue attr_value;
311 if (!attr_value.ParseFromArray(proto, proto_len)) {
312 status->status = InvalidArgument(
313 "Unparseable AttrValue proto passed to "
314 "TF_FunctionSetAttrValueProto");
315 return;
316 }
317 (*func->fdef.mutable_attr())[string(attr_name)] = attr_value;
318 status->status = ::tensorflow::OkStatus();
319}
320
321void TF_FunctionGetAttrValueProto(TF_Function* func, const char* attr_name,
322 TF_Buffer* output_attr_value,
323 TF_Status* status) {
324 const auto& it = func->fdef.attr().find(attr_name);
325 if (it == func->fdef.attr().end()) {
326 status->status =
327 InvalidArgument("Function '", func->fdef.signature().name(),
328 "' has no attr named '", attr_name, "'.");
329 return;
330 }
331 status->status = MessageToBuffer(it->second, output_attr_value);
332}
333
334void TF_DeleteFunction(TF_Function* func) { delete func; }
335