1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "mlir/IR/Operation.h" // from @llvm-project |
17 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
18 | #include "tensorflow/core/ir/dialect.h" |
19 | #include "tensorflow/core/ir/tf_op_wrapper.h" |
20 | |
21 | namespace mlir { |
22 | namespace tfg { |
23 | |
24 | bool TFGraphDialect::IsAdd(TFOp op) const { |
25 | StringAttr op_name = op->getName().getIdentifier(); |
26 | |
27 | if (op_name == add_v2_) return true; |
28 | if (op_name == add_) |
29 | return !op->getAttrOfType<TypeAttr>("T").getValue().isa<StringType>(); |
30 | return false; |
31 | } |
32 | |
33 | bool TFGraphDialect::IsAddN(TFOp op) const { |
34 | StringAttr op_name = op->getName().getIdentifier(); |
35 | return op_name == add_n_; |
36 | } |
37 | |
38 | bool TFGraphDialect::IsAll(TFOp op) const { |
39 | StringAttr op_name = op->getName().getIdentifier(); |
40 | return op_name == all_; |
41 | } |
42 | |
43 | bool TFGraphDialect::IsAngle(TFOp op) const { |
44 | StringAttr op_name = op->getName().getIdentifier(); |
45 | return op_name == angle_; |
46 | } |
47 | |
48 | bool TFGraphDialect::IsAny(TFOp op) const { |
49 | StringAttr op_name = op->getName().getIdentifier(); |
50 | return op_name == any_; |
51 | } |
52 | |
53 | bool TFGraphDialect::IsAnyDiv(TFOp op) const { |
54 | StringAttr op_name = op->getName().getIdentifier(); |
55 | return op_name == real_div_ || op_name == div_ || IsXdivy(op) || |
56 | op_name == floor_div_ || op_name == truncate_div_; |
57 | } |
58 | |
59 | bool TFGraphDialect::IsAnyBatchMatMul(TFOp op) const { |
60 | StringAttr op_name = op->getName().getIdentifier(); |
61 | return op_name == batch_matmul_ || op_name == batch_matmul_v2_; |
62 | } |
63 | |
64 | bool TFGraphDialect::IsAnyMatMul(TFOp op) const { |
65 | StringAttr op_name = op->getName().getIdentifier(); |
66 | return op_name == matmul_ || op_name == sparse_matmul_ || |
67 | IsAnyBatchMatMul(op) || IsQuantizedMatMul(op); |
68 | } |
69 | |
70 | bool TFGraphDialect::IsAnyMax(TFOp op) const { |
71 | StringAttr op_name = op->getName().getIdentifier(); |
72 | return op_name == max_ || op_name == segment_max_ || |
73 | op_name == unsorted_segment_max_; |
74 | } |
75 | |
76 | bool TFGraphDialect::IsAnyMaxPool(TFOp op) const { |
77 | StringAttr op_name = op->getName().getIdentifier(); |
78 | return op_name == max_pool_ || op_name == max_pool_v2_ || |
79 | op_name == max_pool_3d_ || op_name == max_pool_with_argmax_ || |
80 | op_name == fractional_max_pool_; |
81 | } |
82 | |
83 | bool TFGraphDialect::IsAnyMin(TFOp op) const { |
84 | StringAttr op_name = op->getName().getIdentifier(); |
85 | return op_name == min_ || op_name == segment_min_ || |
86 | op_name == unsorted_segment_min_; |
87 | } |
88 | |
89 | bool TFGraphDialect::IsAnySparseSegmentReduction(TFOp op) const { |
90 | StringAttr op_name = op->getName().getIdentifier(); |
91 | return op_name == sparse_segment_sum_ || |
92 | op_name == sparse_segment_sum_with_num_segments_ || |
93 | op_name == sparse_segment_mean_ || |
94 | op_name == sparse_segment_mean_with_num_segments_ || |
95 | op_name == sparse_segment_sqrtn_ || |
96 | op_name == sparse_segment_sqrtn_with_num_segments_; |
97 | } |
98 | |
99 | bool TFGraphDialect::IsApproximateEqual(TFOp op) const { |
100 | StringAttr op_name = op->getName().getIdentifier(); |
101 | return op_name == approximate_equal_; |
102 | } |
103 | |
104 | bool TFGraphDialect::IsArg(TFOp op) const { |
105 | StringAttr op_name = op->getName().getIdentifier(); |
106 | return op_name == arg_ || op_name == device_arg_; |
107 | } |
108 | |
109 | bool TFGraphDialect::IsArgMax(TFOp op) const { |
110 | StringAttr op_name = op->getName().getIdentifier(); |
111 | return op_name == arg_max_; |
112 | } |
113 | |
114 | bool TFGraphDialect::IsArgMin(TFOp op) const { |
115 | StringAttr op_name = op->getName().getIdentifier(); |
116 | return op_name == arg_min_; |
117 | } |
118 | |
119 | bool TFGraphDialect::IsAvgPoolGrad(TFOp op) const { |
120 | StringAttr op_name = op->getName().getIdentifier(); |
121 | return op_name == arg_pool_grad_; |
122 | } |
123 | |
124 | bool TFGraphDialect::IsAssign(TFOp op) const { |
125 | StringAttr op_name = op->getName().getIdentifier(); |
126 | return op_name == assign_ || op_name == assign_variable_op_; |
127 | } |
128 | |
129 | bool TFGraphDialect::IsAssert(TFOp op) const { |
130 | StringAttr op_name = op->getName().getIdentifier(); |
131 | return op_name == assert_; |
132 | } |
133 | |
134 | bool TFGraphDialect::IsAsString(TFOp op) const { |
135 | StringAttr op_name = op->getName().getIdentifier(); |
136 | return op_name == as_string_; |
137 | } |
138 | |
139 | bool TFGraphDialect::IsAtan2(TFOp op) const { |
140 | StringAttr op_name = op->getName().getIdentifier(); |
141 | return op_name == atan2_; |
142 | } |
143 | |
144 | bool TFGraphDialect::IsBetainc(TFOp op) const { |
145 | StringAttr op_name = op->getName().getIdentifier(); |
146 | return op_name == betainc_; |
147 | } |
148 | |
149 | bool TFGraphDialect::IsBiasAdd(TFOp op) const { |
150 | StringAttr op_name = op->getName().getIdentifier(); |
151 | return op_name == bias_add_ || op_name == bias_add_v1_; |
152 | } |
153 | |
154 | bool TFGraphDialect::IsBiasAddV2(TFOp op) const { |
155 | StringAttr op_name = op->getName().getIdentifier(); |
156 | return op_name == bias_add_; |
157 | } |
158 | |
159 | bool TFGraphDialect::IsBiasAddGrad(TFOp op) const { |
160 | StringAttr op_name = op->getName().getIdentifier(); |
161 | return op_name == bias_add_grad_; |
162 | } |
163 | |
164 | bool TFGraphDialect::IsBitcast(TFOp op) const { |
165 | StringAttr op_name = op->getName().getIdentifier(); |
166 | return op_name == bitcast_; |
167 | } |
168 | |
169 | bool TFGraphDialect::IsBroadcastTo(TFOp op) const { |
170 | StringAttr op_name = op->getName().getIdentifier(); |
171 | return op_name == broadcast_to_; |
172 | } |
173 | |
174 | bool TFGraphDialect::IsCast(TFOp op) const { |
175 | StringAttr op_name = op->getName().getIdentifier(); |
176 | return op_name == cast_; |
177 | } |
178 | |
179 | bool TFGraphDialect::IsCheckNumerics(TFOp op) const { |
180 | StringAttr op_name = op->getName().getIdentifier(); |
181 | return op_name == check_numerics_; |
182 | } |
183 | |
184 | bool TFGraphDialect::IsCollective(TFOp op) const { |
185 | StringAttr op_name = op->getName().getIdentifier(); |
186 | return op_name == collective_reduce_ || op_name == collective_bcast_send_ || |
187 | op_name == collective_bcast_recv_; |
188 | } |
189 | |
190 | bool TFGraphDialect::IsComplex(TFOp op) const { |
191 | StringAttr op_name = op->getName().getIdentifier(); |
192 | return op_name == complex_; |
193 | } |
194 | |
195 | bool TFGraphDialect::IsComplexAbs(TFOp op) const { |
196 | StringAttr op_name = op->getName().getIdentifier(); |
197 | return op_name == complex_abs_; |
198 | } |
199 | |
200 | bool TFGraphDialect::IsConcat(TFOp op) const { |
201 | StringAttr op_name = op->getName().getIdentifier(); |
202 | return op_name == concat_ || IsConcatV2(op); |
203 | } |
204 | |
205 | bool TFGraphDialect::IsConcatV2(TFOp op) const { |
206 | StringAttr op_name = op->getName().getIdentifier(); |
207 | return op_name == concat_v2_; |
208 | } |
209 | |
210 | bool TFGraphDialect::IsConcatOffset(TFOp op) const { |
211 | StringAttr op_name = op->getName().getIdentifier(); |
212 | return op_name == concat_offset_; |
213 | } |
214 | |
215 | bool TFGraphDialect::IsConstant(TFOp op) const { |
216 | StringAttr op_name = op->getName().getIdentifier(); |
217 | return op_name == const_; |
218 | } |
219 | |
220 | bool TFGraphDialect::IsConj(TFOp op) const { |
221 | StringAttr op_name = op->getName().getIdentifier(); |
222 | return op_name == conj_; |
223 | } |
224 | |
225 | bool TFGraphDialect::IsConjugateTranspose(TFOp op) const { |
226 | StringAttr op_name = op->getName().getIdentifier(); |
227 | return op_name == conjugate_transpose_; |
228 | } |
229 | |
230 | // TODO(chiahungduan): Should we use certain helpers like IsEnter(). |
231 | bool TFGraphDialect::IsControlFlow(TFOp op) const { |
232 | StringAttr op_name = op->getName().getIdentifier(); |
233 | |
234 | return op_name == control_trigger_ || op_name == enter_ || op_name == exit_ || |
235 | op_name == loop_cond_ || op_name == merge_ || op_name == xla_merge_ || |
236 | op_name == next_iteration_ || op_name == switch_ || |
237 | op_name == switch_n_; |
238 | } |
239 | |
240 | bool TFGraphDialect::IsConv2D(TFOp op) const { |
241 | StringAttr op_name = op->getName().getIdentifier(); |
242 | return op_name == conv_2d_; |
243 | } |
244 | |
245 | bool TFGraphDialect::IsConv2DBackpropFilter(TFOp op) const { |
246 | StringAttr op_name = op->getName().getIdentifier(); |
247 | return op_name == conv_2d_back_prop_filter_; |
248 | } |
249 | |
250 | bool TFGraphDialect::IsConv2DBackpropInput(TFOp op) const { |
251 | StringAttr op_name = op->getName().getIdentifier(); |
252 | return op_name == conv_2d_back_prop_input_; |
253 | } |
254 | |
255 | bool TFGraphDialect::IsConv3D(TFOp op) const { |
256 | StringAttr op_name = op->getName().getIdentifier(); |
257 | return op_name == conv_3d_; |
258 | } |
259 | |
260 | bool TFGraphDialect::IsConv3DBackpropFilterV2(TFOp op) const { |
261 | StringAttr op_name = op->getName().getIdentifier(); |
262 | return op_name == conv_3d_back_prop_filter_v2_; |
263 | } |
264 | |
265 | bool TFGraphDialect::IsConv3DBackpropInputV2(TFOp op) const { |
266 | StringAttr op_name = op->getName().getIdentifier(); |
267 | return op_name == conv_3d_back_prop_input_v2_; |
268 | } |
269 | |
270 | bool TFGraphDialect::IsDepthwiseConv2dNative(TFOp op) const { |
271 | StringAttr op_name = op->getName().getIdentifier(); |
272 | return op_name == depth_wise_conv_2d_native_; |
273 | } |
274 | |
275 | bool TFGraphDialect::IsDepthwiseConv2dNativeBackpropFilter(TFOp op) const { |
276 | StringAttr op_name = op->getName().getIdentifier(); |
277 | return op_name == depth_wise_conv_2d_native_back_prop_filter_; |
278 | } |
279 | |
280 | bool TFGraphDialect::IsDepthwiseConv2dNativeBackpropInput(TFOp op) const { |
281 | StringAttr op_name = op->getName().getIdentifier(); |
282 | return op_name == depth_wise_conv_2d_native_back_prop_input_; |
283 | } |
284 | |
285 | bool TFGraphDialect::IsDequeueOp(TFOp op) const { |
286 | StringAttr op_name = op->getName().getIdentifier(); |
287 | return op_name == queue_dequeue_ || op_name == queue_dequeue_v2_ || |
288 | op_name == queue_dequeue_many_ || op_name == queue_dequeue_many_v2_ || |
289 | op_name == queue_dequeue_upto_ || op_name == queue_dequeue_upto_v2_; |
290 | } |
291 | |
292 | bool TFGraphDialect::IsDiv(TFOp op) const { |
293 | StringAttr op_name = op->getName().getIdentifier(); |
294 | return op_name == div_; |
295 | } |
296 | |
297 | bool TFGraphDialect::IsDivNoNan(TFOp op) const { |
298 | StringAttr op_name = op->getName().getIdentifier(); |
299 | return op_name == div_no_nan_; |
300 | } |
301 | |
302 | bool TFGraphDialect::IsElu(TFOp op) const { |
303 | StringAttr op_name = op->getName().getIdentifier(); |
304 | return op_name == elu_; |
305 | } |
306 | |
307 | bool TFGraphDialect::IsEluGrad(TFOp op) const { |
308 | StringAttr op_name = op->getName().getIdentifier(); |
309 | return op_name == elu_grad_; |
310 | } |
311 | |
312 | bool TFGraphDialect::IsQuantizationEmulation(TFOp op) const { |
313 | StringAttr op_name = op->getName().getIdentifier(); |
314 | return op_name == quantize_and_dequantize_ || |
315 | op_name == quantize_and_dequantize_v2_ || |
316 | op_name == quantize_and_dequantize_v3_ || |
317 | op_name == quantize_and_dequantize_v4_ || |
318 | op_name == quantize_and_dequantize_v4_grad_ || |
319 | op_name == fake_quant_with_min_max_args_ || |
320 | op_name == fake_quant_with_min_max_args_gradient_ || |
321 | op_name == fake_quant_with_min_max_vars_ || |
322 | op_name == fake_quant_with_min_max_vars_gradient_ || |
323 | op_name == fake_quant_with_min_max_vars_per_channel_ || |
324 | op_name == fake_quant_with_min_max_vars_per_channel_gradient_; |
325 | } |
326 | |
327 | bool TFGraphDialect::IsEnter(TFOp op) const { |
328 | StringAttr op_name = op->getName().getIdentifier(); |
329 | return op_name == enter_ || op_name == ref_enter_; |
330 | } |
331 | |
332 | bool TFGraphDialect::IsEqual(TFOp op) const { |
333 | StringAttr op_name = op->getName().getIdentifier(); |
334 | return op_name == equal_; |
335 | } |
336 | |
337 | bool TFGraphDialect::IsExit(TFOp op) const { |
338 | StringAttr op_name = op->getName().getIdentifier(); |
339 | return op_name == exit_ || op_name == ref_exit_; |
340 | } |
341 | |
342 | bool TFGraphDialect::IsExp(TFOp op) const { |
343 | StringAttr op_name = op->getName().getIdentifier(); |
344 | return op_name == exp_; |
345 | } |
346 | |
347 | bool TFGraphDialect::IsFakeParam(TFOp op) const { |
348 | StringAttr op_name = op->getName().getIdentifier(); |
349 | return op_name == fake_param_; |
350 | } |
351 | |
352 | bool TFGraphDialect::IsFill(TFOp op) const { |
353 | StringAttr op_name = op->getName().getIdentifier(); |
354 | return op_name == fill_; |
355 | } |
356 | |
357 | bool TFGraphDialect::IsFloorDiv(TFOp op) const { |
358 | StringAttr op_name = op->getName().getIdentifier(); |
359 | return op_name == floor_div_; |
360 | } |
361 | |
362 | bool TFGraphDialect::IsFloorMod(TFOp op) const { |
363 | StringAttr op_name = op->getName().getIdentifier(); |
364 | return op_name == floor_mod_; |
365 | } |
366 | |
367 | bool TFGraphDialect::IsFusedBatchNorm(TFOp op) const { |
368 | StringAttr op_name = op->getName().getIdentifier(); |
369 | return op_name == fused_batch_norm_ || op_name == fused_batch_norm_v2_ || |
370 | op_name == fused_batch_norm_v3_; |
371 | } |
372 | |
373 | bool TFGraphDialect::IsFusedBatchNormEx(TFOp op) const { |
374 | StringAttr op_name = op->getName().getIdentifier(); |
375 | return op_name == fused_batch_norm_ex_; |
376 | } |
377 | |
378 | bool TFGraphDialect::IsFusedBatchNormGrad(TFOp op) const { |
379 | StringAttr op_name = op->getName().getIdentifier(); |
380 | return op_name == fused_batch_norm_grad_ || |
381 | op_name == fused_batch_norm_grad_v2_ || |
382 | op_name == fused_batch_norm_grad_v3_; |
383 | } |
384 | |
385 | bool TFGraphDialect::IsGather(TFOp op) const { |
386 | StringAttr op_name = op->getName().getIdentifier(); |
387 | return op_name == gather_ || op_name == gather_v2_ || |
388 | op_name == resource_gather_; |
389 | } |
390 | |
391 | bool TFGraphDialect::IsGreater(TFOp op) const { |
392 | StringAttr op_name = op->getName().getIdentifier(); |
393 | return op_name == greater_; |
394 | } |
395 | |
396 | bool TFGraphDialect::IsGreaterEqual(TFOp op) const { |
397 | StringAttr op_name = op->getName().getIdentifier(); |
398 | return op_name == greater_equal_; |
399 | } |
400 | |
401 | bool TFGraphDialect::IsHostConstant(TFOp op) const { |
402 | StringAttr op_name = op->getName().getIdentifier(); |
403 | return op_name == host_const_; |
404 | } |
405 | |
406 | bool TFGraphDialect::IsHistogramSummary(TFOp op) const { |
407 | StringAttr op_name = op->getName().getIdentifier(); |
408 | return op_name == histogram_summary_; |
409 | } |
410 | |
411 | bool TFGraphDialect::IsIdentity(TFOp op) const { |
412 | StringAttr op_name = op->getName().getIdentifier(); |
413 | return op_name == identity_ || op_name == ref_identity_; |
414 | } |
415 | |
416 | bool TFGraphDialect::IsIdentityN(TFOp op) const { |
417 | StringAttr op_name = op->getName().getIdentifier(); |
418 | return op_name == identity_n_; |
419 | } |
420 | |
421 | bool TFGraphDialect::IsIdentityNSingleInput(TFOp op) const { |
422 | if (!IsIdentityN(op)) return false; |
423 | auto array_attr = op->getAttrOfType<ArrayAttr>("T"); |
424 | if (!array_attr) return false; |
425 | // TODO(chiahungduan): Do we need to check the content of array_attr? |
426 | return array_attr.size() == 1; |
427 | } |
428 | |
429 | bool TFGraphDialect::IsIf(TFOp op) const { |
430 | StringAttr op_name = op->getName().getIdentifier(); |
431 | return op_name == if_ || op_name == stateless_if_; |
432 | } |
433 | |
434 | bool TFGraphDialect::IsIgamma(TFOp op) const { |
435 | StringAttr op_name = op->getName().getIdentifier(); |
436 | return op_name == igamma_; |
437 | } |
438 | |
439 | bool TFGraphDialect::IsIgammac(TFOp op) const { |
440 | StringAttr op_name = op->getName().getIdentifier(); |
441 | return op_name == igammac_; |
442 | } |
443 | |
444 | bool TFGraphDialect::IsImag(TFOp op) const { |
445 | StringAttr op_name = op->getName().getIdentifier(); |
446 | return op_name == imag_; |
447 | } |
448 | |
449 | bool TFGraphDialect::IsImmutableConst(TFOp op) const { |
450 | StringAttr op_name = op->getName().getIdentifier(); |
451 | return op_name == immutable_const_; |
452 | } |
453 | |
454 | bool TFGraphDialect::IsInvGrad(TFOp op) const { |
455 | StringAttr op_name = op->getName().getIdentifier(); |
456 | return op_name == inv_grad_; |
457 | } |
458 | |
459 | bool TFGraphDialect::IsLeakyRelu(TFOp op) const { |
460 | StringAttr op_name = op->getName().getIdentifier(); |
461 | return op_name == leaky_relu_; |
462 | } |
463 | |
464 | bool TFGraphDialect::IsLeakyReluGrad(TFOp op) const { |
465 | StringAttr op_name = op->getName().getIdentifier(); |
466 | return op_name == leaky_relu_grad_; |
467 | } |
468 | |
469 | bool TFGraphDialect::IsLess(TFOp op) const { |
470 | StringAttr op_name = op->getName().getIdentifier(); |
471 | return op_name == less_; |
472 | } |
473 | |
474 | bool TFGraphDialect::IsLessEqual(TFOp op) const { |
475 | StringAttr op_name = op->getName().getIdentifier(); |
476 | return op_name == less_equal_; |
477 | } |
478 | |
479 | bool TFGraphDialect::IsLog(TFOp op) const { |
480 | StringAttr op_name = op->getName().getIdentifier(); |
481 | return op_name == log_; |
482 | } |
483 | |
484 | bool TFGraphDialect::IsLogicalAnd(TFOp op) const { |
485 | StringAttr op_name = op->getName().getIdentifier(); |
486 | return op_name == logical_and_; |
487 | } |
488 | |
489 | bool TFGraphDialect::IsLogicalNot(TFOp op) const { |
490 | StringAttr op_name = op->getName().getIdentifier(); |
491 | return op_name == logical_not_; |
492 | } |
493 | |
494 | bool TFGraphDialect::IsLogicalOr(TFOp op) const { |
495 | StringAttr op_name = op->getName().getIdentifier(); |
496 | return op_name == logical_or_; |
497 | } |
498 | |
499 | bool TFGraphDialect::IsLoopCond(TFOp op) const { |
500 | StringAttr op_name = op->getName().getIdentifier(); |
501 | return op_name == loop_cond_; |
502 | } |
503 | |
504 | bool TFGraphDialect::IsMatMul(TFOp op) const { |
505 | StringAttr op_name = op->getName().getIdentifier(); |
506 | return op_name == matmul_; |
507 | } |
508 | |
509 | bool TFGraphDialect::IsMax(TFOp op) const { |
510 | StringAttr op_name = op->getName().getIdentifier(); |
511 | return op_name == max_; |
512 | } |
513 | |
514 | bool TFGraphDialect::IsMaximum(TFOp op) const { |
515 | StringAttr op_name = op->getName().getIdentifier(); |
516 | return op_name == maximum_; |
517 | } |
518 | |
519 | bool TFGraphDialect::IsMaxPoolGrad(TFOp op) const { |
520 | StringAttr op_name = op->getName().getIdentifier(); |
521 | return op_name == max_pool_grad_; |
522 | } |
523 | |
524 | bool TFGraphDialect::IsMean(TFOp op) const { |
525 | StringAttr op_name = op->getName().getIdentifier(); |
526 | return op_name == mean_; |
527 | } |
528 | |
529 | bool TFGraphDialect::IsMerge(TFOp op) const { |
530 | StringAttr op_name = op->getName().getIdentifier(); |
531 | return op_name == merge_ || op_name == ref_merge_ || op_name == xla_merge_; |
532 | } |
533 | |
534 | bool TFGraphDialect::IsMin(TFOp op) const { |
535 | StringAttr op_name = op->getName().getIdentifier(); |
536 | return op_name == min_; |
537 | } |
538 | |
539 | bool TFGraphDialect::IsMinimum(TFOp op) const { |
540 | StringAttr op_name = op->getName().getIdentifier(); |
541 | return op_name == minimum_; |
542 | } |
543 | |
544 | bool TFGraphDialect::IsMirrorPad(TFOp op) const { |
545 | StringAttr op_name = op->getName().getIdentifier(); |
546 | return op_name == mirror_pad_; |
547 | } |
548 | |
549 | bool TFGraphDialect::IsMirrorPadGrad(TFOp op) const { |
550 | StringAttr op_name = op->getName().getIdentifier(); |
551 | return op_name == mirror_pad_grad_; |
552 | } |
553 | |
554 | bool TFGraphDialect::IsMod(TFOp op) const { |
555 | StringAttr op_name = op->getName().getIdentifier(); |
556 | return op_name == mod_; |
557 | } |
558 | |
559 | bool TFGraphDialect::IsMul(TFOp op) const { |
560 | StringAttr op_name = op->getName().getIdentifier(); |
561 | return op_name == mul_; |
562 | } |
563 | bool TFGraphDialect::IsMulNoNan(TFOp op) const { |
564 | StringAttr op_name = op->getName().getIdentifier(); |
565 | return op_name == mul_no_nan_; |
566 | } |
567 | bool TFGraphDialect::IsAnyMul(TFOp op) const { |
568 | return IsMul(op) || IsMulNoNan(op); |
569 | } |
570 | |
571 | bool TFGraphDialect::IsNeg(TFOp op) const { |
572 | StringAttr op_name = op->getName().getIdentifier(); |
573 | return op_name == neg_; |
574 | } |
575 | |
576 | bool TFGraphDialect::IsNoOp(TFOp op) const { |
577 | StringAttr op_name = op->getName().getIdentifier(); |
578 | return op_name == no_op_; |
579 | } |
580 | |
581 | bool TFGraphDialect::IsNotEqual(TFOp op) const { |
582 | StringAttr op_name = op->getName().getIdentifier(); |
583 | return op_name == not_equal_; |
584 | } |
585 | |
586 | bool TFGraphDialect::IsNextIteration(TFOp op) const { |
587 | StringAttr op_name = op->getName().getIdentifier(); |
588 | return op_name == next_iteration_ || op_name == ref_next_iteration_; |
589 | } |
590 | |
591 | bool TFGraphDialect::IsOnesLike(TFOp op) const { |
592 | StringAttr op_name = op->getName().getIdentifier(); |
593 | return op_name == ones_like_; |
594 | } |
595 | |
596 | bool TFGraphDialect::IsPack(TFOp op) const { |
597 | StringAttr op_name = op->getName().getIdentifier(); |
598 | return op_name == pack_; |
599 | } |
600 | |
601 | bool TFGraphDialect::IsPad(TFOp op) const { |
602 | StringAttr op_name = op->getName().getIdentifier(); |
603 | return op_name == pad_ || op_name == pad_v2_; |
604 | } |
605 | |
606 | bool TFGraphDialect::IsPartitionedCall(TFOp op) const { |
607 | StringAttr op_name = op->getName().getIdentifier(); |
608 | return op_name == partitioned_call_; |
609 | } |
610 | |
611 | bool TFGraphDialect::IsPlaceholder(TFOp op) const { |
612 | StringAttr op_name = op->getName().getIdentifier(); |
613 | return op_name == placeholder_ || op_name == placeholder_v2_ || |
614 | op_name == placeholder_with_default_; |
615 | } |
616 | |
617 | bool TFGraphDialect::IsPolygamma(TFOp op) const { |
618 | StringAttr op_name = op->getName().getIdentifier(); |
619 | return op_name == poly_gamma_; |
620 | } |
621 | |
622 | bool TFGraphDialect::IsPow(TFOp op) const { |
623 | StringAttr op_name = op->getName().getIdentifier(); |
624 | return op_name == pow_; |
625 | } |
626 | |
627 | bool TFGraphDialect::IsPrint(TFOp op) const { |
628 | StringAttr op_name = op->getName().getIdentifier(); |
629 | return op_name == print_ || op_name == print_v2_; |
630 | } |
631 | |
632 | bool TFGraphDialect::IsProd(TFOp op) const { |
633 | StringAttr op_name = op->getName().getIdentifier(); |
634 | return op_name == prod_; |
635 | } |
636 | |
637 | bool TFGraphDialect::IsQuantizedMatMul(TFOp op) const { |
638 | StringAttr op_name = op->getName().getIdentifier(); |
639 | return op_name == quantized_matmul_ || op_name == quantized_matmul_v2_; |
640 | } |
641 | |
642 | bool TFGraphDialect::IsQueue(TFOp op) const { |
643 | StringAttr op_name = op->getName().getIdentifier(); |
644 | return op_name == random_shuffle_queue_v2_ || op_name == fifo_queue_v2_ || |
645 | op_name == padding_fifo_queue_v2_ || op_name == priority_queue_v2_; |
646 | } |
647 | |
648 | bool TFGraphDialect::IsRandomShuffle(TFOp op) const { |
649 | StringAttr op_name = op->getName().getIdentifier(); |
650 | return op_name == random_shuffle_; |
651 | } |
652 | |
653 | bool TFGraphDialect::IsRank(TFOp op) const { |
654 | StringAttr op_name = op->getName().getIdentifier(); |
655 | return op_name == rank_; |
656 | } |
657 | |
658 | bool TFGraphDialect::IsReadVariableOp(TFOp op) const { |
659 | StringAttr op_name = op->getName().getIdentifier(); |
660 | return op_name == read_variable_op_; |
661 | } |
662 | |
663 | bool TFGraphDialect::IsReadVariablesOp(TFOp op) const { |
664 | StringAttr op_name = op->getName().getIdentifier(); |
665 | return op_name == read_variables_op_; |
666 | } |
667 | |
668 | bool TFGraphDialect::IsReal(TFOp op) const { |
669 | StringAttr op_name = op->getName().getIdentifier(); |
670 | return op_name == real_; |
671 | } |
672 | |
673 | bool TFGraphDialect::IsRealDiv(TFOp op) const { |
674 | StringAttr op_name = op->getName().getIdentifier(); |
675 | return op_name == real_div_; |
676 | } |
677 | |
678 | bool TFGraphDialect::IsReciprocalGrad(TFOp op) const { |
679 | StringAttr op_name = op->getName().getIdentifier(); |
680 | return op_name == reciprocal_grad_; |
681 | } |
682 | |
683 | bool TFGraphDialect::IsRecv(TFOp op) const { |
684 | StringAttr op_name = op->getName().getIdentifier(); |
685 | return op_name == recv_ || op_name == host_recv_; |
686 | } |
687 | |
688 | bool TFGraphDialect::IsReduction(TFOp op) const { |
689 | return IsSum(op) || IsProd(op) || IsMin(op) || IsMax(op) || IsMean(op) || |
690 | IsAny(op) || IsAll(op); |
691 | } |
692 | |
693 | bool TFGraphDialect::IsRelu(TFOp op) const { |
694 | StringAttr op_name = op->getName().getIdentifier(); |
695 | return op_name == relu_; |
696 | } |
697 | |
698 | bool TFGraphDialect::IsRelu6(TFOp op) const { |
699 | StringAttr op_name = op->getName().getIdentifier(); |
700 | return op_name == relu6_; |
701 | } |
702 | |
703 | bool TFGraphDialect::IsReluGrad(TFOp op) const { |
704 | StringAttr op_name = op->getName().getIdentifier(); |
705 | return op_name == relu_grad_; |
706 | } |
707 | |
708 | bool TFGraphDialect::IsRelu6Grad(TFOp op) const { |
709 | StringAttr op_name = op->getName().getIdentifier(); |
710 | return op_name == relu6_grad_; |
711 | } |
712 | |
713 | bool TFGraphDialect::IsReshape(TFOp op) const { |
714 | StringAttr op_name = op->getName().getIdentifier(); |
715 | return op_name == reshape_; |
716 | } |
717 | |
718 | bool TFGraphDialect::IsRestore(TFOp op) const { |
719 | StringAttr op_name = op->getName().getIdentifier(); |
720 | return op_name == restore_ || op_name == restore_v2_ || |
721 | op_name == restore_slice_; |
722 | } |
723 | |
724 | bool TFGraphDialect::IsReturn(TFOp op) const { |
725 | StringAttr op_name = op->getName().getIdentifier(); |
726 | return op_name == return_; |
727 | } |
728 | |
729 | bool TFGraphDialect::IsRetval(TFOp op) const { |
730 | StringAttr op_name = op->getName().getIdentifier(); |
731 | return op_name == retval_ || op_name == device_retval_; |
732 | } |
733 | |
734 | bool TFGraphDialect::IsReverse(TFOp op) const { |
735 | StringAttr op_name = op->getName().getIdentifier(); |
736 | return op_name == reverse_ || IsReverseV2(op); |
737 | } |
738 | |
739 | bool TFGraphDialect::IsReverseV2(TFOp op) const { |
740 | StringAttr op_name = op->getName().getIdentifier(); |
741 | return op_name == reverse_v2_; |
742 | } |
743 | |
744 | bool TFGraphDialect::IsRsqrt(TFOp op) const { |
745 | StringAttr op_name = op->getName().getIdentifier(); |
746 | return op_name == rsqrt_; |
747 | } |
748 | |
749 | bool TFGraphDialect::IsRsqrtGrad(TFOp op) const { |
750 | StringAttr op_name = op->getName().getIdentifier(); |
751 | return op_name == rsqrt_grad_; |
752 | } |
753 | |
754 | bool TFGraphDialect::IsSelect(TFOp op) const { |
755 | StringAttr op_name = op->getName().getIdentifier(); |
756 | return op_name == select_ || op_name == select_v2_; |
757 | } |
758 | |
759 | bool TFGraphDialect::IsSeluGrad(TFOp op) const { |
760 | StringAttr op_name = op->getName().getIdentifier(); |
761 | return op_name == selu_grad_; |
762 | } |
763 | |
764 | bool TFGraphDialect::IsSend(TFOp op) const { |
765 | StringAttr op_name = op->getName().getIdentifier(); |
766 | return op_name == send_ || op_name == host_send_; |
767 | } |
768 | |
769 | bool TFGraphDialect::IsShape(TFOp op) const { |
770 | StringAttr op_name = op->getName().getIdentifier(); |
771 | return op_name == shape_; |
772 | } |
773 | |
774 | bool TFGraphDialect::IsShapeN(TFOp op) const { |
775 | StringAttr op_name = op->getName().getIdentifier(); |
776 | return op_name == shape_n_; |
777 | } |
778 | |
779 | bool TFGraphDialect::IsShuffle(TFOp op) const { |
780 | StringAttr op_name = op->getName().getIdentifier(); |
781 | return op_name == shuffle_; |
782 | } |
783 | |
784 | bool TFGraphDialect::IsSigmoid(TFOp op) const { |
785 | StringAttr op_name = op->getName().getIdentifier(); |
786 | return op_name == sigmoid_; |
787 | } |
788 | |
789 | bool TFGraphDialect::IsSigmoidGrad(TFOp op) const { |
790 | StringAttr op_name = op->getName().getIdentifier(); |
791 | return op_name == sigmoid_grad_; |
792 | } |
793 | |
794 | bool TFGraphDialect::IsSize(TFOp op) const { |
795 | StringAttr op_name = op->getName().getIdentifier(); |
796 | return op_name == size_; |
797 | } |
798 | |
799 | bool TFGraphDialect::IsSlice(TFOp op) const { |
800 | StringAttr op_name = op->getName().getIdentifier(); |
801 | return op_name == slice_; |
802 | } |
803 | |
804 | bool TFGraphDialect::IsSnapshot(TFOp op) const { |
805 | StringAttr op_name = op->getName().getIdentifier(); |
806 | return op_name == snapshot_; |
807 | } |
808 | |
809 | bool TFGraphDialect::IsSoftmax(TFOp op) const { |
810 | StringAttr op_name = op->getName().getIdentifier(); |
811 | return op_name == softmax_; |
812 | } |
813 | |
814 | bool TFGraphDialect::IsSoftplusGrad(TFOp op) const { |
815 | StringAttr op_name = op->getName().getIdentifier(); |
816 | return op_name == softplus_grad_; |
817 | } |
818 | |
819 | bool TFGraphDialect::IsSoftsignGrad(TFOp op) const { |
820 | StringAttr op_name = op->getName().getIdentifier(); |
821 | return op_name == softsign_grad_; |
822 | } |
823 | |
824 | bool TFGraphDialect::IsSplit(TFOp op) const { |
825 | StringAttr op_name = op->getName().getIdentifier(); |
826 | return op_name == split_; |
827 | } |
828 | |
829 | bool TFGraphDialect::IsSplitV(TFOp op) const { |
830 | StringAttr op_name = op->getName().getIdentifier(); |
831 | return op_name == split_v_; |
832 | } |
833 | |
834 | bool TFGraphDialect::IsSqrt(TFOp op) const { |
835 | StringAttr op_name = op->getName().getIdentifier(); |
836 | return op_name == sqrt_; |
837 | } |
838 | |
839 | bool TFGraphDialect::IsSqrtGrad(TFOp op) const { |
840 | StringAttr op_name = op->getName().getIdentifier(); |
841 | return op_name == sqrt_grad_; |
842 | } |
843 | |
844 | bool TFGraphDialect::IsSquare(TFOp op) const { |
845 | StringAttr op_name = op->getName().getIdentifier(); |
846 | return op_name == square_; |
847 | } |
848 | |
849 | bool TFGraphDialect::IsSquaredDifference(TFOp op) const { |
850 | StringAttr op_name = op->getName().getIdentifier(); |
851 | return op_name == squared_difference_; |
852 | } |
853 | |
854 | bool TFGraphDialect::IsSqueeze(TFOp op) const { |
855 | StringAttr op_name = op->getName().getIdentifier(); |
856 | return op_name == squeeze_; |
857 | } |
858 | |
859 | bool TFGraphDialect::IsStackOp(TFOp op) const { |
860 | StringAttr op_name = op->getName().getIdentifier(); |
861 | return op_name == stack_ || op_name == stack_v2_; |
862 | } |
863 | |
864 | bool TFGraphDialect::IsStackCloseOp(TFOp op) const { |
865 | StringAttr op_name = op->getName().getIdentifier(); |
866 | return op_name == stack_close_ || op_name == stack_close_v2_; |
867 | } |
868 | |
869 | bool TFGraphDialect::IsStackPushOp(TFOp op) const { |
870 | StringAttr op_name = op->getName().getIdentifier(); |
871 | return op_name == stack_push_ || op_name == stack_push_v2_; |
872 | } |
873 | |
874 | bool TFGraphDialect::IsStackPopOp(TFOp op) const { |
875 | StringAttr op_name = op->getName().getIdentifier(); |
876 | return op_name == stack_pop_ || op_name == stack_pop_v2_; |
877 | } |
878 | |
879 | bool TFGraphDialect::IsStatefulPartitionedCall(TFOp op) const { |
880 | StringAttr op_name = op->getName().getIdentifier(); |
881 | return op_name == stateful_partitioned_call_; |
882 | } |
883 | |
884 | bool TFGraphDialect::IsStopGradient(TFOp op) const { |
885 | StringAttr op_name = op->getName().getIdentifier(); |
886 | return op_name == stop_gradient_ || op_name == prevent_gradient_; |
887 | } |
888 | |
889 | bool TFGraphDialect::IsStridedSlice(TFOp op) const { |
890 | StringAttr op_name = op->getName().getIdentifier(); |
891 | return op_name == strided_slice_; |
892 | } |
893 | |
894 | bool TFGraphDialect::IsStridedSliceGrad(TFOp op) const { |
895 | StringAttr op_name = op->getName().getIdentifier(); |
896 | return op_name == strided_slice_grad_; |
897 | } |
898 | |
899 | bool TFGraphDialect::IsStringToHashBucketFast(TFOp op) const { |
900 | StringAttr op_name = op->getName().getIdentifier(); |
901 | return op_name == string_to_hashbucket_fast_; |
902 | } |
903 | |
904 | bool TFGraphDialect::IsSub(TFOp op) const { |
905 | StringAttr op_name = op->getName().getIdentifier(); |
906 | return op_name == sub_; |
907 | } |
908 | |
909 | bool TFGraphDialect::IsSum(TFOp op) const { |
910 | StringAttr op_name = op->getName().getIdentifier(); |
911 | return op_name == sum_; |
912 | } |
913 | |
914 | bool TFGraphDialect::IsSwitch(TFOp op) const { |
915 | StringAttr op_name = op->getName().getIdentifier(); |
916 | return op_name == switch_ || op_name == switch_n_ || op_name == ref_switch_; |
917 | } |
918 | |
919 | bool TFGraphDialect::IsSymbolicGradient(TFOp op) const { |
920 | StringAttr op_name = op->getName().getIdentifier(); |
921 | return op_name == symbolic_gradient_; |
922 | } |
923 | |
924 | bool TFGraphDialect::IsTanh(TFOp op) const { |
925 | StringAttr op_name = op->getName().getIdentifier(); |
926 | return op_name == tanh_; |
927 | } |
928 | |
929 | bool TFGraphDialect::IsTanhGrad(TFOp op) const { |
930 | StringAttr op_name = op->getName().getIdentifier(); |
931 | return op_name == tanh_grad_; |
932 | } |
933 | |
934 | bool TFGraphDialect::IsTile(TFOp op) const { |
935 | StringAttr op_name = op->getName().getIdentifier(); |
936 | return op_name == tile_; |
937 | } |
938 | |
939 | bool TFGraphDialect::IsTranspose(TFOp op) const { |
940 | StringAttr op_name = op->getName().getIdentifier(); |
941 | return op_name == transpose_; |
942 | } |
943 | |
944 | bool TFGraphDialect::IsTruncateDiv(TFOp op) const { |
945 | StringAttr op_name = op->getName().getIdentifier(); |
946 | return op_name == truncate_div_; |
947 | } |
948 | |
949 | bool TFGraphDialect::IsTruncateMod(TFOp op) const { |
950 | StringAttr op_name = op->getName().getIdentifier(); |
951 | return op_name == truncate_mod_; |
952 | } |
953 | |
954 | bool TFGraphDialect::IsUnique(TFOp op) const { |
955 | StringAttr op_name = op->getName().getIdentifier(); |
956 | return op_name == unique_ || op_name == unique_v2_; |
957 | } |
958 | |
959 | bool TFGraphDialect::IsUnpack(TFOp op) const { |
960 | StringAttr op_name = op->getName().getIdentifier(); |
961 | return op_name == unpack_; |
962 | } |
963 | |
964 | bool TFGraphDialect::IsVariable(TFOp op) const { |
965 | StringAttr op_name = op->getName().getIdentifier(); |
966 | return op_name == variable_ || op_name == variable_v2_ || |
967 | op_name == auto_reload_variable_ || op_name == var_handle_op_ || |
968 | op_name == var_handles_op_ || IsReadVariableOp(op) || |
969 | IsReadVariablesOp(op); |
970 | } |
971 | |
972 | bool TFGraphDialect::IsWhile(TFOp op) const { |
973 | StringAttr op_name = op->getName().getIdentifier(); |
974 | return op_name == while_ || op_name == stateless_while_; |
975 | } |
976 | |
977 | bool TFGraphDialect::IsXdivy(TFOp op) const { |
978 | StringAttr op_name = op->getName().getIdentifier(); |
979 | return op_name == xdivy_; |
980 | } |
981 | |
982 | bool TFGraphDialect::IsZerosLike(TFOp op) const { |
983 | StringAttr op_name = op->getName().getIdentifier(); |
984 | return op_name == zeros_like_; |
985 | } |
986 | |
987 | bool TFGraphDialect::IsZeta(TFOp op) const { |
988 | StringAttr op_name = op->getName().getIdentifier(); |
989 | return op_name == zeta_; |
990 | } |
991 | |
992 | } // namespace tfg |
993 | } // namespace mlir |
994 |