1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/ir/tf_op_registry.h" |
17 | |
18 | #include "tensorflow/core/framework/op.h" |
19 | #include "tensorflow/core/framework/op_def_builder.h" |
20 | #include "tensorflow/core/ir/ops.h" |
21 | |
22 | namespace mlir { |
23 | namespace tfg { |
24 | TensorFlowOpRegistryInterface::TensorFlowOpRegistryInterface(Dialect *dialect) |
25 | : TensorFlowOpRegistryInterface(dialect, tensorflow::OpRegistry::Global()) { |
26 | } |
27 | |
28 | // Returns true if the op is stateful. |
29 | static bool IsStatefulImpl(const tensorflow::OpRegistry *registry, |
30 | StringRef op_name) { |
31 | const tensorflow::OpRegistrationData *op_reg_data = |
32 | registry->LookUp(op_name.str()); |
33 | // If an op definition was not found, conservatively assume stateful. |
34 | if (!op_reg_data) return true; |
35 | return op_reg_data->op_def.is_stateful(); |
36 | } |
37 | |
38 | bool TensorFlowOpRegistryInterface::isStateful(Operation *op) const { |
39 | // Handle TFG internal ops. |
40 | if (op->hasTrait<OpTrait::IntrinsicOperation>()) return false; |
41 | if (auto func = dyn_cast<GraphFuncOp>(op)) return func.getIsStateful(); |
42 | // Handle TFG region ops. |
43 | // TODO(jeffniu): Region ops should be marked with a trait. |
44 | StringRef op_name = op->getName().stripDialect(); |
45 | if (op->getNumRegions() && op_name.endswith("Region")) |
46 | op_name = op_name.drop_back(/*len("Region")=*/6); |
47 | return IsStatefulImpl(registry_, op_name); |
48 | } |
49 | } // namespace tfg |
50 | } // namespace mlir |
51 |