1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/function.h" |
17 | |
18 | #include <ctype.h> |
19 | |
20 | #include <map> |
21 | #include <unordered_map> |
22 | #include <utility> |
23 | #include <vector> |
24 | |
25 | #include "absl/container/flat_hash_set.h" |
26 | #include "absl/strings/escaping.h" |
27 | #include "absl/strings/str_cat.h" |
28 | #include "absl/strings/str_join.h" |
29 | #include "tensorflow/core/framework/allocator.h" |
30 | #include "tensorflow/core/framework/common_shape_fns.h" |
31 | #include "tensorflow/core/framework/function.pb.h" |
32 | #include "tensorflow/core/framework/graph.pb.h" |
33 | #include "tensorflow/core/framework/node_def.pb.h" |
34 | #include "tensorflow/core/framework/node_def_util.h" |
35 | #include "tensorflow/core/framework/op.h" |
36 | #include "tensorflow/core/graph/graph.h" |
37 | #include "tensorflow/core/lib/core/errors.h" |
38 | #include "tensorflow/core/lib/gtl/inlined_vector.h" |
39 | #include "tensorflow/core/lib/gtl/map_util.h" |
40 | #include "tensorflow/core/lib/strings/proto_serialization.h" |
41 | #include "tensorflow/core/platform/fingerprint.h" |
42 | #include "tensorflow/core/util/device_name_utils.h" |
43 | #include "tensorflow/core/util/equal_graph_def.h" |
44 | |
45 | namespace tensorflow { |
46 | |
47 | /* static */ constexpr const char* const FunctionLibraryDefinition::kArgOp; |
48 | /* static */ constexpr const char* const |
49 | FunctionLibraryDefinition::kDeviceArgOp; |
50 | /* static */ constexpr const char* const FunctionLibraryDefinition::kRetOp; |
51 | /* static */ constexpr const char* const |
52 | FunctionLibraryDefinition::kDeviceRetOp; |
53 | /* static */ constexpr const char* const |
54 | FunctionLibraryDefinition::kIntsOnDeviceAttr; |
55 | /* static */ constexpr const char* const FunctionLibraryDefinition::kGradientOp; |
56 | /* static */ constexpr const char* const FunctionLibraryDefinition::kFuncAttr; |
57 | |
58 | // Extracts the actual type from "attr_values" based on its definition |
59 | // "arg_def". |
60 | // |
61 | // If "arg_def" is a N*T type, *is_type_list is set to false, and |
62 | // *dtypes is set to be a vector of size N and each element is T. |
63 | // |
64 | // If "arg_def" is a list(type), *is_type_list is set to true, and |
65 | // *dtypes is set to be a vector of types specified in attrs for |
66 | // arg_def. |
67 | // |
68 | // Otherwise (arg_def is a simple type T), *is_type_list is set to |
69 | // false, and *dtypes is set to a single element vector, whose only |
70 | // element is T. |
71 | Status ArgNumType(AttrSlice attrs, const OpDef::ArgDef& arg_def, |
72 | bool* is_type_list, DataTypeVector* dtypes) { |
73 | dtypes->clear(); |
74 | if (!arg_def.type_list_attr().empty()) { |
75 | const AttrValue* v = attrs.FindByString(arg_def.type_list_attr()); |
76 | if (v == nullptr) { |
77 | return errors::NotFound("type list attr not found: " , |
78 | arg_def.type_list_attr()); |
79 | } |
80 | *is_type_list = true; |
81 | for (int i = 0; i < v->list().type_size(); ++i) { |
82 | dtypes->push_back(v->list().type(i)); |
83 | } |
84 | return OkStatus(); |
85 | } |
86 | |
87 | *is_type_list = false; |
88 | int num = 1; |
89 | if (!arg_def.number_attr().empty()) { |
90 | const AttrValue* v = attrs.FindByString(arg_def.number_attr()); |
91 | if (v == nullptr) { |
92 | return errors::NotFound("number attr not found: " , arg_def.number_attr()); |
93 | } |
94 | num = v->i(); |
95 | } |
96 | |
97 | DataType dtype; |
98 | if (arg_def.type() != DT_INVALID) { |
99 | dtype = arg_def.type(); |
100 | } else if (arg_def.type_attr().empty()) { |
101 | dtype = DT_INVALID; |
102 | } else { |
103 | const AttrValue* v = attrs.FindByString(arg_def.type_attr()); |
104 | if (v == nullptr) { |
105 | return errors::NotFound("type attr not found: " , arg_def.type_attr()); |
106 | } |
107 | dtype = v->type(); |
108 | } |
109 | dtypes->resize(num, dtype); |
110 | return OkStatus(); |
111 | } |
112 | |
113 | namespace { |
114 | |
115 | template <typename T> |
116 | void AddAttr(const string& name, const T& val, NodeDef* ndef) { |
117 | SetAttrValue(val, &((*ndef->mutable_attr())[name])); |
118 | } |
119 | |
120 | Status ValidateSignatureWithAttrs(const OpDef& sig, AttrSlice attr_values) { |
121 | // attr_values should specify all attrs defined in fdef, except for those |
122 | // which have a default value |
123 | for (const auto& attr : sig.attr()) { |
124 | const AttrValue* attr_value = attr_values.FindByString(attr.name()); |
125 | if (attr_value) { |
126 | Status status = AttrValueHasType(*attr_value, attr.type()); |
127 | if (!status.ok()) { |
128 | errors::AppendToMessage(&status, "for attr '" , attr.name(), "'" ); |
129 | return status; |
130 | } |
131 | } else if (!attr.has_default_value()) { |
132 | return errors::NotFound("Attr " , attr.name(), " is not found from " , |
133 | SummarizeOpDef(sig)); |
134 | } |
135 | } |
136 | |
137 | // TODO(josh11b): Enable this code once it works with function gradients. |
138 | // Right now the C++ function gradient code assumes it can pass |
139 | // all the attrs of the function to the gradient, and any attrs that |
140 | // the gradient doesn't care about will be ignored. |
141 | #if 0 |
142 | if (attr_values.size() != sig.attr_size()) { |
143 | for (const auto& a : attr_values) { |
144 | // TODO(josh11b): Possibly should ignore attrs that start with "_" here? |
145 | bool found = false; |
146 | for (const auto& s : sig.attr()) { |
147 | if (a.first == s.name()) { |
148 | found = true; |
149 | break; |
150 | } |
151 | } |
152 | if (!found) { |
153 | return errors::NotFound("Attr " , a.first, " is not found in " , |
154 | SummarizeOpDef(sig)); |
155 | } |
156 | } |
157 | } |
158 | #endif |
159 | |
160 | return OkStatus(); |
161 | } |
162 | |
163 | // A helper class for instantiating functions. This contains shared information |
164 | // like the resulting graph and node name index. |
165 | class FunctionInstantiationHelper { |
166 | public: |
167 | FunctionInstantiationHelper(GetFunctionSignature get_function, |
168 | InstantiationResult* result) |
169 | : get_function_(std ::move(get_function)), result_(*result) { |
170 | result_.nodes.clear(); |
171 | } |
172 | |
173 | // Builds index for nodes that can be used as node's input arguments. |
174 | // `resource_arg_unique_id`: if non-negative, will be populated to the |
175 | // "_resource_arg_unique_id" attribute of the arg node. |
176 | Status BuildInputArgIndex(const OpDef::ArgDef& arg_def, AttrSlice attr_values, |
177 | const FunctionDef::ArgAttrs* arg_attrs, |
178 | bool ints_on_device, |
179 | int64_t resource_arg_unique_id) { |
180 | bool is_type_list; |
181 | DataTypeVector dtypes; |
182 | TF_RETURN_IF_ERROR( |
183 | ArgNumType(attr_values, arg_def, &is_type_list, &dtypes)); |
184 | if (dtypes.size() < size_t{1}) { |
185 | return errors::Internal("Expected a list of at least one dtype" ); |
186 | } |
187 | int arg_index = result_.nodes.size(); |
188 | TF_RETURN_IF_ERROR( |
189 | AddItem(arg_def.name(), {true, arg_index, 0, is_type_list, dtypes})); |
190 | // Creates dtypes.size() nodes in the graph. |
191 | for (size_t i = 0; i < dtypes.size(); ++i) { |
192 | TF_RETURN_IF_ERROR(AddItem(strings::StrCat(arg_def.name(), ":" , i), |
193 | {true, arg_index, 0, false, {dtypes[i]}})); |
194 | if (arg_index != result_.nodes.size()) { |
195 | return errors::Internal( |
196 | "Expected arg_index to be equal to the number of nodes in result." , |
197 | " Got " , arg_index, " and " , result_.nodes.size()); |
198 | } |
199 | string name = arg_def.name(); |
200 | if (dtypes.size() > 1) { |
201 | strings::StrAppend(&name, "_" , i); |
202 | } |
203 | NodeDef* gnode = AddNode(name); |
204 | if (ints_on_device && dtypes[i] == DataType::DT_INT32) { |
205 | gnode->set_op(FunctionLibraryDefinition::kDeviceArgOp); |
206 | } else { |
207 | gnode->set_op(FunctionLibraryDefinition::kArgOp); |
208 | } |
209 | DataType dtype = arg_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i]; |
210 | AddAttr("T" , dtype, gnode); |
211 | AddAttr("index" , arg_index, gnode); |
212 | if (resource_arg_unique_id >= 0) { |
213 | AddAttr("_resource_arg_unique_id" , resource_arg_unique_id, gnode); |
214 | } |
215 | if (arg_attrs) { |
216 | for (const auto& arg_attr : arg_attrs->attr()) { |
217 | AddAttr(arg_attr.first, arg_attr.second, gnode->mutable_attr()); |
218 | } |
219 | } |
220 | result_.arg_types.push_back(dtypes[i]); |
221 | ++arg_index; |
222 | } |
223 | return OkStatus(); |
224 | } |
225 | |
226 | Status BuildNodeOutputIndex(const NodeDef& node, AttrSlice attrs, |
227 | const int arg_index) { |
228 | const OpDef* node_sig = nullptr; |
229 | TF_RETURN_IF_ERROR(get_function_(node.op(), &node_sig)); |
230 | if (node_sig->output_arg_size() == 0) { |
231 | return AddItem(node.name(), {false, arg_index, 0, false, {}}); |
232 | } |
233 | const int num_retval = node_sig->output_arg_size(); |
234 | int start = 0; |
235 | bool is_type_list; |
236 | DataTypeVector dtypes; |
237 | for (int i = 0; i < num_retval; ++i) { |
238 | TF_RETURN_IF_ERROR( |
239 | ArgNumType(attrs, node_sig->output_arg(i), &is_type_list, &dtypes)); |
240 | // Note that we rely on the backwards-compatibility test enforcing |
241 | // that output_arg(*).name() doesn't change here. |
242 | const string base_name = |
243 | strings::StrCat(node.name(), ":" , node_sig->output_arg(i).name()); |
244 | TF_RETURN_IF_ERROR( |
245 | AddItem(base_name, {false, arg_index, start, is_type_list, dtypes})); |
246 | for (int j = 0; j < static_cast<int>(dtypes.size()); ++j) { |
247 | TF_RETURN_IF_ERROR( |
248 | AddItem(strings::StrCat(base_name, ":" , j), |
249 | {false, arg_index, start + j, false, {dtypes[j]}})); |
250 | } |
251 | start += dtypes.size(); |
252 | } |
253 | return OkStatus(); |
254 | } |
255 | |
256 | Status InstantiateNode(const NodeDef& fnode, AttrSlice attrs) { |
257 | const OpDef* fnode_sig = nullptr; |
258 | TF_CHECK_OK(get_function_(fnode.op(), &fnode_sig)); |
259 | NodeDef* gnode = AddNode(fnode.name()); |
260 | gnode->set_op(fnode.op()); |
261 | gnode->set_device(fnode.device()); |
262 | int gnode_idx = nodes_.size() - 1; |
263 | |
264 | // Input |
265 | const int num_args = fnode_sig->input_arg_size(); |
266 | bool is_type_list; // ignored |
267 | DataTypeVector dtypes; |
268 | int fnode_arg_index = 0; |
269 | for (int i = 0; i < num_args; ++i) { |
270 | TF_RETURN_IF_ERROR( |
271 | ArgNumType(attrs, fnode_sig->input_arg(i), &is_type_list, &dtypes)); |
272 | // Consume inputs (indexed by fnode_arg_index) until we have |
273 | // matched each element of dtypes (indexed by j). |
274 | for (size_t j = 0; j < dtypes.size(); ++fnode_arg_index) { |
275 | if (fnode_arg_index >= fnode.input_size()) { |
276 | // Should never happen if we computed dtypes correctly. |
277 | return errors::InvalidArgument( |
278 | "Attempt to access beyond input size: " , fnode_arg_index, |
279 | " >= " , fnode.input_size()); |
280 | } |
281 | // Look up the next input. |
282 | const string& input_name = fnode.input(fnode_arg_index); |
283 | const auto* item = GetItemOrNull(input_name); |
284 | if (item == nullptr) { |
285 | return errors::InvalidArgument( |
286 | "input " , input_name, |
287 | " is not found: " , FormatNodeDefForError(fnode)); |
288 | } |
289 | if (item->dtypes.size() > dtypes.size() - j) { |
290 | return errors::InvalidArgument("Input " , input_name, " too long for " , |
291 | fnode_sig->input_arg(i).name()); |
292 | } |
293 | // Match up all the elements of this input (indexed by k) with |
294 | // elements of dtypes (advancing j). |
295 | for (int k = 0; k < item->dtypes.size(); ++k, ++j) { |
296 | if (item->dtypes[k] != dtypes[j]) { |
297 | return errors::InvalidArgument( |
298 | "input " , fnode_sig->input_arg(i).name(), "[" , j, |
299 | "] expected type " , DataTypeString(dtypes[j]), |
300 | " != " , DataTypeString(item->dtypes[k]), ", the type of " , |
301 | input_name, "[" , k, "]" ); |
302 | } |
303 | if (item->is_func_arg) { |
304 | AddInput(gnode_idx, item->nid + k, 0); |
305 | } else { |
306 | AddInput(gnode_idx, item->nid, item->idx + k); |
307 | } |
308 | } |
309 | } |
310 | } |
311 | |
312 | // Control deps. |
313 | for (int i = fnode_arg_index; i < fnode.input_size(); ++i) { |
314 | const string& input = fnode.input(i); |
315 | if (input.empty() || input[0] != '^') { |
316 | return errors::InvalidArgument("Expected input[" , i, "] == '" , input, |
317 | "' to be a control input." ); |
318 | } |
319 | int nid = -1; |
320 | const string node_name = input.substr(1); |
321 | const string node_colon = node_name + ":" ; |
322 | const string node_colon_bound = node_name + ";" ; |
323 | // index_ is a map sorted lexicographically, so the key we are looking for |
324 | // must lie in the range [node_name, node_colon_bound). |
325 | auto it = index_.lower_bound(node_name); |
326 | while (it != index_.end() && it->first <= node_colon_bound) { |
327 | if (it->first == node_name || absl::StartsWith(it->first, node_colon)) { |
328 | nid = it->second.nid; |
329 | break; |
330 | } |
331 | ++it; |
332 | } |
333 | if (nid == -1) { |
334 | return errors::InvalidArgument("input[" , i, "] == '" , input, |
335 | "', is not found." ); |
336 | } |
337 | AddDep(gnode_idx, nid); |
338 | } |
339 | |
340 | // Attrs. |
341 | for (const auto& p : attrs) { |
342 | (*gnode->mutable_attr())[p.first] = p.second; |
343 | } |
344 | |
345 | // Experimental_debug_info. |
346 | if (fnode.has_experimental_debug_info()) { |
347 | gnode->mutable_experimental_debug_info()->MergeFrom( |
348 | fnode.experimental_debug_info()); |
349 | } |
350 | |
351 | // Tye info. |
352 | // TODO(mdan): Might this need adjustment at instantiation? |
353 | if (fnode.has_experimental_type()) { |
354 | *gnode->mutable_experimental_type() = fnode.experimental_type(); |
355 | } |
356 | |
357 | return OkStatus(); |
358 | } |
359 | |
360 | Status AddReturnNode( |
361 | const OpDef::ArgDef& ret_def, AttrSlice attrs, |
362 | const ::tensorflow::protobuf::Map<string, string>& ret_map, |
363 | bool ints_on_device, int* ret_index) { |
364 | auto ret_iter = ret_map.find(ret_def.name()); |
365 | if (ret_iter == ret_map.end()) { |
366 | return errors::InvalidArgument("Return " , ret_def.name(), " missing." ); |
367 | } |
368 | bool is_type_list; |
369 | DataTypeVector dtypes; |
370 | TF_RETURN_IF_ERROR(ArgNumType(attrs, ret_def, &is_type_list, &dtypes)); |
371 | CHECK_GE(dtypes.size(), size_t{1}); |
372 | const auto* item = GetItemOrNull(ret_iter->second); |
373 | if (item == nullptr) { |
374 | return errors::InvalidArgument("Return " , ret_def.name(), " -> " , |
375 | ret_iter->second, " is not found." ); |
376 | } |
377 | if (dtypes != item->dtypes) { |
378 | return errors::InvalidArgument("Invalid ret types " , ret_def.name(), |
379 | " : " , DataTypeVectorString(dtypes), |
380 | " vs. " , |
381 | DataTypeVectorString(item->dtypes)); |
382 | } |
383 | for (size_t i = 0; i < dtypes.size(); ++i) { |
384 | string name = strings::StrCat(ret_def.name(), "_RetVal" ); |
385 | if (dtypes.size() > 1) { |
386 | strings::StrAppend(&name, "_" , i); |
387 | } |
388 | NodeDef* gnode = AddNode(name); |
389 | if (ints_on_device && dtypes[i] == DataType::DT_INT32) { |
390 | gnode->set_op(FunctionLibraryDefinition::kDeviceRetOp); |
391 | } else { |
392 | gnode->set_op(FunctionLibraryDefinition::kRetOp); |
393 | } |
394 | AddInput(nodes_.size() - 1, item->nid, item->idx + i); |
395 | DataType dtype = ret_def.is_ref() ? MakeRefType(dtypes[i]) : dtypes[i]; |
396 | AddAttr("T" , dtype, gnode); |
397 | AddAttr("index" , (*ret_index)++, gnode); |
398 | result_.ret_types.push_back(dtypes[i]); |
399 | } |
400 | return OkStatus(); |
401 | } |
402 | |
403 | // Adds the actual node inputs to the result graph by converting indexes to |
404 | // the node names. |
405 | void AddNodeInputs() { |
406 | for (int i = 0; i < result_.nodes.size(); i++) { |
407 | NodeInfo& node_info = nodes_[i]; |
408 | for (const auto& p : node_info.data_inputs) { |
409 | result_.nodes[i].add_input(Name(p.first, p.second)); |
410 | } |
411 | for (int index : node_info.control_inputs) { |
412 | result_.nodes[i].add_input(Dep(index)); |
413 | } |
414 | } |
415 | } |
416 | |
417 | private: |
418 | // This is used to build a small index for all names that can be used as a |
419 | // node's input arguments. |
420 | // |
421 | // If is_func_arg is true, the name is a function's argument. In |
422 | // this case, the produced graph def has node[nid:nid + dtype.size()]. |
423 | // |
424 | // Otherwise, the name is a function body's node return value. In |
425 | // this case, the produced graph def has one node node[nid] and |
426 | // the node's output index [idx ... idx + num) corresponds to the |
427 | // named outputs. |
428 | // |
429 | // In all cases, "dtype" specifies the data type. |
430 | struct NameInfoItem { |
431 | bool is_func_arg; |
432 | int nid; |
433 | int idx; |
434 | bool is_type_list; |
435 | DataTypeVector dtypes; |
436 | }; |
437 | |
438 | // Adds an item into the input name index. |
439 | Status AddItem(const string& name, const NameInfoItem& item) { |
440 | if (!index_.insert({name, item}).second) { |
441 | return errors::InvalidArgument( |
442 | strings::StrCat("Duplicated " , item.is_func_arg ? "arg" : "ret" , |
443 | " name: " ), |
444 | name); |
445 | } |
446 | return OkStatus(); |
447 | } |
448 | |
449 | const NameInfoItem* GetItemOrNull(const string& name) const { |
450 | return gtl::FindOrNull(index_, name); |
451 | } |
452 | |
453 | string Dep(int node_index) const { |
454 | return strings::StrCat("^" , Name(node_index)); |
455 | } |
456 | |
457 | string Name(int node_index) const { |
458 | CHECK_LT(node_index, nodes_.size()); |
459 | return nodes_[node_index].name; |
460 | } |
461 | |
462 | string Name(int node_index, int output_index) const { |
463 | if (output_index == 0) { |
464 | return Name(node_index); |
465 | } else { |
466 | return strings::StrCat(Name(node_index), ":" , output_index); |
467 | } |
468 | } |
469 | |
470 | NodeDef* AddNode(const string& name) { |
471 | result_.nodes.emplace_back(); |
472 | NodeDef* gnode = &result_.nodes.back(); |
473 | gnode->set_name(name); |
474 | nodes_.push_back({name, {}, {}}); |
475 | CHECK_EQ(result_.nodes.size(), nodes_.size()); |
476 | return gnode; |
477 | } |
478 | |
479 | void AddInput(int node_index, int output_node, int output_index) { |
480 | CHECK_LT(node_index, nodes_.size()); |
481 | nodes_[node_index].data_inputs.push_back( |
482 | std::make_pair(output_node, output_index)); |
483 | } |
484 | |
485 | void AddDep(int node_index, int dep_index) { |
486 | CHECK_LT(node_index, nodes_.size()); |
487 | nodes_[node_index].control_inputs.push_back(dep_index); |
488 | } |
489 | |
490 | GetFunctionSignature get_function_; |
491 | InstantiationResult& result_; |
492 | // A small index for all names that can be used as a node's input arguments. |
493 | std::map<string, NameInfoItem> index_; |
494 | // This contains information about a node in the new graph including the node |
495 | // names and input nodes' indexes. |
496 | struct NodeInfo { |
497 | string name; |
498 | // Data inputs where <n, k> means arg k of node n. |
499 | std::vector<std::pair<int, int>> data_inputs; |
500 | // Control inputs (dependencies). |
501 | std::vector<int> control_inputs; |
502 | }; |
503 | // nodes_[i] is the information about result_.nodes[i]. |
504 | std::vector<NodeInfo> nodes_; |
505 | }; |
506 | |
507 | // Various helpers Print(proto) to print relevant protos to ascii. |
508 | string Print(const OpDef::ArgDef& arg) { |
509 | string out; |
510 | strings::StrAppend(&out, arg.name(), ":" ); |
511 | if (arg.is_ref()) strings::StrAppend(&out, "Ref(" ); |
512 | if (!arg.number_attr().empty()) { |
513 | strings::StrAppend(&out, arg.number_attr(), "*" ); |
514 | } |
515 | if (arg.type() != DT_INVALID) { |
516 | strings::StrAppend(&out, DataTypeString(arg.type())); |
517 | } else { |
518 | strings::StrAppend(&out, arg.type_attr()); |
519 | } |
520 | if (arg.is_ref()) strings::StrAppend(&out, ")" ); |
521 | return out; |
522 | } |
523 | |
524 | // TODO(josh11b): Merge this with SummarizeAttrValue(). |
525 | // When hash_string_attrs = true, string attributes are hashed instead of being |
526 | // truncated with ellipses. This is done to reduce the chance of collisions when |
527 | // looking up functions using the canonical representation. |
528 | string Print(const AttrValue& attr_value, |
529 | const bool hash_string_attrs = false) { |
530 | if (attr_value.value_case() == AttrValue::kType) { |
531 | return DataTypeString(attr_value.type()); |
532 | } else if ((attr_value.value_case() == AttrValue::kList) && |
533 | (attr_value.list().type_size() > 0)) { |
534 | string ret = "{" ; |
535 | for (int i = 0; i < attr_value.list().type_size(); ++i) { |
536 | if (i > 0) strings::StrAppend(&ret, ", " ); |
537 | strings::StrAppend(&ret, DataTypeString(attr_value.list().type(i))); |
538 | } |
539 | strings::StrAppend(&ret, "}" ); |
540 | return ret; |
541 | } else if (attr_value.value_case() == AttrValue::kFunc) { |
542 | if (attr_value.func().attr_size() == 0) { |
543 | return attr_value.func().name(); |
544 | } |
545 | std::vector<string> entries; |
546 | for (const auto& p : attr_value.func().attr()) { |
547 | entries.push_back(strings::StrCat(p.first, "=" , Print(p.second))); |
548 | } |
549 | std::sort(entries.begin(), entries.end()); |
550 | return strings::StrCat(attr_value.func().name(), "[" , |
551 | absl::StrJoin(entries, ", " ), "]" ); |
552 | } else if (attr_value.value_case() == AttrValue::kS && hash_string_attrs) { |
553 | return strings::StrCat(Fingerprint64(attr_value.s())); |
554 | } |
555 | return SummarizeAttrValue(attr_value); |
556 | } |
557 | |
558 | // TODO(josh11b): Merge this with SummarizeNodeDef(). |
559 | string Print(const NodeDef& n) { |
560 | string out; |
561 | strings::StrAppend(&out, n.name(), " = " , n.op()); |
562 | if (n.attr_size() > 0) { |
563 | std::vector<string> entries; |
564 | for (auto& a : n.attr()) { |
565 | entries.push_back(strings::StrCat(a.first, "=" , Print(a.second))); |
566 | } |
567 | std::sort(entries.begin(), entries.end()); |
568 | // Add a short device string at the end of all attributes. |
569 | if (!n.device().empty()) { |
570 | DeviceNameUtils::ParsedName parsed; |
571 | if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) { |
572 | entries.push_back( |
573 | strings::StrCat("device=" , parsed.type, ":" , parsed.id)); |
574 | } else { |
575 | entries.push_back("device=<FAILED_TO_PARSE>" ); |
576 | } |
577 | } |
578 | strings::StrAppend(&out, "[" , absl::StrJoin(entries, ", " ), "]" ); |
579 | } |
580 | strings::StrAppend(&out, "(" ); |
581 | std::vector<StringPiece> dat; |
582 | std::vector<string> dep; |
583 | for (StringPiece s : n.input()) { |
584 | if (absl::ConsumePrefix(&s, "^" )) { |
585 | dep.emplace_back(s); |
586 | } else { |
587 | dat.push_back(s); |
588 | } |
589 | } |
590 | strings::StrAppend(&out, absl::StrJoin(dat, ", " ), ")" ); |
591 | if (!dep.empty()) { |
592 | strings::StrAppend(&out, " @ " , absl::StrJoin(dep, ", " )); |
593 | } |
594 | return out; |
595 | } |
596 | |
597 | string Print(const FunctionDef& fdef) { |
598 | string out; |
599 | const OpDef& sig = fdef.signature(); |
600 | strings::StrAppend(&out, "\n" , sig.name()); |
601 | if (sig.attr_size() > 0) { |
602 | strings::StrAppend(&out, "[" ); |
603 | for (int i = 0; i < sig.attr_size(); ++i) { |
604 | const auto& a = sig.attr(i); |
605 | if (i > 0) strings::StrAppend(&out, ", " ); |
606 | if (a.type() == "type" ) { |
607 | strings::StrAppend(&out, a.name(), ":" , Print(a.allowed_values())); |
608 | } else { |
609 | strings::StrAppend(&out, a.name(), ":" , a.type()); |
610 | } |
611 | } |
612 | strings::StrAppend(&out, "]" ); |
613 | } |
614 | strings::StrAppend(&out, "(" ); |
615 | for (int i = 0; i < sig.input_arg_size(); ++i) { |
616 | if (i > 0) strings::StrAppend(&out, ", " ); |
617 | strings::StrAppend(&out, Print(sig.input_arg(i))); |
618 | } |
619 | strings::StrAppend(&out, ") -> (" ); |
620 | for (int i = 0; i < sig.output_arg_size(); ++i) { |
621 | if (i > 0) strings::StrAppend(&out, ", " ); |
622 | strings::StrAppend(&out, Print(sig.output_arg(i))); |
623 | } |
624 | strings::StrAppend(&out, ") {\n" ); |
625 | for (const auto& n : fdef.node_def()) { |
626 | strings::StrAppend(&out, " " , Print(n), "\n" ); |
627 | } |
628 | for (const auto& cr : fdef.control_ret()) { |
629 | strings::StrAppend(&out, " @return " , cr.first, " = " , cr.second, "\n" ); |
630 | } |
631 | for (const auto& r : fdef.ret()) { |
632 | strings::StrAppend(&out, " return " , r.first, " = " , r.second, "\n" ); |
633 | } |
634 | strings::StrAppend(&out, "}\n" ); |
635 | return out; |
636 | } |
637 | |
638 | string Print(gtl::ArraySlice<const NodeDef*> nodes) { |
639 | std::vector<const NodeDef*> arg; |
640 | std::vector<const NodeDef*> ret; |
641 | std::vector<const NodeDef*> body; |
642 | for (const NodeDef* n : nodes) { |
643 | if (n->op() == FunctionLibraryDefinition::kArgOp || |
644 | n->op() == FunctionLibraryDefinition::kDeviceArgOp) { |
645 | arg.push_back(n); |
646 | } else if (n->op() == FunctionLibraryDefinition::kRetOp || |
647 | n->op() == FunctionLibraryDefinition::kDeviceRetOp) { |
648 | ret.push_back(n); |
649 | } else { |
650 | body.push_back(n); |
651 | } |
652 | } |
653 | auto comp = [](const NodeDef* x, const NodeDef* y) { |
654 | int xi; |
655 | TF_CHECK_OK(GetNodeAttr(*x, "index" , &xi)); |
656 | int yi; |
657 | TF_CHECK_OK(GetNodeAttr(*y, "index" , &yi)); |
658 | return xi < yi; |
659 | }; |
660 | std::sort(arg.begin(), arg.end(), comp); |
661 | std::sort(ret.begin(), ret.end(), comp); |
662 | string out; |
663 | strings::StrAppend(&out, "\n(" ); |
664 | auto get_type_and_device = [](const NodeDef& n) { |
665 | DataType dt; |
666 | if (!TryGetNodeAttr(n, "T" , &dt)) { |
667 | dt = DT_INVALID; |
668 | } |
669 | if (!n.device().empty()) { |
670 | DeviceNameUtils::ParsedName parsed; |
671 | if (DeviceNameUtils::ParseFullName(n.device(), &parsed)) { |
672 | return strings::StrCat(DataTypeString(dt), "@" , parsed.type, ":" , |
673 | parsed.id); |
674 | } else { |
675 | LOG(WARNING) << "Failed to parse device \"" << n.device() << "\" in " |
676 | << n.op() << ":" << n.name(); |
677 | return strings::StrCat(DataTypeString(dt), "@" , |
678 | "<FAILED_TO_PARSE_DEVICE>" ); |
679 | } |
680 | } |
681 | return DataTypeString(dt); |
682 | }; |
683 | for (size_t i = 0; i < arg.size(); ++i) { |
684 | const NodeDef* n = arg[i]; |
685 | if (i > 0) strings::StrAppend(&out, ", " ); |
686 | CHECK_GE(n->attr_size(), 2); |
687 | strings::StrAppend(&out, n->name(), ":" , get_type_and_device(*n)); |
688 | } |
689 | strings::StrAppend(&out, ") -> (" ); |
690 | for (size_t i = 0; i < ret.size(); ++i) { |
691 | const NodeDef* n = ret[i]; |
692 | if (i > 0) strings::StrAppend(&out, ", " ); |
693 | CHECK_LE(2, n->attr_size()); |
694 | |
695 | // The _RetVal op should have a unique non-control input. We assert that |
696 | // here and add it to the output. |
697 | bool found_non_control_input = false; |
698 | for (const string& input : n->input()) { |
699 | if (!input.empty() && input[0] != '^') { |
700 | DCHECK_EQ(found_non_control_input, false) |
701 | << "RetVal node has more than one non-control input: " |
702 | << absl::StrJoin(n->input(), ", " ); |
703 | strings::StrAppend(&out, n->input(0), ":" , get_type_and_device(*n)); |
704 | found_non_control_input = true; |
705 | } |
706 | } |
707 | DCHECK_EQ(found_non_control_input, true) |
708 | << "RetVal did not have any non-control inputs: " |
709 | << absl::StrJoin(n->input(), ", " ); |
710 | } |
711 | strings::StrAppend(&out, ") {\n" ); |
712 | for (size_t i = 0; i < body.size(); ++i) { |
713 | strings::StrAppend(&out, " " , Print(*body[i]), "\n" ); |
714 | } |
715 | strings::StrAppend(&out, "}\n" ); |
716 | return out; |
717 | } |
718 | |
719 | Status AddDefaultAttrs(const string& op, |
720 | const GetFunctionSignature& get_function, |
721 | AttrValueMap* attrs) { |
722 | const OpDef* op_def = nullptr; |
723 | TF_RETURN_IF_ERROR(get_function(op, &op_def)); |
724 | AttrSlice attr_slice(attrs); |
725 | for (const auto& attr_def : op_def->attr()) { |
726 | if (attr_def.has_default_value() && !attr_slice.Find(attr_def.name())) { |
727 | if (!attrs->insert({attr_def.name(), attr_def.default_value()}).second) { |
728 | return errors::Internal("Somehow duplicated: " , attr_def.name()); |
729 | } |
730 | } |
731 | } |
732 | return OkStatus(); |
733 | } |
734 | |
735 | } // end namespace |
736 | |
737 | Status InstantiateFunction(const FunctionDef& fdef, AttrSlice attr_values, |
738 | GetFunctionSignature get_function, |
739 | InstantiationResult* result) { |
740 | if (VLOG_IS_ON(5)) { |
741 | const auto& signature = fdef.signature(); |
742 | VLOG(5) << "Instantiate function definition: name=" << signature.name() |
743 | << " #input_args=" << signature.input_arg_size() |
744 | << " #output_args=" << signature.output_arg_size() |
745 | << " #control_output=" << signature.control_output_size(); |
746 | for (const auto& line : str_util::Split(Print(fdef), '\n')) { |
747 | VLOG(5) << "|| " << line; |
748 | } |
749 | } |
750 | |
751 | const OpDef& sig = fdef.signature(); |
752 | TF_RETURN_IF_ERROR(ValidateSignatureWithAttrs(sig, attr_values)); |
753 | |
754 | const AttrValue* attr_values_ints_on_device = |
755 | attr_values.Find(FunctionLibraryDefinition::kIntsOnDeviceAttr); |
756 | bool ints_on_device = |
757 | (fdef.attr().count(FunctionLibraryDefinition::kIntsOnDeviceAttr) != 0 && |
758 | fdef.attr().at(FunctionLibraryDefinition::kIntsOnDeviceAttr).b()) || |
759 | (attr_values_ints_on_device != nullptr && |
760 | attr_values_ints_on_device->b()); |
761 | |
762 | FunctionInstantiationHelper helper(get_function, result); |
763 | Status s; |
764 | for (int i = 0, e = sig.input_arg_size(); i < e; ++i) { |
765 | const OpDef::ArgDef& arg_def = sig.input_arg(i); |
766 | auto it = fdef.arg_attr().find(i); |
767 | const FunctionDef::ArgAttrs* arg_attrs = |
768 | it != fdef.arg_attr().end() ? &it->second : nullptr; |
769 | auto resource_id_it = fdef.resource_arg_unique_id().find(i); |
770 | int64_t resource_arg_unique_id = |
771 | resource_id_it != fdef.resource_arg_unique_id().end() |
772 | ? resource_id_it->second |
773 | : -1LL; |
774 | s = helper.BuildInputArgIndex(arg_def, attr_values, arg_attrs, |
775 | ints_on_device, resource_arg_unique_id); |
776 | |
777 | if (!s.ok()) { |
778 | errors::AppendToMessage(&s, "In " , Print(arg_def)); |
779 | return s; |
780 | } |
781 | } |
782 | |
783 | auto substitute = [attr_values, &sig](const string& name, AttrValue* val) { |
784 | // Look for a specified value... |
785 | if (const AttrValue* v = attr_values.FindByString(name)) { |
786 | *val = *v; |
787 | return true; |
788 | } |
789 | // .. and if not, then check for a default value. |
790 | if (const OpDef::AttrDef* attr = FindAttr(name, sig)) { |
791 | if (attr->has_default_value()) { |
792 | *val = attr->default_value(); |
793 | return true; |
794 | } |
795 | } |
796 | // No luck finding a substitution. |
797 | return false; |
798 | }; |
799 | |
800 | // Makes a copy of all attrs in fdef and substitutes placeholders. |
801 | // After this step, every attr is bound to a concrete value. |
802 | std::vector<AttrValueMap> node_attrs; |
803 | node_attrs.resize(fdef.node_def_size()); |
804 | for (int i = 0; i < fdef.node_def_size(); ++i) { |
805 | for (auto attr : fdef.node_def(i).attr()) { |
806 | if (!SubstitutePlaceholders(substitute, &attr.second)) { |
807 | return errors::InvalidArgument("Failed to bind all placeholders in " , |
808 | SummarizeAttrValue(attr.second)); |
809 | } |
810 | if (!node_attrs[i].insert(attr).second) { |
811 | return errors::Internal("Somehow duplicated: " , attr.first); |
812 | } |
813 | } |
814 | TF_RETURN_IF_ERROR( |
815 | AddDefaultAttrs(fdef.node_def(i).op(), get_function, &node_attrs[i])); |
816 | } |
817 | |
818 | for (int i = 0; i < fdef.node_def_size(); ++i) { |
819 | s = helper.BuildNodeOutputIndex(fdef.node_def(i), AttrSlice(&node_attrs[i]), |
820 | result->nodes.size() + i); |
821 | if (!s.ok()) { |
822 | errors::AppendToMessage(&s, "In " , |
823 | FormatNodeDefForError(fdef.node_def(i))); |
824 | return s; |
825 | } |
826 | } |
827 | // Emits one node for each fdef.node_def. |
828 | for (int i = 0; i < fdef.node_def_size(); ++i) { |
829 | s = helper.InstantiateNode(fdef.node_def(i), AttrSlice(&node_attrs[i])); |
830 | if (!s.ok()) { |
831 | errors::AppendToMessage(&s, "In " , |
832 | FormatNodeDefForError(fdef.node_def(i))); |
833 | return s; |
834 | } |
835 | } |
836 | |
837 | // Emits nodes for the function's return values. |
838 | int ret_index = 0; |
839 | for (const OpDef::ArgDef& ret_def : sig.output_arg()) { |
840 | s = helper.AddReturnNode(ret_def, attr_values, fdef.ret(), ints_on_device, |
841 | &ret_index); |
842 | if (!s.ok()) { |
843 | errors::AppendToMessage(&s, "In function output " , Print(ret_def)); |
844 | return s; |
845 | } |
846 | } |
847 | |
848 | // Adds the actual node inputs using the input indexes. |
849 | helper.AddNodeInputs(); |
850 | |
851 | return OkStatus(); |
852 | } |
853 | |
854 | string DebugString(const FunctionDef& func_def) { return Print(func_def); } |
855 | |
856 | string DebugString(const GraphDef& instantiated_func_def) { |
857 | std::vector<const NodeDef*> ptrs; |
858 | for (const NodeDef& n : instantiated_func_def.node()) { |
859 | ptrs.push_back(&n); |
860 | } |
861 | return Print(ptrs); |
862 | } |
863 | |
864 | string DebugString(gtl::ArraySlice<NodeDef> instantiated_func_nodes) { |
865 | std::vector<const NodeDef*> ptrs; |
866 | for (const NodeDef& n : instantiated_func_nodes) { |
867 | ptrs.push_back(&n); |
868 | } |
869 | return Print(ptrs); |
870 | } |
871 | |
872 | string DebugStringWhole(const GraphDef& gdef) { |
873 | string ret; |
874 | for (const auto& fdef : gdef.library().function()) { |
875 | strings::StrAppend(&ret, Print(fdef)); |
876 | } |
877 | strings::StrAppend(&ret, "\n" ); |
878 | for (const auto& ndef : gdef.node()) { |
879 | strings::StrAppend(&ret, Print(ndef), "\n" ); |
880 | } |
881 | return ret; |
882 | } |
883 | |
884 | namespace { |
885 | |
886 | // Returns the name -> attr mapping of fdef's attrs that have a value set. In |
887 | // Python, it's possible to access unset attrs, which returns a default value |
888 | // and adds an unset attr to the map. |
889 | std::map<string, AttrValue> GetSetAttrs(const FunctionDef& fdef) { |
890 | std::map<string, AttrValue> set_attrs; |
891 | for (const auto& pair : fdef.attr()) { |
892 | if (pair.second.value_case() != AttrValue::VALUE_NOT_SET) { |
893 | set_attrs[pair.first] = pair.second; |
894 | } |
895 | } |
896 | return set_attrs; |
897 | } |
898 | |
899 | } // end namespace |
900 | |
901 | bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) { |
902 | if (!OpDefEqual(f1.signature(), f2.signature())) return false; |
903 | |
904 | std::map<string, AttrValue> f1_attrs = GetSetAttrs(f1); |
905 | std::map<string, AttrValue> f2_attrs = GetSetAttrs(f2); |
906 | if (f1_attrs.size() != f2_attrs.size()) return false; |
907 | for (const auto& iter1 : f1_attrs) { |
908 | auto iter2 = f2_attrs.find(iter1.first); |
909 | if (iter2 == f2_attrs.end()) return false; |
910 | if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false; |
911 | } |
912 | |
913 | if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) { |
914 | return false; |
915 | } |
916 | |
917 | std::map<string, string> ret1(f1.ret().begin(), f1.ret().end()); |
918 | std::map<string, string> ret2(f2.ret().begin(), f2.ret().end()); |
919 | if (ret1 != ret2) return false; |
920 | |
921 | std::map<string, string> control_ret1(f1.control_ret().begin(), |
922 | f1.control_ret().end()); |
923 | std::map<string, string> control_ret2(f2.control_ret().begin(), |
924 | f2.control_ret().end()); |
925 | if (control_ret1 != control_ret2) return false; |
926 | |
927 | return true; |
928 | } |
929 | |
930 | uint64 FunctionDefHash(const FunctionDef& fdef) { |
931 | // signature |
932 | uint64 h = OpDefHash(fdef.signature()); |
933 | |
934 | // attrs |
935 | std::map<string, AttrValue> attrs = GetSetAttrs(fdef); |
936 | for (const auto& p : attrs) { |
937 | h = Hash64(p.first.data(), p.first.size(), h); |
938 | h = Hash64Combine(AttrValueHash(p.second), h); |
939 | } |
940 | |
941 | // node defs |
942 | h = Hash64Combine(RepeatedNodeDefHash(fdef.node_def()), h); |
943 | |
944 | // output names |
945 | std::map<string, string> ret(fdef.ret().begin(), fdef.ret().end()); |
946 | for (const auto& p : ret) { |
947 | h = Hash64(p.first.data(), p.first.size(), h); |
948 | h = Hash64(p.second.data(), p.second.size(), h); |
949 | } |
950 | |
951 | // control output names |
952 | std::map<string, string> control_ret(fdef.control_ret().begin(), |
953 | fdef.control_ret().end()); |
954 | for (const auto& p : control_ret) { |
955 | h = Hash64(p.first.data(), p.first.size(), h); |
956 | h = Hash64(p.second.data(), p.second.size(), h); |
957 | } |
958 | |
959 | return h; |
960 | } |
961 | |
962 | static constexpr const char* const kExecutorAttr = "_executor" ; |
963 | |
964 | /* static */ |
965 | string FunctionLibraryRuntime::ExecutorType(const InstantiateOptions& options, |
966 | AttrSlice attrs) { |
967 | if (!options.executor_type.empty()) { |
968 | return options.executor_type; |
969 | } else if (const AttrValue* executor_attr = attrs.Find(kExecutorAttr)) { |
970 | return executor_attr->s(); |
971 | } else { |
972 | return string(); |
973 | } |
974 | } |
975 | |
976 | namespace { |
977 | class AttrKeyAndValue { |
978 | public: |
979 | enum ValueRepresentationOp { |
980 | kRaw, |
981 | kCEscape, |
982 | }; |
983 | AttrKeyAndValue(absl::string_view key_name, int key_suffix, string value, |
984 | ValueRepresentationOp value_op = kRaw) |
985 | : key_name_(key_name), |
986 | key_suffix_(key_suffix), |
987 | value_op_(value_op), |
988 | value_(std::move(value)) {} |
989 | |
990 | bool operator<(const AttrKeyAndValue& b) const { |
991 | if (key_name_ != b.key_name_) { |
992 | return key_name_ < b.key_name_; |
993 | } else if (key_suffix_ != b.key_suffix_) { |
994 | return key_suffix_ < b.key_suffix_; |
995 | } else { |
996 | return value_ < b.value_; |
997 | } |
998 | } |
999 | |
1000 | void AppendTo(bool first, string* s) const { |
1001 | absl::string_view v; |
1002 | bool add_escaped = false; |
1003 | if ((value_op_ == kCEscape) && NeedsEscaping(value_)) { |
1004 | // Use CEscape call below |
1005 | add_escaped = true; |
1006 | } else { |
1007 | // Add raw value contents directly |
1008 | v = value_; |
1009 | } |
1010 | if (key_suffix_ >= 0) { |
1011 | strings::StrAppend(s, first ? "" : "," , key_name_, key_suffix_, "=" , v); |
1012 | } else { |
1013 | strings::StrAppend(s, first ? "" : "," , key_name_, "=" , v); |
1014 | } |
1015 | if (add_escaped) { |
1016 | strings::StrAppend(s, absl::CEscape(value_)); |
1017 | } |
1018 | } |
1019 | |
1020 | private: |
1021 | static bool NeedsEscaping(const string& s) { |
1022 | for (auto c : s) { |
1023 | if (!isalnum(c) && (c != ' ')) { |
1024 | return true; |
1025 | } |
1026 | } |
1027 | return false; |
1028 | } |
1029 | |
1030 | absl::string_view key_name_; |
1031 | int key_suffix_; // -1 if missing |
1032 | ValueRepresentationOp value_op_; |
1033 | string value_; |
1034 | }; |
1035 | } // namespace |
1036 | |
1037 | string GetFunctionResourceInputDevice( |
1038 | const Tensor& input, const int arg_index, const FunctionDef& function_def, |
1039 | absl::flat_hash_map<string, std::vector<string>>* composite_devices) { |
1040 | const auto& handles = input.flat<ResourceHandle>(); |
1041 | const ResourceHandle& handle0 = handles(0); |
1042 | string composite_device; |
1043 | auto iter = function_def.arg_attr().find(arg_index); |
1044 | if (iter != function_def.arg_attr().end()) { |
1045 | auto arg_attr = iter->second.attr().find("_composite_device" ); |
1046 | if (arg_attr != iter->second.attr().end()) { |
1047 | composite_device = arg_attr->second.s(); |
1048 | } |
1049 | } |
1050 | if (!composite_device.empty()) { |
1051 | if (composite_devices->find(composite_device) == composite_devices->end()) { |
1052 | for (int i = 0; i < handles.size(); ++i) { |
1053 | (*composite_devices)[composite_device].push_back(handles(i).device()); |
1054 | } |
1055 | } |
1056 | return composite_device; |
1057 | } else { |
1058 | return handle0.device(); |
1059 | } |
1060 | } |
1061 | |
1062 | string Canonicalize(const string& funcname, AttrSlice attrs, |
1063 | const FunctionLibraryRuntime::InstantiateOptions& options) { |
1064 | absl::InlinedVector<AttrKeyAndValue, 8> entries; |
1065 | entries.reserve(attrs.size() + static_cast<int>(!options.target.empty()) + |
1066 | options.input_devices.size()); |
1067 | for (const auto& p : attrs) { |
1068 | if (p.first != kExecutorAttr) { |
1069 | entries.push_back(AttrKeyAndValue( |
1070 | p.first, -1, Print(p.second, /*hash_string_attrs=*/true))); |
1071 | } |
1072 | } |
1073 | if (!options.target.empty()) { |
1074 | entries.push_back(AttrKeyAndValue("_target" , -1, options.target, |
1075 | AttrKeyAndValue::kCEscape)); |
1076 | } |
1077 | for (int i = 0; i < options.input_devices.size(); ++i) { |
1078 | entries.push_back(AttrKeyAndValue("_input_dev" , i, options.input_devices[i], |
1079 | AttrKeyAndValue::kCEscape)); |
1080 | } |
1081 | for (int i = 0; i < options.output_devices.size(); ++i) { |
1082 | entries.push_back(AttrKeyAndValue("_output_dev" , i, |
1083 | options.output_devices[i], |
1084 | AttrKeyAndValue::kCEscape)); |
1085 | } |
1086 | for (const auto& iter : options.input_resource_dtypes_and_shapes) { |
1087 | entries.push_back(AttrKeyAndValue("_input_resource_dtype" , iter.first, |
1088 | DataTypeString(iter.second.dtype))); |
1089 | entries.push_back(AttrKeyAndValue("_input_resource_shape" , iter.first, |
1090 | iter.second.shape.DebugString(), |
1091 | AttrKeyAndValue::kCEscape)); |
1092 | } |
1093 | if (options.lib_def) { |
1094 | entries.push_back(AttrKeyAndValue( |
1095 | "_lib_def" , -1, |
1096 | absl::StrCat("" , reinterpret_cast<uintptr_t>(options.lib_def)))); |
1097 | } |
1098 | if (!options.state_handle.empty()) { |
1099 | entries.push_back( |
1100 | AttrKeyAndValue("_state_handle" , -1, options.state_handle)); |
1101 | } |
1102 | string executor_type = FunctionLibraryRuntime::ExecutorType(options, attrs); |
1103 | if (!executor_type.empty()) { |
1104 | entries.push_back(AttrKeyAndValue(kExecutorAttr, -1, executor_type)); |
1105 | } |
1106 | if (options.config_proto.ByteSize() > 0) { |
1107 | string config_proto_serialized; |
1108 | SerializeToStringDeterministic(options.config_proto, |
1109 | &config_proto_serialized); |
1110 | entries.push_back(AttrKeyAndValue("_config_proto" , -1, |
1111 | config_proto_serialized, |
1112 | AttrKeyAndValue::kCEscape)); |
1113 | } |
1114 | std::sort(entries.begin(), entries.end()); |
1115 | string result = strings::StrCat(funcname, "[" ); |
1116 | bool first = true; |
1117 | for (const auto& entry : entries) { |
1118 | entry.AppendTo(first, &result); |
1119 | first = false; |
1120 | } |
1121 | result += "]" ; |
1122 | return result; |
1123 | } |
1124 | |
1125 | string Canonicalize(const string& funcname, AttrSlice attrs) { |
1126 | static const FunctionLibraryRuntime::InstantiateOptions* kEmptyOptions = |
1127 | new FunctionLibraryRuntime::InstantiateOptions; |
1128 | return Canonicalize(funcname, attrs, *kEmptyOptions); |
1129 | } |
1130 | |
1131 | FunctionCallFrame::FunctionCallFrame(DataTypeSlice arg_types, |
1132 | DataTypeSlice ret_types) |
1133 | : arg_types_(arg_types.begin(), arg_types.end()), |
1134 | ret_types_(ret_types.begin(), ret_types.end()) { |
1135 | args_.resize(arg_types_.size()); |
1136 | rets_.resize(ret_types_.size()); |
1137 | } |
1138 | |
1139 | FunctionCallFrame::~FunctionCallFrame() {} |
1140 | |
1141 | Status FunctionCallFrame::SetArgs(gtl::ArraySlice<Tensor> args) { |
1142 | // Input type checks. |
1143 | if (args.size() != arg_types_.size()) { |
1144 | return errors::InvalidArgument("Expects " , arg_types_.size(), |
1145 | " arguments, but " , args.size(), |
1146 | " is provided" ); |
1147 | } |
1148 | for (size_t i = 0; i < args.size(); ++i) { |
1149 | if (arg_types_[i] != args[i].dtype()) { |
1150 | return errors::InvalidArgument( |
1151 | "Expects arg[" , i, "] to be " , DataTypeString(arg_types_[i]), " but " , |
1152 | DataTypeString(args[i].dtype()), " is provided" ); |
1153 | } |
1154 | args_[i] = args[i]; |
1155 | } |
1156 | return OkStatus(); |
1157 | } |
1158 | |
1159 | Status FunctionCallFrame::GetRetvals(std::vector<Tensor>* rets) const { |
1160 | rets->clear(); |
1161 | rets->reserve(rets_.size()); |
1162 | for (size_t i = 0; i < rets_.size(); ++i) { |
1163 | const auto& item = rets_[i]; |
1164 | if (item.has_val) { |
1165 | rets->push_back(item.val); |
1166 | } else { |
1167 | return errors::Internal("Retval[" , i, "] does not have value" ); |
1168 | } |
1169 | } |
1170 | return OkStatus(); |
1171 | } |
1172 | |
1173 | Status FunctionCallFrame::ConsumeRetvals(std::vector<Tensor>* rets, |
1174 | bool allow_dead_tensors) { |
1175 | rets->clear(); |
1176 | rets->reserve(rets_.size()); |
1177 | for (size_t i = 0; i < rets_.size(); ++i) { |
1178 | if (rets_[i].has_val) { |
1179 | rets->emplace_back(std::move(rets_[i].val)); |
1180 | } else if (allow_dead_tensors) { |
1181 | rets->emplace_back(); |
1182 | } else { |
1183 | return errors::Internal("Retval[" , i, "] does not have value" ); |
1184 | } |
1185 | } |
1186 | return OkStatus(); |
1187 | } |
1188 | |
1189 | Status FunctionCallFrame::GetArg(int index, const Tensor** val) { |
1190 | if (index < 0 || static_cast<size_t>(index) >= args_.size()) { |
1191 | return errors::InvalidArgument("GetArg " , index, " is not within [0, " , |
1192 | args_.size(), ")" ); |
1193 | } |
1194 | *val = &args_[index]; |
1195 | return OkStatus(); |
1196 | } |
1197 | |
1198 | Status FunctionCallFrame::SetRetval(int index, const Tensor& val) { |
1199 | if (index < 0 || static_cast<size_t>(index) >= rets_.size()) { |
1200 | return errors::InvalidArgument("SetRetval " , index, " is not within [0, " , |
1201 | rets_.size(), ")" ); |
1202 | } |
1203 | if (val.dtype() != ret_types_[index]) { |
1204 | return errors::InvalidArgument( |
1205 | "Expects ret[" , index, "] to be " , DataTypeString(ret_types_[index]), |
1206 | ", but " , DataTypeString(val.dtype()), " is provided." ); |
1207 | } |
1208 | Retval* item = &rets_[index]; |
1209 | if (!item->has_val) { |
1210 | item->has_val = true; |
1211 | item->val = val; |
1212 | } else { |
1213 | return errors::Internal("Retval[" , index, "] has already been set." ); |
1214 | } |
1215 | return OkStatus(); |
1216 | } |
1217 | |
1218 | FunctionLibraryDefinition::FunctionDefAndOpRegistration:: |
1219 | FunctionDefAndOpRegistration(const FunctionDef& fdef_in, |
1220 | const StackTracesMap& stack_traces) |
1221 | : fdef(fdef_in), |
1222 | // Exact shape inference for functions is handled by ShapeRefiner. |
1223 | // Here we pass a dummy shape inference function for legacy code paths. |
1224 | op_registration_data(fdef.signature(), shape_inference::UnknownShape, |
1225 | true /* is_function */), |
1226 | stack_traces(stack_traces) {} |
1227 | |
1228 | FunctionLibraryDefinition::FunctionLibraryDefinition( |
1229 | const FunctionLibraryDefinition& other) |
1230 | : default_registry_(other.default_registry_) { |
1231 | tf_shared_lock l(other.mu_); |
1232 | function_defs_ = other.function_defs_; |
1233 | func_grad_ = other.func_grad_; |
1234 | } |
1235 | |
1236 | FunctionLibraryDefinition::FunctionLibraryDefinition( |
1237 | const OpRegistryInterface* default_registry, |
1238 | const FunctionDefLibrary& def_lib) |
1239 | : default_registry_(default_registry), |
1240 | function_defs_(def_lib.function_size()) { |
1241 | for (const auto& fdef : def_lib.function()) { |
1242 | // The latter function definition wins. |
1243 | auto& ptr = function_defs_[fdef.signature().name()]; |
1244 | ptr.reset(new FunctionDefAndOpRegistration(fdef)); |
1245 | } |
1246 | for (const auto& grad : def_lib.gradient()) { |
1247 | func_grad_[grad.function_name()] = grad.gradient_func(); |
1248 | } |
1249 | } |
1250 | |
1251 | FunctionLibraryDefinition::~FunctionLibraryDefinition() {} |
1252 | |
1253 | bool FunctionLibraryDefinition::Contains(const string& func) const { |
1254 | tf_shared_lock l(mu_); |
1255 | return function_defs_.find(func) != function_defs_.end(); |
1256 | } |
1257 | |
1258 | const FunctionDef* FunctionLibraryDefinition::Find(const string& func) const { |
1259 | tf_shared_lock l(mu_); |
1260 | auto result = FindHelper(func); |
1261 | if (result) { |
1262 | return &result->fdef; |
1263 | } else { |
1264 | return nullptr; |
1265 | } |
1266 | } |
1267 | |
1268 | std::shared_ptr<FunctionLibraryDefinition::FunctionDefAndOpRegistration> |
1269 | FunctionLibraryDefinition::FindHelper(const string& func) const { |
1270 | auto iter = function_defs_.find(func); |
1271 | if (iter == function_defs_.end()) { |
1272 | return nullptr; |
1273 | } else { |
1274 | return iter->second; |
1275 | } |
1276 | } |
1277 | |
1278 | Status FunctionLibraryDefinition::AddFunctionDef( |
1279 | const FunctionDef& fdef, const StackTracesMap& stack_traces) { |
1280 | mutex_lock l(mu_); |
1281 | bool added; |
1282 | return AddFunctionDefHelper(fdef, stack_traces, &added); |
1283 | } |
1284 | |
1285 | Status FunctionLibraryDefinition::AddFunctionDefHelper( |
1286 | const FunctionDef& fdef, const StackTracesMap& stack_traces, bool* added) { |
1287 | *added = false; |
1288 | std::shared_ptr<FunctionDefAndOpRegistration>& entry = |
1289 | function_defs_[fdef.signature().name()]; |
1290 | if (entry) { |
1291 | if (!FunctionDefsEqual(entry->fdef, fdef)) { |
1292 | return errors::InvalidArgument( |
1293 | "Cannot add function '" , fdef.signature().name(), |
1294 | "' because a different function with the same name already " |
1295 | "exists." ); |
1296 | } |
1297 | // Ignore duplicate FunctionDefs. |
1298 | return OkStatus(); |
1299 | } |
1300 | const OpDef* op_def; |
1301 | if (default_registry_->LookUpOpDef(fdef.signature().name(), &op_def).ok()) { |
1302 | return errors::InvalidArgument( |
1303 | "Cannot add function '" , fdef.signature().name(), |
1304 | "' because an op with the same name already exists." ); |
1305 | } |
1306 | entry = std::make_shared<FunctionDefAndOpRegistration>(fdef, stack_traces); |
1307 | *added = true; |
1308 | return OkStatus(); |
1309 | } |
1310 | |
1311 | Status FunctionLibraryDefinition::AddHelper( |
1312 | std::shared_ptr<FunctionDefAndOpRegistration> registration, bool* added) { |
1313 | *added = false; |
1314 | std::shared_ptr<FunctionDefAndOpRegistration>& entry = |
1315 | function_defs_[registration->fdef.signature().name()]; |
1316 | if (entry) { |
1317 | if (!FunctionDefsEqual(entry->fdef, registration->fdef)) { |
1318 | return errors::InvalidArgument( |
1319 | "Cannot add function '" , registration->fdef.signature().name(), |
1320 | "' because a different function with the same name already " |
1321 | "exists." ); |
1322 | } |
1323 | // Ignore duplicate FunctionDefs. |
1324 | return OkStatus(); |
1325 | } |
1326 | const OpDef* op_def; |
1327 | if (default_registry_ |
1328 | ->LookUpOpDef(registration->fdef.signature().name(), &op_def) |
1329 | .ok()) { |
1330 | return errors::InvalidArgument( |
1331 | "Cannot add function '" , registration->fdef.signature().name(), |
1332 | "' because an op with the same name already exists." ); |
1333 | } |
1334 | entry = std::move(registration); |
1335 | *added = true; |
1336 | return OkStatus(); |
1337 | } |
1338 | |
1339 | Status FunctionLibraryDefinition::CopyFunctionDefFrom( |
1340 | const string& func, const FunctionLibraryDefinition& other) { |
1341 | if (default_registry_ != other.default_registry_) { |
1342 | return errors::InvalidArgument( |
1343 | "Cannot copy function '" , func, |
1344 | "' because CopyFunctionDefFrom() requires that both libraries have the " |
1345 | "same default registry." ); |
1346 | } |
1347 | std::shared_ptr<FunctionDefAndOpRegistration> function_def; |
1348 | { |
1349 | tf_shared_lock l(other.mu_); |
1350 | function_def = other.FindHelper(func); |
1351 | } |
1352 | if (!function_def) { |
1353 | return errors::InvalidArgument( |
1354 | "Cannot copy function '" , func, |
1355 | "' because no function with that name exists in the other library." ); |
1356 | } |
1357 | { |
1358 | mutex_lock l(mu_); |
1359 | std::shared_ptr<FunctionDefAndOpRegistration>& entry = function_defs_[func]; |
1360 | if (entry) { |
1361 | if (!FunctionDefsEqual(entry->fdef, function_def->fdef)) { |
1362 | return errors::InvalidArgument( |
1363 | "Cannot copy function '" , func, |
1364 | "' because a different function with the same name already " |
1365 | "exists." ); |
1366 | } |
1367 | } else { |
1368 | entry = std::move(function_def); |
1369 | } |
1370 | } |
1371 | return OkStatus(); |
1372 | } |
1373 | |
1374 | Status FunctionLibraryDefinition::AddGradientDef(const GradientDef& grad) { |
1375 | mutex_lock l(mu_); |
1376 | bool added; |
1377 | return AddGradientDefHelper(grad, &added); |
1378 | } |
1379 | |
1380 | Status FunctionLibraryDefinition::AddGradientDefHelper(const GradientDef& grad, |
1381 | bool* added) { |
1382 | *added = false; |
1383 | string* entry = &func_grad_[grad.function_name()]; |
1384 | if (!entry->empty()) { |
1385 | if (*entry != grad.gradient_func()) { |
1386 | return errors::InvalidArgument( |
1387 | "Cannot assign gradient function '" , grad.gradient_func(), "' to '" , |
1388 | grad.function_name(), "' because it already has gradient function " , |
1389 | "'" , *entry, "'" ); |
1390 | } |
1391 | // Ignore duplicate GradientDefs |
1392 | return OkStatus(); |
1393 | } |
1394 | *entry = grad.gradient_func(); |
1395 | *added = true; |
1396 | return OkStatus(); |
1397 | } |
1398 | |
1399 | Status FunctionLibraryDefinition::AddLibrary( |
1400 | const FunctionLibraryDefinition& other) { |
1401 | // Clone `other` to ensure thread-safety (grabbing `other`'s lock for |
1402 | // the duration of the function could lead to deadlock). |
1403 | FunctionLibraryDefinition clone(other); |
1404 | mutex_lock l(mu_); |
1405 | mutex_lock l2(clone.mu_); |
1406 | // Remember the funcs and grads that we added successfully so that |
1407 | // we can roll them back on error. |
1408 | std::vector<string> funcs; |
1409 | std::vector<string> funcs_with_grads; |
1410 | Status s; |
1411 | bool added; |
1412 | for (auto iter : clone.function_defs_) { |
1413 | s = AddHelper(iter.second, &added); |
1414 | if (!s.ok()) { |
1415 | Status remove_status = Remove(funcs, funcs_with_grads); |
1416 | if (!remove_status.ok()) { |
1417 | return remove_status; |
1418 | } |
1419 | return s; |
1420 | } |
1421 | if (added) { |
1422 | funcs.push_back(iter.second->fdef.signature().name()); |
1423 | } |
1424 | } |
1425 | for (auto iter : clone.func_grad_) { |
1426 | GradientDef grad; |
1427 | grad.set_function_name(iter.first); |
1428 | grad.set_gradient_func(iter.second); |
1429 | s = AddGradientDefHelper(grad, &added); |
1430 | if (!s.ok()) { |
1431 | Status remove_status = Remove(funcs, funcs_with_grads); |
1432 | if (!remove_status.ok()) { |
1433 | return remove_status; |
1434 | } |
1435 | return s; |
1436 | } |
1437 | if (added) { |
1438 | funcs_with_grads.push_back(grad.function_name()); |
1439 | } |
1440 | } |
1441 | return OkStatus(); |
1442 | } |
1443 | |
1444 | Status FunctionLibraryDefinition::AddLibrary( |
1445 | const FunctionDefLibrary& lib_def) { |
1446 | // Remember the funcs and grads that we added successfully so that |
1447 | // we can roll them back on error. |
1448 | mutex_lock l(mu_); |
1449 | std::vector<string> funcs; |
1450 | std::vector<string> funcs_with_grads; |
1451 | Status s; |
1452 | bool added; |
1453 | for (const FunctionDef& fdef : lib_def.function()) { |
1454 | s = AddFunctionDefHelper(fdef, /*stack_traces=*/{}, &added); |
1455 | if (!s.ok()) { |
1456 | Status remove_status = Remove(funcs, funcs_with_grads); |
1457 | if (!remove_status.ok()) { |
1458 | return remove_status; |
1459 | } |
1460 | return s; |
1461 | } |
1462 | if (added) { |
1463 | funcs.push_back(fdef.signature().name()); |
1464 | } |
1465 | } |
1466 | for (const GradientDef& grad : lib_def.gradient()) { |
1467 | s = AddGradientDefHelper(grad, &added); |
1468 | if (!s.ok()) { |
1469 | Status remove_status = Remove(funcs, funcs_with_grads); |
1470 | if (!remove_status.ok()) { |
1471 | return remove_status; |
1472 | } |
1473 | return s; |
1474 | } |
1475 | if (added) { |
1476 | funcs_with_grads.push_back(grad.function_name()); |
1477 | } |
1478 | } |
1479 | return OkStatus(); |
1480 | } |
1481 | |
1482 | Status FunctionLibraryDefinition::ReplaceFunction( |
1483 | const string& func, const FunctionDef& fdef, |
1484 | const StackTracesMap& stack_traces) { |
1485 | mutex_lock l(mu_); |
1486 | bool added; |
1487 | TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); |
1488 | TF_RETURN_IF_ERROR(AddFunctionDefHelper(fdef, stack_traces, &added)); |
1489 | return OkStatus(); |
1490 | } |
1491 | |
1492 | Status FunctionLibraryDefinition::ReplaceGradient(const GradientDef& grad) { |
1493 | mutex_lock l(mu_); |
1494 | bool added; |
1495 | TF_RETURN_IF_ERROR(RemoveGradient(grad.function_name())); |
1496 | TF_RETURN_IF_ERROR(AddGradientDefHelper(grad, &added)); |
1497 | return OkStatus(); |
1498 | } |
1499 | |
1500 | Status FunctionLibraryDefinition::RemoveFunction(const string& func) { |
1501 | mutex_lock l(mu_); |
1502 | TF_RETURN_IF_ERROR(RemoveFunctionHelper(func)); |
1503 | return OkStatus(); |
1504 | } |
1505 | |
1506 | Status FunctionLibraryDefinition::RemoveFunctionHelper(const string& func) { |
1507 | const auto& i = function_defs_.find(func); |
1508 | if (i == function_defs_.end()) { |
1509 | return errors::InvalidArgument("Tried to remove non-existent function '" , |
1510 | func, "'." ); |
1511 | } |
1512 | function_defs_.erase(i); |
1513 | return OkStatus(); |
1514 | } |
1515 | |
1516 | void FunctionLibraryDefinition::Clear() { |
1517 | mutex_lock l(mu_); |
1518 | function_defs_.clear(); |
1519 | func_grad_.clear(); |
1520 | } |
1521 | |
1522 | Status FunctionLibraryDefinition::RemoveGradient(const string& func) { |
1523 | const auto& i = func_grad_.find(func); |
1524 | if (i == func_grad_.end()) { |
1525 | return errors::InvalidArgument("Tried to remove non-existent gradient '" , |
1526 | func, "'." ); |
1527 | } |
1528 | func_grad_.erase(i); |
1529 | return OkStatus(); |
1530 | } |
1531 | |
1532 | Status FunctionLibraryDefinition::Remove( |
1533 | const std::vector<string>& funcs, |
1534 | const std::vector<string>& funcs_with_grads) { |
1535 | Status s; |
1536 | for (const string& f : funcs) { |
1537 | s = RemoveFunctionHelper(f); |
1538 | if (!s.ok()) { |
1539 | return s; |
1540 | } |
1541 | } |
1542 | for (const string& f : funcs_with_grads) { |
1543 | s = RemoveGradient(f); |
1544 | if (!s.ok()) { |
1545 | return s; |
1546 | } |
1547 | } |
1548 | return OkStatus(); |
1549 | } |
1550 | |
1551 | string FunctionLibraryDefinition::FindGradient(const string& func) const { |
1552 | tf_shared_lock l(mu_); |
1553 | return gtl::FindWithDefault(func_grad_, func, "" ); |
1554 | } |
1555 | |
1556 | string FunctionLibraryDefinition::FindGradientHelper(const string& func) const { |
1557 | return gtl::FindWithDefault(func_grad_, func, "" ); |
1558 | } |
1559 | |
1560 | Status FunctionLibraryDefinition::LookUp( |
1561 | const string& op, const OpRegistrationData** op_reg_data) const { |
1562 | tf_shared_lock l(mu_); |
1563 | auto iter = function_defs_.find(op); |
1564 | if (iter != function_defs_.end()) { |
1565 | *op_reg_data = &iter->second->op_registration_data; |
1566 | return OkStatus(); |
1567 | } |
1568 | return default_registry_->LookUp(op, op_reg_data); |
1569 | } |
1570 | |
1571 | string FunctionLibraryDefinition::UniqueFunctionName(StringPiece prefix) const { |
1572 | tf_shared_lock l(mu_); |
1573 | int index = 0; |
1574 | string name = strings::StrCat(prefix, index); |
1575 | while (function_defs_.find(name) != function_defs_.end()) { |
1576 | ++index; |
1577 | name = strings::StrCat(prefix, index); |
1578 | } |
1579 | return name; |
1580 | } |
1581 | |
1582 | const FunctionDef* FunctionLibraryDefinition::GetAttrImpl( |
1583 | const NodeDef& ndef) const { |
1584 | if (ndef.op() != kGradientOp) { |
1585 | // If 'ndef' calls a function and the function's def has the attr, |
1586 | // returns it. |
1587 | return Find(ndef.op()); |
1588 | } |
1589 | |
1590 | // If ndef is SymbolicGradient[f=Foo], we use Foo's gradient or |
1591 | // Foo's attributes. |
1592 | const NameAttrList* forward_func_attrs; |
1593 | if (!TryGetNodeAttr(ndef, kFuncAttr, &forward_func_attrs)) { |
1594 | return nullptr; |
1595 | } |
1596 | const string& func_name = forward_func_attrs->name(); |
1597 | { |
1598 | tf_shared_lock l(mu_); |
1599 | const string& grad_name = FindGradientHelper(func_name); |
1600 | // If 'func' has a user-defined gradient function, uses the grad |
1601 | // function's attrs to see if noinline is specified. Otherwise, |
1602 | // uses func's attrs. |
1603 | if (!grad_name.empty()) { |
1604 | if (const auto helper = FindHelper(grad_name)) { |
1605 | return &(helper->fdef); |
1606 | } else { |
1607 | return nullptr; |
1608 | } |
1609 | } |
1610 | if (const auto helper = FindHelper(func_name)) { |
1611 | return &(helper->fdef); |
1612 | } else { |
1613 | return nullptr; |
1614 | } |
1615 | } |
1616 | } |
1617 | |
1618 | std::vector<string> FunctionLibraryDefinition::ListFunctionNames() const { |
1619 | std::vector<string> function_names; |
1620 | tf_shared_lock l(mu_); |
1621 | function_names.reserve(function_defs_.size()); |
1622 | for (const auto& it : function_defs_) { |
1623 | function_names.emplace_back(it.first); |
1624 | } |
1625 | return function_names; |
1626 | } |
1627 | |
1628 | FunctionDefLibrary FunctionLibraryDefinition::ToProto() const { |
1629 | FunctionDefLibrary lib; |
1630 | tf_shared_lock l(mu_); |
1631 | for (const auto& f : function_defs_) { |
1632 | *lib.add_function() = f.second->fdef; |
1633 | } |
1634 | for (const auto& g : func_grad_) { |
1635 | GradientDef* gd = lib.add_gradient(); |
1636 | gd->set_function_name(g.first); |
1637 | gd->set_gradient_func(g.second); |
1638 | } |
1639 | return lib; |
1640 | } |
1641 | |
1642 | template <typename T> |
1643 | Status FunctionLibraryDefinition::GetAttr(const NodeDef& ndef, |
1644 | const string& attr, T* value) const { |
1645 | const FunctionDef* fdef = GetAttrImpl(ndef); |
1646 | if (fdef && TryGetNodeAttr(AttrSlice(&fdef->attr()), attr, value)) { |
1647 | return OkStatus(); |
1648 | } |
1649 | return errors::InvalidArgument("Attr " , attr, " is not defined." ); |
1650 | } |
1651 | |
1652 | template <typename T> |
1653 | Status FunctionLibraryDefinition::GetAttr(const Node& node, const string& attr, |
1654 | T* value) const { |
1655 | return GetAttr(node.def(), attr, value); |
1656 | } |
1657 | |
1658 | #define GET_ATTR(T) \ |
1659 | template Status FunctionLibraryDefinition::GetAttr(const Node&, \ |
1660 | const string&, T*) const; \ |
1661 | template Status FunctionLibraryDefinition::GetAttr(const NodeDef&, \ |
1662 | const string&, T*) const; |
1663 | GET_ATTR(string) |
1664 | GET_ATTR(bool) |
1665 | #undef GET_ATTR |
1666 | |
1667 | namespace { |
1668 | |
1669 | constexpr char kApiImplements[] = "api_implements" ; |
1670 | |
1671 | std::set<string> ReachableFunctions( |
1672 | const FunctionLibraryDefinition& flib, |
1673 | const protobuf::RepeatedPtrField<NodeDef>& nodes) { |
1674 | // Functions that are reachable from the graph. |
1675 | std::set<string> reachable_funcs; |
1676 | |
1677 | // For any functions, if it has attribute "api_implements" = |
1678 | // "some_interface" and it is reachable, then it means any other |
1679 | // function with same attribute name and value could also be potentially |
1680 | // reachable, eg via implementation_selector swapping the nodedef. |
1681 | absl::flat_hash_set<string> reachable_api_interface; |
1682 | |
1683 | // Functions might be reachable from the nested function calls, so we keep a |
1684 | // queue of functions that we have to check. |
1685 | gtl::InlinedVector<const FunctionDef*, 4> func_queue; |
1686 | |
1687 | // Add reachable and not already processed functions to the functions queue. |
1688 | const auto add_to_func_queue = [&](const string& func_name) { |
1689 | const FunctionDef* func = flib.Find(func_name); |
1690 | if (func && reachable_funcs.find(func_name) == reachable_funcs.end()) { |
1691 | func_queue.push_back(func); |
1692 | } |
1693 | }; |
1694 | |
1695 | // If any function with certain API name is reachable, all the other functions |
1696 | // with same API name should also be checked. |
1697 | const auto add_function_with_api_interface = [&](const string& api_name) { |
1698 | if (!reachable_api_interface.contains(api_name)) { |
1699 | reachable_api_interface.insert(api_name); |
1700 | for (const auto& func_name : flib.ListFunctionNames()) { |
1701 | const auto& func_def = flib.Find(func_name); |
1702 | const auto attr_it = func_def->attr().find(kApiImplements); |
1703 | if (attr_it != func_def->attr().end() && |
1704 | attr_it->second.s() == api_name) { |
1705 | add_to_func_queue(func_name); |
1706 | } |
1707 | } |
1708 | } |
1709 | }; |
1710 | |
1711 | // Add all the functions that are reachable from the given node to the queue. |
1712 | const auto process_node = [&](const NodeDef& node) { |
1713 | // Node itself can be a call to the function. |
1714 | add_to_func_queue(node.op()); |
1715 | |
1716 | // Or node can have an attribute referencing a function. |
1717 | for (const auto& attr : node.attr()) { |
1718 | const auto& attr_value = attr.second; |
1719 | |
1720 | // 1. AttrValue.func |
1721 | if (attr_value.has_func()) { |
1722 | add_to_func_queue(attr_value.func().name()); |
1723 | } |
1724 | |
1725 | // 2. AttrValue.ListValue.func |
1726 | if (attr_value.has_list()) { |
1727 | for (const auto& func : attr_value.list().func()) { |
1728 | add_to_func_queue(func.name()); |
1729 | } |
1730 | } |
1731 | } |
1732 | }; |
1733 | |
1734 | // Add all functions that are directly called from the optimized graph. |
1735 | std::for_each(nodes.begin(), nodes.end(), process_node); |
1736 | |
1737 | // Process all reachable functions. |
1738 | while (!func_queue.empty()) { |
1739 | const FunctionDef* func = func_queue.back(); |
1740 | func_queue.pop_back(); |
1741 | |
1742 | const string& func_name = func->signature().name(); |
1743 | reachable_funcs.insert(func_name); |
1744 | |
1745 | const auto attr_it = func->attr().find(kApiImplements); |
1746 | if (attr_it != func->attr().end()) { |
1747 | add_function_with_api_interface(attr_it->second.s()); |
1748 | } |
1749 | |
1750 | // Find all the functions called from the function body. |
1751 | const auto& func_body = func->node_def(); |
1752 | std::for_each(func_body.begin(), func_body.end(), process_node); |
1753 | |
1754 | // Check if the function has a registered gradient. |
1755 | const string grad_func_name = flib.FindGradient(func_name); |
1756 | if (!grad_func_name.empty()) add_to_func_queue(grad_func_name); |
1757 | } |
1758 | |
1759 | return reachable_funcs; |
1760 | } |
1761 | |
1762 | FunctionLibraryDefinition ReachableFunctionLibraryDefinition( |
1763 | const FunctionLibraryDefinition& flib, |
1764 | const protobuf::RepeatedPtrField<NodeDef>& nodes) { |
1765 | std::set<string> reachable_funcs = ReachableFunctions(flib, nodes); |
1766 | |
1767 | FunctionLibraryDefinition reachable_flib(flib.default_registry(), |
1768 | FunctionDefLibrary()); |
1769 | |
1770 | for (const string& func_name : reachable_funcs) { |
1771 | // This should never fail, because we copy functions from a valid flib and |
1772 | // use the same default registry. |
1773 | Status added = reachable_flib.CopyFunctionDefFrom(func_name, flib); |
1774 | TF_DCHECK_OK(added); |
1775 | |
1776 | const string grad_func_name = flib.FindGradient(func_name); |
1777 | if (!grad_func_name.empty()) { |
1778 | GradientDef grad; |
1779 | grad.set_function_name(func_name); |
1780 | grad.set_gradient_func(grad_func_name); |
1781 | // It can only fail if function already has a gradient function. |
1782 | const Status added_grad = reachable_flib.AddGradientDef(grad); |
1783 | TF_DCHECK_OK(added_grad); |
1784 | } |
1785 | } |
1786 | |
1787 | return reachable_flib; |
1788 | } |
1789 | |
1790 | string AllocatorAttributesToString( |
1791 | const std::vector<AllocatorAttributes>& attrs) { |
1792 | string result("[" ); |
1793 | // AllocatorAttribute::DebugString produces around 85 bytes now. |
1794 | result.reserve(100 * attrs.size()); |
1795 | for (const AllocatorAttributes& attr : attrs) { |
1796 | result.append(attr.DebugString()); |
1797 | result.append(", " ); |
1798 | } |
1799 | if (!attrs.empty()) { |
1800 | result.resize(result.size() - 2); |
1801 | } |
1802 | result.append("]" ); |
1803 | return result; |
1804 | } |
1805 | |
1806 | const char* IsSet(void* ptr) { return ptr == nullptr ? "unset" : "set" ; } |
1807 | |
1808 | } // namespace |
1809 | |
1810 | FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions( |
1811 | const GraphDef& graph) const { |
1812 | return ReachableFunctionLibraryDefinition(*this, graph.node()); |
1813 | } |
1814 | |
1815 | FunctionLibraryDefinition FunctionLibraryDefinition::ReachableDefinitions( |
1816 | const FunctionDef& func) const { |
1817 | return ReachableFunctionLibraryDefinition(*this, func.node_def()); |
1818 | } |
1819 | |
1820 | string FunctionLibraryRuntime::Options::DebugString() const { |
1821 | return absl::StrCat( |
1822 | "FLR::Options(step_id=" , step_id, " rendezvous=" , IsSet(rendezvous), |
1823 | " cancellation_manager=" , IsSet(cancellation_manager), |
1824 | " collective_executor=" , IsSet(collective_executor), |
1825 | " step_container=" , IsSet(step_container), |
1826 | " stats_collector=" , IsSet(stats_collector), " runner=" , IsSet(runner), |
1827 | " remote_execution=" , remote_execution, " source_device=" , source_device, |
1828 | " create_rendezvous=" , create_rendezvous, |
1829 | " allow_dead_tensors=" , allow_dead_tensors, |
1830 | " args_alloc_attrs=" , AllocatorAttributesToString(args_alloc_attrs), |
1831 | " rets_alloc_attrs=" , AllocatorAttributesToString(rets_alloc_attrs), ")" ); |
1832 | } |
1833 | |
1834 | void FunctionDefHelper::AttrValueWrapper::InitFromString(StringPiece val) { |
1835 | if (val.size() >= 2 && val[0] == '$') { |
1836 | proto.set_placeholder(val.data() + 1, val.size() - 1); |
1837 | } else { |
1838 | SetAttrValue(val, &proto); |
1839 | } |
1840 | } |
1841 | |
1842 | FunctionDefHelper::AttrValueWrapper FunctionDefHelper::FunctionRef( |
1843 | const string& name, |
1844 | gtl::ArraySlice<std::pair<string, AttrValueWrapper>> attrs) { |
1845 | AttrValueWrapper ret; |
1846 | ret.proto.mutable_func()->set_name(name); |
1847 | for (const auto& a : attrs) { |
1848 | ret.proto.mutable_func()->mutable_attr()->insert({a.first, a.second.proto}); |
1849 | } |
1850 | return ret; |
1851 | } |
1852 | |
1853 | NodeDef FunctionDefHelper::Node::ToNodeDef() const { |
1854 | NodeDef n; |
1855 | n.set_op(this->op); |
1856 | n.set_name(GetName()); |
1857 | for (const auto& a : this->attr) { |
1858 | n.mutable_attr()->insert({a.first, a.second.proto}); |
1859 | } |
1860 | for (const string& a : this->arg) { |
1861 | n.add_input(a); |
1862 | } |
1863 | for (const string& d : this->dep) { |
1864 | n.add_input(strings::StrCat("^" , d)); |
1865 | } |
1866 | if (!this->device.empty()) { |
1867 | n.set_device(this->device); |
1868 | } |
1869 | if (!this->original_node_names.empty()) { |
1870 | *n.mutable_experimental_debug_info()->mutable_original_node_names() = { |
1871 | this->original_node_names.begin(), this->original_node_names.end()}; |
1872 | } |
1873 | if (!this->original_func_names.empty()) { |
1874 | *n.mutable_experimental_debug_info()->mutable_original_func_names() = { |
1875 | this->original_func_names.begin(), this->original_func_names.end()}; |
1876 | } |
1877 | return n; |
1878 | } |
1879 | |
1880 | /* static */ |
1881 | FunctionDef FunctionDefHelper::Create( |
1882 | const string& function_name, gtl::ArraySlice<string> in_def, |
1883 | gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def, |
1884 | gtl::ArraySlice<Node> node_def, |
1885 | gtl::ArraySlice<std::pair<string, string>> ret_def, |
1886 | gtl::ArraySlice<std::pair<string, string>> control_ret_def) { |
1887 | FunctionDef fdef; |
1888 | |
1889 | // Signature |
1890 | OpDefBuilder b(function_name); |
1891 | for (const auto& i : in_def) b.Input(i); |
1892 | for (const auto& o : out_def) b.Output(o); |
1893 | for (const auto& a : attr_def) b.Attr(a); |
1894 | for (const auto& c : control_ret_def) b.ControlOutput(c.first); |
1895 | |
1896 | OpRegistrationData op_reg_data; |
1897 | TF_CHECK_OK(b.Finalize(&op_reg_data)); |
1898 | fdef.mutable_signature()->Swap(&op_reg_data.op_def); |
1899 | |
1900 | // Function body |
1901 | for (const auto& n : node_def) { |
1902 | *(fdef.add_node_def()) = n.ToNodeDef(); |
1903 | } |
1904 | |
1905 | // Returns |
1906 | for (const auto& r : ret_def) { |
1907 | fdef.mutable_ret()->insert({r.first, r.second}); |
1908 | } |
1909 | |
1910 | // Control returns |
1911 | for (const auto& cr : control_ret_def) { |
1912 | fdef.mutable_control_ret()->insert({cr.first, cr.second}); |
1913 | } |
1914 | |
1915 | auto* op_def_registry = OpRegistry::Global(); |
1916 | // Check if any op is stateful. |
1917 | for (const auto& n : node_def) { |
1918 | const OpDef* op_def = nullptr; |
1919 | auto status = op_def_registry->LookUpOpDef(n.op, &op_def); |
1920 | // Lookup can fail if e.g. we are calling a function that was not yet |
1921 | // defined. If it happens, conservatively assume the op is stateful. |
1922 | if (!status.ok() || op_def->is_stateful()) { |
1923 | fdef.mutable_signature()->set_is_stateful(true); |
1924 | } |
1925 | } |
1926 | |
1927 | return fdef; |
1928 | } |
1929 | |
1930 | /* static */ |
1931 | FunctionDef FunctionDefHelper::Create( |
1932 | const string& function_name, gtl::ArraySlice<string> in_def, |
1933 | gtl::ArraySlice<string> out_def, gtl::ArraySlice<string> attr_def, |
1934 | gtl::ArraySlice<Node> node_def, |
1935 | gtl::ArraySlice<std::pair<string, string>> ret_def) { |
1936 | return Create(function_name, in_def, out_def, attr_def, node_def, ret_def, |
1937 | /*control_ret_def=*/{}); |
1938 | } |
1939 | |
1940 | /* static */ |
1941 | FunctionDef FunctionDefHelper::Define(const string& name, |
1942 | gtl::ArraySlice<string> arg_def, |
1943 | gtl::ArraySlice<string> ret_def, |
1944 | gtl::ArraySlice<string> attr_def, |
1945 | gtl::ArraySlice<Node> node_def) { |
1946 | FunctionDef fdef; |
1947 | OpDefBuilder b(name); |
1948 | for (const auto& a : arg_def) b.Input(a); |
1949 | for (const auto& r : ret_def) b.Output(r); |
1950 | for (const auto& a : attr_def) b.Attr(a); |
1951 | |
1952 | OpRegistrationData op_reg_data; |
1953 | TF_CHECK_OK(b.Finalize(&op_reg_data)); |
1954 | fdef.mutable_signature()->Swap(&op_reg_data.op_def); |
1955 | |
1956 | // Mapping from legacy output names to NodeDef outputs. |
1957 | std::unordered_map<string, string> ret_index; |
1958 | for (const auto& a : fdef.signature().input_arg()) { |
1959 | ret_index[a.name()] = a.name(); |
1960 | } |
1961 | |
1962 | // For looking up OpDefs |
1963 | auto* op_def_registry = OpRegistry::Global(); |
1964 | |
1965 | // Function body |
1966 | for (const auto& src : node_def) { |
1967 | NodeDef* n = fdef.add_node_def(); |
1968 | n->set_op(src.op); |
1969 | n->set_name(src.GetName()); |
1970 | for (const auto& a : src.attr) { |
1971 | n->mutable_attr()->insert({a.first, a.second.proto}); |
1972 | } |
1973 | for (const string& a : src.arg) { |
1974 | const auto iter = ret_index.find(a); |
1975 | CHECK(iter != ret_index.end()) |
1976 | << "Node input '" << a << "' in '" << n->name() << "' of " << name; |
1977 | n->add_input(iter->second); |
1978 | } |
1979 | for (const string& d : src.dep) { |
1980 | n->add_input(strings::StrCat("^" , d)); |
1981 | } |
1982 | |
1983 | // Add the outputs of this node to ret_index. |
1984 | const OpDef* op_def = nullptr; |
1985 | TF_CHECK_OK(op_def_registry->LookUpOpDef(n->op(), &op_def)) << n->op(); |
1986 | CHECK(op_def != nullptr) << n->op(); |
1987 | NameRangeMap output_names; |
1988 | TF_CHECK_OK(NameRangesForNode(*n, *op_def, nullptr, &output_names)); |
1989 | for (const auto& o : output_names) { |
1990 | CHECK_LE(o.second.second, src.ret.size()) |
1991 | << "Missing ret for output '" << o.first << "' in '" << n->name() |
1992 | << "' of " << name; |
1993 | for (int i = o.second.first; i < o.second.second; ++i) { |
1994 | ret_index[src.ret[i]] = |
1995 | strings::StrCat(n->name(), ":" , o.first, ":" , i - o.second.first); |
1996 | } |
1997 | } |
1998 | if (op_def->is_stateful()) fdef.mutable_signature()->set_is_stateful(true); |
1999 | } |
2000 | |
2001 | // Returns |
2002 | for (const auto& r : fdef.signature().output_arg()) { |
2003 | const auto iter = ret_index.find(r.name()); |
2004 | CHECK(iter != ret_index.end()) << "Return '" << r.name() << "' in " << name; |
2005 | fdef.mutable_ret()->insert({r.name(), iter->second}); |
2006 | } |
2007 | return fdef; |
2008 | } |
2009 | |
2010 | FunctionDef FunctionDefHelper::Define(gtl::ArraySlice<string> arg_def, |
2011 | gtl::ArraySlice<string> ret_def, |
2012 | gtl::ArraySlice<string> attr_def, |
2013 | gtl::ArraySlice<Node> node_def) { |
2014 | return Define("_" , arg_def, ret_def, attr_def, node_def); |
2015 | } |
2016 | |
2017 | namespace gradient { |
2018 | |
2019 | typedef std::unordered_map<string, Creator> OpGradFactory; |
2020 | |
2021 | OpGradFactory* GetOpGradFactory() { |
2022 | static OpGradFactory* factory = new OpGradFactory; |
2023 | return factory; |
2024 | } |
2025 | |
2026 | bool RegisterOp(const string& op, Creator func) { |
2027 | CHECK(GetOpGradFactory()->insert({op, func}).second) |
2028 | << "Duplicated gradient for " << op; |
2029 | return true; |
2030 | } |
2031 | |
2032 | Status GetOpGradientCreator(const string& op, Creator* creator) { |
2033 | auto fac = GetOpGradFactory(); |
2034 | auto iter = fac->find(op); |
2035 | if (iter == fac->end()) { |
2036 | return errors::NotFound("No gradient defined for op: " , op); |
2037 | } |
2038 | *creator = iter->second; |
2039 | return OkStatus(); |
2040 | } |
2041 | |
2042 | } // end namespace gradient |
2043 | |
2044 | } // namespace tensorflow |
2045 | |