1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ |
17 | #define TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ |
18 | |
19 | #include <cstddef> |
20 | |
21 | #include "llvm/ADT/iterator_range.h" |
22 | #include "mlir/IR/Operation.h" // from @llvm-project |
23 | #include "mlir/IR/OperationSupport.h" // from @llvm-project |
24 | #include "mlir/IR/TypeRange.h" // from @llvm-project |
25 | #include "tensorflow/core/ir/dialect.h" |
26 | #include "tensorflow/core/ir/types/dialect.h" |
27 | #include "tensorflow/core/ir/utility.h" |
28 | |
29 | namespace mlir { |
30 | namespace detail { |
31 | // This class iterates over the control dependencies of the values. |
32 | template <typename ValueIteratorT> |
33 | class ControlRetIterator final |
34 | : public llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>, |
35 | ValueIteratorT, Value> { |
36 | public: |
37 | using llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>, |
38 | ValueIteratorT, Value>::mapped_iterator_base; |
39 | |
40 | Value mapElement(Value value) const { |
41 | return value.getType().isa<tf_type::ControlType>() |
42 | ? value |
43 | : tfg::LookupControlDependency(value); |
44 | } |
45 | }; |
46 | } // namespace detail |
47 | |
48 | namespace tfg { |
49 | |
50 | // Wrapper class exposing convenience methods to manipulate TensorFlow graph |
51 | // nodes uniformly. |
52 | class TFOp { |
53 | public: |
54 | // Wrap an operation. The operation can be null. The constructor must be |
55 | // marked as implicit to support `llvm::dyn_cast`. |
56 | TFOp(Operation *op = nullptr); // NOLINT |
57 | |
58 | explicit TFOp(Operation &op) : TFOp(&op) {} |
59 | |
60 | // Support LLVM-style RTTI. |
61 | static bool classof(Operation *op) { |
62 | return isa<TFGraphDialect>(op->getDialect()); |
63 | } |
64 | |
65 | // Get the wrapped operation. |
66 | Operation *getOperation() { return op_; } |
67 | |
68 | // Returns a pointer to the TensorFlow Graph Dialect. It nevers returns |
69 | // nullptr. |
70 | TFGraphDialect *getDialect() { |
71 | return cast<TFGraphDialect>(op_->getDialect()); |
72 | } |
73 | |
74 | // Split the operands into data and control operands. |
75 | std::pair<OperandRange, OperandRange> splitOperands() { |
76 | ControlType ctl_type = getDialect()->getControlType(); |
77 | return SplitDataAndControlValues(op_->getOperands(), ctl_type); |
78 | } |
79 | |
80 | // Returns the regular operands, the control operands will be excluded. |
81 | OperandRange getNonControlOperands() { return splitOperands().first; } |
82 | |
83 | // The control operands are always after the regular inputs. |
84 | OperandRange getControlOperands() { return splitOperands().second; } |
85 | |
86 | // Returns the control token produced by this operation. |
87 | Value controlRet() { return op_->getResult(op_->getNumResults() - 1); } |
88 | |
89 | // Returns the non-control results produced by this operation. |
90 | ResultRange getNonControlResults() { |
91 | return op_->getResults().slice(0, op_->getNumResults() - 1); |
92 | } |
93 | |
94 | // Returns the node name for this operation. |
95 | StringAttr nameAttr(); |
96 | StringRef name(); |
97 | // Set a new node name for this operation. |
98 | void setName(const Twine &name); |
99 | void setName(StringAttr name); |
100 | |
101 | // Returns the requested device, which is also the "device" field in a |
102 | // GraphDef. |
103 | StringAttr requestedDeviceAttr(); |
104 | StringRef requestedDevice(); |
105 | // Set a new requested device for this operation. |
106 | void setRequestedDevice(const Twine &requested_device); |
107 | void setRequestedDevice(StringAttr requested_device); |
108 | |
109 | // Returns the assigned device, this field is set by placer in general. |
110 | StringAttr assignedDeviceAttr(); |
111 | StringRef assignedDevice(); |
112 | // Set a new assigned device for this operation. |
113 | void setAssignedDevice(const Twine &assigned_device); |
114 | void setAssignedDevice(StringAttr assigned_device); |
115 | |
116 | // Returns the assigned TPU cluster name. |
117 | StringAttr tpuReplicate(); |
118 | // Set the assigned TPU cluster name. |
119 | void setTpuReplicate(StringAttr tpu_replicate); |
120 | |
121 | // Returns the device, preferring the assigned device if set, and the |
122 | // requested device otherwise. |
123 | StringAttr deviceAttr() { |
124 | StringAttr device = assignedDeviceAttr(); |
125 | if (device) { |
126 | assert(!device.getValue().empty()); |
127 | return device; |
128 | } |
129 | return requestedDeviceAttr(); |
130 | } |
131 | StringRef device() { |
132 | StringAttr device_attr = deviceAttr(); |
133 | if (device_attr) return device_attr.getValue(); |
134 | return "" ; |
135 | } |
136 | |
137 | // Forward `->` to the underlying operation, exposing the `Operation` methods. |
138 | Operation *operator->() { return op_; } |
139 | Operation &operator*() { return *op_; } |
140 | |
141 | // Converts to true if there is a wrapped operation. |
142 | explicit operator bool() const { return op_; } |
143 | |
144 | private: |
145 | // The wrapped operation. |
146 | Operation *op_; |
147 | }; |
148 | |
149 | // A range iterator to get the control tokens associated with a value range. |
150 | // This range allows to wrap a ValueRange (or an OperandRange) and iterates on |
151 | // the control token associated to the producer of each value. For example, if |
152 | // you wrap the operands of an operation: |
153 | // OperandControlRetRange range = op->getOperands(); |
154 | // iterating this range will yield the control edges from each of the operations |
155 | // (or block arguments) producing these operands. |
156 | template <typename ValueRangeT> |
157 | class ControlRetRange final |
158 | : public llvm::iterator_range< |
159 | ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>> { |
160 | public: |
161 | using Base = llvm::iterator_range< |
162 | ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>>; |
163 | explicit ControlRetRange(ValueRangeT c) : Base(c.begin(), c.end()) {} |
164 | |
165 | /// Return the value at the given index. |
166 | Value operator[](size_t index) const { |
167 | assert(index < size() && "invalid index into value range" ); |
168 | return *(this->begin() + index); |
169 | } |
170 | |
171 | // Return the size of this range. |
172 | size_t size() const { return llvm::size(*this); } |
173 | |
174 | // Return first value in the range. |
175 | Value front() { return (*this)[0]; } |
176 | |
177 | // Compare this range with another. |
178 | template <typename OtherT> |
179 | bool operator==(const OtherT &other) const { |
180 | return llvm::size(*this) == llvm::size(other) && |
181 | std::equal(this->begin(), this->end(), other.begin()); |
182 | } |
183 | template <typename OtherT> |
184 | bool operator!=(const OtherT &other) const { |
185 | return !(*this == other); |
186 | } |
187 | }; |
188 | |
189 | using OperandControlRetRange = ControlRetRange<OperandRange>; |
190 | using ValueControlRetRange = ControlRetRange<ValueRange>; |
191 | |
192 | } // namespace tfg |
193 | } // namespace mlir |
194 | |
195 | #endif // TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_ |
196 | |