1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "mlir/IR/Operation.h" // from @llvm-project
17#include "mlir/IR/OperationSupport.h" // from @llvm-project
18#include "tensorflow/core/ir/dialect.h"
19#include "tensorflow/core/ir/tf_op_wrapper.h"
20
21namespace mlir {
22namespace tfg {
23
24bool TFGraphDialect::IsAdd(TFOp op) const {
25 StringAttr op_name = op->getName().getIdentifier();
26
27 if (op_name == add_v2_) return true;
28 if (op_name == add_)
29 return !op->getAttrOfType<TypeAttr>("T").getValue().isa<StringType>();
30 return false;
31}
32
33bool TFGraphDialect::IsAddN(TFOp op) const {
34 StringAttr op_name = op->getName().getIdentifier();
35 return op_name == add_n_;
36}
37
38bool TFGraphDialect::IsAll(TFOp op) const {
39 StringAttr op_name = op->getName().getIdentifier();
40 return op_name == all_;
41}
42
43bool TFGraphDialect::IsAngle(TFOp op) const {
44 StringAttr op_name = op->getName().getIdentifier();
45 return op_name == angle_;
46}
47
48bool TFGraphDialect::IsAny(TFOp op) const {
49 StringAttr op_name = op->getName().getIdentifier();
50 return op_name == any_;
51}
52
53bool TFGraphDialect::IsAnyDiv(TFOp op) const {
54 StringAttr op_name = op->getName().getIdentifier();
55 return op_name == real_div_ || op_name == div_ || IsXdivy(op) ||
56 op_name == floor_div_ || op_name == truncate_div_;
57}
58
59bool TFGraphDialect::IsAnyBatchMatMul(TFOp op) const {
60 StringAttr op_name = op->getName().getIdentifier();
61 return op_name == batch_matmul_ || op_name == batch_matmul_v2_;
62}
63
64bool TFGraphDialect::IsAnyMatMul(TFOp op) const {
65 StringAttr op_name = op->getName().getIdentifier();
66 return op_name == matmul_ || op_name == sparse_matmul_ ||
67 IsAnyBatchMatMul(op) || IsQuantizedMatMul(op);
68}
69
70bool TFGraphDialect::IsAnyMax(TFOp op) const {
71 StringAttr op_name = op->getName().getIdentifier();
72 return op_name == max_ || op_name == segment_max_ ||
73 op_name == unsorted_segment_max_;
74}
75
76bool TFGraphDialect::IsAnyMaxPool(TFOp op) const {
77 StringAttr op_name = op->getName().getIdentifier();
78 return op_name == max_pool_ || op_name == max_pool_v2_ ||
79 op_name == max_pool_3d_ || op_name == max_pool_with_argmax_ ||
80 op_name == fractional_max_pool_;
81}
82
83bool TFGraphDialect::IsAnyMin(TFOp op) const {
84 StringAttr op_name = op->getName().getIdentifier();
85 return op_name == min_ || op_name == segment_min_ ||
86 op_name == unsorted_segment_min_;
87}
88
89bool TFGraphDialect::IsAnySparseSegmentReduction(TFOp op) const {
90 StringAttr op_name = op->getName().getIdentifier();
91 return op_name == sparse_segment_sum_ ||
92 op_name == sparse_segment_sum_with_num_segments_ ||
93 op_name == sparse_segment_mean_ ||
94 op_name == sparse_segment_mean_with_num_segments_ ||
95 op_name == sparse_segment_sqrtn_ ||
96 op_name == sparse_segment_sqrtn_with_num_segments_;
97}
98
99bool TFGraphDialect::IsApproximateEqual(TFOp op) const {
100 StringAttr op_name = op->getName().getIdentifier();
101 return op_name == approximate_equal_;
102}
103
104bool TFGraphDialect::IsArg(TFOp op) const {
105 StringAttr op_name = op->getName().getIdentifier();
106 return op_name == arg_ || op_name == device_arg_;
107}
108
109bool TFGraphDialect::IsArgMax(TFOp op) const {
110 StringAttr op_name = op->getName().getIdentifier();
111 return op_name == arg_max_;
112}
113
114bool TFGraphDialect::IsArgMin(TFOp op) const {
115 StringAttr op_name = op->getName().getIdentifier();
116 return op_name == arg_min_;
117}
118
119bool TFGraphDialect::IsAvgPoolGrad(TFOp op) const {
120 StringAttr op_name = op->getName().getIdentifier();
121 return op_name == arg_pool_grad_;
122}
123
124bool TFGraphDialect::IsAssign(TFOp op) const {
125 StringAttr op_name = op->getName().getIdentifier();
126 return op_name == assign_ || op_name == assign_variable_op_;
127}
128
129bool TFGraphDialect::IsAssert(TFOp op) const {
130 StringAttr op_name = op->getName().getIdentifier();
131 return op_name == assert_;
132}
133
134bool TFGraphDialect::IsAsString(TFOp op) const {
135 StringAttr op_name = op->getName().getIdentifier();
136 return op_name == as_string_;
137}
138
139bool TFGraphDialect::IsAtan2(TFOp op) const {
140 StringAttr op_name = op->getName().getIdentifier();
141 return op_name == atan2_;
142}
143
144bool TFGraphDialect::IsBetainc(TFOp op) const {
145 StringAttr op_name = op->getName().getIdentifier();
146 return op_name == betainc_;
147}
148
149bool TFGraphDialect::IsBiasAdd(TFOp op) const {
150 StringAttr op_name = op->getName().getIdentifier();
151 return op_name == bias_add_ || op_name == bias_add_v1_;
152}
153
154bool TFGraphDialect::IsBiasAddV2(TFOp op) const {
155 StringAttr op_name = op->getName().getIdentifier();
156 return op_name == bias_add_;
157}
158
159bool TFGraphDialect::IsBiasAddGrad(TFOp op) const {
160 StringAttr op_name = op->getName().getIdentifier();
161 return op_name == bias_add_grad_;
162}
163
164bool TFGraphDialect::IsBitcast(TFOp op) const {
165 StringAttr op_name = op->getName().getIdentifier();
166 return op_name == bitcast_;
167}
168
169bool TFGraphDialect::IsBroadcastTo(TFOp op) const {
170 StringAttr op_name = op->getName().getIdentifier();
171 return op_name == broadcast_to_;
172}
173
174bool TFGraphDialect::IsCast(TFOp op) const {
175 StringAttr op_name = op->getName().getIdentifier();
176 return op_name == cast_;
177}
178
179bool TFGraphDialect::IsCheckNumerics(TFOp op) const {
180 StringAttr op_name = op->getName().getIdentifier();
181 return op_name == check_numerics_;
182}
183
184bool TFGraphDialect::IsCollective(TFOp op) const {
185 StringAttr op_name = op->getName().getIdentifier();
186 return op_name == collective_reduce_ || op_name == collective_bcast_send_ ||
187 op_name == collective_bcast_recv_;
188}
189
190bool TFGraphDialect::IsComplex(TFOp op) const {
191 StringAttr op_name = op->getName().getIdentifier();
192 return op_name == complex_;
193}
194
195bool TFGraphDialect::IsComplexAbs(TFOp op) const {
196 StringAttr op_name = op->getName().getIdentifier();
197 return op_name == complex_abs_;
198}
199
200bool TFGraphDialect::IsConcat(TFOp op) const {
201 StringAttr op_name = op->getName().getIdentifier();
202 return op_name == concat_ || IsConcatV2(op);
203}
204
205bool TFGraphDialect::IsConcatV2(TFOp op) const {
206 StringAttr op_name = op->getName().getIdentifier();
207 return op_name == concat_v2_;
208}
209
210bool TFGraphDialect::IsConcatOffset(TFOp op) const {
211 StringAttr op_name = op->getName().getIdentifier();
212 return op_name == concat_offset_;
213}
214
215bool TFGraphDialect::IsConstant(TFOp op) const {
216 StringAttr op_name = op->getName().getIdentifier();
217 return op_name == const_;
218}
219
220bool TFGraphDialect::IsConj(TFOp op) const {
221 StringAttr op_name = op->getName().getIdentifier();
222 return op_name == conj_;
223}
224
225bool TFGraphDialect::IsConjugateTranspose(TFOp op) const {
226 StringAttr op_name = op->getName().getIdentifier();
227 return op_name == conjugate_transpose_;
228}
229
230// TODO(chiahungduan): Should we use certain helpers like IsEnter().
231bool TFGraphDialect::IsControlFlow(TFOp op) const {
232 StringAttr op_name = op->getName().getIdentifier();
233
234 return op_name == control_trigger_ || op_name == enter_ || op_name == exit_ ||
235 op_name == loop_cond_ || op_name == merge_ || op_name == xla_merge_ ||
236 op_name == next_iteration_ || op_name == switch_ ||
237 op_name == switch_n_;
238}
239
240bool TFGraphDialect::IsConv2D(TFOp op) const {
241 StringAttr op_name = op->getName().getIdentifier();
242 return op_name == conv_2d_;
243}
244
245bool TFGraphDialect::IsConv2DBackpropFilter(TFOp op) const {
246 StringAttr op_name = op->getName().getIdentifier();
247 return op_name == conv_2d_back_prop_filter_;
248}
249
250bool TFGraphDialect::IsConv2DBackpropInput(TFOp op) const {
251 StringAttr op_name = op->getName().getIdentifier();
252 return op_name == conv_2d_back_prop_input_;
253}
254
255bool TFGraphDialect::IsConv3D(TFOp op) const {
256 StringAttr op_name = op->getName().getIdentifier();
257 return op_name == conv_3d_;
258}
259
260bool TFGraphDialect::IsConv3DBackpropFilterV2(TFOp op) const {
261 StringAttr op_name = op->getName().getIdentifier();
262 return op_name == conv_3d_back_prop_filter_v2_;
263}
264
265bool TFGraphDialect::IsConv3DBackpropInputV2(TFOp op) const {
266 StringAttr op_name = op->getName().getIdentifier();
267 return op_name == conv_3d_back_prop_input_v2_;
268}
269
270bool TFGraphDialect::IsDepthwiseConv2dNative(TFOp op) const {
271 StringAttr op_name = op->getName().getIdentifier();
272 return op_name == depth_wise_conv_2d_native_;
273}
274
275bool TFGraphDialect::IsDepthwiseConv2dNativeBackpropFilter(TFOp op) const {
276 StringAttr op_name = op->getName().getIdentifier();
277 return op_name == depth_wise_conv_2d_native_back_prop_filter_;
278}
279
280bool TFGraphDialect::IsDepthwiseConv2dNativeBackpropInput(TFOp op) const {
281 StringAttr op_name = op->getName().getIdentifier();
282 return op_name == depth_wise_conv_2d_native_back_prop_input_;
283}
284
285bool TFGraphDialect::IsDequeueOp(TFOp op) const {
286 StringAttr op_name = op->getName().getIdentifier();
287 return op_name == queue_dequeue_ || op_name == queue_dequeue_v2_ ||
288 op_name == queue_dequeue_many_ || op_name == queue_dequeue_many_v2_ ||
289 op_name == queue_dequeue_upto_ || op_name == queue_dequeue_upto_v2_;
290}
291
292bool TFGraphDialect::IsDiv(TFOp op) const {
293 StringAttr op_name = op->getName().getIdentifier();
294 return op_name == div_;
295}
296
297bool TFGraphDialect::IsDivNoNan(TFOp op) const {
298 StringAttr op_name = op->getName().getIdentifier();
299 return op_name == div_no_nan_;
300}
301
302bool TFGraphDialect::IsElu(TFOp op) const {
303 StringAttr op_name = op->getName().getIdentifier();
304 return op_name == elu_;
305}
306
307bool TFGraphDialect::IsEluGrad(TFOp op) const {
308 StringAttr op_name = op->getName().getIdentifier();
309 return op_name == elu_grad_;
310}
311
312bool TFGraphDialect::IsQuantizationEmulation(TFOp op) const {
313 StringAttr op_name = op->getName().getIdentifier();
314 return op_name == quantize_and_dequantize_ ||
315 op_name == quantize_and_dequantize_v2_ ||
316 op_name == quantize_and_dequantize_v3_ ||
317 op_name == quantize_and_dequantize_v4_ ||
318 op_name == quantize_and_dequantize_v4_grad_ ||
319 op_name == fake_quant_with_min_max_args_ ||
320 op_name == fake_quant_with_min_max_args_gradient_ ||
321 op_name == fake_quant_with_min_max_vars_ ||
322 op_name == fake_quant_with_min_max_vars_gradient_ ||
323 op_name == fake_quant_with_min_max_vars_per_channel_ ||
324 op_name == fake_quant_with_min_max_vars_per_channel_gradient_;
325}
326
327bool TFGraphDialect::IsEnter(TFOp op) const {
328 StringAttr op_name = op->getName().getIdentifier();
329 return op_name == enter_ || op_name == ref_enter_;
330}
331
332bool TFGraphDialect::IsEqual(TFOp op) const {
333 StringAttr op_name = op->getName().getIdentifier();
334 return op_name == equal_;
335}
336
337bool TFGraphDialect::IsExit(TFOp op) const {
338 StringAttr op_name = op->getName().getIdentifier();
339 return op_name == exit_ || op_name == ref_exit_;
340}
341
342bool TFGraphDialect::IsExp(TFOp op) const {
343 StringAttr op_name = op->getName().getIdentifier();
344 return op_name == exp_;
345}
346
347bool TFGraphDialect::IsFakeParam(TFOp op) const {
348 StringAttr op_name = op->getName().getIdentifier();
349 return op_name == fake_param_;
350}
351
352bool TFGraphDialect::IsFill(TFOp op) const {
353 StringAttr op_name = op->getName().getIdentifier();
354 return op_name == fill_;
355}
356
357bool TFGraphDialect::IsFloorDiv(TFOp op) const {
358 StringAttr op_name = op->getName().getIdentifier();
359 return op_name == floor_div_;
360}
361
362bool TFGraphDialect::IsFloorMod(TFOp op) const {
363 StringAttr op_name = op->getName().getIdentifier();
364 return op_name == floor_mod_;
365}
366
367bool TFGraphDialect::IsFusedBatchNorm(TFOp op) const {
368 StringAttr op_name = op->getName().getIdentifier();
369 return op_name == fused_batch_norm_ || op_name == fused_batch_norm_v2_ ||
370 op_name == fused_batch_norm_v3_;
371}
372
373bool TFGraphDialect::IsFusedBatchNormEx(TFOp op) const {
374 StringAttr op_name = op->getName().getIdentifier();
375 return op_name == fused_batch_norm_ex_;
376}
377
378bool TFGraphDialect::IsFusedBatchNormGrad(TFOp op) const {
379 StringAttr op_name = op->getName().getIdentifier();
380 return op_name == fused_batch_norm_grad_ ||
381 op_name == fused_batch_norm_grad_v2_ ||
382 op_name == fused_batch_norm_grad_v3_;
383}
384
385bool TFGraphDialect::IsGather(TFOp op) const {
386 StringAttr op_name = op->getName().getIdentifier();
387 return op_name == gather_ || op_name == gather_v2_ ||
388 op_name == resource_gather_;
389}
390
391bool TFGraphDialect::IsGreater(TFOp op) const {
392 StringAttr op_name = op->getName().getIdentifier();
393 return op_name == greater_;
394}
395
396bool TFGraphDialect::IsGreaterEqual(TFOp op) const {
397 StringAttr op_name = op->getName().getIdentifier();
398 return op_name == greater_equal_;
399}
400
401bool TFGraphDialect::IsHostConstant(TFOp op) const {
402 StringAttr op_name = op->getName().getIdentifier();
403 return op_name == host_const_;
404}
405
406bool TFGraphDialect::IsHistogramSummary(TFOp op) const {
407 StringAttr op_name = op->getName().getIdentifier();
408 return op_name == histogram_summary_;
409}
410
411bool TFGraphDialect::IsIdentity(TFOp op) const {
412 StringAttr op_name = op->getName().getIdentifier();
413 return op_name == identity_ || op_name == ref_identity_;
414}
415
416bool TFGraphDialect::IsIdentityN(TFOp op) const {
417 StringAttr op_name = op->getName().getIdentifier();
418 return op_name == identity_n_;
419}
420
421bool TFGraphDialect::IsIdentityNSingleInput(TFOp op) const {
422 if (!IsIdentityN(op)) return false;
423 auto array_attr = op->getAttrOfType<ArrayAttr>("T");
424 if (!array_attr) return false;
425 // TODO(chiahungduan): Do we need to check the content of array_attr?
426 return array_attr.size() == 1;
427}
428
429bool TFGraphDialect::IsIf(TFOp op) const {
430 StringAttr op_name = op->getName().getIdentifier();
431 return op_name == if_ || op_name == stateless_if_;
432}
433
434bool TFGraphDialect::IsIgamma(TFOp op) const {
435 StringAttr op_name = op->getName().getIdentifier();
436 return op_name == igamma_;
437}
438
439bool TFGraphDialect::IsIgammac(TFOp op) const {
440 StringAttr op_name = op->getName().getIdentifier();
441 return op_name == igammac_;
442}
443
444bool TFGraphDialect::IsImag(TFOp op) const {
445 StringAttr op_name = op->getName().getIdentifier();
446 return op_name == imag_;
447}
448
449bool TFGraphDialect::IsImmutableConst(TFOp op) const {
450 StringAttr op_name = op->getName().getIdentifier();
451 return op_name == immutable_const_;
452}
453
454bool TFGraphDialect::IsInvGrad(TFOp op) const {
455 StringAttr op_name = op->getName().getIdentifier();
456 return op_name == inv_grad_;
457}
458
459bool TFGraphDialect::IsLeakyRelu(TFOp op) const {
460 StringAttr op_name = op->getName().getIdentifier();
461 return op_name == leaky_relu_;
462}
463
464bool TFGraphDialect::IsLeakyReluGrad(TFOp op) const {
465 StringAttr op_name = op->getName().getIdentifier();
466 return op_name == leaky_relu_grad_;
467}
468
469bool TFGraphDialect::IsLess(TFOp op) const {
470 StringAttr op_name = op->getName().getIdentifier();
471 return op_name == less_;
472}
473
474bool TFGraphDialect::IsLessEqual(TFOp op) const {
475 StringAttr op_name = op->getName().getIdentifier();
476 return op_name == less_equal_;
477}
478
479bool TFGraphDialect::IsLog(TFOp op) const {
480 StringAttr op_name = op->getName().getIdentifier();
481 return op_name == log_;
482}
483
484bool TFGraphDialect::IsLogicalAnd(TFOp op) const {
485 StringAttr op_name = op->getName().getIdentifier();
486 return op_name == logical_and_;
487}
488
489bool TFGraphDialect::IsLogicalNot(TFOp op) const {
490 StringAttr op_name = op->getName().getIdentifier();
491 return op_name == logical_not_;
492}
493
494bool TFGraphDialect::IsLogicalOr(TFOp op) const {
495 StringAttr op_name = op->getName().getIdentifier();
496 return op_name == logical_or_;
497}
498
499bool TFGraphDialect::IsLoopCond(TFOp op) const {
500 StringAttr op_name = op->getName().getIdentifier();
501 return op_name == loop_cond_;
502}
503
504bool TFGraphDialect::IsMatMul(TFOp op) const {
505 StringAttr op_name = op->getName().getIdentifier();
506 return op_name == matmul_;
507}
508
509bool TFGraphDialect::IsMax(TFOp op) const {
510 StringAttr op_name = op->getName().getIdentifier();
511 return op_name == max_;
512}
513
514bool TFGraphDialect::IsMaximum(TFOp op) const {
515 StringAttr op_name = op->getName().getIdentifier();
516 return op_name == maximum_;
517}
518
519bool TFGraphDialect::IsMaxPoolGrad(TFOp op) const {
520 StringAttr op_name = op->getName().getIdentifier();
521 return op_name == max_pool_grad_;
522}
523
524bool TFGraphDialect::IsMean(TFOp op) const {
525 StringAttr op_name = op->getName().getIdentifier();
526 return op_name == mean_;
527}
528
529bool TFGraphDialect::IsMerge(TFOp op) const {
530 StringAttr op_name = op->getName().getIdentifier();
531 return op_name == merge_ || op_name == ref_merge_ || op_name == xla_merge_;
532}
533
534bool TFGraphDialect::IsMin(TFOp op) const {
535 StringAttr op_name = op->getName().getIdentifier();
536 return op_name == min_;
537}
538
539bool TFGraphDialect::IsMinimum(TFOp op) const {
540 StringAttr op_name = op->getName().getIdentifier();
541 return op_name == minimum_;
542}
543
544bool TFGraphDialect::IsMirrorPad(TFOp op) const {
545 StringAttr op_name = op->getName().getIdentifier();
546 return op_name == mirror_pad_;
547}
548
549bool TFGraphDialect::IsMirrorPadGrad(TFOp op) const {
550 StringAttr op_name = op->getName().getIdentifier();
551 return op_name == mirror_pad_grad_;
552}
553
554bool TFGraphDialect::IsMod(TFOp op) const {
555 StringAttr op_name = op->getName().getIdentifier();
556 return op_name == mod_;
557}
558
559bool TFGraphDialect::IsMul(TFOp op) const {
560 StringAttr op_name = op->getName().getIdentifier();
561 return op_name == mul_;
562}
563bool TFGraphDialect::IsMulNoNan(TFOp op) const {
564 StringAttr op_name = op->getName().getIdentifier();
565 return op_name == mul_no_nan_;
566}
567bool TFGraphDialect::IsAnyMul(TFOp op) const {
568 return IsMul(op) || IsMulNoNan(op);
569}
570
571bool TFGraphDialect::IsNeg(TFOp op) const {
572 StringAttr op_name = op->getName().getIdentifier();
573 return op_name == neg_;
574}
575
576bool TFGraphDialect::IsNoOp(TFOp op) const {
577 StringAttr op_name = op->getName().getIdentifier();
578 return op_name == no_op_;
579}
580
581bool TFGraphDialect::IsNotEqual(TFOp op) const {
582 StringAttr op_name = op->getName().getIdentifier();
583 return op_name == not_equal_;
584}
585
586bool TFGraphDialect::IsNextIteration(TFOp op) const {
587 StringAttr op_name = op->getName().getIdentifier();
588 return op_name == next_iteration_ || op_name == ref_next_iteration_;
589}
590
591bool TFGraphDialect::IsOnesLike(TFOp op) const {
592 StringAttr op_name = op->getName().getIdentifier();
593 return op_name == ones_like_;
594}
595
596bool TFGraphDialect::IsPack(TFOp op) const {
597 StringAttr op_name = op->getName().getIdentifier();
598 return op_name == pack_;
599}
600
601bool TFGraphDialect::IsPad(TFOp op) const {
602 StringAttr op_name = op->getName().getIdentifier();
603 return op_name == pad_ || op_name == pad_v2_;
604}
605
606bool TFGraphDialect::IsPartitionedCall(TFOp op) const {
607 StringAttr op_name = op->getName().getIdentifier();
608 return op_name == partitioned_call_;
609}
610
611bool TFGraphDialect::IsPlaceholder(TFOp op) const {
612 StringAttr op_name = op->getName().getIdentifier();
613 return op_name == placeholder_ || op_name == placeholder_v2_ ||
614 op_name == placeholder_with_default_;
615}
616
617bool TFGraphDialect::IsPolygamma(TFOp op) const {
618 StringAttr op_name = op->getName().getIdentifier();
619 return op_name == poly_gamma_;
620}
621
622bool TFGraphDialect::IsPow(TFOp op) const {
623 StringAttr op_name = op->getName().getIdentifier();
624 return op_name == pow_;
625}
626
627bool TFGraphDialect::IsPrint(TFOp op) const {
628 StringAttr op_name = op->getName().getIdentifier();
629 return op_name == print_ || op_name == print_v2_;
630}
631
632bool TFGraphDialect::IsProd(TFOp op) const {
633 StringAttr op_name = op->getName().getIdentifier();
634 return op_name == prod_;
635}
636
637bool TFGraphDialect::IsQuantizedMatMul(TFOp op) const {
638 StringAttr op_name = op->getName().getIdentifier();
639 return op_name == quantized_matmul_ || op_name == quantized_matmul_v2_;
640}
641
642bool TFGraphDialect::IsQueue(TFOp op) const {
643 StringAttr op_name = op->getName().getIdentifier();
644 return op_name == random_shuffle_queue_v2_ || op_name == fifo_queue_v2_ ||
645 op_name == padding_fifo_queue_v2_ || op_name == priority_queue_v2_;
646}
647
648bool TFGraphDialect::IsRandomShuffle(TFOp op) const {
649 StringAttr op_name = op->getName().getIdentifier();
650 return op_name == random_shuffle_;
651}
652
653bool TFGraphDialect::IsRank(TFOp op) const {
654 StringAttr op_name = op->getName().getIdentifier();
655 return op_name == rank_;
656}
657
658bool TFGraphDialect::IsReadVariableOp(TFOp op) const {
659 StringAttr op_name = op->getName().getIdentifier();
660 return op_name == read_variable_op_;
661}
662
663bool TFGraphDialect::IsReadVariablesOp(TFOp op) const {
664 StringAttr op_name = op->getName().getIdentifier();
665 return op_name == read_variables_op_;
666}
667
668bool TFGraphDialect::IsReal(TFOp op) const {
669 StringAttr op_name = op->getName().getIdentifier();
670 return op_name == real_;
671}
672
673bool TFGraphDialect::IsRealDiv(TFOp op) const {
674 StringAttr op_name = op->getName().getIdentifier();
675 return op_name == real_div_;
676}
677
678bool TFGraphDialect::IsReciprocalGrad(TFOp op) const {
679 StringAttr op_name = op->getName().getIdentifier();
680 return op_name == reciprocal_grad_;
681}
682
683bool TFGraphDialect::IsRecv(TFOp op) const {
684 StringAttr op_name = op->getName().getIdentifier();
685 return op_name == recv_ || op_name == host_recv_;
686}
687
688bool TFGraphDialect::IsReduction(TFOp op) const {
689 return IsSum(op) || IsProd(op) || IsMin(op) || IsMax(op) || IsMean(op) ||
690 IsAny(op) || IsAll(op);
691}
692
693bool TFGraphDialect::IsRelu(TFOp op) const {
694 StringAttr op_name = op->getName().getIdentifier();
695 return op_name == relu_;
696}
697
698bool TFGraphDialect::IsRelu6(TFOp op) const {
699 StringAttr op_name = op->getName().getIdentifier();
700 return op_name == relu6_;
701}
702
703bool TFGraphDialect::IsReluGrad(TFOp op) const {
704 StringAttr op_name = op->getName().getIdentifier();
705 return op_name == relu_grad_;
706}
707
708bool TFGraphDialect::IsRelu6Grad(TFOp op) const {
709 StringAttr op_name = op->getName().getIdentifier();
710 return op_name == relu6_grad_;
711}
712
713bool TFGraphDialect::IsReshape(TFOp op) const {
714 StringAttr op_name = op->getName().getIdentifier();
715 return op_name == reshape_;
716}
717
718bool TFGraphDialect::IsRestore(TFOp op) const {
719 StringAttr op_name = op->getName().getIdentifier();
720 return op_name == restore_ || op_name == restore_v2_ ||
721 op_name == restore_slice_;
722}
723
724bool TFGraphDialect::IsReturn(TFOp op) const {
725 StringAttr op_name = op->getName().getIdentifier();
726 return op_name == return_;
727}
728
729bool TFGraphDialect::IsRetval(TFOp op) const {
730 StringAttr op_name = op->getName().getIdentifier();
731 return op_name == retval_ || op_name == device_retval_;
732}
733
734bool TFGraphDialect::IsReverse(TFOp op) const {
735 StringAttr op_name = op->getName().getIdentifier();
736 return op_name == reverse_ || IsReverseV2(op);
737}
738
739bool TFGraphDialect::IsReverseV2(TFOp op) const {
740 StringAttr op_name = op->getName().getIdentifier();
741 return op_name == reverse_v2_;
742}
743
744bool TFGraphDialect::IsRsqrt(TFOp op) const {
745 StringAttr op_name = op->getName().getIdentifier();
746 return op_name == rsqrt_;
747}
748
749bool TFGraphDialect::IsRsqrtGrad(TFOp op) const {
750 StringAttr op_name = op->getName().getIdentifier();
751 return op_name == rsqrt_grad_;
752}
753
754bool TFGraphDialect::IsSelect(TFOp op) const {
755 StringAttr op_name = op->getName().getIdentifier();
756 return op_name == select_ || op_name == select_v2_;
757}
758
759bool TFGraphDialect::IsSeluGrad(TFOp op) const {
760 StringAttr op_name = op->getName().getIdentifier();
761 return op_name == selu_grad_;
762}
763
764bool TFGraphDialect::IsSend(TFOp op) const {
765 StringAttr op_name = op->getName().getIdentifier();
766 return op_name == send_ || op_name == host_send_;
767}
768
769bool TFGraphDialect::IsShape(TFOp op) const {
770 StringAttr op_name = op->getName().getIdentifier();
771 return op_name == shape_;
772}
773
774bool TFGraphDialect::IsShapeN(TFOp op) const {
775 StringAttr op_name = op->getName().getIdentifier();
776 return op_name == shape_n_;
777}
778
779bool TFGraphDialect::IsShuffle(TFOp op) const {
780 StringAttr op_name = op->getName().getIdentifier();
781 return op_name == shuffle_;
782}
783
784bool TFGraphDialect::IsSigmoid(TFOp op) const {
785 StringAttr op_name = op->getName().getIdentifier();
786 return op_name == sigmoid_;
787}
788
789bool TFGraphDialect::IsSigmoidGrad(TFOp op) const {
790 StringAttr op_name = op->getName().getIdentifier();
791 return op_name == sigmoid_grad_;
792}
793
794bool TFGraphDialect::IsSize(TFOp op) const {
795 StringAttr op_name = op->getName().getIdentifier();
796 return op_name == size_;
797}
798
799bool TFGraphDialect::IsSlice(TFOp op) const {
800 StringAttr op_name = op->getName().getIdentifier();
801 return op_name == slice_;
802}
803
804bool TFGraphDialect::IsSnapshot(TFOp op) const {
805 StringAttr op_name = op->getName().getIdentifier();
806 return op_name == snapshot_;
807}
808
809bool TFGraphDialect::IsSoftmax(TFOp op) const {
810 StringAttr op_name = op->getName().getIdentifier();
811 return op_name == softmax_;
812}
813
814bool TFGraphDialect::IsSoftplusGrad(TFOp op) const {
815 StringAttr op_name = op->getName().getIdentifier();
816 return op_name == softplus_grad_;
817}
818
819bool TFGraphDialect::IsSoftsignGrad(TFOp op) const {
820 StringAttr op_name = op->getName().getIdentifier();
821 return op_name == softsign_grad_;
822}
823
824bool TFGraphDialect::IsSplit(TFOp op) const {
825 StringAttr op_name = op->getName().getIdentifier();
826 return op_name == split_;
827}
828
829bool TFGraphDialect::IsSplitV(TFOp op) const {
830 StringAttr op_name = op->getName().getIdentifier();
831 return op_name == split_v_;
832}
833
834bool TFGraphDialect::IsSqrt(TFOp op) const {
835 StringAttr op_name = op->getName().getIdentifier();
836 return op_name == sqrt_;
837}
838
839bool TFGraphDialect::IsSqrtGrad(TFOp op) const {
840 StringAttr op_name = op->getName().getIdentifier();
841 return op_name == sqrt_grad_;
842}
843
844bool TFGraphDialect::IsSquare(TFOp op) const {
845 StringAttr op_name = op->getName().getIdentifier();
846 return op_name == square_;
847}
848
849bool TFGraphDialect::IsSquaredDifference(TFOp op) const {
850 StringAttr op_name = op->getName().getIdentifier();
851 return op_name == squared_difference_;
852}
853
854bool TFGraphDialect::IsSqueeze(TFOp op) const {
855 StringAttr op_name = op->getName().getIdentifier();
856 return op_name == squeeze_;
857}
858
859bool TFGraphDialect::IsStackOp(TFOp op) const {
860 StringAttr op_name = op->getName().getIdentifier();
861 return op_name == stack_ || op_name == stack_v2_;
862}
863
864bool TFGraphDialect::IsStackCloseOp(TFOp op) const {
865 StringAttr op_name = op->getName().getIdentifier();
866 return op_name == stack_close_ || op_name == stack_close_v2_;
867}
868
869bool TFGraphDialect::IsStackPushOp(TFOp op) const {
870 StringAttr op_name = op->getName().getIdentifier();
871 return op_name == stack_push_ || op_name == stack_push_v2_;
872}
873
874bool TFGraphDialect::IsStackPopOp(TFOp op) const {
875 StringAttr op_name = op->getName().getIdentifier();
876 return op_name == stack_pop_ || op_name == stack_pop_v2_;
877}
878
879bool TFGraphDialect::IsStatefulPartitionedCall(TFOp op) const {
880 StringAttr op_name = op->getName().getIdentifier();
881 return op_name == stateful_partitioned_call_;
882}
883
884bool TFGraphDialect::IsStopGradient(TFOp op) const {
885 StringAttr op_name = op->getName().getIdentifier();
886 return op_name == stop_gradient_ || op_name == prevent_gradient_;
887}
888
889bool TFGraphDialect::IsStridedSlice(TFOp op) const {
890 StringAttr op_name = op->getName().getIdentifier();
891 return op_name == strided_slice_;
892}
893
894bool TFGraphDialect::IsStridedSliceGrad(TFOp op) const {
895 StringAttr op_name = op->getName().getIdentifier();
896 return op_name == strided_slice_grad_;
897}
898
899bool TFGraphDialect::IsStringToHashBucketFast(TFOp op) const {
900 StringAttr op_name = op->getName().getIdentifier();
901 return op_name == string_to_hashbucket_fast_;
902}
903
904bool TFGraphDialect::IsSub(TFOp op) const {
905 StringAttr op_name = op->getName().getIdentifier();
906 return op_name == sub_;
907}
908
909bool TFGraphDialect::IsSum(TFOp op) const {
910 StringAttr op_name = op->getName().getIdentifier();
911 return op_name == sum_;
912}
913
914bool TFGraphDialect::IsSwitch(TFOp op) const {
915 StringAttr op_name = op->getName().getIdentifier();
916 return op_name == switch_ || op_name == switch_n_ || op_name == ref_switch_;
917}
918
919bool TFGraphDialect::IsSymbolicGradient(TFOp op) const {
920 StringAttr op_name = op->getName().getIdentifier();
921 return op_name == symbolic_gradient_;
922}
923
924bool TFGraphDialect::IsTanh(TFOp op) const {
925 StringAttr op_name = op->getName().getIdentifier();
926 return op_name == tanh_;
927}
928
929bool TFGraphDialect::IsTanhGrad(TFOp op) const {
930 StringAttr op_name = op->getName().getIdentifier();
931 return op_name == tanh_grad_;
932}
933
934bool TFGraphDialect::IsTile(TFOp op) const {
935 StringAttr op_name = op->getName().getIdentifier();
936 return op_name == tile_;
937}
938
939bool TFGraphDialect::IsTranspose(TFOp op) const {
940 StringAttr op_name = op->getName().getIdentifier();
941 return op_name == transpose_;
942}
943
944bool TFGraphDialect::IsTruncateDiv(TFOp op) const {
945 StringAttr op_name = op->getName().getIdentifier();
946 return op_name == truncate_div_;
947}
948
949bool TFGraphDialect::IsTruncateMod(TFOp op) const {
950 StringAttr op_name = op->getName().getIdentifier();
951 return op_name == truncate_mod_;
952}
953
954bool TFGraphDialect::IsUnique(TFOp op) const {
955 StringAttr op_name = op->getName().getIdentifier();
956 return op_name == unique_ || op_name == unique_v2_;
957}
958
959bool TFGraphDialect::IsUnpack(TFOp op) const {
960 StringAttr op_name = op->getName().getIdentifier();
961 return op_name == unpack_;
962}
963
964bool TFGraphDialect::IsVariable(TFOp op) const {
965 StringAttr op_name = op->getName().getIdentifier();
966 return op_name == variable_ || op_name == variable_v2_ ||
967 op_name == auto_reload_variable_ || op_name == var_handle_op_ ||
968 op_name == var_handles_op_ || IsReadVariableOp(op) ||
969 IsReadVariablesOp(op);
970}
971
972bool TFGraphDialect::IsWhile(TFOp op) const {
973 StringAttr op_name = op->getName().getIdentifier();
974 return op_name == while_ || op_name == stateless_while_;
975}
976
977bool TFGraphDialect::IsXdivy(TFOp op) const {
978 StringAttr op_name = op->getName().getIdentifier();
979 return op_name == xdivy_;
980}
981
982bool TFGraphDialect::IsZerosLike(TFOp op) const {
983 StringAttr op_name = op->getName().getIdentifier();
984 return op_name == zeros_like_;
985}
986
987bool TFGraphDialect::IsZeta(TFOp op) const {
988 StringAttr op_name = op->getName().getIdentifier();
989 return op_name == zeta_;
990}
991
992} // namespace tfg
993} // namespace mlir
994