1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/grappler/op_types.h" |
17 | |
18 | #include "tensorflow/core/framework/attr_value.pb.h" |
19 | #include "tensorflow/core/framework/op.h" |
20 | #include "tensorflow/core/framework/types.h" |
21 | #include "tensorflow/core/grappler/utils.h" |
22 | #include "tensorflow/core/lib/core/status.h" |
23 | #include "tensorflow/core/lib/gtl/flatset.h" |
24 | #include "tensorflow/core/lib/strings/str_util.h" |
25 | #include "tensorflow/core/platform/logging.h" |
26 | |
27 | namespace tensorflow { |
28 | namespace grappler { |
29 | |
30 | bool IsAdd(const NodeDef& node) { |
31 | if (node.op() == "AddV2" ) { |
32 | return true; |
33 | } |
34 | if (node.op() == "Add" ) { |
35 | DataType type = node.attr().at("T" ).type(); |
36 | return type != DT_STRING; |
37 | } |
38 | return false; |
39 | } |
40 | |
41 | bool IsAddN(const NodeDef& node) { return node.op() == "AddN" ; } |
42 | |
43 | bool IsAll(const NodeDef& node) { return node.op() == "All" ; } |
44 | |
45 | bool IsAngle(const NodeDef& node) { return node.op() == "Angle" ; } |
46 | |
47 | bool IsAny(const NodeDef& node) { return node.op() == "Any" ; } |
48 | |
49 | bool IsAnyDiv(const NodeDef& node) { |
50 | return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "Xdivy" || |
51 | node.op() == "FloorDiv" || node.op() == "TruncateDiv" ; |
52 | } |
53 | |
54 | bool IsAnyBatchMatMul(const NodeDef& node) { |
55 | return node.op() == "BatchMatMul" || node.op() == "BatchMatMulV2" ; |
56 | } |
57 | |
58 | bool IsAnyMatMul(const NodeDef& node) { |
59 | return node.op() == "MatMul" || node.op() == "SparseMatMul" || |
60 | IsAnyBatchMatMul(node) || IsQuantizedMatMul(node); |
61 | } |
62 | |
63 | bool IsAnyMax(const NodeDef& node) { |
64 | const auto& op = node.op(); |
65 | return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax" ; |
66 | } |
67 | |
68 | bool IsAnyMaxPool(const NodeDef& node) { |
69 | const auto& op = node.op(); |
70 | return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" || |
71 | op == "MaxPoolWithArgmax" || op == "FractionalMaxPool" ; |
72 | } |
73 | |
74 | bool IsAnyMin(const NodeDef& node) { |
75 | const auto& op = node.op(); |
76 | return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin" ; |
77 | } |
78 | |
79 | bool IsAnySparseSegmentReduction(const NodeDef& node) { |
80 | const auto& op = node.op(); |
81 | return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" || |
82 | op == "SparseSegmentMean" || |
83 | op == "SparseSegmentMeanWithNumSegments" || |
84 | op == "SparseSegmentSqrtN" || |
85 | op == "SparseSegmentSqrtNWithNumSegments" ; |
86 | } |
87 | |
88 | bool IsApproximateEqual(const NodeDef& node) { |
89 | return node.op() == "ApproximateEqual" ; |
90 | } |
91 | |
92 | bool IsArg(const NodeDef& node) { |
93 | return node.op() == "_Arg" || node.op() == "_DeviceArg" ; |
94 | } |
95 | |
96 | bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax" ; } |
97 | |
98 | bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin" ; } |
99 | |
100 | bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad" ; } |
101 | |
102 | bool IsAssign(const NodeDef& node) { |
103 | return node.op() == "Assign" || node.op() == "AssignVariableOp" ; |
104 | } |
105 | |
106 | bool IsAssert(const NodeDef& node) { return node.op() == "Assert" ; } |
107 | |
108 | bool IsAsString(const NodeDef& node) { return node.op() == "AsString" ; } |
109 | |
110 | bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2" ; } |
111 | |
112 | bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc" ; } |
113 | |
114 | bool IsBiasAdd(const NodeDef& node) { |
115 | return node.op() == "BiasAdd" || node.op() == "BiasAddV1" ; |
116 | } |
117 | |
118 | bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd" ; } |
119 | |
120 | bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad" ; } |
121 | |
122 | bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast" ; } |
123 | |
124 | bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo" ; } |
125 | |
126 | bool IsCast(const NodeDef& node) { return node.op() == "Cast" ; } |
127 | |
128 | bool IsCastLike(const NodeDef& node) { |
129 | static const gtl::FlatSet<string>* const kCastLikeOps = |
130 | CHECK_NOTNULL((new gtl::FlatSet<string>{ |
131 | "Angle" , "Bucketize" , "Cast" , "Dequantize" , "HistogramFixedWidth" , |
132 | "Imag" , "IsFinite" , "IsInf" , "IsNan" , "Quantize" , |
133 | "QuantizeDownAndShrinkRange" , "QuantizeV2" , "QuantizedInstanceNorm" , |
134 | "QuantizedRelu" , "QuantizedRelu6" , "QuantizedReluX" , "Real" , |
135 | "Requantize" })); |
136 | return kCastLikeOps->count(node.op()) > 0; |
137 | } |
138 | |
139 | bool IsCheckNumerics(const NodeDef& node) { |
140 | return node.op() == "CheckNumerics" ; |
141 | } |
142 | |
143 | bool IsCollective(const NodeDef& node) { |
144 | return node.op() == "CollectiveReduce" || |
145 | node.op() == "CollectiveBcastSend" || |
146 | node.op() == "CollectiveBcastRecv" ; |
147 | } |
148 | |
149 | bool IsComplex(const NodeDef& node) { return node.op() == "Complex" ; } |
150 | |
151 | bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs" ; } |
152 | |
153 | bool IsConcat(const NodeDef& node) { |
154 | return node.op() == "Concat" || node.op() == "ConcatV2" ; |
155 | } |
156 | |
157 | bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset" ; } |
158 | |
159 | bool IsConstant(const NodeDef& node) { return node.op() == "Const" ; } |
160 | |
161 | bool IsConj(const NodeDef& node) { return node.op() == "Conj" ; } |
162 | |
163 | bool IsConjugateTranspose(const NodeDef& node) { |
164 | return node.op() == "ConjugateTranspose" ; |
165 | } |
166 | |
167 | bool IsControlFlow(const NodeDef& node) { |
168 | // clang-format off |
169 | return node.op() == "ControlTrigger" || |
170 | node.op() == "Enter" || |
171 | node.op() == "Exit" || |
172 | node.op() == "LoopCond" || |
173 | node.op() == "Merge" || |
174 | node.op() == "_XlaMerge" || |
175 | node.op() == "NextIteration" || |
176 | node.op() == "Switch" || |
177 | node.op() == "_SwitchN" ; |
178 | // clang-format on |
179 | } |
180 | |
181 | bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D" ; } |
182 | |
183 | bool IsConv2DBackpropFilter(const NodeDef& node) { |
184 | return node.op() == "Conv2DBackpropFilter" ; |
185 | } |
186 | |
187 | bool IsConv2DBackpropInput(const NodeDef& node) { |
188 | return node.op() == "Conv2DBackpropInput" ; |
189 | } |
190 | |
191 | bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D" ; } |
192 | |
193 | bool IsConv3DBackpropFilterV2(const NodeDef& node) { |
194 | return node.op() == "Conv3DBackpropFilterV2" ; |
195 | } |
196 | |
197 | bool IsConv3DBackpropInputV2(const NodeDef& node) { |
198 | return node.op() == "Conv3DBackpropInputV2" ; |
199 | } |
200 | |
201 | bool IsDepthwiseConv2dNative(const NodeDef& node) { |
202 | return node.op() == "DepthwiseConv2dNative" ; |
203 | } |
204 | |
205 | bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) { |
206 | return node.op() == "DepthwiseConv2dNativeBackpropFilter" ; |
207 | } |
208 | |
209 | bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) { |
210 | return node.op() == "DepthwiseConv2dNativeBackpropInput" ; |
211 | } |
212 | |
213 | bool IsDequeueOp(const NodeDef& node) { |
214 | const auto& op = node.op(); |
215 | return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" || |
216 | op == "QueueDequeueV2" || op == "QueueDequeue" || |
217 | op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo" ; |
218 | } |
219 | |
220 | bool IsDiv(const NodeDef& node) { return node.op() == "Div" ; } |
221 | |
222 | bool IsDivNoNan(const NodeDef& node) { return node.op() == "DivNoNan" ; } |
223 | |
224 | // Returns true if node represents a unary elementwise function that is |
225 | // monotonic. If *is_non_decreasing is true, the function is non-decreasing, |
226 | // e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing, |
227 | // e.g. inv. |
228 | bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) { |
229 | static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps = |
230 | CHECK_NOTNULL((new gtl::FlatSet<string>{ |
231 | "Acosh" , "Asin" , "Asinh" , "Atan" , "Atanh" , "Ceil" , |
232 | "Elu" , "Erf" , "Exp" , "Expm1" , "Floor" , "Log" , |
233 | "Log1p" , "Relu" , "Relu6" , "Rint" , "Selu" , "Sigmoid" , |
234 | "Sign" , "Sinh" , "Softsign" , "Softplus" , "Sqrt" , "Tanh" , |
235 | })); |
236 | static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps = |
237 | CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos" , "Erfc" , "Neg" , "Rsqrt" })); |
238 | if (kMonotonicNonDecreasingOps->count(node.op()) > 0) { |
239 | if (is_non_decreasing) { |
240 | *is_non_decreasing = true; |
241 | } |
242 | return true; |
243 | } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) { |
244 | if (is_non_decreasing) { |
245 | *is_non_decreasing = false; |
246 | } |
247 | return true; |
248 | } |
249 | return false; |
250 | } |
251 | |
252 | bool IsElu(const NodeDef& node) { return node.op() == "Elu" ; } |
253 | |
254 | bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad" ; } |
255 | |
256 | bool IsQuantizationEmulation(const NodeDef& node) { |
257 | const auto& op = node.op(); |
258 | return absl::StartsWith(op, "QuantizeAndDequantize" ) || |
259 | absl::StartsWith(op, "FakeQuantWithMinMax" ); |
260 | } |
261 | |
262 | bool IsEnter(const NodeDef& node) { |
263 | const auto& op = node.op(); |
264 | return op == "Enter" || op == "RefEnter" ; |
265 | } |
266 | |
267 | bool IsEqual(const NodeDef& node) { return node.op() == "Equal" ; } |
268 | |
269 | bool IsExit(const NodeDef& node) { |
270 | const auto& op = node.op(); |
271 | return op == "Exit" || op == "RefExit" ; |
272 | } |
273 | |
274 | bool IsExp(const NodeDef& node) { return node.op() == "Exp" ; } |
275 | |
276 | bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam" ; } |
277 | |
278 | bool IsFill(const NodeDef& node) { return node.op() == "Fill" ; } |
279 | |
280 | bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv" ; } |
281 | |
282 | bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod" ; } |
283 | |
284 | bool IsFusedBatchNorm(const NodeDef& node) { |
285 | const auto& op = node.op(); |
286 | return op == "FusedBatchNorm" || op == "FusedBatchNormV2" || |
287 | op == "FusedBatchNormV3" ; |
288 | } |
289 | |
290 | bool IsFusedBatchNormEx(const NodeDef& node) { |
291 | return node.op() == "_FusedBatchNormEx" ; |
292 | } |
293 | |
294 | bool IsFusedBatchNormGrad(const NodeDef& node) { |
295 | const auto& op = node.op(); |
296 | return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2" || |
297 | op == "FusedBatchNormGradV3" ; |
298 | } |
299 | |
300 | bool IsGather(const NodeDef& node) { |
301 | const auto& op = node.op(); |
302 | return op == "Gather" || op == "GatherV2" || op == "ResourceGather" ; |
303 | } |
304 | |
305 | bool IsGreater(const NodeDef& node) { return node.op() == "Greater" ; } |
306 | |
307 | bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual" ; } |
308 | |
309 | bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst" ; } |
310 | |
311 | bool IsHistogramSummary(const NodeDef& node) { |
312 | return node.op() == "HistogramSummary" ; |
313 | } |
314 | |
315 | bool IsIdentity(const NodeDef& node) { |
316 | const auto& op = node.op(); |
317 | return op == "Identity" || op == "RefIdentity" ; |
318 | } |
319 | |
320 | bool IsIdentityN(const NodeDef& node) { |
321 | const auto& op = node.op(); |
322 | return op == "IdentityN" ; |
323 | } |
324 | |
325 | bool IsIdentityNSingleInput(const NodeDef& node) { |
326 | return IsIdentityN(node) && node.attr().count("T" ) != 0 && |
327 | node.attr().at("T" ).list().type_size() == 1; |
328 | } |
329 | |
330 | bool IsIf(const NodeDef& node) { |
331 | const auto& op = node.op(); |
332 | return op == "If" || op == "StatelessIf" ; |
333 | } |
334 | |
335 | bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma" ; } |
336 | |
337 | bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac" ; } |
338 | |
339 | bool IsImag(const NodeDef& node) { return node.op() == "Imag" ; } |
340 | |
341 | bool IsImmutableConst(const NodeDef& node) { |
342 | return node.op() == "ImmutableConst" ; |
343 | } |
344 | |
345 | bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad" ; } |
346 | |
347 | bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu" ; } |
348 | |
349 | bool IsLeakyReluGrad(const NodeDef& node) { |
350 | return node.op() == "LeakyReluGrad" ; |
351 | } |
352 | |
353 | bool IsLess(const NodeDef& node) { return node.op() == "Less" ; } |
354 | |
355 | bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual" ; } |
356 | |
357 | bool IsLog(const NodeDef& node) { return node.op() == "Log" ; } |
358 | |
359 | bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd" ; } |
360 | |
361 | bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot" ; } |
362 | |
363 | bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr" ; } |
364 | |
365 | bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond" ; } |
366 | |
367 | bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul" ; } |
368 | |
369 | bool IsMax(const NodeDef& node) { return node.op() == "Max" ; } |
370 | |
371 | bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum" ; } |
372 | |
373 | bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad" ; } |
374 | |
375 | bool IsMean(const NodeDef& node) { return node.op() == "Mean" ; } |
376 | |
377 | bool IsMerge(const NodeDef& node) { |
378 | const auto& op = node.op(); |
379 | return op == "Merge" || op == "RefMerge" || op == "_XlaMerge" ; |
380 | } |
381 | |
382 | bool IsMin(const NodeDef& node) { return node.op() == "Min" ; } |
383 | |
384 | bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum" ; } |
385 | |
386 | bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad" ; } |
387 | |
388 | bool IsMirrorPadGrad(const NodeDef& node) { |
389 | return node.op() == "MirrorPadGrad" ; |
390 | } |
391 | |
392 | bool IsMod(const NodeDef& node) { return node.op() == "Mod" ; } |
393 | |
394 | bool IsMul(const NodeDef& node) { return node.op() == "Mul" ; } |
395 | bool IsMulNoNan(const NodeDef& node) { return node.op() == "MulNoNan" ; } |
396 | bool IsAnyMul(const NodeDef& node) { return IsMul(node) || IsMulNoNan(node); } |
397 | |
398 | bool IsNeg(const NodeDef& node) { return node.op() == "Neg" ; } |
399 | |
400 | bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp" ; } |
401 | |
402 | bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual" ; } |
403 | |
404 | bool IsNextIteration(const NodeDef& node) { |
405 | const auto& op = node.op(); |
406 | return op == "NextIteration" || op == "RefNextIteration" ; |
407 | } |
408 | |
409 | bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike" ; } |
410 | |
411 | bool IsPack(const NodeDef& node) { return node.op() == "Pack" ; } |
412 | |
413 | bool IsPad(const NodeDef& node) { |
414 | const auto& op = node.op(); |
415 | return op == "Pad" || op == "PadV2" ; |
416 | } |
417 | |
418 | bool IsPartitionedCall(const NodeDef& node) { |
419 | return node.op() == "PartitionedCall" ; |
420 | } |
421 | |
422 | bool IsPlaceholder(const NodeDef& node) { |
423 | const auto& op = node.op(); |
424 | return op == "Placeholder" || op == "PlaceholderV2" || |
425 | op == "PlaceholderWithDefault" ; |
426 | } |
427 | |
428 | bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma" ; } |
429 | |
430 | bool IsPow(const NodeDef& node) { return node.op() == "Pow" ; } |
431 | |
432 | bool IsPrint(const NodeDef& node) { |
433 | return node.op() == "Print" || node.op() == "PrintV2" ; |
434 | } |
435 | |
436 | bool IsProd(const NodeDef& node) { return node.op() == "Prod" ; } |
437 | |
438 | bool IsQuantizedMatMul(const NodeDef& node) { |
439 | return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2" ; |
440 | } |
441 | |
442 | bool IsQueue(const NodeDef& node) { |
443 | return str_util::EndsWith(node.op(), "QueueV2" ); |
444 | } |
445 | |
446 | bool IsRandomShuffle(const NodeDef& node) { |
447 | return node.op() == "RandomShuffle" ; |
448 | } |
449 | |
450 | bool IsRank(const NodeDef& node) { return node.op() == "Rank" ; } |
451 | |
452 | bool IsReadVariableOp(const NodeDef& node) { |
453 | return node.op() == "ReadVariableOp" ; |
454 | } |
455 | |
456 | bool IsReadVariablesOp(const NodeDef& node) { |
457 | return node.op() == "_ReadVariablesOp" ; |
458 | } |
459 | |
460 | bool IsReal(const NodeDef& node) { return node.op() == "Real" ; } |
461 | |
462 | bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv" ; } |
463 | |
464 | bool IsReciprocalGrad(const NodeDef& node) { |
465 | return node.op() == "ReciprocalGrad" ; |
466 | } |
467 | |
468 | bool IsRecv(const NodeDef& node) { |
469 | return node.op() == "_Recv" || node.op() == "_HostRecv" ; |
470 | } |
471 | |
472 | bool IsReduction(const NodeDef& node) { |
473 | const auto& op = node.op(); |
474 | return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" || |
475 | op == "Mean" || op == "Any" || op == "All" ; |
476 | } |
477 | |
478 | bool IsRelu(const NodeDef& node) { return node.op() == "Relu" ; } |
479 | |
480 | bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6" ; } |
481 | |
482 | bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad" ; } |
483 | |
484 | bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad" ; } |
485 | |
486 | bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape" ); } |
487 | |
488 | bool IsRestore(const NodeDef& node) { |
489 | return (node.op() == "Restore" || node.op() == "RestoreV2" || |
490 | node.op() == "RestoreSlice" ); |
491 | } |
492 | |
493 | bool IsRetval(const NodeDef& node) { |
494 | return node.op() == "_Retval" || node.op() == "_DeviceRetval" ; |
495 | } |
496 | |
497 | bool IsReverse(const NodeDef& node) { |
498 | return node.op() == "Reverse" || node.op() == "ReverseV2" ; |
499 | } |
500 | |
501 | bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2" ; } |
502 | |
503 | bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt" ; } |
504 | |
505 | bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad" ; } |
506 | |
507 | bool IsSelect(const NodeDef& node) { |
508 | return node.op() == "Select" || node.op() == "SelectV2" ; |
509 | } |
510 | |
511 | bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad" ; } |
512 | |
513 | bool IsSend(const NodeDef& node) { |
514 | return node.op() == "_Send" || node.op() == "_HostSend" ; |
515 | } |
516 | |
517 | bool IsShape(const NodeDef& node) { return node.op() == "Shape" ; } |
518 | |
519 | bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN" ; } |
520 | |
521 | bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle" ; } |
522 | |
523 | bool IsSigmoid(const NodeDef& node) { return node.op() == "Sigmoid" ; } |
524 | |
525 | bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad" ; } |
526 | |
527 | bool IsSize(const NodeDef& node) { return node.op() == "Size" ; } |
528 | |
529 | bool IsSlice(const NodeDef& node) { return node.op() == "Slice" ; } |
530 | |
531 | bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot" ; } |
532 | |
533 | bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax" ; } |
534 | |
535 | bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad" ; } |
536 | |
537 | bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad" ; } |
538 | |
539 | bool IsSplit(const NodeDef& node) { return node.op() == "Split" ; } |
540 | |
541 | bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV" ; } |
542 | |
543 | bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt" ; } |
544 | |
545 | bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad" ; } |
546 | |
547 | bool IsSquare(const NodeDef& node) { return node.op() == "Square" ; } |
548 | |
549 | bool IsSquaredDifference(const NodeDef& node) { |
550 | return node.op() == "SquaredDifference" ; |
551 | } |
552 | |
553 | bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze" ; } |
554 | |
555 | bool IsStackOp(const NodeDef& node) { |
556 | return node.op() == "Stack" || node.op() == "StackV2" ; |
557 | } |
558 | bool IsStackCloseOp(const NodeDef& node) { |
559 | return node.op() == "StackClose" || node.op() == "StackCloseV2" ; |
560 | } |
561 | bool IsStackPushOp(const NodeDef& node) { |
562 | return node.op() == "StackPush" || node.op() == "StackPushV2" ; |
563 | } |
564 | bool IsStackPopOp(const NodeDef& node) { |
565 | return node.op() == "StackPop" || node.op() == "StackPopV2" ; |
566 | } |
567 | |
568 | bool IsStatefulPartitionedCall(const NodeDef& node) { |
569 | return node.op() == "StatefulPartitionedCall" ; |
570 | } |
571 | |
572 | bool IsStopGradient(const NodeDef& node) { |
573 | const auto& op = node.op(); |
574 | return op == "StopGradient" || op == "PreventGradient" ; |
575 | } |
576 | |
577 | bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice" ; } |
578 | |
579 | bool IsStridedSliceGrad(const NodeDef& node) { |
580 | return node.op() == "StridedSliceGrad" ; |
581 | } |
582 | |
583 | bool IsStringToHashBucketFast(const NodeDef& node) { |
584 | return node.op() == "StringToHashBucketFast" ; |
585 | } |
586 | |
587 | bool IsSub(const NodeDef& node) { return node.op() == "Sub" ; } |
588 | |
589 | bool IsSum(const NodeDef& node) { return node.op() == "Sum" ; } |
590 | |
591 | bool IsSwitch(const NodeDef& node) { |
592 | const auto& op = node.op(); |
593 | return op == "_SwitchN" || op == "Switch" || op == "RefSwitch" ; |
594 | } |
595 | |
596 | bool IsSymbolicGradient(const NodeDef& node) { |
597 | return node.op() == "SymbolicGradient" ; |
598 | } |
599 | |
600 | bool IsTanh(const NodeDef& node) { return node.op() == "Tanh" ; } |
601 | |
602 | bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad" ; } |
603 | |
604 | bool IsTensorArray(const NodeDef& node) { |
605 | static const gtl::FlatSet<string>* const kTensorArrayOps = |
606 | CHECK_NOTNULL((new gtl::FlatSet<string>{ |
607 | "TensorArray" , |
608 | "TensorArrayV2" , |
609 | "TensorArrayV3" , |
610 | "TensorArrayGrad" , |
611 | "TensorArrayGradV2" , |
612 | "TensorArrayGradV3" , |
613 | "TensorArrayGradWithShape" , |
614 | "TensorArrayWrite" , |
615 | "TensorArrayWriteV2" , |
616 | "TensorArrayWriteV3" , |
617 | "TensorArrayRead" , |
618 | "TensorArrayReadV2" , |
619 | "TensorArrayReadV3" , |
620 | "TensorArrayConcat" , |
621 | "TensorArrayConcatV2" , |
622 | "TensorArrayConcatV3" , |
623 | "TensorArraySplit" , |
624 | "TensorArraySplitV2" , |
625 | "TensorArraySplitV3" , |
626 | "TensorArraySize" , |
627 | "TensorArraySizeV2" , |
628 | "TensorArraySizeV3" , |
629 | "TensorArrayClose" , |
630 | "TensorArrayCloseV2" , |
631 | "TensorArrayCloseV3" , |
632 | })); |
633 | return kTensorArrayOps->count(node.op()) > 0; |
634 | } |
635 | |
636 | bool IsTile(const NodeDef& node) { return node.op() == "Tile" ; } |
637 | |
638 | bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose" ; } |
639 | |
640 | bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv" ; } |
641 | |
642 | bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod" ; } |
643 | |
644 | bool IsUnique(const NodeDef& node) { |
645 | const auto& op = node.op(); |
646 | return op == "Unique" || op == "UniqueV2" ; |
647 | } |
648 | |
649 | bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack" ; } |
650 | |
651 | bool IsVariable(const NodeDef& node) { |
652 | const auto& op = node.op(); |
653 | return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" || |
654 | op == "VarHandleOp" || op == "ReadVariableOp" || |
655 | op == "_VarHandlesOp" || op == "_ReadVariablesOp" ; |
656 | } |
657 | |
658 | bool IsWhile(const NodeDef& node) { |
659 | const auto& op = node.op(); |
660 | return op == "While" || op == "StatelessWhile" ; |
661 | } |
662 | |
663 | bool IsXdivy(const NodeDef& node) { return node.op() == "Xdivy" ; } |
664 | |
665 | bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike" ; } |
666 | |
667 | bool IsZeta(const NodeDef& node) { return node.op() == "Zeta" ; } |
668 | |
669 | namespace { |
670 | bool GetBoolAttr(const NodeDef& node, const string& name) { |
671 | return node.attr().count(name) > 0 && node.attr().at(name).b(); |
672 | } |
673 | } // namespace |
674 | |
675 | bool IsPersistent(const NodeDef& node) { |
676 | return IsConstant(node) || IsVariable(node) || IsHostConstant(node); |
677 | } |
678 | |
679 | bool HasRefInput(const NodeDef& node) { |
680 | const OpDef* op_def; |
681 | Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); |
682 | if (!status.ok()) { |
683 | return false; |
684 | } |
685 | // Nodes such as Assign or AssignAdd modify one of their inputs. |
686 | for (const auto& input : op_def->input_arg()) { |
687 | if (input.is_ref()) { |
688 | return true; |
689 | } |
690 | } |
691 | return false; |
692 | } |
693 | |
694 | bool IsDataset(const NodeDef& node) { |
695 | const string& op = node.op(); |
696 | // See `GetNodeClassForOp` in core/graph/graph.cc. |
697 | return op == "IteratorGetNext" || op == "IteratorGetNextSync" || |
698 | op == "DatasetToSingleElement" || op == "ReduceDataset" ; |
699 | } |
700 | |
701 | bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) { |
702 | const OpDef* op_def = nullptr; |
703 | const string& op_name = node.op(); |
704 | Status status = op_registry->LookUpOpDef(op_name, &op_def); |
705 | if (!status.ok()) { |
706 | LOG(WARNING) << "Failed to lookup OpDef for " << op_name |
707 | << ". Error: " << status.error_message(); |
708 | return false; |
709 | } |
710 | return op_def->is_stateful(); |
711 | } |
712 | |
713 | bool IsStateful(const NodeDef node) { |
714 | return IsStateful(node, OpRegistry::Global()); |
715 | } |
716 | |
717 | bool IsFreeOfSideEffect(const NodeDef& node, |
718 | const OpRegistryInterface* op_registry) { |
719 | // Placeholders must be preserved to keep the graph feedable. |
720 | if (IsPlaceholder(node)) { |
721 | return false; |
722 | } |
723 | const OpDef* op_def = nullptr; |
724 | const string& op_name = node.op(); |
725 | Status status = op_registry->LookUpOpDef(op_name, &op_def); |
726 | if (!status.ok()) { |
727 | return false; |
728 | } |
729 | if (op_def->is_stateful()) { |
730 | return false; |
731 | } |
732 | // Nodes such as Assign or AssignAdd modify one of their inputs. |
733 | for (const auto& input : op_def->input_arg()) { |
734 | if (input.is_ref()) { |
735 | return false; |
736 | } |
737 | } |
738 | // Queue ops modify the queue which is a side effect. |
739 | if (node.op().find("Queue" ) != string::npos) { |
740 | return false; |
741 | } |
742 | // Sending a tensor via a network is a side effect. |
743 | if (IsSend(node)) { |
744 | return false; |
745 | } |
746 | return !ModifiesInputsInPlace(node); |
747 | } |
748 | |
749 | bool IsFreeOfSideEffect(const NodeDef& node) { |
750 | return IsFreeOfSideEffect(node, OpRegistry::Global()); |
751 | } |
752 | |
753 | bool ModifiesInputsInPlace(const NodeDef& node) { |
754 | // Some nodes do in-place updates on regular tensor inputs. |
755 | const string& op_name = node.op(); |
756 | |
757 | // Ops that modify resource variables effectively modify one of their inputs. |
758 | if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" || |
759 | op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" || |
760 | op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" || |
761 | op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" || |
762 | op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax" ) { |
763 | return false; |
764 | } |
765 | |
766 | string lower_op_name = op_name; |
767 | std::transform(lower_op_name.begin(), lower_op_name.end(), |
768 | lower_op_name.begin(), ::tolower); |
769 | if (absl::StrContains(lower_op_name, "inplace" )) { |
770 | return true; |
771 | } |
772 | return GetBoolAttr(node, "in_place" ) || GetBoolAttr(node, "inplace" ); |
773 | } |
774 | |
775 | bool ModifiesFrameInfo(const NodeDef& node) { |
776 | return IsEnter(node) || IsExit(node) || IsNextIteration(node); |
777 | } |
778 | |
779 | #define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \ |
780 | bool Is##PROPERTY_CAP(const NodeDef& node) { \ |
781 | if (node.op() == "Add") { \ |
782 | /* Workaround for "Add" not being marked is_commutative and */ \ |
783 | /* is_aggregate. (See cl/173915048). */ \ |
784 | const auto type = GetDataTypeFromAttr(node, "T"); \ |
785 | return type != DT_INVALID && type != DT_STRING; \ |
786 | } \ |
787 | const OpDef* op_def = nullptr; \ |
788 | Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \ |
789 | return status.ok() && op_def->is_##PROPERTY(); \ |
790 | } |
791 | |
792 | OPDEF_PROPERTY_HELPER(Aggregate, aggregate) |
793 | OPDEF_PROPERTY_HELPER(Commutative, commutative) |
794 | |
795 | bool IsInvolution(const NodeDef& node) { |
796 | static const gtl::FlatSet<string>* const kInvolutionOps = |
797 | CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj" , "Reciprocal" , "Invert" , |
798 | "Neg" , "LogicalNot" })); |
799 | return kInvolutionOps->count(node.op()) > 0; |
800 | } |
801 | |
802 | bool IsValueAndOrderAndShapePreserving(const NodeDef& node) { |
803 | if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { |
804 | return true; |
805 | } |
806 | static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps = |
807 | CHECK_NOTNULL((new const gtl::FlatSet<string>{ |
808 | "CheckNumerics" , |
809 | "DebugGradientIdentity" , |
810 | "DeepCopy" , |
811 | "Enter" , |
812 | "Exit" , |
813 | "PreventGradient" , |
814 | "Print" , |
815 | "Snapshot" , |
816 | "StopGradient" , |
817 | })); |
818 | return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 || |
819 | IsIdentity(node); |
820 | } |
821 | |
822 | bool IsValueAndOrderPreserving(const NodeDef& node) { |
823 | if (NumNonControlInputs(node) == 1 && IsAggregate(node)) { |
824 | return true; |
825 | } |
826 | static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps = |
827 | CHECK_NOTNULL((new const gtl::FlatSet<string>{ |
828 | "ExpandDims" , |
829 | "Reshape" , |
830 | "Squeeze" , |
831 | })); |
832 | return kValueAndOrderPreservingOps->count(node.op()) > 0 || |
833 | IsValueAndOrderAndShapePreserving(node); |
834 | } |
835 | |
836 | bool IsValuePreserving(const NodeDef& node) { |
837 | static const gtl::FlatSet<string>* const kValuePreservingOps = |
838 | CHECK_NOTNULL((new gtl::FlatSet<string>{ |
839 | "InvertPermutation" , |
840 | "Reverse" , |
841 | "ReverseV2" , |
842 | "Roll" , |
843 | "Transpose" , |
844 | "DepthToSpace" , |
845 | "SpaceToDepth" , |
846 | "BatchToSpace" , |
847 | "BatchToSpaceND" , |
848 | "SpaceToBatch" , |
849 | "SpaceToBatchND" , |
850 | })); |
851 | return IsValueAndOrderPreserving(node) || |
852 | kValuePreservingOps->count(node.op()) > 0; |
853 | } |
854 | |
855 | bool IsUnaryElementWise(const NodeDef& node) { |
856 | static const gtl::FlatSet<string>* const kElementWiseOps = |
857 | CHECK_NOTNULL((new gtl::FlatSet<string>{ |
858 | "Abs" , "Acos" , "Acosh" , "Asin" , "Asinh" , |
859 | "Atan" , "Atanh" , "Ceil" , "ComplexAbs" , "Conj" , |
860 | "Cos" , "Cosh" , "Digamma" , "Elu" , "Erf" , |
861 | "Erfc" , "Exp" , "Expm1" , "Floor" , "Inv" , |
862 | "Invert" , "Isinf" , "Isnan" , "Isfinite" , "Lgamma" , |
863 | "Log" , "Log1p" , "LogicalNot" , "Neg" , "Reciprocal" , |
864 | "Relu" , "Relu6" , "Rint" , "Round" , "Selu" , |
865 | "Rsqrt" , "Sigmoid" , "Sign" , "Sin" , "SinH" , |
866 | "Softplus" , "Softsign" , "Sqrt" , "Square" , "Tan" , |
867 | "Tanh" , |
868 | })); |
869 | return kElementWiseOps->count(node.op()) > 0 || |
870 | IsValueAndOrderAndShapePreserving(node); |
871 | } |
872 | |
873 | bool HasOpDef(const NodeDef& node) { |
874 | const OpDef* op_def = nullptr; |
875 | return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok(); |
876 | } |
877 | |
878 | bool IsIdempotent(const NodeDef& node) { |
879 | return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) && |
880 | !ModifiesFrameInfo(node); |
881 | } |
882 | |
883 | bool NeverForwardsInputs(const NodeDef& node) { |
884 | static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL( |
885 | (new gtl::FlatSet<string>{"ArgMax" , |
886 | "ArgMin" , |
887 | "AudioSpectrogram" , |
888 | "AvgPool" , |
889 | "BatchMatMul" , |
890 | "BatchMatMulV2" , |
891 | "BatchNormWithGlobalNormalization" , |
892 | "BatchToSpace" , |
893 | "BatchToSpaceND" , |
894 | "Bincount" , |
895 | "BroadcastArgs" , |
896 | "BroadcastGradientArgs" , |
897 | "Bucketize" , |
898 | "CTCBeamSearchDecoder" , |
899 | "CTCGreedyDecoder" , |
900 | "CTCLoss" , |
901 | "CompareAndBitpack" , |
902 | "ComplexAbs" , |
903 | "Concat" , |
904 | "ConcatOffset" , |
905 | "ConcatV2" , |
906 | "Conv2D" , |
907 | "Copy" , |
908 | "CopyHost" , |
909 | "Cross" , |
910 | "CudnnRNN" , |
911 | "CudnnRNNBackprop" , |
912 | "CudnnRNNBackpropV2" , |
913 | "CudnnRNNBackpropV3" , |
914 | "CudnnRNNCanonicalToParams" , |
915 | "CudnnRNNCanonicalToParamsV2" , |
916 | "CudnnRNNParamsSize" , |
917 | "CudnnRNNParamsToCanonical" , |
918 | "CudnnRNNParamsToCanonicalV2" , |
919 | "CudnnRNNV2" , |
920 | "CudnnRNNV3" , |
921 | "CumProd" , |
922 | "CumSum" , |
923 | "DebugNanCount" , |
924 | "DebugNumericSummary" , |
925 | "DecodeProtoV2" , |
926 | "DecodeWav" , |
927 | "DeepCopy" , |
928 | "DepthToSpace" , |
929 | "Dequantize" , |
930 | "Diag" , |
931 | "DiagPart" , |
932 | "EditDistance" , |
933 | "Empty" , |
934 | "EncodeProtoV2" , |
935 | "EncodeWav" , |
936 | "ExtractImagePatches" , |
937 | "ExtractVolumePatches" , |
938 | "Fill" , |
939 | "Gather" , |
940 | "GatherNd" , |
941 | "GatherV2" , |
942 | "HistogramFixedWidth" , |
943 | "InvertPermutation" , |
944 | "IsInf" , |
945 | "IsNan" , |
946 | "Isfinite" , |
947 | "LinSpace" , |
948 | "LowerBound" , |
949 | "MatMul" , |
950 | "MatrixDiag" , |
951 | "MatrixDiagPart" , |
952 | "MatrixDiagPartV2" , |
953 | "MatrixDiagV2" , |
954 | "Mfcc" , |
955 | "Multinomial" , |
956 | "OneHot" , |
957 | "Pack" , |
958 | "ParameterizedTruncatedNormal" , |
959 | "PopulationCount" , |
960 | "RandomGamma" , |
961 | "RandomPoisson" , |
962 | "RandomPoissonV2" , |
963 | "RandomStandardNormal" , |
964 | "RandomUniform" , |
965 | "RandomUniformInt" , |
966 | "Range" , |
967 | "Rank" , |
968 | "RequantizationRange" , |
969 | "Requantize" , |
970 | "ReverseSequence" , |
971 | "Shape" , |
972 | "ShapeN" , |
973 | "Size" , |
974 | "SpaceToBatch" , |
975 | "SpaceToBatchND" , |
976 | "SpaceToDepth" , |
977 | "SparseMatMul" , |
978 | "Split" , |
979 | "SplitV" , |
980 | "TruncatedNormal" , |
981 | "Unique" , |
982 | "UniqueV2" , |
983 | "UniqueWithCounts" , |
984 | "UniqueWithCountsV2" , |
985 | "Unpack" , |
986 | "UnravelIndex" , |
987 | "UpperBound" , |
988 | "Where" })); |
989 | const string& op_name = node.op(); |
990 | return kNonForwardingOps->count(op_name) > 0 || |
991 | absl::StrContains(op_name, "Segment" ) || |
992 | absl::StartsWith(op_name, "Quantize" ); |
993 | } |
994 | |
995 | bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch" ; } |
996 | |
997 | } // namespace grappler |
998 | } // end namespace tensorflow |
999 | |