1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_
17#define TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_
18
19#include <cstddef>
20
21#include "llvm/ADT/iterator_range.h"
22#include "mlir/IR/Operation.h" // from @llvm-project
23#include "mlir/IR/OperationSupport.h" // from @llvm-project
24#include "mlir/IR/TypeRange.h" // from @llvm-project
25#include "tensorflow/core/ir/dialect.h"
26#include "tensorflow/core/ir/types/dialect.h"
27#include "tensorflow/core/ir/utility.h"
28
29namespace mlir {
30namespace detail {
31// This class iterates over the control dependencies of the values.
32template <typename ValueIteratorT>
33class ControlRetIterator final
34 : public llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>,
35 ValueIteratorT, Value> {
36 public:
37 using llvm::mapped_iterator_base<ControlRetIterator<ValueIteratorT>,
38 ValueIteratorT, Value>::mapped_iterator_base;
39
40 Value mapElement(Value value) const {
41 return value.getType().isa<tf_type::ControlType>()
42 ? value
43 : tfg::LookupControlDependency(value);
44 }
45};
46} // namespace detail
47
48namespace tfg {
49
50// Wrapper class exposing convenience methods to manipulate TensorFlow graph
51// nodes uniformly.
52class TFOp {
53 public:
54 // Wrap an operation. The operation can be null. The constructor must be
55 // marked as implicit to support `llvm::dyn_cast`.
56 TFOp(Operation *op = nullptr); // NOLINT
57
58 explicit TFOp(Operation &op) : TFOp(&op) {}
59
60 // Support LLVM-style RTTI.
61 static bool classof(Operation *op) {
62 return isa<TFGraphDialect>(op->getDialect());
63 }
64
65 // Get the wrapped operation.
66 Operation *getOperation() { return op_; }
67
68 // Returns a pointer to the TensorFlow Graph Dialect. It nevers returns
69 // nullptr.
70 TFGraphDialect *getDialect() {
71 return cast<TFGraphDialect>(op_->getDialect());
72 }
73
74 // Split the operands into data and control operands.
75 std::pair<OperandRange, OperandRange> splitOperands() {
76 ControlType ctl_type = getDialect()->getControlType();
77 return SplitDataAndControlValues(op_->getOperands(), ctl_type);
78 }
79
80 // Returns the regular operands, the control operands will be excluded.
81 OperandRange getNonControlOperands() { return splitOperands().first; }
82
83 // The control operands are always after the regular inputs.
84 OperandRange getControlOperands() { return splitOperands().second; }
85
86 // Returns the control token produced by this operation.
87 Value controlRet() { return op_->getResult(op_->getNumResults() - 1); }
88
89 // Returns the non-control results produced by this operation.
90 ResultRange getNonControlResults() {
91 return op_->getResults().slice(0, op_->getNumResults() - 1);
92 }
93
94 // Returns the node name for this operation.
95 StringAttr nameAttr();
96 StringRef name();
97 // Set a new node name for this operation.
98 void setName(const Twine &name);
99 void setName(StringAttr name);
100
101 // Returns the requested device, which is also the "device" field in a
102 // GraphDef.
103 StringAttr requestedDeviceAttr();
104 StringRef requestedDevice();
105 // Set a new requested device for this operation.
106 void setRequestedDevice(const Twine &requested_device);
107 void setRequestedDevice(StringAttr requested_device);
108
109 // Returns the assigned device, this field is set by placer in general.
110 StringAttr assignedDeviceAttr();
111 StringRef assignedDevice();
112 // Set a new assigned device for this operation.
113 void setAssignedDevice(const Twine &assigned_device);
114 void setAssignedDevice(StringAttr assigned_device);
115
116 // Returns the assigned TPU cluster name.
117 StringAttr tpuReplicate();
118 // Set the assigned TPU cluster name.
119 void setTpuReplicate(StringAttr tpu_replicate);
120
121 // Returns the device, preferring the assigned device if set, and the
122 // requested device otherwise.
123 StringAttr deviceAttr() {
124 StringAttr device = assignedDeviceAttr();
125 if (device) {
126 assert(!device.getValue().empty());
127 return device;
128 }
129 return requestedDeviceAttr();
130 }
131 StringRef device() {
132 StringAttr device_attr = deviceAttr();
133 if (device_attr) return device_attr.getValue();
134 return "";
135 }
136
137 // Forward `->` to the underlying operation, exposing the `Operation` methods.
138 Operation *operator->() { return op_; }
139 Operation &operator*() { return *op_; }
140
141 // Converts to true if there is a wrapped operation.
142 explicit operator bool() const { return op_; }
143
144 private:
145 // The wrapped operation.
146 Operation *op_;
147};
148
149// A range iterator to get the control tokens associated with a value range.
150// This range allows to wrap a ValueRange (or an OperandRange) and iterates on
151// the control token associated to the producer of each value. For example, if
152// you wrap the operands of an operation:
153// OperandControlRetRange range = op->getOperands();
154// iterating this range will yield the control edges from each of the operations
155// (or block arguments) producing these operands.
156template <typename ValueRangeT>
157class ControlRetRange final
158 : public llvm::iterator_range<
159 ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>> {
160 public:
161 using Base = llvm::iterator_range<
162 ::mlir::detail::ControlRetIterator<typename ValueRangeT::iterator>>;
163 explicit ControlRetRange(ValueRangeT c) : Base(c.begin(), c.end()) {}
164
165 /// Return the value at the given index.
166 Value operator[](size_t index) const {
167 assert(index < size() && "invalid index into value range");
168 return *(this->begin() + index);
169 }
170
171 // Return the size of this range.
172 size_t size() const { return llvm::size(*this); }
173
174 // Return first value in the range.
175 Value front() { return (*this)[0]; }
176
177 // Compare this range with another.
178 template <typename OtherT>
179 bool operator==(const OtherT &other) const {
180 return llvm::size(*this) == llvm::size(other) &&
181 std::equal(this->begin(), this->end(), other.begin());
182 }
183 template <typename OtherT>
184 bool operator!=(const OtherT &other) const {
185 return !(*this == other);
186 }
187};
188
189using OperandControlRetRange = ControlRetRange<OperandRange>;
190using ValueControlRetRange = ControlRetRange<ValueRange>;
191
192} // namespace tfg
193} // namespace mlir
194
195#endif // TENSORFLOW_CORE_IR_TF_OP_WRAPPER_H_
196