1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <algorithm>
17#include <memory>
18#include <string>
19#include <vector>
20
21#include "llvm/ADT/DenseMap.h"
22#include "llvm/ADT/StringRef.h"
23#include "llvm/Support/FormatVariadic.h"
24#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
25#include "mlir/IR/BuiltinOps.h" // from @llvm-project
26#include "mlir/IR/Types.h" // from @llvm-project
27#include "mlir/IR/Visitors.h" // from @llvm-project
28#include "mlir/Pass/Pass.h" // from @llvm-project
29#include "mlir/Support/LogicalResult.h" // from @llvm-project
30#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
31#include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h"
32#include "tensorflow/core/platform/str_util.h"
33#include "tensorflow/dtensor/cc/constants.h"
34#include "tensorflow/dtensor/cc/dtensor_utils.h"
35#include "tensorflow/dtensor/mlir/dtensor_location.h"
36#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
37#include "tensorflow/dtensor/mlir/group_assignment.h"
38#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
39#include "tensorflow/dtensor/mlir/layout_parsing.h"
40
41namespace tensorflow {
42namespace dtensor {
43
44namespace {
45#define GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION
46#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
47
48namespace ops_util = ::mlir::TF::collection_ops_util;
49
50// Pad the merged tensor shape to multiples of 1024B, so delinearization
51// skipping optimization in XLA can get activated.
52constexpr int32 kAllReducePadding = 1024;
53
54// Returns true if `successor` depends on `predecessor`.
55// TODO(jiawenhao): Repeatedly computing dependency sets for a large cluster can
56// get expensive when the number of all-reduces is high. Consider building a
57// cluster-scope op dependency graph ahead of time to amortize the cost.
58bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor) {
59 llvm::SmallVector<mlir::Operation*, 4> to_visit;
60 llvm::SmallPtrSet<mlir::Operation*, 4> visited;
61 to_visit.push_back(predecessor);
62 while (!to_visit.empty()) {
63 mlir::Operation* producer = to_visit.pop_back_val();
64 if (visited.contains(producer)) continue;
65 visited.insert(producer);
66 if (successor == producer) return true;
67 for (mlir::Operation* user : producer->getUsers()) {
68 if (visited.contains(user)) continue;
69 to_visit.push_back(user);
70 }
71 }
72 return false;
73}
74
75// Moves all usages of `a` (direct and transitive) to right after `b` in
76// `cluster`, preserving the original order of moved ops.
77// `a` and `b` must be in `cluster`. `a` must appear before `b` originally.
78// `a` itself is not moved.
79//
80// For example, this program:
81//
82// tf_device.cluster() ({
83// %a = tf.A()
84// %1 = tf.C(%a)
85// %2 = tf.D(%a)
86// %3 = tf.E(%1, %2)
87// %b = tf.B()
88// %4 = tf.F(%3)
89// %5 = tf.G(%b)
90// tf_device.return()
91// })
92//
93// will become this:
94//
95// tf_device.cluster() ({
96// %a = tf.A()
97// %b = tf.B()
98// %1 = tf.C(%a)
99// %2 = tf.D(%a)
100// %3 = tf.E(%1, %2)
101// %4 = tf.F(%3)
102// %5 = tf.G(%b)
103// tf_device.return()
104// })
105void MoveUsagesAfter(mlir::tf_device::ClusterOp cluster, mlir::Operation* a,
106 mlir::Operation* b) {
107 llvm::SmallVector<mlir::Operation*, 4> to_visit;
108 llvm::SmallPtrSet<mlir::Operation*, 4> visited;
109 to_visit.push_back(a);
110 while (!to_visit.empty()) {
111 mlir::Operation* producer = to_visit.pop_back_val();
112 if (visited.contains(producer)) continue;
113 visited.insert(producer);
114 for (mlir::Operation* user : producer->getUsers()) {
115 if (visited.contains(user)) continue;
116 to_visit.push_back(user);
117 }
118 }
119
120 llvm::SmallVector<mlir::Operation*, 4> to_move;
121 cluster.GetBody().walk([&](mlir::Operation* op) {
122 if (op != a && visited.contains(op) && op->isBeforeInBlock(b)) {
123 to_move.push_back(op);
124 }
125 });
126
127 mlir::Operation* last = b;
128 for (mlir::Operation* op : to_move) {
129 if (mlir::dyn_cast<mlir::TF::YieldOp>(op)) {
130 LOG(FATAL) << "Should never move YieldOp"; // Crash OK
131 }
132 op->moveAfter(last);
133 last = op;
134 }
135}
136
137// Merge all-reduces in the group into one all-reduce.
138//
139// Requirements:
140// - The group should have at least two all-reduces.
141// - They should be located next to each other in the parent block.
142// - They should all have the same element type.
143// - They should all have the same group assignment.
144//
145// The merged all-reduce operates on a 1D tensor, whose size is the sum of all
146// merged all-reduce tensors padded to 1024B. (The padding is necessary for the
147// XLA delinearization skipping logic.) Each to-be-merged all-reduce flattens
148// its input tensor and writes the resulting 1D tensor into the corresponding
149// offset in the merged 1D tensor. After the merged all-reduce is done, the
150// reverse happens: results are sliced out and reshaped to the original shape.
151mlir::LogicalResult MergeAllReduceGroup(
152 std::vector<mlir::TF::DTensorAllReduceOp>& all_reduce_group) {
153 // Create the initial all-zero merged tensor.
154 // The merged tensor's size is the sum of all individual all-reduces' sizes.
155 int num_all_reduces = all_reduce_group.size();
156 DCHECK(num_all_reduces > 1)
157 << "All reduce group size expected to be greater than 1.";
158 int total_num_elements = 0;
159 std::vector<llvm::ArrayRef<int64_t>> all_reduce_shapes;
160 all_reduce_shapes.reserve(num_all_reduces);
161 for (mlir::TF::DTensorAllReduceOp& all_reduce : all_reduce_group) {
162 auto all_reduce_ranked_type =
163 all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
164 if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
165 return all_reduce.emitOpError(llvm::formatv(
166 "requires static shape for DTensorAllReduceOp, but got : {0}",
167 all_reduce_ranked_type));
168 }
169 int num_elements = all_reduce_ranked_type.getNumElements();
170 total_num_elements += num_elements;
171 all_reduce_shapes.push_back(all_reduce_ranked_type.getShape());
172 }
173
174 // Pad the merged tensor shape to multiples of 1024B, so delinearization
175 // skipping optimization in XLA can get activated.
176 if (total_num_elements % kAllReducePadding != 0) {
177 total_num_elements =
178 total_num_elements / kAllReducePadding * kAllReducePadding +
179 kAllReducePadding;
180 }
181
182 // Fill the merged tensor with 0 initially.
183 mlir::OpBuilder builder(all_reduce_group[0]);
184 mlir::Location loc = all_reduce_group[0].getLoc();
185 mlir::Type elem_type = all_reduce_group[0].getType().getElementType();
186 auto zero_scalar = ops_util::CreateScalarConst(0, builder, loc);
187 auto zero_scalar_elem_type = builder.create<mlir::TF::CastOp>(
188 loc, mlir::RankedTensorType::get({}, elem_type), zero_scalar);
189 auto merged = builder.create<mlir::TF::FillOp>(
190 loc, ops_util::GetR1Const({total_num_elements}, builder, loc),
191 zero_scalar_elem_type);
192
193 // Store every all-reduce's input at an offset location in the merged tensor,
194 // as a 1D tensor.
195 int offset_num_elements = 0;
196 std::vector<mlir::Type> flattened_types;
197 flattened_types.reserve(num_all_reduces);
198 mlir::Value updated;
199 for (int i = 0; i < all_reduce_group.size(); ++i) {
200 mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
201 mlir::Location loc = all_reduce.getLoc();
202 auto all_reduce_ranked_type =
203 all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
204 if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
205 return all_reduce.emitOpError(llvm::formatv(
206 "requires static shape for DTensorAllReduceOp, but got : {0}",
207 all_reduce_ranked_type));
208 }
209
210 int num_elements = all_reduce_ranked_type.getNumElements();
211 auto flattened = builder.create<mlir::TF::ReshapeOp>(
212 DT_LOC2(loc, "CombinedReduceFlatten"), all_reduce.input(),
213 ops_util::GetR1Const({num_elements}, builder, loc));
214 flattened_types.push_back(flattened.getType());
215 auto indices = ops_util::GetR1Const({offset_num_elements}, builder, loc);
216
217 if (all_reduce.device_type().contains("TPU")) {
218 updated = builder.create<mlir::TF::XlaDynamicUpdateSliceOp>(
219 DT_LOC2(loc, "CombinedReduceUpdateSlice"), merged.getType(),
220 /*input=*/i == 0 ? merged.getResult() : updated,
221 /*update=*/flattened, indices);
222 } else {
223 auto end = ops_util::GetR1Const({offset_num_elements + num_elements},
224 builder, loc);
225 auto strides = ops_util::GetR1Const({1}, builder, loc);
226 updated = builder.create<mlir::TF::TensorStridedSliceUpdateOp>(
227 DT_LOC2(loc, "CombinedReduceUpdateSlice"), merged.getType(),
228 /*input=*/i == 0 ? merged.getResult() : updated, indices, end,
229 strides,
230 /*value=*/flattened);
231 }
232 offset_num_elements += num_elements;
233 }
234
235 // All-reduce the updated merged tensor.
236 auto merged_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
237 all_reduce_group[0].getLoc(), updated.getType(), updated,
238 all_reduce_group[0].group_assignment(), all_reduce_group[0].reduce_op(),
239 all_reduce_group[0].device_type());
240 SetSingleLayoutOnOp(
241 merged_all_reduce,
242 ExtractSingleLayoutFromOp(all_reduce_group[0]).value().value());
243
244 // Slice out the original all-reduces, and reshape back to the original shape.
245 offset_num_elements = 0;
246 std::vector<mlir::TF::ReshapeOp> replacements;
247 replacements.reserve(num_all_reduces);
248 for (int i = 0; i < all_reduce_group.size(); ++i) {
249 mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
250 mlir::Location loc = all_reduce.getLoc();
251 auto all_reduce_ranked_type =
252 all_reduce.getType().dyn_cast<mlir::RankedTensorType>();
253 if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) {
254 return all_reduce.emitOpError(llvm::formatv(
255 "requires static shape for DTensorAllReduceOp, but got : {0}",
256 all_reduce_ranked_type));
257 }
258 int num_elements = all_reduce_ranked_type.getNumElements();
259 auto slice = builder.create<mlir::TF::SliceOp>(
260 DT_LOC2(loc, "PostCombinedReduceSlice"), flattened_types[i],
261 /*input=*/merged_all_reduce,
262 /*begin=*/ops_util::GetR1Const({offset_num_elements}, builder, loc),
263 /*size=*/ops_util::GetR1Const({num_elements}, builder, loc));
264 auto replacement = builder.create<mlir::TF::ReshapeOp>(
265 DT_LOC2(loc, "PostCombinedReduceReshape"), slice.getResult(),
266 ops_util::GetR1Const(all_reduce_shapes[i], builder, loc));
267 replacements.push_back(replacement);
268 offset_num_elements += num_elements;
269 }
270
271 // Replace usages and clean up.
272 for (int i = 0; i < all_reduce_group.size(); ++i) {
273 mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i];
274 mlir::TF::ReshapeOp& replacement = replacements[i];
275 all_reduce.replaceAllUsesWith(replacement.getResult());
276 all_reduce.erase();
277 }
278 return mlir::success();
279}
280
281// Dump the dependencies between AllReduce ops as a DOT graph.
282std::string DrawAllReduceDependencies(
283 std::vector<mlir::TF::DTensorAllReduceOp> all_reduces) {
284 std::vector<std::vector<int>> dependents(all_reduces.size(),
285 std::vector<int>());
286 for (int j = 0; j < all_reduces.size(); ++j) {
287 mlir::TF::DTensorAllReduceOp later = all_reduces[j];
288 for (int i = 0; i < j; ++i) {
289 mlir::TF::DTensorAllReduceOp earlier = all_reduces[i];
290 DCHECK(!DependsOn(earlier, later));
291 if (earlier->getBlock() != later->getBlock() ||
292 DependsOn(later, earlier)) {
293 dependents[i].push_back(j);
294 }
295 }
296 }
297 std::string output = "digraph all_reduces {\n";
298 for (int i = 0; i < dependents.size(); i++) {
299 strings::StrAppend(&output, i);
300 strings::StrAppend(&output, "\n");
301 }
302 for (int i = 0; i < dependents.size(); i++) {
303 for (int j : dependents[i]) {
304 strings::StrAppend(&output, i, " -> ", j, "\n");
305 }
306 }
307 output += "}";
308 return output;
309}
310
311// Combine cross-slice DTensorAllReduce ops of the same element type and group
312// assignment into as few groups as possible. Only independent ops can be
313// combined together.
314//
315// For example, this program:
316//
317// clang-format off
318// NOLINTBEGIN(whitespace/line_length)
319// %0 = "tf_device.cluster"() ({
320// %1 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
321// %2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
322// %3 = "tf.DTensorAllReduce"(%1, %2) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
323// %4 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
324// %5 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
325// %6 = "tf.DTensorAllReduce"(%4, %5) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32>
326// %7 = "tf.Add"(%3, %6) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
327// "tf_device.return"(%7) : (tensor<4x4xf32>) -> ()
328// }) : () -> tensor<4x4xf32>
329// NOLINTEND
330// clang-format on
331//
332// will become this:
333//
334// clang-format off
335// NOLINTBEGIN(whitespace/line_length)
336// %0 = "tf_device.cluster"() ( {
337// %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
338// %cst_0 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
339// %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32>
340// %cst_2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32>
341// %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
342// %1 = "tf.Cast"(%cst_3) {Truncate = false} : (tensor<i32>) -> tensor<f32>
343// %cst_4 = "tf.Const"() {value = dense<1024> : tensor<1xi32>} : () -> tensor<1xi32>
344// %2 = "tf.Fill"(%cst_4, %1) : (tensor<1xi32>, tensor<f32>) -> tensor<1024xf32>
345// %cst_5 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
346// %3 = "tf.Reshape"(%cst, %cst_5) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32>
347// %cst_6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
348// %4 = "tf.XlaDynamicUpdateSlice"(%2, %3, %cst_6) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32>
349// %cst_7 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
350// %5 = "tf.Reshape"(%cst_1, %cst_7) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32>
351// %cst_8 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
352// %6 = "tf.XlaDynamicUpdateSlice"(%4, %5, %cst_8) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32>
353// %7 = "tf.DTensorAllReduce"(%6, %cst_0) {reduce_op = "Add"} : (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32>
354// %cst_9 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32>
355// %cst_10 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
356// %8 = "tf.Slice"(%7, %cst_9, %cst_10) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32>
357// %cst_11 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32>
358// %9 = "tf.Reshape"(%8, %cst_11) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32>
359// %cst_12 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
360// %cst_13 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32>
361// %10 = "tf.Slice"(%7, %cst_12, %cst_13) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32>
362// %cst_14 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32>
363// %11 = "tf.Reshape"(%10, %cst_14) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32>
364// %12 = "tf.Add"(%9, %11) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
365// tf_device.return %12 : tensor<4x4xf32>
366// }) : () -> tensor<4x4xf32>
367// NOLINTEND
368// clang-format on
369mlir::LogicalResult CombineAllReduceOpsOfSameTypeAndGroupAssignment(
370 mlir::tf_device::ClusterOp cluster,
371 const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) {
372 // Drop within-slice all-reduces.
373 std::vector<mlir::TF::DTensorAllReduceOp> cross_slice_all_reduces;
374 for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) {
375 mlir::DenseIntElementsAttr group_assignment_attr;
376 if (!matchPattern(all_reduce.group_assignment(),
377 m_Constant(&group_assignment_attr))) {
378 return all_reduce.emitOpError("group_assignment should be a constant");
379 }
380 // LINT.IfChange
381 // TODO(ishark): Confirm the right check for GPUs.
382 int num_slices = NumClients();
383 int slice_size = kTpuDonutSize;
384 if (group_assignment_attr.getNumElements() < kTpuDonutSize) {
385 DCHECK_EQ(num_slices, 1) << "Num slices expected to be equal to 1.";
386 slice_size = group_assignment_attr.getNumElements();
387 }
388 StatusOr<GroupAssignment> group_assignment = GroupAssignment::FromMLIR(
389 group_assignment_attr,
390 GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap(
391 num_slices, slice_size));
392 // LINT.ThenChange(//tensorflow/dtensor/mlir/utils/collective_lowering_google.inc)
393 if (!group_assignment.ok()) {
394 return all_reduce.emitOpError(
395 llvm::formatv("Failed to create a GroupAssignment due to {0}",
396 group_assignment.status().error_message()));
397 }
398 // Unit tests have only one slice. Always combine all all-reduces in them.
399 if (group_assignment->num_slices() == 1 ||
400 !group_assignment->IsWithinSlices()) {
401 cross_slice_all_reduces.push_back(all_reduce);
402 }
403 }
404
405 // A single op has nothing to combine with.
406 int num_all_reduces = cross_slice_all_reduces.size();
407 if (num_all_reduces <= 1) return mlir::success();
408
409 // Export the all reduces as a DOT graph.
410 VLOG(4) << "Visualizing AllReduce dependencies:\n"
411 << DrawAllReduceDependencies(cross_slice_all_reduces);
412
413 // Build a reverse adjacency matrix from dependents to requirements.
414 std::vector<std::vector<int>> requirements(num_all_reduces,
415 std::vector<int>());
416 for (int i = 0; i < num_all_reduces - 1; ++i) {
417 mlir::TF::DTensorAllReduceOp requirement = cross_slice_all_reduces[i];
418 for (int j = i + 1; j < num_all_reduces; ++j) {
419 mlir::TF::DTensorAllReduceOp dependent = cross_slice_all_reduces[j];
420 DCHECK(
421 !DependsOn(requirement, dependent)); // guaranteed by program order
422 // In this example, all three DTensorAllReduce ops are independent from
423 // each other according to MLIR value use-def chains considered by
424 // DependsOn. However, moving all three to after the WhileRegion and
425 // combine them would break the program.
426 //
427 // %3 = tf.DTensorAllReduce(%1, %2)
428 // %4 = tf.WhileRegion(%1) ({
429 // ^bb0(%arg):
430 // %5 = tf.TooBool(%arg)
431 // tf.Yield(%5)
432 // }, {
433 // %6 = tf.DTensorAllReduce(%1, %2)
434 // tf.Yield(%5)
435 // })
436 // %7 = tf.DTensorAllReduce(%1, %2)
437 //
438 // Therefore, in addition to DependsOn, we also check if two
439 // DTensorAllReduceOps belong to different blocks. If they do, since they
440 // exist in the same ClusterOp, one or both of them must be inside a
441 // control flow region block. We treat them as if there is a dependency
442 // between them.
443 //
444 // In the example above, the second DTensorAllReduceOp would "depend on"
445 // the first one, and the third on the second. This effectively prevents
446 // any two DTensorAllReduce from merging together.
447 if (requirement->getBlock() != dependent->getBlock() ||
448 DependsOn(dependent, requirement)) {
449 requirements[j].push_back(i);
450 }
451 }
452 }
453
454 // Traverse the adjacency matrix layer by layer to find combination groups.
455 std::vector<std::vector<mlir::TF::DTensorAllReduceOp>> all_reduce_groups;
456 std::set<int> fulfilled;
457 while (fulfilled.size() < cross_slice_all_reduces.size()) {
458 std::vector<int> fulfilled_this_layer;
459 for (int j = 0; j < requirements.size(); ++j) {
460 if (fulfilled.count(j) > 0) continue;
461 bool requirements_met = true;
462 for (int i : requirements[j]) {
463 if (fulfilled.count(i) == 0) {
464 requirements_met = false;
465 break;
466 }
467 }
468 if (requirements_met) {
469 fulfilled_this_layer.push_back(j);
470 }
471 }
472 VLOG(4) << "Fulfilled: " << str_util::Join(fulfilled_this_layer, ", ");
473 all_reduce_groups.push_back({});
474 for (int i : fulfilled_this_layer) {
475 fulfilled.insert(i);
476 all_reduce_groups.back().push_back(cross_slice_all_reduces[i]);
477 }
478 }
479 VLOG(2) << num_all_reduces << " all-reduce ops in "
480 << all_reduce_groups.size() << " groups";
481
482 // Move all-reduces in the same group together and combine them.
483 for (auto& all_reduce_group : all_reduce_groups) {
484 int num_all_reduces = all_reduce_group.size();
485 if (num_all_reduces <= 1) continue;
486 mlir::TF::DTensorAllReduceOp final_all_reduce =
487 all_reduce_group[num_all_reduces - 1];
488 for (int i = num_all_reduces - 2; i >= 0; --i) {
489 mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i];
490 MoveUsagesAfter(cluster, all_reduce, final_all_reduce);
491 }
492 for (int i = 0; i < num_all_reduces - 1; ++i) {
493 mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i];
494 all_reduce->moveBefore(final_all_reduce);
495 }
496 auto merge_result = MergeAllReduceGroup(all_reduce_group);
497 if (merge_result.failed()) return merge_result;
498 }
499
500 return mlir::success();
501}
502
503// Returns true if both group assignments are constant and equal.
504bool same_group_assignments(mlir::Value group_assignment_a,
505 mlir::Value group_assignment_b) {
506 if (group_assignment_a == group_assignment_b) {
507 return true;
508 }
509 mlir::DenseIntElementsAttr attr_a;
510 if (!matchPattern(group_assignment_a, m_Constant(&attr_a))) {
511 return false;
512 }
513 mlir::DenseIntElementsAttr attr_b;
514 if (!matchPattern(group_assignment_b, m_Constant(&attr_b))) {
515 return false;
516 }
517 if (attr_a.getType().getShape() != attr_b.getType().getShape()) {
518 return false;
519 }
520 return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end());
521}
522
523// Combines DTensorAllReduce ops of the same element type into as few groups as
524// possible. Only ops with the same group assignment can be combined together.
525mlir::LogicalResult CombineAllReduceOpsOfSameType(
526 mlir::tf_device::ClusterOp cluster,
527 const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) {
528 // Maintain a list of seen group assignments, sorted by first appearance.
529 // Also find and store all-reduces by group assignment. Use the first
530 // mlir::Value that contains a certain group assignment to represent all the
531 // same group assignments.
532 std::vector<mlir::Value> group_assignments;
533 llvm::DenseMap<mlir::Value, std::vector<mlir::TF::DTensorAllReduceOp>>
534 all_reduces_by_group_assignment;
535 for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) {
536 mlir::Value group_assignment = all_reduce.group_assignment();
537 bool seen = false;
538 for (mlir::Value seen_group_assignment : group_assignments) {
539 if (same_group_assignments(group_assignment, seen_group_assignment)) {
540 group_assignment = seen_group_assignment;
541 seen = true;
542 break;
543 }
544 }
545 if (!seen) group_assignments.push_back(group_assignment);
546 all_reduces_by_group_assignment[group_assignment].push_back(all_reduce);
547 }
548
549 // Combine all-reduces of the same group assignment in first-appearance order.
550 for (mlir::Value group_assignment : group_assignments) {
551 mlir::LogicalResult result =
552 CombineAllReduceOpsOfSameTypeAndGroupAssignment(
553 cluster, all_reduces_by_group_assignment[group_assignment]);
554 if (mlir::failed(result)) return result;
555 }
556
557 return mlir::success();
558}
559
560struct DTensorAllReduceCombineOptimization
561 : public impl::DTensorAllReduceCombineOptimizationBase<
562 DTensorAllReduceCombineOptimization> {
563 void runOnOperation() override {
564 mlir::func::FuncOp function = getOperation();
565 function.walk([&](mlir::tf_device::ClusterOp cluster) {
566 // Maintain a list of seen element types, sorted by first appearance.
567 // Also find and store all-reduces by element type.
568 std::vector<mlir::Type> elem_types;
569 llvm::DenseMap<mlir::Type, std::vector<mlir::TF::DTensorAllReduceOp>>
570 all_reduces_by_elem_type;
571 cluster.GetBody().walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
572 mlir::Type elem_type = all_reduce.getType().getElementType();
573 if (std::find(elem_types.begin(), elem_types.end(), elem_type) ==
574 elem_types.end()) {
575 elem_types.push_back(elem_type);
576 }
577 all_reduces_by_elem_type[elem_type].push_back(all_reduce);
578 });
579
580 // Combine all-reduces of the same element type in first-appearance order.
581 for (mlir::Type elem_type : elem_types) {
582 // Combine all-reduces for the same attribute reduce_op for the element
583 // type.
584 auto& all_reduces_for_elem_type = all_reduces_by_elem_type[elem_type];
585 llvm::DenseMap<llvm::StringRef,
586 std::vector<mlir::TF::DTensorAllReduceOp>>
587 all_reduces_by_attr_reduce_op;
588 for (mlir::TF::DTensorAllReduceOp all_reduce :
589 all_reduces_for_elem_type) {
590 llvm::StringRef attr_reduce_op = all_reduce.reduce_op();
591 all_reduces_by_attr_reduce_op[attr_reduce_op].push_back(all_reduce);
592 }
593 for (auto& all_reduces_to_merge : all_reduces_by_attr_reduce_op) {
594 if (mlir::failed(CombineAllReduceOpsOfSameType(
595 cluster, all_reduces_to_merge.second))) {
596 return signalPassFailure();
597 }
598 }
599 }
600 });
601 }
602};
603
604} // namespace
605
606std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
607CreateDTensorAllReduceCombineOptimization() {
608 return std::make_unique<DTensorAllReduceCombineOptimization>();
609}
610
611} // namespace dtensor
612} // namespace tensorflow
613