1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_set> |
21 | #include <vector> |
22 | |
23 | #include "tensorflow/core/framework/attr_value_util.h" |
24 | #include "tensorflow/core/framework/node_def.pb.h" |
25 | #include "tensorflow/core/framework/op_def.pb.h" |
26 | #include "tensorflow/core/framework/tensor.h" |
27 | #include "tensorflow/core/framework/tensor_shape.h" |
28 | #include "tensorflow/core/framework/types.h" |
29 | #include "tensorflow/core/framework/types.pb.h" |
30 | #include "tensorflow/core/lib/core/stringpiece.h" |
31 | #include "tensorflow/core/lib/gtl/array_slice.h" |
32 | #include "tensorflow/core/lib/gtl/flatmap.h" |
33 | #include "tensorflow/core/lib/hash/hash.h" |
34 | #include "tensorflow/core/platform/hash.h" |
35 | #include "tensorflow/core/platform/protobuf.h" |
36 | #include "tensorflow/core/platform/status.h" |
37 | #include "tensorflow/core/platform/stringpiece.h" |
38 | #include "tensorflow/core/platform/types.h" |
39 | #include "tensorflow/core/util/padding.h" |
40 | |
41 | namespace tensorflow { |
42 | |
43 | class AttrSlice; |
44 | // We forward declare protos so that kernels don't need to depend on them |
45 | class OpDef; |
46 | class AttrValue; |
47 | class NameAttrList; |
48 | class TensorProto; |
49 | class TensorShapeProto; |
50 | |
51 | // Name of the attribute used to encode node colocation constraints. |
52 | // |
53 | // Nodes can be co-located on the same device. Desire for explicit co-location |
54 | // is described by list(string) attribute containing the name of colocation |
55 | // groups. |
56 | extern const char* const kColocationAttrName; |
57 | |
58 | // String prefix applied to the operation name for colocation constraints. |
59 | extern const char* const kColocationGroupPrefix; |
60 | |
61 | // Constants for host CPU staging op for TPUExecute. |
62 | extern const char* const kTpuExecuteStagingOp; |
63 | extern const char* const kTpuExecuteStagingNodeName; |
64 | |
65 | // Produce a human-readable version of a Node or NodeDef that is more concise |
66 | // than a text-format proto. |
67 | // |
68 | // The parameter `max_inputs_in_summary` specifies how many inputs at most to |
69 | // serialize in the output (in order not to get a string which is overly large). |
70 | // The value `-1` specifies that all inputs will be shown. |
71 | std::string SummarizeNodeDef(const NodeDef& node_def, |
72 | int max_inputs_in_summary = -1); |
73 | std::string SummarizeAttrs(const NodeDef& node_def); |
74 | std::string SummarizeAttrsHelper(AttrSlice attrs, StringPiece device); |
75 | |
76 | // Produces a formatted string pattern from the node which can uniquely identify |
77 | // this node upstream to produce an informative error message. The pattern |
78 | // followed is: {{node <node_name>}} |
79 | std::string FormatNodeDefForError(const NodeDef& node_def); |
80 | std::string FormatNodeDefForError( |
81 | StringPiece node_name, bool has_experimental_debug_info, |
82 | const NodeDef_ExperimentalDebugInfo& experimental_debug_info); |
83 | |
84 | typedef protobuf::Map<string, AttrValue> AttrValueMap; |
85 | |
86 | // Adds an attr with name <name> and value <value> to *node_def. |
87 | // The type of the attr is based on the type of value. |
88 | void AddNodeAttr(StringPiece name, const AttrValue& value, NodeDef* node_def); |
89 | void AddNodeAttr(StringPiece name, AttrValue&& value, NodeDef* node_def); |
90 | void AddNodeAttr(StringPiece name, StringPiece value, NodeDef* node_def); |
91 | void AddNodeAttr(StringPiece name, const char* value, NodeDef* node_def); |
92 | void AddNodeAttr(StringPiece name, int32_t value, NodeDef* node_def); |
93 | void AddNodeAttr(StringPiece name, int64_t value, NodeDef* node_def); |
94 | void AddNodeAttr(StringPiece name, float value, NodeDef* node_def); |
95 | void AddNodeAttr(StringPiece name, double value, NodeDef* node_def); |
96 | void AddNodeAttr(StringPiece name, bool value, NodeDef* node_def); |
97 | void AddNodeAttr(StringPiece name, DataType value, NodeDef* node_def); |
98 | void AddNodeAttr(StringPiece name, const PartialTensorShape& value, |
99 | NodeDef* node_def); |
100 | void AddNodeAttr(StringPiece name, const Tensor& value, NodeDef* node_def); |
101 | void AddNodeAttr(StringPiece name, const TensorProto& value, NodeDef* node_def); |
102 | void AddNodeAttr(StringPiece name, const NameAttrList& value, |
103 | NodeDef* node_def); |
104 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<StringPiece> value, |
105 | NodeDef* node_def); |
106 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<const char*> value, |
107 | NodeDef* node_def); |
108 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<string> value, |
109 | NodeDef* node_def); |
110 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<int32> value, |
111 | NodeDef* node_def); |
112 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<int64_t> value, |
113 | NodeDef* node_def); |
114 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<float> value, |
115 | NodeDef* node_def); |
116 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<bool> value, |
117 | NodeDef* node_def); |
118 | void AddNodeAttr(StringPiece name, const std::vector<bool>& value, |
119 | NodeDef* node_def); |
120 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<DataType> value, |
121 | NodeDef* node_def); |
122 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShape> value, |
123 | NodeDef* node_def); |
124 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<PartialTensorShape> value, |
125 | NodeDef* node_def); |
126 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<TensorShapeProto> value, |
127 | NodeDef* node_def); |
128 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<Tensor> value, |
129 | NodeDef* node_def); |
130 | void AddNodeAttr(StringPiece name, gtl::ArraySlice<NameAttrList> value, |
131 | NodeDef* node_def); |
132 | |
133 | // Version to workaround C++'s "perfect" forwarding not being able to |
134 | // forward {...} initialization. |
135 | template <class T> |
136 | void AddNodeAttr(StringPiece name, std::initializer_list<T> value, |
137 | NodeDef* node_def) { |
138 | AddNodeAttr(name, gtl::ArraySlice<T>(value), node_def); |
139 | } |
140 | |
141 | // Adds an attr to an attr value map. |
142 | void AddAttr(StringPiece name, const AttrValue& value, AttrValueMap* map); |
143 | void AddAttr(StringPiece name, bool value, AttrValueMap* map); |
144 | |
145 | class AttrSlice { |
146 | public: |
147 | AttrSlice(const NodeDef& node_def); // NOLINT(runtime/explicit) |
148 | |
149 | AttrSlice(); // Empty |
150 | explicit AttrSlice(const AttrValueMap* a); |
151 | |
152 | int size() const { return attrs()->size(); } |
153 | |
154 | // Returns the attr with attr_name if found. Otherwise, returns |
155 | // nullptr. |
156 | const AttrValue* Find(StringPiece attr_name) const; |
157 | const AttrValue* FindByString(const std::string& attr_name) const; |
158 | |
159 | // Returns the attr_value for attr_name if found. Otherwise, returns a |
160 | // NotFound status. |
161 | Status Find(StringPiece attr_name, const AttrValue** attr_value) const; |
162 | Status FindByString(const std::string& attr_name, |
163 | const AttrValue** attr_value) const; |
164 | |
165 | // Helper class to avoid allocations in EqualAttrs. |
166 | // TODO(irving): Will go away once NodeInfo is used. |
167 | struct Scratch { |
168 | std::string a; |
169 | std::string b; |
170 | }; |
171 | |
172 | // Check if all attrs and attr values match. Does not take defaults into |
173 | // account. |
174 | // |
175 | // TODO(irving): There is a bug in this routine inherited from its |
176 | // OptimizerCSE::EqualAttrs predecessor. The same tensor attr can be |
177 | // represented in more than one way as an AttrValue, since TensorProto is |
178 | // not 1-1. This bug will go away once I replace everything with NodeInfo, |
179 | // which stores a Tensor object directly. The Scratch object will also go |
180 | // away. |
181 | bool EqualAttrs(AttrSlice other, Scratch* scratch) const; |
182 | |
183 | // If this AttrSlice has an attached NodeDef, summarize it. This is for |
184 | // error messages only: we intentionally do not provide direct access to the |
185 | // NodeDef, since it is not always there. |
186 | std::string SummarizeNode() const; |
187 | |
188 | // Iteration over all attrs |
189 | AttrValueMap::const_iterator begin() const { return attrs()->begin(); } |
190 | AttrValueMap::const_iterator end() const { return attrs()->end(); } |
191 | |
192 | std::string DebugString() const; |
193 | |
194 | private: |
195 | const AttrValueMap* attrs() const { |
196 | return ndef_ != nullptr ? &ndef_->attr() : attrs_; |
197 | } |
198 | |
199 | Status CheckFind(StringPiece attr_name, const AttrValue* attr_value) const; |
200 | |
201 | const NodeDef* ndef_; |
202 | const AttrValueMap* attrs_; |
203 | }; |
204 | |
205 | // Return true if the attr with the name attr_name is defined in node_def. |
206 | bool HasNodeAttr(const NodeDef& node_def, StringPiece attr_name); |
207 | |
208 | // Look up the attr with name attr_name and set *value to its value. If no |
209 | // attr with attr_name is found in node_def, or the attr does not have |
210 | // a matching type, a non-ok status will be returned. |
211 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
212 | std::string* value); // type: "string" |
213 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
214 | tstring* value); // type: "tstring" |
215 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
216 | int64_t* value); // type: "int" |
217 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
218 | int32* value); // type: "int" |
219 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
220 | float* value); // type: "float" |
221 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
222 | bool* value); // type: "bool" |
223 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
224 | DataType* value); // type: "type" |
225 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
226 | TensorShapeProto* value); // type: "shape" |
227 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
228 | TensorShape* value); // type: "shape" |
229 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
230 | PartialTensorShape* value); // type: "shape" |
231 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
232 | Tensor* value); // type: "tensor" |
233 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
234 | std::vector<string>* value); // type "list(string)" |
235 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
236 | std::vector<tstring>* value); // type "list(tstring)" |
237 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
238 | std::vector<int64_t>* value); // type "list(int)" |
239 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
240 | std::vector<int32>* value); // type "list(int)" |
241 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
242 | std::vector<float>* value); // type "list(float)" |
243 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
244 | std::vector<bool>* value); // type "list(bool)" |
245 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
246 | std::vector<DataType>* value); // type "list(type)" |
247 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
248 | DataTypeVector* value); // type "list(type)" |
249 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
250 | std::vector<TensorShapeProto>* value); // type "list(shape)" |
251 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
252 | std::vector<TensorShape>* value); // type "list(shape)" |
253 | Status GetNodeAttr( |
254 | const AttrSlice& attrs, StringPiece attr_name, |
255 | std::vector<PartialTensorShape>* value); // type "list(shape)" |
256 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
257 | std::vector<Tensor>* value); // type: "list(tensor)" |
258 | |
259 | template <typename T> |
260 | StatusOr<T> GetNodeAttr(const NodeDef& ndef, absl::string_view attr_name) { |
261 | T val; |
262 | TF_RETURN_IF_ERROR(GetNodeAttr(ndef, attr_name, &val)); |
263 | return val; |
264 | } |
265 | |
266 | // This version avoids copying the TensorProto. |
267 | // REQUIRES: Must not use *value beyond the lifetime of node_def. |
268 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
269 | const TensorProto** value); // type: "tensor" |
270 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
271 | const TensorProto** value); // type: "tensor" |
272 | |
273 | // This version avoids copying the NameAttrList. |
274 | // REQUIRES: Must not use *value beyond the lifetime of node_def. |
275 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
276 | const NameAttrList** value); // type: "func" |
277 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
278 | const NameAttrList** value); // type: "func" |
279 | |
280 | // These versions copies the NameAttrList(s). |
281 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
282 | NameAttrList* value); // type: "func" |
283 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
284 | std::vector<NameAttrList>* value); // type: "list(func)" |
285 | |
286 | // Look up the attr with name attr_name and set *value to its value. If no |
287 | // attr with attr_name is found in node_def, or the attr does not have |
288 | // a matching type, false is returned. |
289 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
290 | std::string* value); // type: "string" |
291 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
292 | int64_t* value); // type: "int" |
293 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
294 | std::vector<int64_t>* value); // type: "int" |
295 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
296 | int32* value); // type: "int" |
297 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
298 | float* value); // type: "float" |
299 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
300 | bool* value); // type: "bool" |
301 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
302 | DataType* value); // type: "type" |
303 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
304 | TensorShape* value); // type: "shape" |
305 | |
306 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
307 | std::vector<string>* value); // type: "list(string)" |
308 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
309 | std::vector<tstring>* value); // type: "list(tstring)" |
310 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
311 | std::vector<int32>* value); // type: "list(int)" |
312 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
313 | std::vector<float>* value); // type: "list(float)" |
314 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
315 | std::vector<bool>* value); // type: "list(bool)" |
316 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
317 | std::vector<DataType>* value); // type: "list(type)" |
318 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
319 | std::vector<TensorShape> value); // type: "shape" |
320 | |
321 | // Overloads of TryGetNodeAttr() that avoid copying the non-POD attribute |
322 | // values. |
323 | bool TryGetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
324 | std::vector<const string*>* value); // type: "list(string)" |
325 | bool TryGetNodeAttr( |
326 | const AttrSlice& attrs, StringPiece attr_name, |
327 | std::vector<const TensorShapeProto*>* value); // type: "list(shape)" |
328 | |
329 | // Look up the attr with name attr_name and return a reference to its value. |
330 | // If no attr with attr_name is found in node_def, or the attr does not have |
331 | // a matching type, a reference to an empty string is returned. |
332 | // REQUIRES: Must not use the returned value beyond the lifetime of node_def. |
333 | const std::string& GetNodeAttrString(const AttrSlice& attrs, |
334 | StringPiece attr_name); |
335 | |
336 | // Specialization to parse an attribute directly into a Padding enum. |
337 | Status GetNodeAttr(const AttrSlice& attrs, StringPiece attr_name, |
338 | Padding* value); |
339 | |
340 | // Computes the input type for a specific node input. |
341 | // REQUIRES: ValidateOpDef(op_def).ok() |
342 | Status InputTypeForNode(const NodeDef& node_def, const OpDef& op_def, |
343 | int input_port, DataType* input_type); |
344 | // Computes the input types for a specific node. |
345 | // REQUIRES: ValidateOpDef(op_def).ok() |
346 | Status InputTypesForNode(const NodeDef& node_def, const OpDef& op_def, |
347 | DataTypeVector* inputs); |
348 | // Computes the output type for a specific node output. |
349 | // REQUIRES: ValidateOpDef(op_def).ok() |
350 | Status OutputTypeForNode(const NodeDef& node_def, const OpDef& op_def, |
351 | int output_port, DataType* output_type); |
352 | // Computes the output types for a specific node. |
353 | // REQUIRES: ValidateOpDef(op_def).ok() |
354 | Status OutputTypesForNode(const NodeDef& node_def, const OpDef& op_def, |
355 | DataTypeVector* outputs); |
356 | Status OutputTypesForNode(const AttrSlice& attrs, const OpDef& op_def, |
357 | DataTypeVector* outputs); |
358 | |
359 | // Computes the input and output types for a specific node. |
360 | // REQUIRES: ValidateOpDef(op_def).ok() |
361 | Status InOutTypesForNode(const NodeDef& node_def, const OpDef& op_def, |
362 | DataTypeVector* inputs, DataTypeVector* outputs); |
363 | // Computes the number of outputs for a specific node. |
364 | // REQUIRES: ValidateOpDef(op_def).ok() |
365 | Status NumOutputsForNode(const NodeDef& node_def, const OpDef& op_def, |
366 | int* num_outputs); |
367 | |
368 | // Map a node/op's input/output port_id to arg_id. |
369 | // |
370 | // The port_id refers to the n-th tensor of the node, while the arg_id refers to |
371 | // the n-th arg of the op. These two can be different if an op's arg is a list |
372 | // of tensors. |
373 | // |
374 | // We return -1 for any invalid port_id (i.e., no corresponding arg_id). |
375 | int OpPortIdToArgId(const NodeDef& node, |
376 | const protobuf::RepeatedPtrField<OpDef::ArgDef>& args, |
377 | int port_id); |
378 | |
379 | // Validates that the NodeDef: |
380 | // * Defines all expected attrs from the OpDef. |
381 | // * All attrs satisfies constraints from the OpDef. |
382 | // * Has a signature matching SignatureForNode(). |
383 | // etc. |
384 | Status ValidateNodeDef(const NodeDef& node_def, const OpDef& op_def); |
385 | |
386 | // Computes the mapping from input/output argument name to the |
387 | // corresponding input/output index range. For example, |
388 | // input "foo" corresponds to input indices |
389 | // [ (*inputs)["foo"].first, (*inputs)["foo"].second ). |
390 | // NOTE(mrry): To reduce allocations when the map is used and save |
391 | // space, the returned `NameRangeMap` objects borrow the input/output |
392 | // argument names from `op_def`. The `op_def` must outlive the |
393 | // returned `NameRangeMap` objects. |
394 | typedef gtl::FlatMap<StringPiece, std::pair<int, int>, hash<StringPiece>> |
395 | NameRangeMap; |
396 | Status NameRangesForNode(const AttrSlice& attrs, const OpDef& op_def, |
397 | NameRangeMap* inputs, NameRangeMap* outputs); |
398 | // Adds default values to *node_def for unspecified attrs from op_def. |
399 | void AddDefaultsToNodeDef(const OpDef& op_def, NodeDef* node_def); |
400 | |
401 | // Remove attributes from node_def when the value is the default from the |
402 | // op_def. |
403 | void StripDefaultsFromNodeDef(const OpDef& op_def, NodeDef* node_def); |
404 | |
405 | // Validates the syntax of a NodeDef provided externally. |
406 | // |
407 | // The following is an EBNF-style syntax for NodeDef objects. Note that |
408 | // Node objects are actually specified as tensorflow::NodeDef protocol buffers, |
409 | // which contain many other fields that are not (currently) validated. |
410 | // |
411 | // Node = NodeName, Inputs |
412 | // Inputs = ( DataInput * ), ( ControlInput * ) |
413 | // DataInput = NodeName, ( ":", [1-9], [0-9] * ) ? |
414 | // ControlInput = "^", NodeName |
415 | // NodeName = [A-Za-z0-9.], [A-Za-z0-9_./] * |
416 | Status ValidateExternalNodeDefSyntax(const NodeDef& node_def); |
417 | |
418 | // Returns "status" with formatted NodeDef attached as additional text |
419 | // in the error message. If 'allow_multiple_formatted_node' is false and there |
420 | // is already a formatted NodeDef present in 'status', we simply attach the name |
421 | // of the NodeDef instead of the formatted string. |
422 | Status AttachDef(const Status& status, const NodeDef& node_def, |
423 | bool allow_multiple_formatted_node = false); |
424 | // Appends the given prefix and suffix to the original node name in order to |
425 | // make the name unique. If it's an "Enter" node and uniquify_frame_name is |
426 | // true, use the same way to reset attribute "frame_name". |
427 | Status AddPrefixAndSuffixToNode(StringPiece prefix, StringPiece suffix, |
428 | NodeDef* node_def, |
429 | bool uniquify_frame_name = true); |
430 | |
431 | // Appends the given prefix to the colocation group name if the name exists |
432 | // in `to_match`. |
433 | Status MaybeAddPrefixToColocationConstraints( |
434 | const std::unordered_set<string>& match, StringPiece prefix, |
435 | NodeDef* node_def); |
436 | |
437 | // Updates the colocation constraint name with the one provided in the map (if |
438 | // it exists in the map) for node_def. |
439 | Status MaybeUpdateColocationConstraintsWithMap( |
440 | const std::map<absl::string_view, absl::string_view>& node_name_map, |
441 | NodeDef* node_def); |
442 | |
443 | } // namespace tensorflow |
444 | |
445 | #endif // TENSORFLOW_CORE_FRAMEWORK_NODE_DEF_UTIL_H_ |
446 | |