1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "llvm/ADT/STLExtras.h"
19#include "llvm/Support/Casting.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "mlir/IR/Block.h" // from @llvm-project
22#include "mlir/IR/Builders.h" // from @llvm-project
23#include "mlir/IR/BuiltinOps.h" // from @llvm-project
24#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
25#include "mlir/IR/Operation.h" // from @llvm-project
26#include "mlir/IR/UseDefLists.h" // from @llvm-project
27#include "mlir/IR/Value.h" // from @llvm-project
28#include "mlir/Support/LogicalResult.h" // from @llvm-project
29#include "mlir/Transforms/Passes.h" // from @llvm-project
30#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
31#include "tensorflow/compiler/mlir/tensorflow/ir/tf_attributes.h"
32#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
33#include "tensorflow/dtensor/cc/constants.h"
34#include "tensorflow/dtensor/mlir/dtensor_dialect/ir/dialect.h"
35#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
36#include "tensorflow/dtensor/mlir/layout_parsing.h"
37#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
38
39namespace tensorflow {
40namespace dtensor {
41
42namespace {
43#define GEN_PASS_DEF_DTENSORHANDLECROSSCLUSTERDEPENDENCIES
44#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
45
46constexpr char kMissingMeshErrorMsg[] =
47 "Failed to extract mesh for DTensorHandleCrossClusterDependencies pass. "
48 "All clusters must have specified mesh.";
49
50constexpr char kInvalidTensorTransferErrorMsg[] =
51 "CopyToMeshOp must be used to send data across mesh.";
52
53constexpr char kInvalidLayoutMsg[] =
54 "found CopyToMesh with invalid layout. Found layout {0}.";
55
56// Extracts mesh from `cluster`.
57mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
58 Mesh* mesh_output) {
59 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
60 if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
61
62 const auto& mesh_or_null = mesh_or_status.value();
63 if (!mesh_or_null.has_value())
64 return cluster.emitOpError(kMissingMeshErrorMsg);
65
66 *mesh_output = mesh_or_null.value();
67 return mlir::success();
68}
69
70// Returns const op if `op` is a const op or DTensorLayoutOp with Const op as
71// input.
72mlir::Operation* GetConstOp(mlir::Operation* op) {
73 if (llvm::isa<mlir::TF::ConstOp>(op)) return op;
74
75 if (auto layout = llvm::dyn_cast<mlir::TF::DTensorLayout>(op)) {
76 mlir::Operation* input_op = layout.input().getDefiningOp();
77 if (input_op && llvm::isa<mlir::TF::ConstOp>(input_op)) return input_op;
78 }
79 return nullptr;
80}
81
82// Creates a clone of `const_op` at the beginning of `cluster` body region and
83// set the output value of cloned op replace output of CopyToMesh op within
84// `cluster`.
85mlir::LogicalResult CloneOpToCluster(mlir::Operation* const_op,
86 mlir::tf_device::ClusterOp cluster,
87 mlir::OpOperand* operand) {
88 auto copy_to_mesh =
89 llvm::dyn_cast<mlir::TF::CopyToMeshOp>(operand->getOwner());
90 assert(copy_to_mesh);
91 const std::string layout_attr = copy_to_mesh.layout().str();
92 StatusOr<Layout> layout = Layout::FromString(layout_attr);
93 if (!layout.ok())
94 return copy_to_mesh.emitOpError(
95 llvm::formatv(kInvalidLayoutMsg, layout_attr));
96
97 mlir::OpBuilder builder(&cluster.GetBody().front());
98 mlir::Operation* cloned_op = builder.clone(*const_op);
99 mlir::TensorType type =
100 cloned_op->getResult(0).getType().cast<mlir::TensorType>();
101 auto layout_op = builder.create<mlir::TF::DTensorLayout>(
102 const_op->getLoc(), cloned_op->getResult(0),
103 mlir::dtensor::LayoutAttr::get(builder.getContext(), *layout),
104 mlir::TF::ShapeAttr::get(builder.getContext(), type));
105
106 copy_to_mesh.output().replaceUsesWithIf(
107 layout_op.output(), [&](mlir::OpOperand& operand) {
108 return cluster.getOperation()->isProperAncestor(operand.getOwner());
109 });
110
111 if (copy_to_mesh->getUsers().empty()) copy_to_mesh.erase();
112
113 return mlir::success();
114}
115
116mlir::LogicalResult GetInputProducingValue(mlir::OpOperand& operand,
117 mlir::Value* val_output) {
118 auto input_value = operand.get().dyn_cast<mlir::OpResult>();
119 if (!input_value) return mlir::success();
120
121 auto input_cluster =
122 llvm::dyn_cast<mlir::tf_device::ClusterOp>(input_value.getOwner());
123 if (input_cluster) {
124 // If value is from another tf_device.cluster output, then query into
125 // the terminator of the input cluster to get mlir::Value from Tensorflow
126 // operation that is producing the value.
127 *val_output = input_cluster.GetBody().getTerminator()->getOperand(
128 input_value.getResultNumber());
129 } else {
130 *val_output = input_value;
131 }
132 return mlir::success();
133}
134
135// Copies constant operation to mesh clusters if there are multiple usages of
136// constants across multiple mesh computations. This is needed for 2 reasons.
137// a) Cloning constants across mesh can reduce send/recvs during execution.
138// b) DTensor SPMD Expansion for some ops (like tf.reduce_sum) requires inputs
139// to computation to be constants.
140mlir::LogicalResult CloneConstantsAcrossMesh(
141 mlir::tf_device::ClusterOp cluster) {
142 auto& body_region = cluster.getBody();
143 Mesh mesh;
144 if (mlir::failed(ExtractMeshFromCluster(cluster, &mesh)))
145 return mlir::failure();
146
147 mlir::LogicalResult result(mlir::success());
148 mlir::visitUsedValuesDefinedAbove(
149 body_region, body_region, [&](mlir::OpOperand* operand) {
150 if (mlir::failed(result)) return;
151
152 mlir::Value input_value;
153 result = GetInputProducingValue(*operand, &input_value);
154 if (mlir::failed(result) || !input_value) return;
155
156 auto input_cluster =
157 input_value.getDefiningOp()
158 ->getParentOfType<mlir::tf_device::ClusterOp>();
159 Mesh input_mesh;
160 if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) {
161 result = mlir::failure();
162 return;
163 }
164
165 if (input_mesh == mesh) return;
166 if (!llvm::isa<mlir::TF::CopyToMeshOp>(operand->getOwner())) {
167 result =
168 operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg);
169 return;
170 }
171
172 mlir::Operation* const_op = GetConstOp(input_value.getDefiningOp());
173 if (const_op) result = CloneOpToCluster(const_op, cluster, operand);
174 });
175
176 return result;
177}
178
179// Transforms CopyToMesh op to a pair of DTensorSend/DTensorRecv operations.
180mlir::LogicalResult LowerToSendRecv(mlir::TF::CopyToMeshOp copy_to_mesh,
181 mlir::MLIRContext* context,
182 int* send_recv_counter) {
183 const mlir::OpResult copied_value =
184 copy_to_mesh.input().cast<mlir::OpResult>();
185 const int result_index = copied_value.getResultNumber();
186 auto src_cluster =
187 llvm::cast<mlir::tf_device::ClusterOp>(copied_value.getDefiningOp());
188 mlir::Value value_to_send =
189 src_cluster.GetBody().getTerminator()->getOperand(result_index);
190
191 // Create DTensorSend op that sends `value_to_send` across mesh cluster.
192 mlir::OpBuilder builder(value_to_send.getParentBlock()->getTerminator());
193
194 const std::string op_key =
195 llvm::formatv("communication_key_{0}_{1}", copy_to_mesh.layout(),
196 *send_recv_counter)
197 .str();
198 const std::string layout_attr = copy_to_mesh.layout().str();
199 auto layout_or_status = Layout::FromString(layout_attr);
200 if (!layout_or_status.ok())
201 return copy_to_mesh.emitOpError(
202 llvm::formatv(kInvalidLayoutMsg, layout_attr));
203
204 // Create send op that sends data from input cluster to target cluster.
205 const Layout& target_layout = layout_or_status.value();
206 builder.create<mlir::TF::DTensorSend>(
207 copy_to_mesh.getLoc(), value_to_send, builder.getStringAttr(op_key),
208 mlir::dtensor::LayoutAttr::get(context, target_layout));
209
210 // Create recv op that recvs data from send op.
211 auto tensor_type = value_to_send.getType().dyn_cast<mlir::TensorType>();
212 if (!tensor_type)
213 return copy_to_mesh.emitOpError(
214 "found CopyToMesh sending value with unknown shape. Inputs to "
215 "CopyToMesh op must have static shape.");
216
217 builder.setInsertionPoint(copy_to_mesh);
218 auto recv_op = builder.create<mlir::TF::DTensorRecv>(
219 copy_to_mesh.getLoc(), value_to_send.getType(),
220 builder.getStringAttr(op_key),
221 mlir::TF::ShapeAttr::get(context, tensor_type),
222 mlir::dtensor::LayoutAttr::get(context, target_layout));
223
224 // Replace value for recv ops for all usages of `copy_to_mesh` op.
225 copy_to_mesh.replaceAllUsesWith(recv_op.output());
226
227 // Remove copy to mesh op.
228 copy_to_mesh.erase();
229
230 *send_recv_counter += 1;
231
232 return mlir::success();
233}
234
235// Lowers tf.CopyToMesh to a pair of DTensorSend/DTensorRecv operations.
236//
237// For example:
238// %0 = "tf_device.cluster"() ({
239// %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
240// tf_device.return %1 : tensor<i32>
241// }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
242//
243// %2 = "tf_device.cluster"() ({
244// %3 = "tf.CopyToMesh"(%0)
245// { layout ="mesh:TPU,x=2,y=2 layout:x,replicated" } :
246// (tensor<i32>) -> (tensor<i32>)
247// %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32>
248// tf_device.return %4 : tensor<i32>
249// }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>)
250// return
251// }
252//
253// Is transformed to:
254//
255// %0 = "tf_device.cluster"() ({
256// %1 = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
257// "tf.DTensorSend"(%1) {...} : (tensor<i32>) -> ()
258// tf_device.return %1 : tensor<i32>
259// }) {_mesh="mesh:CPU,x=2,y=2"} : () -> (tensor<i32>)
260//
261// %2 = "tf_device.cluster"() ({
262// %3 = "tf.DTensorRecv"() {...} : () -> (tensor<i32>)
263// %4 = "tf.Neg"(%3) : (tensor<i32>) -> tensor<i32>
264// tf_device.return %4 : tensor<i32>
265// }) {_mesh="mesh:TPU,x=2,y=2"} : () -> (tensor<i32>)
266// return
267// }
268mlir::LogicalResult ReplaceCopyToMeshWithVirtualSendRecv(
269 mlir::tf_device::ClusterOp cluster, mlir::MLIRContext* context,
270 int* send_recv_counter) {
271 Mesh current_mesh;
272 if (mlir::failed(ExtractMeshFromCluster(cluster, &current_mesh)))
273 return mlir::failure();
274
275 mlir::Region& cluster_region = cluster.getBody();
276 mlir::LogicalResult result = mlir::success();
277
278 mlir::visitUsedValuesDefinedAbove(
279 cluster_region, cluster_region, [&](mlir::OpOperand* operand) {
280 mlir::Value input_value;
281 if (mlir::failed(GetInputProducingValue(*operand, &input_value))) {
282 result = mlir::failure();
283 return;
284 }
285 if (!input_value) return;
286
287 auto input_cluster =
288 input_value.getDefiningOp()
289 ->getParentOfType<mlir::tf_device::ClusterOp>();
290 Mesh input_mesh;
291 if (mlir::failed(ExtractMeshFromCluster(input_cluster, &input_mesh))) {
292 result = mlir::failure();
293 return;
294 }
295
296 if (current_mesh == input_mesh) return;
297
298 // Check that values that cross mesh boundaries go through CopyToMesh
299 // op.
300 mlir::Operation* input_op = operand->getOwner();
301 mlir::TF::CopyToMeshOp copy_to_mesh =
302 llvm::dyn_cast<mlir::TF::CopyToMeshOp>(input_op);
303 if (!copy_to_mesh) {
304 result =
305 operand->getOwner()->emitOpError(kInvalidTensorTransferErrorMsg);
306 return;
307 }
308
309 // Lower CopyToMesh op to a pair of virtual Send/Recv op.
310 if (mlir::failed(
311 LowerToSendRecv(copy_to_mesh, context, send_recv_counter))) {
312 result = mlir::failure();
313 return;
314 }
315 });
316 return result;
317}
318
319struct DTensorHandleCrossClusterDependencies
320 : public impl::DTensorHandleCrossClusterDependenciesBase<
321 DTensorHandleCrossClusterDependencies> {
322 void getDependentDialects(mlir::DialectRegistry& registry) const override {
323 registry.insert<mlir::dtensor::DTensorDialect>();
324 }
325
326 void runOnOperation() override {
327 mlir::MLIRContext& context = getContext();
328 mlir::ModuleOp module = getOperation();
329 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> clusters;
330 module.walk([&](mlir::tf_device::ClusterOp cluster) {
331 clusters.emplace_back(cluster);
332 });
333
334 int send_recv_counter = 0;
335 for (auto cluster : clusters) {
336 if (mlir::failed(CloneConstantsAcrossMesh(cluster)))
337 return signalPassFailure();
338
339 if (mlir::failed(ReplaceCopyToMeshWithVirtualSendRecv(
340 cluster, &context, &send_recv_counter)))
341 return signalPassFailure();
342 }
343
344 // Once CopyToMesh has been lowered to DTensorSend/Recv operations,
345 // tf_device.Cluster may now have dangling/unused result values. Remove all
346 // such return values.
347 for (auto cluster : clusters) RemoveUnusedClusterResults(cluster);
348 }
349};
350
351} // namespace
352
353std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
354CreateDTensorHandleCrossClusterDependencies() {
355 return std::make_unique<DTensorHandleCrossClusterDependencies>();
356}
357
358} // namespace dtensor
359} // namespace tensorflow
360