1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/ir/tf_op_wrapper.h" |
17 | |
18 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
19 | #include "tensorflow/core/ir/dialect.h" |
20 | |
21 | namespace mlir { |
22 | namespace tfg { |
23 | |
24 | TFOp::TFOp(Operation *op) : op_(op) { |
25 | assert(!op || classof(op) && "expected a TFG op" ); |
26 | } |
27 | |
28 | StringAttr TFOp::nameAttr() { |
29 | return op_->getAttrOfType<StringAttr>(getDialect()->getNameAttrIdentifier()); |
30 | } |
31 | |
32 | StringRef TFOp::name() { return nameAttr().getValue(); } |
33 | |
34 | void TFOp::setName(const Twine &name) { |
35 | setName(StringAttr::get(op_->getContext(), name.str())); |
36 | } |
37 | |
38 | void TFOp::setName(StringAttr name) { |
39 | op_->setAttr(getDialect()->getNameAttrIdentifier(), name); |
40 | } |
41 | |
42 | StringAttr TFOp::requestedDeviceAttr() { |
43 | return op_->getAttrOfType<StringAttr>( |
44 | getDialect()->getDeviceAttrIdentifier()); |
45 | } |
46 | |
47 | StringRef TFOp::requestedDevice() { return requestedDeviceAttr().getValue(); } |
48 | |
49 | void TFOp::setRequestedDevice(const Twine &device) { |
50 | setRequestedDevice(StringAttr::get(op_->getContext(), device.str())); |
51 | } |
52 | |
53 | void TFOp::setRequestedDevice(StringAttr device) { |
54 | op_->setAttr(getDialect()->getDeviceAttrIdentifier(), device); |
55 | } |
56 | |
57 | StringAttr TFOp::assignedDeviceAttr() { |
58 | return op_->getAttrOfType<StringAttr>( |
59 | getDialect()->getAssignedDeviceAttrIdentifier()); |
60 | } |
61 | |
62 | StringRef TFOp::assignedDevice() { return assignedDeviceAttr().getValue(); } |
63 | |
64 | void TFOp::setAssignedDevice(const Twine &device) { |
65 | setAssignedDevice(StringAttr::get(op_->getContext(), device.str())); |
66 | } |
67 | |
68 | void TFOp::setAssignedDevice(StringAttr device) { |
69 | op_->setAttr(getDialect()->getAssignedDeviceAttrIdentifier(), device); |
70 | } |
71 | |
72 | StringAttr TFOp::tpuReplicate() { |
73 | return op_->getAttrOfType<StringAttr>("_tpu_replicate" ); |
74 | } |
75 | |
76 | void TFOp::setTpuReplicate(StringAttr tpu_replicate) { |
77 | op_->setAttr("_tpu_replicate" , tpu_replicate); |
78 | } |
79 | |
80 | } // namespace tfg |
81 | } // namespace mlir |
82 | |