1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
17#include "mlir/IR/Builders.h" // from @llvm-project
18#include "mlir/IR/BuiltinOps.h" // from @llvm-project
19#include "mlir/IR/Operation.h" // from @llvm-project
20#include "mlir/IR/Types.h" // from @llvm-project
21#include "mlir/IR/Value.h" // from @llvm-project
22#include "mlir/IR/Visitors.h" // from @llvm-project
23#include "mlir/Pass/Pass.h" // from @llvm-project
24#include "mlir/Pass/PassManager.h" // from @llvm-project
25#include "mlir/Support/LogicalResult.h" // from @llvm-project
26#include "mlir/Transforms/Passes.h" // from @llvm-project
27#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
28#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
29#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30#include "tensorflow/dtensor/cc/constants.h"
31#include "tensorflow/dtensor/mlir/device_utils.h"
32#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33#include "tensorflow/dtensor/mlir/dtensor_send_recv.h"
34#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
35#include "tensorflow/dtensor/mlir/layout_parsing.h"
36#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
37#include "tensorflow/dtensor/mlir/value_utils.h"
38
39namespace tensorflow {
40namespace dtensor {
41
42namespace {
43#define GEN_PASS_DEF_DTENSORLOWERSENDRECV
44#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
45
46constexpr char kMissingMeshErrorMsg[] =
47 "Failed to extract mesh for DTensorMergeCluster pass. "
48 "All clusters must have specified mesh.";
49
50// Extracts mesh from `cluster`.
51mlir::LogicalResult ExtractMeshFromCluster(mlir::tf_device::ClusterOp cluster,
52 Mesh* mesh_output) {
53 auto mesh_or_status = ExtractDeviceMeshFromOp(cluster);
54 if (!mesh_or_status.ok()) return cluster.emitOpError(kMissingMeshErrorMsg);
55
56 const absl::optional<Mesh>& mesh_or_null = *mesh_or_status;
57 if (!mesh_or_null.has_value())
58 return cluster.emitOpError(kMissingMeshErrorMsg);
59
60 *mesh_output = mesh_or_null.value();
61 return mlir::success();
62}
63
64// Find all DTesorSend/Recv ops and lower into TF/XLA Send/Recv operations with
65// execution kernels.
66mlir::LogicalResult LowerDTensorSendRecvsOps(mlir::ModuleOp module) {
67 mlir::LogicalResult result = mlir::success();
68 module.walk([&](mlir::TF::DTensorSend send_op) {
69 if (mlir::failed(result)) return;
70
71 auto recv_op = GetCorrespondingDTensorSendRecvOp<mlir::TF::DTensorSend>(
72 module, send_op);
73 if (!recv_op.ok()) {
74 result = send_op.emitOpError(recv_op.status().error_message());
75 return;
76 }
77 auto dtensor_recv = llvm::dyn_cast<mlir::TF::DTensorRecv>(*recv_op);
78 if (!dtensor_recv) {
79 result = send_op.emitOpError(
80 "Cannot find a matching DTensorRecv op for this DTensorSend op");
81 return;
82 }
83 const Mesh recv_mesh = dtensor_recv.layout().mesh();
84
85 Mesh send_mesh;
86 if (mlir::failed(ExtractMeshFromCluster(
87 send_op->getParentOfType<mlir::tf_device::ClusterOp>(),
88 &send_mesh))) {
89 result = mlir::failure();
90 return;
91 }
92
93 if (!send_mesh.is_tpu_mesh() && !recv_mesh.is_tpu_mesh()) {
94 result = send_op->emitOpError(
95 "Multi-mesh tensor transfer between non-xla devices are not yet "
96 "supported.");
97 return;
98 }
99
100 const Layout recv_layout =
101 Layout::ReplicatedOnMesh(recv_mesh, ValueRank(dtensor_recv.output()));
102 const Layout send_input_layout =
103 Layout::ReplicatedOnMesh(send_mesh, ValueRank(send_op.input()));
104
105 StatusOr<mlir::Operation*> lowered_recv =
106 LowerDTensorRecvToXlaOp(dtensor_recv);
107 if (!lowered_recv.ok()) {
108 result = dtensor_recv->emitOpError(lowered_recv.status().error_message());
109 return;
110 }
111 dtensor_recv->replaceAllUsesWith(*lowered_recv);
112 dtensor_recv.erase();
113
114 auto lowered_send_or =
115 LowerDTensorSendToXlaOp(send_input_layout, send_op.input(), send_op,
116 /*from_spmd_expander=*/false);
117 if (!lowered_send_or.ok()) {
118 result = send_op->emitOpError(lowered_send_or.status().error_message());
119 return;
120 }
121 });
122 return result;
123}
124
125// Adds Identity Op that uses device_id argument as inputs for clusters that
126// does not have device id usages. When send/recv operations exists in
127// tf_device.Clusters to transfer data across mesh clusters, device_id argument
128// is required. However, mlir::func::FuncOp's created by transforming
129// tf_device.Cluster to tf_device.ClusterFunc during ClusterOutlining pass will
130// **not** include device_id as input argument if there are no usages within the
131// cluster op body. As so, add Identity op that uses device_id argument from
132// main function in all tf_device.Clusters so that device_id argument can be
133// retained when converting tf_device.Cluster to functions.
134void PropagateDeviceIdToClusters(mlir::ModuleOp module) {
135 mlir::WalkResult result = module.walk([&](mlir::Operation* op) {
136 if (llvm::isa<mlir::TF::_XlaSendFromHostOp, mlir::TF::_XlaRecvAtHostV2Op,
137 mlir::TF::XlaSendToHostOp, mlir::TF::XlaRecvFromHostOp,
138 mlir::TF::_HostSendOp, mlir::TF::_HostRecvOp,
139 mlir::TF::SendOp, mlir::TF::RecvOp>(op))
140 return mlir::WalkResult::interrupt();
141 return mlir::WalkResult::advance();
142 });
143
144 const bool has_cross_mesh_send_recv = result.wasInterrupted();
145 if (!has_cross_mesh_send_recv) return;
146
147 mlir::func::FuncOp main_func =
148 module.lookupSymbol<mlir::func::FuncOp>("main");
149 auto device_id = DeviceId(main_func);
150
151 module.walk([&](mlir::tf_device::ClusterOp op) {
152 mlir::OpBuilder builder(&op.GetBody().front());
153 builder.create<mlir::TF::IdentityOp>(main_func.getLoc(),
154 device_id->getType(), *device_id);
155 });
156}
157
158// Pass that merges multiple tf_device.Cluster ops for multi-mesh computation
159// into a single cluster. After this pass, exactly one tf_device.Cluster op
160// exists for each device mesh.
161struct DTensorLowerSendRecv
162 : public impl::DTensorLowerSendRecvBase<DTensorLowerSendRecv> {
163 void runOnOperation() override {
164 mlir::MLIRContext& context = getContext();
165 mlir::OpBuilder op_builder(&context);
166 auto module = getOperation();
167
168 // Merging clusters and decomposing control flow may have created new
169 // DTensorSend/DTensorRecv ops. Lower DTensorSend/DTensorRecv ops added by
170 // above transformations.
171 if (mlir::failed(LowerDTensorSendRecvsOps(module)))
172 return signalPassFailure();
173
174 // Ensure that all mesh clusters have at least one usages of device_id
175 // argument from main function to guarantee that device_id argument is
176 // retained after ClusterOutlinging.
177 PropagateDeviceIdToClusters(module);
178 };
179};
180
181} // namespace
182
183std::unique_ptr<mlir::OperationPass<mlir::ModuleOp>>
184CreateDTensorLowerSendRecv() {
185 return std::make_unique<DTensorLowerSendRecv>();
186}
187
188} // namespace dtensor
189} // namespace tensorflow
190