1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "llvm/ADT/STLExtras.h"
17#include "llvm/ADT/SmallVector.h"
18#include "llvm/ADT/StringRef.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20#include "mlir/IR/Attributes.h" // from @llvm-project
21#include "mlir/IR/Builders.h" // from @llvm-project
22#include "mlir/IR/BuiltinOps.h" // from @llvm-project
23#include "mlir/IR/Diagnostics.h" // from @llvm-project
24#include "mlir/IR/Operation.h" // from @llvm-project
25#include "mlir/IR/Types.h" // from @llvm-project
26#include "mlir/Support/LogicalResult.h" // from @llvm-project
27#include "mlir/Transforms/Passes.h" // from @llvm-project
28#include "mlir/Transforms/RegionUtils.h" // from @llvm-project
29#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
30#include "tensorflow/dtensor/cc/constants.h"
31#include "tensorflow/dtensor/cc/tensor_layout.h"
32#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
33#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
34#include "tensorflow/dtensor/mlir/layout_parsing.h"
35
36namespace tensorflow {
37namespace dtensor {
38
39namespace {
40#define GEN_PASS_DEF_DTENSORDEVICEMESHCLUSTERCOARSENING
41#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
42
43constexpr char kMissingMeshAttributeErrorMessage[] =
44 "failed to merge mesh cluster as cluster does not have mesh attribute. "
45 "This is likely due to problem in mesh propagation.";
46
47// Determines whether two adjoining clusters should be merged.
48mlir::LogicalResult ShouldMergeClusters(mlir::tf_device::ClusterOp cluster_a,
49 mlir::tf_device::ClusterOp cluster_b,
50 bool* should_merge) {
51 if (cluster_a->getParentRegion() != cluster_b->getParentRegion()) {
52 *should_merge = false;
53 return mlir::success();
54 }
55
56 auto mesh_a_or_status = ExtractDeviceMeshFromOp(cluster_a.getOperation());
57 if (!mesh_a_or_status.ok())
58 return cluster_a.emitOpError(mesh_a_or_status.status().error_message());
59
60 auto mesh_b_or_status = ExtractDeviceMeshFromOp(cluster_b.getOperation());
61 if (!mesh_b_or_status.ok())
62 return cluster_b.emitOpError(mesh_b_or_status.status().error_message());
63
64 auto mesh_a = mesh_a_or_status.value();
65 auto mesh_b = mesh_b_or_status.value();
66 if (!mesh_a || !mesh_b) {
67 return !mesh_a ? cluster_a.emitOpError(kMissingMeshAttributeErrorMessage)
68 : cluster_b.emitOpError(kMissingMeshAttributeErrorMessage);
69 }
70
71 *should_merge = mesh_a == mesh_b;
72 return mlir::success();
73}
74
75// Moves all ops (except tf_device.return op) inside `src_cluster` to
76// block inside `target_cluster`. Ops are moved before the `exit_op`
77// inside the `target_cluster`.
78void MoveOpsInsideCluster(mlir::tf_device::ClusterOp src_cluster,
79 mlir::tf_device::ClusterOp target_cluster,
80 mlir::Operation* exit_op) {
81 auto& cluster_body = src_cluster.GetBody().getOperations();
82 target_cluster.GetBody().getOperations().splice(
83 exit_op->getIterator(), cluster_body, cluster_body.begin(),
84 std::prev(cluster_body.end()));
85}
86
87// Returns a list of pair of mlir Values that represent <return values of ops
88// inside the merged_cluster, output values of merged cluster>.
89//
90// If outputs of `current_cluster` is used as operands to ops in
91// `merging_cluster`, then make sure to replace operands such that
92// results values from the inner ops of `current_cluster` is used instead.
93//
94// For example,
95// %0 = "tf_device.cluster"() ({
96// %1 = "tf.A"() : () -> tensor<i32>
97// "tf_device.return"(%1) : (tensor<i32>) -> ()
98// }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>)
99//
100// %2 = "tf_device.cluster"() ({
101// %3 = "tf.B"(%0) : (tenosr<i32>) -> tensor<f32>
102// "tf_device.return"(%3) : (tensor<f32>) -> ()
103// }) { mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<f32>)
104//
105// will be:
106// %0 = "tf_device.cluster"() ({
107// %1 = "tf.A"() : () -> tensor<i32>
108//
109// # NOTE: `tf.B` op now takes operand directly from
110// # `tf.A` instead of `tf_dtensor.cluster op.
111// %2 = "tf.B"(%1) : (tenosr<i32>) -> tensor<f32>
112// "tf_device.return"(%1, %2) : (tensor<i32>, tensor<f32>)) -> ()
113// }) {mesh = "mesh_config: cpu[1, 1]"} : () -> (tensor<i32>, tensor<f32>)
114llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8>
115GetMergedMeshClusterResults(mlir::tf_device::ClusterOp current_cluster,
116 mlir::tf_device::ClusterOp merging_cluster) {
117 llvm::SmallVector<std::pair<mlir::Value, mlir::Value>, 8>
118 merged_cluster_results;
119 merged_cluster_results.reserve(current_cluster.getNumResults() +
120 merging_cluster.getNumResults());
121
122 auto current_cluster_return_op = current_cluster.GetBody().getTerminator();
123 for (auto result : llvm::zip(current_cluster_return_op->getOpOperands(),
124 current_cluster.getResults())) {
125 mlir::Value inner_op_result = std::get<0>(result).get();
126 mlir::Value outer_op_result = std::get<1>(result);
127
128 // If the output value of `current_cluster` is only used by ops
129 // inside the `merged_cluster`, do not add the value as a return
130 // value for newly created tf_device.cluster op.
131 bool result_only_used_by_merging_cluster = true;
132 for (auto& use : llvm::make_early_inc_range(outer_op_result.getUses())) {
133 if (merging_cluster.GetBody().findAncestorOpInBlock(*use.getOwner())) {
134 use.set(inner_op_result);
135 } else {
136 result_only_used_by_merging_cluster = false;
137 }
138 }
139
140 if (!result_only_used_by_merging_cluster) {
141 merged_cluster_results.emplace_back(inner_op_result, outer_op_result);
142 }
143 }
144
145 auto merging_cluster_return_op = merging_cluster.GetBody().getTerminator();
146 for (auto result : llvm::zip(merging_cluster_return_op->getOpOperands(),
147 merging_cluster.getResults())) {
148 mlir::Value inner_op_result = std::get<0>(result).get();
149 mlir::Value outer_op_result = std::get<1>(result);
150
151 if (!outer_op_result.getUses().empty())
152 merged_cluster_results.emplace_back(inner_op_result, outer_op_result);
153 }
154
155 return merged_cluster_results;
156}
157
158// Updates the users of `merging_cluster` so that they use values
159// from `merged_cluster` instead.
160void ReplaceOperandUsagesWithMergedClusterOutputs(
161 const llvm::SmallVectorImpl<mlir::Value>& values_to_replace,
162 mlir::tf_device::ClusterOp merged_cluster) {
163 for (auto result :
164 llvm::zip(values_to_replace, merged_cluster.getResults())) {
165 std::get<0>(result).replaceAllUsesWith(std::get<1>(result));
166 }
167}
168
169// Creates a new tf_device.cluster op that merges
170// `current_cluster` and `merging_cluster`.
171mlir::LogicalResult CreateMergedMeshCluster(
172 mlir::OpBuilder* builder, mlir::tf_device::ClusterOp current_cluster,
173 mlir::tf_device::ClusterOp merging_cluster,
174 mlir::tf_device::ClusterOp* merged_cluster) {
175 auto return_values =
176 GetMergedMeshClusterResults(current_cluster, merging_cluster);
177
178 llvm::SmallVector<mlir::Type, 8> merged_cluster_output_types;
179 llvm::SmallVector<mlir::Value, 8> merged_cluster_output_values;
180 llvm::SmallVector<mlir::Value, 8> output_values_to_replace;
181 merged_cluster_output_types.reserve(return_values.size());
182 merged_cluster_output_values.reserve(return_values.size());
183 output_values_to_replace.reserve(return_values.size());
184 for (auto cluster_return_value : return_values) {
185 auto inner_op_return_value = std::get<0>(cluster_return_value);
186 merged_cluster_output_types.emplace_back(inner_op_return_value.getType());
187 merged_cluster_output_values.emplace_back(inner_op_return_value);
188 output_values_to_replace.emplace_back(std::get<1>(cluster_return_value));
189 }
190
191 *merged_cluster = builder->create<mlir::tf_device::ClusterOp>(
192 current_cluster.getLoc(), merged_cluster_output_types);
193 auto mesh_attr = current_cluster->getAttrOfType<mlir::StringAttr>(kMeshAttr);
194 if (!mesh_attr)
195 return current_cluster.emitOpError(kMissingMeshAttributeErrorMessage);
196
197 (*merged_cluster)->setAttr(kMeshAttr, mesh_attr);
198
199 // Create a terminator op that returns all return values from
200 // `current_cluster` and `merging_cluster`.
201 merged_cluster->getBody().push_back(new mlir::Block);
202 builder->setInsertionPointToEnd(&merged_cluster->GetBody());
203 builder->create<mlir::tf_device::ReturnOp>(merged_cluster->getLoc(),
204 merged_cluster_output_values);
205
206 // Make sure to replace usages of tf_device.cluster ops to be merged-away with
207 // newly created tf_device.cluster op.
208 ReplaceOperandUsagesWithMergedClusterOutputs(output_values_to_replace,
209 *merged_cluster);
210
211 return mlir::success();
212}
213
214// Merges `current_cluster` and `merging_cluster` and returns a new merged
215// tf_device.cluster.
216mlir::LogicalResult MergeClusters(mlir::OpBuilder* builder,
217 mlir::tf_device::ClusterOp current_cluster,
218 mlir::tf_device::ClusterOp merging_cluster,
219 mlir::tf_device::ClusterOp* merged_cluster) {
220 builder->setInsertionPoint(current_cluster);
221
222 // Create new tf_device.cluster op that outputs results of both
223 // `current_cluster` and `merging_cluster`.
224 if (mlir::failed(CreateMergedMeshCluster(builder, current_cluster,
225 merging_cluster, merged_cluster)))
226 return mlir::failure();
227
228 // Move all ops to newly created merged cluster.
229 auto exit_op = merged_cluster->GetBody().getTerminator();
230 MoveOpsInsideCluster(current_cluster, *merged_cluster, exit_op);
231 MoveOpsInsideCluster(merging_cluster, *merged_cluster, exit_op);
232
233 // Remove mesh clusters as they are now merged to a new cluster.
234 current_cluster.erase();
235 merging_cluster.erase();
236 return mlir::success();
237}
238
239// Loops through tf_device.Cluster ops and merge clusters with same execution
240// device set.
241mlir::LogicalResult ClusterDeviceClusterOpsInBlock(mlir::OpBuilder* builder,
242 mlir::Block* block) {
243 llvm::SmallVector<mlir::tf_device::ClusterOp, 4> block_ops;
244 block->walk([&](mlir::Operation* op) {
245 if (auto cluster = llvm::dyn_cast<mlir::tf_device::ClusterOp>(op))
246 block_ops.emplace_back(cluster);
247 });
248
249 llvm::Optional<mlir::tf_device::ClusterOp> current_cluster;
250 for (mlir::tf_device::ClusterOp cluster :
251 llvm::make_early_inc_range(block_ops)) {
252 if (!current_cluster.has_value()) {
253 current_cluster = cluster;
254 continue;
255 }
256 bool should_merge;
257 if (failed(ShouldMergeClusters(*current_cluster, cluster, &should_merge)))
258 return mlir::failure();
259
260 if (should_merge) {
261 mlir::tf_device::ClusterOp new_cluster;
262 if (mlir::failed(
263 MergeClusters(builder, *current_cluster, cluster, &new_cluster)))
264 return mlir::failure();
265
266 current_cluster.emplace(new_cluster);
267 } else {
268 current_cluster.emplace(cluster);
269 }
270 }
271 return mlir::success();
272}
273
274} // namespace
275
276// MLIR pass that merges cluster ops with the same mesh attribute.
277struct DTensorDeviceMeshClusterCoarsening
278 : public impl::DTensorDeviceMeshClusterCoarseningBase<
279 DTensorDeviceMeshClusterCoarsening> {
280 void runOnOperation() override {
281 mlir::MLIRContext& context = getContext();
282 mlir::OpBuilder builder(&context);
283 for (mlir::Block& block : getOperation())
284 if (mlir::failed(ClusterDeviceClusterOpsInBlock(&builder, &block)))
285 return signalPassFailure();
286 }
287};
288
289std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
290CreateDTensorDeviceMeshClusterCoarsening() {
291 return std::make_unique<DTensorDeviceMeshClusterCoarsening>();
292}
293
294} // namespace dtensor
295} // namespace tensorflow
296