1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
17#define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
18
19#include <string>
20#include <unordered_set>
21#include <vector>
22
23#include "tensorflow/core/framework/attr_value_util.h"
24#include "tensorflow/core/framework/node_def.pb.h"
25#include "tensorflow/core/framework/op_def.pb.h"
26#include "tensorflow/core/framework/tensor.h"
27#include "tensorflow/core/framework/tensor_shape.h"
28#include "tensorflow/core/framework/types.h"
29#include "tensorflow/core/framework/types.pb.h"
30#include "tensorflow/core/lib/core/stringpiece.h"
31#include "tensorflow/core/lib/gtl/array_slice.h"
32#include "tensorflow/core/lib/gtl/flatmap.h"
33#include "tensorflow/core/lib/hash/hash.h"
34#include "tensorflow/core/platform/hash.h"
35#include "tensorflow/core/platform/protobuf.h"
36#include "tensorflow/core/platform/status.h"
37#include "tensorflow/core/platform/stringpiece.h"
38#include "tensorflow/core/platform/types.h"
39#include "tensorflow/core/util/padding.h"
40
41namespace tensorflow {
42
43class AttrSlice;
44// We forward declare protos so that kernels don't need to depend on them
45class OpDef;
46class AttrValue;
47class NameAttrList;
48class TensorProto;
49class TensorShapeProto;
50
51// Name of the attribute used to encode node colocation constraints.
52//
53// Nodes can be co-located on the same device. Desire for explicit co-location
54// is described by list(string) attribute containing the name of colocation
55// groups.
56extern const char* const kColocationAttrName;
57
58// String prefix applied to the operation name for colocation constraints.
59extern const char* const kColocationGroupPrefix;
60
61// Constants for host CPU staging op for TPUExecute.
62extern const char* const kTpuExecuteStagingOp;
63extern const char* const kTpuExecuteStagingNodeName;
64
65// Produce a human-readable version of a Node or NodeDef that is more concise
66// than a text-format proto.
67//
68// The parameter `max_inputs_in_summary` specifies how many inputs at most to
69// serialize in the output (in order not to get a string which is overly large).
70// The value `-1` specifies that all inputs will be shown.
71std::string SummarizeNodeDef(const NodeDef& node_def,
72 int max_inputs_in_summary = -1);
73std::string SummarizeAttrs(const NodeDef& node_def);
74std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device);
75
76// Produces a formatted string pattern from the node which can uniquely identify
77// this node upstream to produce an informative error message. The pattern
78// followed is: {{node <node_name>}}
79std::string FormatNodeDefForError(const NodeDef& node_def);
80std::string FormatNodeDefForError(
81 StringPiece node_name, bool has_experimental_debug_info,
82 const NodeDef_ExperimentalDebugInfo& experimental_debug_info);
83
84typedef protobuf::Map<string, AttrValue> AttrValueMap;
85
86// Adds an attr with name <name> and value <value> to *node_def.
87// The type of the attr is based on the type of value.
88void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def);
89void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def);
90void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def);
91void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def);
92void AddNodeAttr(StringPiece name, int32_t value, NodeDef* node_def);
93void AddNodeAttr(StringPiece name, int64_t value, NodeDef* node_def);
94void AddNodeAttr(StringPiece name, float value, NodeDef* node_def);
95void AddNodeAttr(StringPiece name, double value, NodeDef* node_def);
96void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def);
97void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def);
98void AddNodeAttr(StringPiece name, const PartialTensorShape& value,
99 NodeDef* node_def);
100void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def);
101void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def);
102void AddNodeAttr(StringPiece name, const NameAttrList& value,
103 NodeDef* node_def);
104void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value,
105 NodeDef* node_def);
106void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value,
107 NodeDef* node_def);
108void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value,
109 NodeDef* node_def);
110void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value,
111 NodeDef* node_def);
112void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64_t> value,
113 NodeDef* node_def);
114void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value,
115 NodeDef* node_def);
116void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value,
117 NodeDef* node_def);
118void AddNodeAttr(StringPiece name, const std::vector<bool>& value,
119 NodeDef* node_def);
120void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value,
121 NodeDef* node_def);
122void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value,
123 NodeDef* node_def);
124void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value,
125 NodeDef* node_def);
126void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value,
127 NodeDef* node_def);
128void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value,
129 NodeDef* node_def);
130void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value,
131 NodeDef* node_def);
132
133// Version to workaround C++'s "perfect" forwarding not being able to
134// forward {...} initialization.
135template <class T>
136void AddNodeAttr(StringPiece name, std::initializer_list<T> value,
137 NodeDef* node_def) {
138 AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def);
139}
140
141// Adds an attr to an attr value map.
142void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map);
143void AddAttr(StringPiece name, bool value, AttrValueMap* map);
144
145class AttrSlice {
146 public:
147 AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit)
148
149 AttrSlice(); // Empty
150 explicit AttrSlice(const AttrValueMap* a);
151
152 int size() const { return attrs()->size(); }
153
154 // Returns the attr with attr_name if found. Otherwise, returns
155 // nullptr.
156 const AttrValue* Find(StringPiece attr_name) const;
157 const AttrValue* FindByString(const std::string& attr_name) const;
158
159 // Returns the attr_value for attr_name if found. Otherwise, returns a
160 // NotFound status.
161 Status Find(StringPiece attr_name, const AttrValue** attr_value) const;
162 Status FindByString(const std::string& attr_name,
163 const AttrValue** attr_value) const;
164
165 // Helper class to avoid allocations in EqualAttrs.
166 // TODO(irving): Will go away once NodeInfo is used.
167 struct Scratch {
168 std::string a;
169 std::string b;
170 };
171
172 // Check if all attrs and attr values match. Does not take defaults into
173 // account.
174 //
175 // TODO(irving): There is a bug in this routine inherited from its
176 // OptimizerCSE::EqualAttrs predecessor. The same tensor attr can be
177 // represented in more than one way as an AttrValue, since TensorProto is
178 // not 1-1. This bug will go away once I replace everything with NodeInfo,
179 // which stores a Tensor object directly. The Scratch object will also go
180 // away.
181 bool EqualAttrs(AttrSlice other, Scratch* scratch) const;
182
183 // If this AttrSlice has an attached NodeDef, summarize it. This is for
184 // error messages only: we intentionally do not provide direct access to the
185 // NodeDef, since it is not always there.
186 std::string SummarizeNode() const;
187
188 // Iteration over all attrs
189 AttrValueMap::const_iterator begin() const { return attrs()->begin(); }
190 AttrValueMap::const_iterator end() const { return attrs()->end(); }
191
192 std::string DebugString() const;
193
194 private:
195 const AttrValueMap* attrs() const {
196 return ndef_ != nullptr ? &ndef_->attr() : attrs_;
197 }
198
199 Status CheckFind(StringPiece attr_name, const AttrValue* attr_value) const;
200
201 const NodeDef* ndef_;
202 const AttrValueMap* attrs_;
203};
204
205// Return true if the attr with the name attr_name is defined in node_def.
206bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name);
207
208// Look up the attr with name attr_name and set *value to its value. If no
209// attr with attr_name is found in node_def, or the attr does not have
210// a matching type, a non-ok status will be returned.
211Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
212 std::string* value); // type: "string"
213Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
214 tstring* value); // type: "tstring"
215Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
216 int64_t* value); // type: "int"
217Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
218 int32* value); // type: "int"
219Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
220 float* value); // type: "float"
221Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
222 bool* value); // type: "bool"
223Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
224 DataType* value); // type: "type"
225Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
226 TensorShapeProto* value); // type: "shape"
227Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
228 TensorShape* value); // type: "shape"
229Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
230 PartialTensorShape* value); // type: "shape"
231Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
232 Tensor* value); // type: "tensor"
233Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
234 std::vector<string>* value); // type "list(string)"
235Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
236 std::vector<tstring>* value); // type "list(tstring)"
237Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
238 std::vector<int64_t>* value); // type "list(int)"
239Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
240 std::vector<int32>* value); // type "list(int)"
241Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
242 std::vector<float>* value); // type "list(float)"
243Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
244 std::vector<bool>* value); // type "list(bool)"
245Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
246 std::vector<DataType>* value); // type "list(type)"
247Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
248 DataTypeVector* value); // type "list(type)"
249Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
250 std::vector<TensorShapeProto>* value); // type "list(shape)"
251Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
252 std::vector<TensorShape>* value); // type "list(shape)"
253Status GetNodeAttr(
254 const AttrSlice& attrs, StringPiece attr_name,
255 std::vector<PartialTensorShape>* value); // type "list(shape)"
256Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
257 std::vector<Tensor>* value); // type: "list(tensor)"
258
259template <typename T>
260StatusOr<T> GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) {
261 T val;
262 TF_RETURN_IF_ERROR(GetNodeAttr(ndef, attr_name, &val));
263 return val;
264}
265
266// This version avoids copying the TensorProto.
267// REQUIRES: Must not use *value beyond the lifetime of node_def.
268Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
269 const TensorProto** value); // type: "tensor"
270bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
271 const TensorProto** value); // type: "tensor"
272
273// This version avoids copying the NameAttrList.
274// REQUIRES: Must not use *value beyond the lifetime of node_def.
275Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
276 const NameAttrList** value); // type: "func"
277bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
278 const NameAttrList** value); // type: "func"
279
280// These versions copies the NameAttrList(s).
281Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
282 NameAttrList* value); // type: "func"
283Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
284 std::vector<NameAttrList>* value); // type: "list(func)"
285
286// Look up the attr with name attr_name and set *value to its value. If no
287// attr with attr_name is found in node_def, or the attr does not have
288// a matching type, false is returned.
289bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
290 std::string* value); // type: "string"
291bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
292 int64_t* value); // type: "int"
293bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
294 std::vector<int64_t>* value); // type: "int"
295bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
296 int32* value); // type: "int"
297bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
298 float* value); // type: "float"
299bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
300 bool* value); // type: "bool"
301bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
302 DataType* value); // type: "type"
303bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
304 TensorShape* value); // type: "shape"
305
306bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
307 std::vector<string>* value); // type: "list(string)"
308bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
309 std::vector<tstring>* value); // type: "list(tstring)"
310bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
311 std::vector<int32>* value); // type: "list(int)"
312bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
313 std::vector<float>* value); // type: "list(float)"
314bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
315 std::vector<bool>* value); // type: "list(bool)"
316bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
317 std::vector<DataType>* value); // type: "list(type)"
318bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
319 std::vector<TensorShape> value); // type: "shape"
320
321// Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute
322// values.
323bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
324 std::vector<const string*>* value); // type: "list(string)"
325bool TryGetNodeAttr(
326 const AttrSlice& attrs, StringPiece attr_name,
327 std::vector<const TensorShapeProto*>* value); // type: "list(shape)"
328
329// Look up the attr with name attr_name and return a reference to its value.
330// If no attr with attr_name is found in node_def, or the attr does not have
331// a matching type, a reference to an empty string is returned.
332// REQUIRES: Must not use the returned value beyond the lifetime of node_def.
333const std::string& GetNodeAttrString(const AttrSlice& attrs,
334 StringPiece attr_name);
335
336// Specialization to parse an attribute directly into a Padding enum.
337Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name,
338 Padding* value);
339
340// Computes the input type for a specific node input.
341// REQUIRES: ValidateOpDef(op_def).ok()
342Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
343 int input_port, DataType* input_type);
344// Computes the input types for a specific node.
345// REQUIRES: ValidateOpDef(op_def).ok()
346Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
347 DataTypeVector* inputs);
348// Computes the output type for a specific node output.
349// REQUIRES: ValidateOpDef(op_def).ok()
350Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def,
351 int output_port, DataType* output_type);
352// Computes the output types for a specific node.
353// REQUIRES: ValidateOpDef(op_def).ok()
354Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def,
355 DataTypeVector* outputs);
356Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def,
357 DataTypeVector* outputs);
358
359// Computes the input and output types for a specific node.
360// REQUIRES: ValidateOpDef(op_def).ok()
361Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def,
362 DataTypeVector* inputs, DataTypeVector* outputs);
363// Computes the number of outputs for a specific node.
364// REQUIRES: ValidateOpDef(op_def).ok()
365Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def,
366 int* num_outputs);
367
368// Map a node/op's input/output port_id to arg_id.
369//
370// The port_id refers to the n-th tensor of the node, while the arg_id refers to
371// the n-th arg of the op. These two can be different if an op's arg is a list
372// of tensors.
373//
374// We return -1 for any invalid port_id (i.e., no corresponding arg_id).
375int OpPortIdToArgId(const NodeDef& node,
376 const protobuf::RepeatedPtrField<OpDef::ArgDef>& args,
377 int port_id);
378
379// Validates that the NodeDef:
380// * Defines all expected attrs from the OpDef.
381// * All attrs satisfies constraints from the OpDef.
382// * Has a signature matching SignatureForNode().
383// etc.
384Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def);
385
386// Computes the mapping from input/output argument name to the
387// corresponding input/output index range. For example,
388// input "foo" corresponds to input indices
389// [ (*inputs)["foo"].first, (*inputs)["foo"].second ).
390// NOTE(mrry): To reduce allocations when the map is used and save
391// space, the returned `NameRangeMap` objects borrow the input/output
392// argument names from `op_def`. The `op_def` must outlive the
393// returned `NameRangeMap` objects.
394typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>>
395 NameRangeMap;
396Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def,
397 NameRangeMap* inputs, NameRangeMap* outputs);
398// Adds default values to *node_def for unspecified attrs from op_def.
399void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def);
400
401// Remove attributes from node_def when the value is the default from the
402// op_def.
403void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def);
404
405// Validates the syntax of a NodeDef provided externally.
406//
407// The following is an EBNF-style syntax for NodeDef objects. Note that
408// Node objects are actually specified as tensorflow::NodeDef protocol buffers,
409// which contain many other fields that are not (currently) validated.
410//
411// Node = NodeName, Inputs
412// Inputs = ( DataInput * ), ( ControlInput * )
413// DataInput = NodeName, ( ":", [1-9], [0-9] * ) ?
414// ControlInput = "^", NodeName
415// NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] *
416Status ValidateExternalNodeDefSyntax(const NodeDef& node_def);
417
418// Returns "status" with formatted NodeDef attached as additional text
419// in the error message. If 'allow_multiple_formatted_node' is false and there
420// is already a formatted NodeDef present in 'status', we simply attach the name
421// of the NodeDef instead of the formatted string.
422Status AttachDef(const Status& status, const NodeDef& node_def,
423 bool allow_multiple_formatted_node = false);
424// Appends the given prefix and suffix to the original node name in order to
425// make the name unique. If it's an "Enter" node and uniquify_frame_name is
426// true, use the same way to reset attribute "frame_name".
427Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix,
428 NodeDef* node_def,
429 bool uniquify_frame_name = true);
430
431// Appends the given prefix to the colocation group name if the name exists
432// in `to_match`.
433Status MaybeAddPrefixToColocationConstraints(
434 const std::unordered_set<string>& match, StringPiece prefix,
435 NodeDef* node_def);
436
437// Updates the colocation constraint name with the one provided in the map (if
438// it exists in the map) for node_def.
439Status MaybeUpdateColocationConstraintsWithMap(
440 const std::map<absl::string_view, absl::string_view>& node_name_map,
441 NodeDef* node_def);
442
443} // namespace tensorflow
444
445#endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_
446