1/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/grappler/op_types.h"
17
18#include "tensorflow/core/framework/attr_value.pb.h"
19#include "tensorflow/core/framework/op.h"
20#include "tensorflow/core/framework/types.h"
21#include "tensorflow/core/grappler/utils.h"
22#include "tensorflow/core/lib/core/status.h"
23#include "tensorflow/core/lib/gtl/flatset.h"
24#include "tensorflow/core/lib/strings/str_util.h"
25#include "tensorflow/core/platform/logging.h"
26
27namespace tensorflow {
28namespace grappler {
29
30bool IsAdd(const NodeDef& node) {
31 if (node.op() == "AddV2") {
32 return true;
33 }
34 if (node.op() == "Add") {
35 DataType type = node.attr().at("T").type();
36 return type != DT_STRING;
37 }
38 return false;
39}
40
41bool IsAddN(const NodeDef& node) { return node.op() == "AddN"; }
42
43bool IsAll(const NodeDef& node) { return node.op() == "All"; }
44
45bool IsAngle(const NodeDef& node) { return node.op() == "Angle"; }
46
47bool IsAny(const NodeDef& node) { return node.op() == "Any"; }
48
49bool IsAnyDiv(const NodeDef& node) {
50 return node.op() == "RealDiv" || node.op() == "Div" || node.op() == "Xdivy" ||
51 node.op() == "FloorDiv" || node.op() == "TruncateDiv";
52}
53
54bool IsAnyBatchMatMul(const NodeDef& node) {
55 return node.op() == "BatchMatMul" || node.op() == "BatchMatMulV2";
56}
57
58bool IsAnyMatMul(const NodeDef& node) {
59 return node.op() == "MatMul" || node.op() == "SparseMatMul" ||
60 IsAnyBatchMatMul(node) || IsQuantizedMatMul(node);
61}
62
63bool IsAnyMax(const NodeDef& node) {
64 const auto& op = node.op();
65 return op == "Max" || op == "SegmentMax" || op == "UnsortedSegmentMax";
66}
67
68bool IsAnyMaxPool(const NodeDef& node) {
69 const auto& op = node.op();
70 return op == "MaxPool" || op == "MaxPoolV2" || op == "MaxPool3D" ||
71 op == "MaxPoolWithArgmax" || op == "FractionalMaxPool";
72}
73
74bool IsAnyMin(const NodeDef& node) {
75 const auto& op = node.op();
76 return op == "Min" || op == "SegmentMin" || op == "UnsortedSegmentMin";
77}
78
79bool IsAnySparseSegmentReduction(const NodeDef& node) {
80 const auto& op = node.op();
81 return op == "SparseSegmentSum" || op == "SparseSegmentSumWithNumSegments" ||
82 op == "SparseSegmentMean" ||
83 op == "SparseSegmentMeanWithNumSegments" ||
84 op == "SparseSegmentSqrtN" ||
85 op == "SparseSegmentSqrtNWithNumSegments";
86}
87
88bool IsApproximateEqual(const NodeDef& node) {
89 return node.op() == "ApproximateEqual";
90}
91
92bool IsArg(const NodeDef& node) {
93 return node.op() == "_Arg" || node.op() == "_DeviceArg";
94}
95
96bool IsArgMax(const NodeDef& node) { return node.op() == "ArgMax"; }
97
98bool IsArgMin(const NodeDef& node) { return node.op() == "ArgMin"; }
99
100bool IsAvgPoolGrad(const NodeDef& node) { return node.op() == "AvgPoolGrad"; }
101
102bool IsAssign(const NodeDef& node) {
103 return node.op() == "Assign" || node.op() == "AssignVariableOp";
104}
105
106bool IsAssert(const NodeDef& node) { return node.op() == "Assert"; }
107
108bool IsAsString(const NodeDef& node) { return node.op() == "AsString"; }
109
110bool IsAtan2(const NodeDef& node) { return node.op() == "Atan2"; }
111
112bool IsBetainc(const NodeDef& node) { return node.op() == "Betainc"; }
113
114bool IsBiasAdd(const NodeDef& node) {
115 return node.op() == "BiasAdd" || node.op() == "BiasAddV1";
116}
117
118bool IsBiasAddV2(const NodeDef& node) { return node.op() == "BiasAdd"; }
119
120bool IsBiasAddGrad(const NodeDef& node) { return node.op() == "BiasAddGrad"; }
121
122bool IsBitcast(const NodeDef& node) { return node.op() == "Bitcast"; }
123
124bool IsBroadcastTo(const NodeDef& node) { return node.op() == "BroadcastTo"; }
125
126bool IsCast(const NodeDef& node) { return node.op() == "Cast"; }
127
128bool IsCastLike(const NodeDef& node) {
129 static const gtl::FlatSet<string>* const kCastLikeOps =
130 CHECK_NOTNULL((new gtl::FlatSet<string>{
131 "Angle", "Bucketize", "Cast", "Dequantize", "HistogramFixedWidth",
132 "Imag", "IsFinite", "IsInf", "IsNan", "Quantize",
133 "QuantizeDownAndShrinkRange", "QuantizeV2", "QuantizedInstanceNorm",
134 "QuantizedRelu", "QuantizedRelu6", "QuantizedReluX", "Real",
135 "Requantize"}));
136 return kCastLikeOps->count(node.op()) > 0;
137}
138
139bool IsCheckNumerics(const NodeDef& node) {
140 return node.op() == "CheckNumerics";
141}
142
143bool IsCollective(const NodeDef& node) {
144 return node.op() == "CollectiveReduce" ||
145 node.op() == "CollectiveBcastSend" ||
146 node.op() == "CollectiveBcastRecv";
147}
148
149bool IsComplex(const NodeDef& node) { return node.op() == "Complex"; }
150
151bool IsComplexAbs(const NodeDef& node) { return node.op() == "ComplexAbs"; }
152
153bool IsConcat(const NodeDef& node) {
154 return node.op() == "Concat" || node.op() == "ConcatV2";
155}
156
157bool IsConcatOffset(const NodeDef& node) { return node.op() == "ConcatOffset"; }
158
159bool IsConstant(const NodeDef& node) { return node.op() == "Const"; }
160
161bool IsConj(const NodeDef& node) { return node.op() == "Conj"; }
162
163bool IsConjugateTranspose(const NodeDef& node) {
164 return node.op() == "ConjugateTranspose";
165}
166
167bool IsControlFlow(const NodeDef& node) {
168 // clang-format off
169 return node.op() == "ControlTrigger" ||
170 node.op() == "Enter" ||
171 node.op() == "Exit" ||
172 node.op() == "LoopCond" ||
173 node.op() == "Merge" ||
174 node.op() == "_XlaMerge" ||
175 node.op() == "NextIteration" ||
176 node.op() == "Switch" ||
177 node.op() == "_SwitchN";
178 // clang-format on
179}
180
181bool IsConv2D(const NodeDef& node) { return node.op() == "Conv2D"; }
182
183bool IsConv2DBackpropFilter(const NodeDef& node) {
184 return node.op() == "Conv2DBackpropFilter";
185}
186
187bool IsConv2DBackpropInput(const NodeDef& node) {
188 return node.op() == "Conv2DBackpropInput";
189}
190
191bool IsConv3D(const NodeDef& node) { return node.op() == "Conv3D"; }
192
193bool IsConv3DBackpropFilterV2(const NodeDef& node) {
194 return node.op() == "Conv3DBackpropFilterV2";
195}
196
197bool IsConv3DBackpropInputV2(const NodeDef& node) {
198 return node.op() == "Conv3DBackpropInputV2";
199}
200
201bool IsDepthwiseConv2dNative(const NodeDef& node) {
202 return node.op() == "DepthwiseConv2dNative";
203}
204
205bool IsDepthwiseConv2dNativeBackpropFilter(const NodeDef& node) {
206 return node.op() == "DepthwiseConv2dNativeBackpropFilter";
207}
208
209bool IsDepthwiseConv2dNativeBackpropInput(const NodeDef& node) {
210 return node.op() == "DepthwiseConv2dNativeBackpropInput";
211}
212
213bool IsDequeueOp(const NodeDef& node) {
214 const auto& op = node.op();
215 return op == "QueueDequeueManyV2" || op == "QueueDequeueMany" ||
216 op == "QueueDequeueV2" || op == "QueueDequeue" ||
217 op == "QueueDequeueUpToV2" || op == "QueueDequeueUpTo";
218}
219
220bool IsDiv(const NodeDef& node) { return node.op() == "Div"; }
221
222bool IsDivNoNan(const NodeDef& node) { return node.op() == "DivNoNan"; }
223
224// Returns true if node represents a unary elementwise function that is
225// monotonic. If *is_non_decreasing is true, the function is non-decreasing,
226// e.g. sqrt, exp. *is_non_decreasing is false, the function is non-increasing,
227// e.g. inv.
228bool IsElementWiseMonotonic(const NodeDef& node, bool* is_non_decreasing) {
229 static const gtl::FlatSet<string>* const kMonotonicNonDecreasingOps =
230 CHECK_NOTNULL((new gtl::FlatSet<string>{
231 "Acosh", "Asin", "Asinh", "Atan", "Atanh", "Ceil",
232 "Elu", "Erf", "Exp", "Expm1", "Floor", "Log",
233 "Log1p", "Relu", "Relu6", "Rint", "Selu", "Sigmoid",
234 "Sign", "Sinh", "Softsign", "Softplus", "Sqrt", "Tanh",
235 }));
236 static const gtl::FlatSet<string>* const kMonotonicNonIncreasingOps =
237 CHECK_NOTNULL((new gtl::FlatSet<string>{"Acos", "Erfc", "Neg", "Rsqrt"}));
238 if (kMonotonicNonDecreasingOps->count(node.op()) > 0) {
239 if (is_non_decreasing) {
240 *is_non_decreasing = true;
241 }
242 return true;
243 } else if (kMonotonicNonIncreasingOps->count(node.op()) > 0) {
244 if (is_non_decreasing) {
245 *is_non_decreasing = false;
246 }
247 return true;
248 }
249 return false;
250}
251
252bool IsElu(const NodeDef& node) { return node.op() == "Elu"; }
253
254bool IsEluGrad(const NodeDef& node) { return node.op() == "EluGrad"; }
255
256bool IsQuantizationEmulation(const NodeDef& node) {
257 const auto& op = node.op();
258 return absl::StartsWith(op, "QuantizeAndDequantize") ||
259 absl::StartsWith(op, "FakeQuantWithMinMax");
260}
261
262bool IsEnter(const NodeDef& node) {
263 const auto& op = node.op();
264 return op == "Enter" || op == "RefEnter";
265}
266
267bool IsEqual(const NodeDef& node) { return node.op() == "Equal"; }
268
269bool IsExit(const NodeDef& node) {
270 const auto& op = node.op();
271 return op == "Exit" || op == "RefExit";
272}
273
274bool IsExp(const NodeDef& node) { return node.op() == "Exp"; }
275
276bool IsFakeParam(const NodeDef& node) { return node.op() == "FakeParam"; }
277
278bool IsFill(const NodeDef& node) { return node.op() == "Fill"; }
279
280bool IsFloorDiv(const NodeDef& node) { return node.op() == "FloorDiv"; }
281
282bool IsFloorMod(const NodeDef& node) { return node.op() == "FloorMod"; }
283
284bool IsFusedBatchNorm(const NodeDef& node) {
285 const auto& op = node.op();
286 return op == "FusedBatchNorm" || op == "FusedBatchNormV2" ||
287 op == "FusedBatchNormV3";
288}
289
290bool IsFusedBatchNormEx(const NodeDef& node) {
291 return node.op() == "_FusedBatchNormEx";
292}
293
294bool IsFusedBatchNormGrad(const NodeDef& node) {
295 const auto& op = node.op();
296 return op == "FusedBatchNormGrad" || op == "FusedBatchNormGradV2" ||
297 op == "FusedBatchNormGradV3";
298}
299
300bool IsGather(const NodeDef& node) {
301 const auto& op = node.op();
302 return op == "Gather" || op == "GatherV2" || op == "ResourceGather";
303}
304
305bool IsGreater(const NodeDef& node) { return node.op() == "Greater"; }
306
307bool IsGreaterEqual(const NodeDef& node) { return node.op() == "GreaterEqual"; }
308
309bool IsHostConstant(const NodeDef& node) { return node.op() == "HostConst"; }
310
311bool IsHistogramSummary(const NodeDef& node) {
312 return node.op() == "HistogramSummary";
313}
314
315bool IsIdentity(const NodeDef& node) {
316 const auto& op = node.op();
317 return op == "Identity" || op == "RefIdentity";
318}
319
320bool IsIdentityN(const NodeDef& node) {
321 const auto& op = node.op();
322 return op == "IdentityN";
323}
324
325bool IsIdentityNSingleInput(const NodeDef& node) {
326 return IsIdentityN(node) && node.attr().count("T") != 0 &&
327 node.attr().at("T").list().type_size() == 1;
328}
329
330bool IsIf(const NodeDef& node) {
331 const auto& op = node.op();
332 return op == "If" || op == "StatelessIf";
333}
334
335bool IsIgamma(const NodeDef& node) { return node.op() == "Igamma"; }
336
337bool IsIgammac(const NodeDef& node) { return node.op() == "Igammac"; }
338
339bool IsImag(const NodeDef& node) { return node.op() == "Imag"; }
340
341bool IsImmutableConst(const NodeDef& node) {
342 return node.op() == "ImmutableConst";
343}
344
345bool IsInvGrad(const NodeDef& node) { return node.op() == "InvGrad"; }
346
347bool IsLeakyRelu(const NodeDef& node) { return node.op() == "LeakyRelu"; }
348
349bool IsLeakyReluGrad(const NodeDef& node) {
350 return node.op() == "LeakyReluGrad";
351}
352
353bool IsLess(const NodeDef& node) { return node.op() == "Less"; }
354
355bool IsLessEqual(const NodeDef& node) { return node.op() == "LessEqual"; }
356
357bool IsLog(const NodeDef& node) { return node.op() == "Log"; }
358
359bool IsLogicalAnd(const NodeDef& node) { return node.op() == "LogicalAnd"; }
360
361bool IsLogicalNot(const NodeDef& node) { return node.op() == "LogicalNot"; }
362
363bool IsLogicalOr(const NodeDef& node) { return node.op() == "LogicalOr"; }
364
365bool IsLoopCond(const NodeDef& node) { return node.op() == "LoopCond"; }
366
367bool IsMatMul(const NodeDef& node) { return node.op() == "MatMul"; }
368
369bool IsMax(const NodeDef& node) { return node.op() == "Max"; }
370
371bool IsMaximum(const NodeDef& node) { return node.op() == "Maximum"; }
372
373bool IsMaxPoolGrad(const NodeDef& node) { return node.op() == "MaxPoolGrad"; }
374
375bool IsMean(const NodeDef& node) { return node.op() == "Mean"; }
376
377bool IsMerge(const NodeDef& node) {
378 const auto& op = node.op();
379 return op == "Merge" || op == "RefMerge" || op == "_XlaMerge";
380}
381
382bool IsMin(const NodeDef& node) { return node.op() == "Min"; }
383
384bool IsMinimum(const NodeDef& node) { return node.op() == "Minimum"; }
385
386bool IsMirrorPad(const NodeDef& node) { return node.op() == "MirrorPad"; }
387
388bool IsMirrorPadGrad(const NodeDef& node) {
389 return node.op() == "MirrorPadGrad";
390}
391
392bool IsMod(const NodeDef& node) { return node.op() == "Mod"; }
393
394bool IsMul(const NodeDef& node) { return node.op() == "Mul"; }
395bool IsMulNoNan(const NodeDef& node) { return node.op() == "MulNoNan"; }
396bool IsAnyMul(const NodeDef& node) { return IsMul(node) || IsMulNoNan(node); }
397
398bool IsNeg(const NodeDef& node) { return node.op() == "Neg"; }
399
400bool IsNoOp(const NodeDef& node) { return node.op() == "NoOp"; }
401
402bool IsNotEqual(const NodeDef& node) { return node.op() == "NotEqual"; }
403
404bool IsNextIteration(const NodeDef& node) {
405 const auto& op = node.op();
406 return op == "NextIteration" || op == "RefNextIteration";
407}
408
409bool IsOnesLike(const NodeDef& node) { return node.op() == "OnesLike"; }
410
411bool IsPack(const NodeDef& node) { return node.op() == "Pack"; }
412
413bool IsPad(const NodeDef& node) {
414 const auto& op = node.op();
415 return op == "Pad" || op == "PadV2";
416}
417
418bool IsPartitionedCall(const NodeDef& node) {
419 return node.op() == "PartitionedCall";
420}
421
422bool IsPlaceholder(const NodeDef& node) {
423 const auto& op = node.op();
424 return op == "Placeholder" || op == "PlaceholderV2" ||
425 op == "PlaceholderWithDefault";
426}
427
428bool IsPolygamma(const NodeDef& node) { return node.op() == "Polygamma"; }
429
430bool IsPow(const NodeDef& node) { return node.op() == "Pow"; }
431
432bool IsPrint(const NodeDef& node) {
433 return node.op() == "Print" || node.op() == "PrintV2";
434}
435
436bool IsProd(const NodeDef& node) { return node.op() == "Prod"; }
437
438bool IsQuantizedMatMul(const NodeDef& node) {
439 return node.op() == "QuantizedMatMul" || node.op() == "QuantizedMatMulV2";
440}
441
442bool IsQueue(const NodeDef& node) {
443 return str_util::EndsWith(node.op(), "QueueV2");
444}
445
446bool IsRandomShuffle(const NodeDef& node) {
447 return node.op() == "RandomShuffle";
448}
449
450bool IsRank(const NodeDef& node) { return node.op() == "Rank"; }
451
452bool IsReadVariableOp(const NodeDef& node) {
453 return node.op() == "ReadVariableOp";
454}
455
456bool IsReadVariablesOp(const NodeDef& node) {
457 return node.op() == "_ReadVariablesOp";
458}
459
460bool IsReal(const NodeDef& node) { return node.op() == "Real"; }
461
462bool IsRealDiv(const NodeDef& node) { return node.op() == "RealDiv"; }
463
464bool IsReciprocalGrad(const NodeDef& node) {
465 return node.op() == "ReciprocalGrad";
466}
467
468bool IsRecv(const NodeDef& node) {
469 return node.op() == "_Recv" || node.op() == "_HostRecv";
470}
471
472bool IsReduction(const NodeDef& node) {
473 const auto& op = node.op();
474 return op == "Sum" || op == "Prod" || op == "Min" || op == "Max" ||
475 op == "Mean" || op == "Any" || op == "All";
476}
477
478bool IsRelu(const NodeDef& node) { return node.op() == "Relu"; }
479
480bool IsRelu6(const NodeDef& node) { return node.op() == "Relu6"; }
481
482bool IsReluGrad(const NodeDef& node) { return node.op() == "ReluGrad"; }
483
484bool IsRelu6Grad(const NodeDef& node) { return node.op() == "Relu6Grad"; }
485
486bool IsReshape(const NodeDef& node) { return (node.op() == "Reshape"); }
487
488bool IsRestore(const NodeDef& node) {
489 return (node.op() == "Restore" || node.op() == "RestoreV2" ||
490 node.op() == "RestoreSlice");
491}
492
493bool IsRetval(const NodeDef& node) {
494 return node.op() == "_Retval" || node.op() == "_DeviceRetval";
495}
496
497bool IsReverse(const NodeDef& node) {
498 return node.op() == "Reverse" || node.op() == "ReverseV2";
499}
500
501bool IsReverseV2(const NodeDef& node) { return node.op() == "ReverseV2"; }
502
503bool IsRsqrt(const NodeDef& node) { return node.op() == "Rsqrt"; }
504
505bool IsRsqrtGrad(const NodeDef& node) { return node.op() == "RsqrtGrad"; }
506
507bool IsSelect(const NodeDef& node) {
508 return node.op() == "Select" || node.op() == "SelectV2";
509}
510
511bool IsSeluGrad(const NodeDef& node) { return node.op() == "SeluGrad"; }
512
513bool IsSend(const NodeDef& node) {
514 return node.op() == "_Send" || node.op() == "_HostSend";
515}
516
517bool IsShape(const NodeDef& node) { return node.op() == "Shape"; }
518
519bool IsShapeN(const NodeDef& node) { return node.op() == "ShapeN"; }
520
521bool IsShuffle(const NodeDef& node) { return node.op() == "Shuffle"; }
522
523bool IsSigmoid(const NodeDef& node) { return node.op() == "Sigmoid"; }
524
525bool IsSigmoidGrad(const NodeDef& node) { return node.op() == "SigmoidGrad"; }
526
527bool IsSize(const NodeDef& node) { return node.op() == "Size"; }
528
529bool IsSlice(const NodeDef& node) { return node.op() == "Slice"; }
530
531bool IsSnapshot(const NodeDef& node) { return node.op() == "Snapshot"; }
532
533bool IsSoftmax(const NodeDef& node) { return node.op() == "Softmax"; }
534
535bool IsSoftplusGrad(const NodeDef& node) { return node.op() == "SoftplusGrad"; }
536
537bool IsSoftsignGrad(const NodeDef& node) { return node.op() == "SoftsignGrad"; }
538
539bool IsSplit(const NodeDef& node) { return node.op() == "Split"; }
540
541bool IsSplitV(const NodeDef& node) { return node.op() == "SplitV"; }
542
543bool IsSqrt(const NodeDef& node) { return node.op() == "Sqrt"; }
544
545bool IsSqrtGrad(const NodeDef& node) { return node.op() == "SqrtGrad"; }
546
547bool IsSquare(const NodeDef& node) { return node.op() == "Square"; }
548
549bool IsSquaredDifference(const NodeDef& node) {
550 return node.op() == "SquaredDifference";
551}
552
553bool IsSqueeze(const NodeDef& node) { return node.op() == "Squeeze"; }
554
555bool IsStackOp(const NodeDef& node) {
556 return node.op() == "Stack" || node.op() == "StackV2";
557}
558bool IsStackCloseOp(const NodeDef& node) {
559 return node.op() == "StackClose" || node.op() == "StackCloseV2";
560}
561bool IsStackPushOp(const NodeDef& node) {
562 return node.op() == "StackPush" || node.op() == "StackPushV2";
563}
564bool IsStackPopOp(const NodeDef& node) {
565 return node.op() == "StackPop" || node.op() == "StackPopV2";
566}
567
568bool IsStatefulPartitionedCall(const NodeDef& node) {
569 return node.op() == "StatefulPartitionedCall";
570}
571
572bool IsStopGradient(const NodeDef& node) {
573 const auto& op = node.op();
574 return op == "StopGradient" || op == "PreventGradient";
575}
576
577bool IsStridedSlice(const NodeDef& node) { return node.op() == "StridedSlice"; }
578
579bool IsStridedSliceGrad(const NodeDef& node) {
580 return node.op() == "StridedSliceGrad";
581}
582
583bool IsStringToHashBucketFast(const NodeDef& node) {
584 return node.op() == "StringToHashBucketFast";
585}
586
587bool IsSub(const NodeDef& node) { return node.op() == "Sub"; }
588
589bool IsSum(const NodeDef& node) { return node.op() == "Sum"; }
590
591bool IsSwitch(const NodeDef& node) {
592 const auto& op = node.op();
593 return op == "_SwitchN" || op == "Switch" || op == "RefSwitch";
594}
595
596bool IsSymbolicGradient(const NodeDef& node) {
597 return node.op() == "SymbolicGradient";
598}
599
600bool IsTanh(const NodeDef& node) { return node.op() == "Tanh"; }
601
602bool IsTanhGrad(const NodeDef& node) { return node.op() == "TanhGrad"; }
603
604bool IsTensorArray(const NodeDef& node) {
605 static const gtl::FlatSet<string>* const kTensorArrayOps =
606 CHECK_NOTNULL((new gtl::FlatSet<string>{
607 "TensorArray",
608 "TensorArrayV2",
609 "TensorArrayV3",
610 "TensorArrayGrad",
611 "TensorArrayGradV2",
612 "TensorArrayGradV3",
613 "TensorArrayGradWithShape",
614 "TensorArrayWrite",
615 "TensorArrayWriteV2",
616 "TensorArrayWriteV3",
617 "TensorArrayRead",
618 "TensorArrayReadV2",
619 "TensorArrayReadV3",
620 "TensorArrayConcat",
621 "TensorArrayConcatV2",
622 "TensorArrayConcatV3",
623 "TensorArraySplit",
624 "TensorArraySplitV2",
625 "TensorArraySplitV3",
626 "TensorArraySize",
627 "TensorArraySizeV2",
628 "TensorArraySizeV3",
629 "TensorArrayClose",
630 "TensorArrayCloseV2",
631 "TensorArrayCloseV3",
632 }));
633 return kTensorArrayOps->count(node.op()) > 0;
634}
635
636bool IsTile(const NodeDef& node) { return node.op() == "Tile"; }
637
638bool IsTranspose(const NodeDef& node) { return node.op() == "Transpose"; }
639
640bool IsTruncateDiv(const NodeDef& node) { return node.op() == "TruncateDiv"; }
641
642bool IsTruncateMod(const NodeDef& node) { return node.op() == "TruncateMod"; }
643
644bool IsUnique(const NodeDef& node) {
645 const auto& op = node.op();
646 return op == "Unique" || op == "UniqueV2";
647}
648
649bool IsUnpack(const NodeDef& node) { return node.op() == "Unpack"; }
650
651bool IsVariable(const NodeDef& node) {
652 const auto& op = node.op();
653 return op == "Variable" || op == "VariableV2" || op == "AutoReloadVariable" ||
654 op == "VarHandleOp" || op == "ReadVariableOp" ||
655 op == "_VarHandlesOp" || op == "_ReadVariablesOp";
656}
657
658bool IsWhile(const NodeDef& node) {
659 const auto& op = node.op();
660 return op == "While" || op == "StatelessWhile";
661}
662
663bool IsXdivy(const NodeDef& node) { return node.op() == "Xdivy"; }
664
665bool IsZerosLike(const NodeDef& node) { return node.op() == "ZerosLike"; }
666
667bool IsZeta(const NodeDef& node) { return node.op() == "Zeta"; }
668
669namespace {
670bool GetBoolAttr(const NodeDef& node, const string& name) {
671 return node.attr().count(name) > 0 && node.attr().at(name).b();
672}
673} // namespace
674
675bool IsPersistent(const NodeDef& node) {
676 return IsConstant(node) || IsVariable(node) || IsHostConstant(node);
677}
678
679bool HasRefInput(const NodeDef& node) {
680 const OpDef* op_def;
681 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def);
682 if (!status.ok()) {
683 return false;
684 }
685 // Nodes such as Assign or AssignAdd modify one of their inputs.
686 for (const auto& input : op_def->input_arg()) {
687 if (input.is_ref()) {
688 return true;
689 }
690 }
691 return false;
692}
693
694bool IsDataset(const NodeDef& node) {
695 const string& op = node.op();
696 // See `GetNodeClassForOp` in core/graph/graph.cc.
697 return op == "IteratorGetNext" || op == "IteratorGetNextSync" ||
698 op == "DatasetToSingleElement" || op == "ReduceDataset";
699}
700
701bool IsStateful(const NodeDef node, const OpRegistryInterface* op_registry) {
702 const OpDef* op_def = nullptr;
703 const string& op_name = node.op();
704 Status status = op_registry->LookUpOpDef(op_name, &op_def);
705 if (!status.ok()) {
706 LOG(WARNING) << "Failed to lookup OpDef for " << op_name
707 << ". Error: " << status.error_message();
708 return false;
709 }
710 return op_def->is_stateful();
711}
712
713bool IsStateful(const NodeDef node) {
714 return IsStateful(node, OpRegistry::Global());
715}
716
717bool IsFreeOfSideEffect(const NodeDef& node,
718 const OpRegistryInterface* op_registry) {
719 // Placeholders must be preserved to keep the graph feedable.
720 if (IsPlaceholder(node)) {
721 return false;
722 }
723 const OpDef* op_def = nullptr;
724 const string& op_name = node.op();
725 Status status = op_registry->LookUpOpDef(op_name, &op_def);
726 if (!status.ok()) {
727 return false;
728 }
729 if (op_def->is_stateful()) {
730 return false;
731 }
732 // Nodes such as Assign or AssignAdd modify one of their inputs.
733 for (const auto& input : op_def->input_arg()) {
734 if (input.is_ref()) {
735 return false;
736 }
737 }
738 // Queue ops modify the queue which is a side effect.
739 if (node.op().find("Queue") != string::npos) {
740 return false;
741 }
742 // Sending a tensor via a network is a side effect.
743 if (IsSend(node)) {
744 return false;
745 }
746 return !ModifiesInputsInPlace(node);
747}
748
749bool IsFreeOfSideEffect(const NodeDef& node) {
750 return IsFreeOfSideEffect(node, OpRegistry::Global());
751}
752
753bool ModifiesInputsInPlace(const NodeDef& node) {
754 // Some nodes do in-place updates on regular tensor inputs.
755 const string& op_name = node.op();
756
757 // Ops that modify resource variables effectively modify one of their inputs.
758 if (op_name == "AssignVariableOp" || op_name == "AssignAddVariableOp" ||
759 op_name == "AssignSubVariableOp" || op_name == "ResourceScatterUpdate" ||
760 op_name == "ResourceScatterAdd" || op_name == "ResourceScatterSub" ||
761 op_name == "ResourceScatterMul" || op_name == "ResourceScatterDiv" ||
762 op_name == "ResourceScatterMin" || op_name == "ResourceScatterMax") {
763 return false;
764 }
765
766 string lower_op_name = op_name;
767 std::transform(lower_op_name.begin(), lower_op_name.end(),
768 lower_op_name.begin(), ::tolower);
769 if (absl::StrContains(lower_op_name, "inplace")) {
770 return true;
771 }
772 return GetBoolAttr(node, "in_place") || GetBoolAttr(node, "inplace");
773}
774
775bool ModifiesFrameInfo(const NodeDef& node) {
776 return IsEnter(node) || IsExit(node) || IsNextIteration(node);
777}
778
779#define OPDEF_PROPERTY_HELPER(PROPERTY_CAP, PROPERTY) \
780 bool Is##PROPERTY_CAP(const NodeDef& node) { \
781 if (node.op() == "Add") { \
782 /* Workaround for "Add" not being marked is_commutative and */ \
783 /* is_aggregate. (See cl/173915048). */ \
784 const auto type = GetDataTypeFromAttr(node, "T"); \
785 return type != DT_INVALID && type != DT_STRING; \
786 } \
787 const OpDef* op_def = nullptr; \
788 Status status = OpRegistry::Global()->LookUpOpDef(node.op(), &op_def); \
789 return status.ok() && op_def->is_##PROPERTY(); \
790 }
791
792OPDEF_PROPERTY_HELPER(Aggregate, aggregate)
793OPDEF_PROPERTY_HELPER(Commutative, commutative)
794
795bool IsInvolution(const NodeDef& node) {
796 static const gtl::FlatSet<string>* const kInvolutionOps =
797 CHECK_NOTNULL((new gtl::FlatSet<string>{"Conj", "Reciprocal", "Invert",
798 "Neg", "LogicalNot"}));
799 return kInvolutionOps->count(node.op()) > 0;
800}
801
802bool IsValueAndOrderAndShapePreserving(const NodeDef& node) {
803 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
804 return true;
805 }
806 static const gtl::FlatSet<string>* const kValueAndOrderAndShapePreservingOps =
807 CHECK_NOTNULL((new const gtl::FlatSet<string>{
808 "CheckNumerics",
809 "DebugGradientIdentity",
810 "DeepCopy",
811 "Enter",
812 "Exit",
813 "PreventGradient",
814 "Print",
815 "Snapshot",
816 "StopGradient",
817 }));
818 return kValueAndOrderAndShapePreservingOps->count(node.op()) > 0 ||
819 IsIdentity(node);
820}
821
822bool IsValueAndOrderPreserving(const NodeDef& node) {
823 if (NumNonControlInputs(node) == 1 && IsAggregate(node)) {
824 return true;
825 }
826 static const gtl::FlatSet<string>* const kValueAndOrderPreservingOps =
827 CHECK_NOTNULL((new const gtl::FlatSet<string>{
828 "ExpandDims",
829 "Reshape",
830 "Squeeze",
831 }));
832 return kValueAndOrderPreservingOps->count(node.op()) > 0 ||
833 IsValueAndOrderAndShapePreserving(node);
834}
835
836bool IsValuePreserving(const NodeDef& node) {
837 static const gtl::FlatSet<string>* const kValuePreservingOps =
838 CHECK_NOTNULL((new gtl::FlatSet<string>{
839 "InvertPermutation",
840 "Reverse",
841 "ReverseV2",
842 "Roll",
843 "Transpose",
844 "DepthToSpace",
845 "SpaceToDepth",
846 "BatchToSpace",
847 "BatchToSpaceND",
848 "SpaceToBatch",
849 "SpaceToBatchND",
850 }));
851 return IsValueAndOrderPreserving(node) ||
852 kValuePreservingOps->count(node.op()) > 0;
853}
854
855bool IsUnaryElementWise(const NodeDef& node) {
856 static const gtl::FlatSet<string>* const kElementWiseOps =
857 CHECK_NOTNULL((new gtl::FlatSet<string>{
858 "Abs", "Acos", "Acosh", "Asin", "Asinh",
859 "Atan", "Atanh", "Ceil", "ComplexAbs", "Conj",
860 "Cos", "Cosh", "Digamma", "Elu", "Erf",
861 "Erfc", "Exp", "Expm1", "Floor", "Inv",
862 "Invert", "Isinf", "Isnan", "Isfinite", "Lgamma",
863 "Log", "Log1p", "LogicalNot", "Neg", "Reciprocal",
864 "Relu", "Relu6", "Rint", "Round", "Selu",
865 "Rsqrt", "Sigmoid", "Sign", "Sin", "SinH",
866 "Softplus", "Softsign", "Sqrt", "Square", "Tan",
867 "Tanh",
868 }));
869 return kElementWiseOps->count(node.op()) > 0 ||
870 IsValueAndOrderAndShapePreserving(node);
871}
872
873bool HasOpDef(const NodeDef& node) {
874 const OpDef* op_def = nullptr;
875 return OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok();
876}
877
878bool IsIdempotent(const NodeDef& node) {
879 return IsValueAndOrderAndShapePreserving(node) && IsFreeOfSideEffect(node) &&
880 !ModifiesFrameInfo(node);
881}
882
883bool NeverForwardsInputs(const NodeDef& node) {
884 static const gtl::FlatSet<string>* const kNonForwardingOps = CHECK_NOTNULL(
885 (new gtl::FlatSet<string>{"ArgMax",
886 "ArgMin",
887 "AudioSpectrogram",
888 "AvgPool",
889 "BatchMatMul",
890 "BatchMatMulV2",
891 "BatchNormWithGlobalNormalization",
892 "BatchToSpace",
893 "BatchToSpaceND",
894 "Bincount",
895 "BroadcastArgs",
896 "BroadcastGradientArgs",
897 "Bucketize",
898 "CTCBeamSearchDecoder",
899 "CTCGreedyDecoder",
900 "CTCLoss",
901 "CompareAndBitpack",
902 "ComplexAbs",
903 "Concat",
904 "ConcatOffset",
905 "ConcatV2",
906 "Conv2D",
907 "Copy",
908 "CopyHost",
909 "Cross",
910 "CudnnRNN",
911 "CudnnRNNBackprop",
912 "CudnnRNNBackpropV2",
913 "CudnnRNNBackpropV3",
914 "CudnnRNNCanonicalToParams",
915 "CudnnRNNCanonicalToParamsV2",
916 "CudnnRNNParamsSize",
917 "CudnnRNNParamsToCanonical",
918 "CudnnRNNParamsToCanonicalV2",
919 "CudnnRNNV2",
920 "CudnnRNNV3",
921 "CumProd",
922 "CumSum",
923 "DebugNanCount",
924 "DebugNumericSummary",
925 "DecodeProtoV2",
926 "DecodeWav",
927 "DeepCopy",
928 "DepthToSpace",
929 "Dequantize",
930 "Diag",
931 "DiagPart",
932 "EditDistance",
933 "Empty",
934 "EncodeProtoV2",
935 "EncodeWav",
936 "ExtractImagePatches",
937 "ExtractVolumePatches",
938 "Fill",
939 "Gather",
940 "GatherNd",
941 "GatherV2",
942 "HistogramFixedWidth",
943 "InvertPermutation",
944 "IsInf",
945 "IsNan",
946 "Isfinite",
947 "LinSpace",
948 "LowerBound",
949 "MatMul",
950 "MatrixDiag",
951 "MatrixDiagPart",
952 "MatrixDiagPartV2",
953 "MatrixDiagV2",
954 "Mfcc",
955 "Multinomial",
956 "OneHot",
957 "Pack",
958 "ParameterizedTruncatedNormal",
959 "PopulationCount",
960 "RandomGamma",
961 "RandomPoisson",
962 "RandomPoissonV2",
963 "RandomStandardNormal",
964 "RandomUniform",
965 "RandomUniformInt",
966 "Range",
967 "Rank",
968 "RequantizationRange",
969 "Requantize",
970 "ReverseSequence",
971 "Shape",
972 "ShapeN",
973 "Size",
974 "SpaceToBatch",
975 "SpaceToBatchND",
976 "SpaceToDepth",
977 "SparseMatMul",
978 "Split",
979 "SplitV",
980 "TruncatedNormal",
981 "Unique",
982 "UniqueV2",
983 "UniqueWithCounts",
984 "UniqueWithCountsV2",
985 "Unpack",
986 "UnravelIndex",
987 "UpperBound",
988 "Where"}));
989 const string& op_name = node.op();
990 return kNonForwardingOps->count(op_name) > 0 ||
991 absl::StrContains(op_name, "Segment") ||
992 absl::StartsWith(op_name, "Quantize");
993}
994
995bool IsXlaLaunch(const NodeDef& node) { return node.op() == "XlaLaunch"; }
996
997} // namespace grappler
998} // end namespace tensorflow
999