1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "llvm/ADT/DenseMap.h"
17#include "llvm/ADT/SmallPtrSet.h"
18#include "llvm/Support/FormatVariadic.h"
19#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
20#include "mlir/IR/Builders.h" // from @llvm-project
21#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
22#include "mlir/IR/Operation.h" // from @llvm-project
23#include "mlir/IR/UseDefLists.h" // from @llvm-project
24#include "mlir/IR/Value.h" // from @llvm-project
25#include "mlir/Support/LogicalResult.h" // from @llvm-project
26#include "mlir/Transforms/Passes.h" // from @llvm-project
27#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
28#include "tensorflow/dtensor/mlir/collectives_common.h"
29#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
30#include "tensorflow/dtensor/mlir/group_assignment.h"
31#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32#include "tensorflow/dtensor/mlir/layout_parsing.h"
33#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
34
35namespace tensorflow {
36namespace dtensor {
37
38namespace {
39#define GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION
40#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
41
42// Returns true if both group assignments are constant and equal.
43bool same_group_assignments(mlir::DenseIntElementsAttr attr_a,
44 mlir::DenseIntElementsAttr attr_b) {
45 if (attr_a.getType().getShape() != attr_b.getType().getShape()) {
46 return false;
47 }
48 return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end());
49}
50
51mlir::DenseIntElementsAttr GetScatterGroupAssignment(
52 mlir::TF::DTensorAllScatterOp all_scatter, int scatter_dim) {
53 const Layout original_layout = all_scatter.input_layout();
54 const Layout desired_layout = all_scatter.output_layout();
55 absl::flat_hash_set<std::string> scattered_dims;
56 scattered_dims.insert(desired_layout.sharding_spec(scatter_dim));
57
58 auto partitions =
59 GetAllReducePartitionsFromReducedDims(original_layout, scattered_dims)
60 .value();
61 const int32 num_partitions = partitions.size();
62
63 // Construct a flattened list of scatter partitions.
64 std::vector<int32> partitions_flat;
65 for (auto& p : partitions) {
66 partitions_flat.insert(partitions_flat.end(), p.second.begin(),
67 p.second.end());
68 }
69
70 int32 partition_size = partitions.begin()->second.size();
71 mlir::OpBuilder builder(all_scatter);
72 auto group_shaped_type = mlir::RankedTensorType::get(
73 {num_partitions, partition_size},
74 mlir::IntegerType::get(builder.getContext(), 32));
75
76 return mlir::DenseIntElementsAttr::get(group_shaped_type, partitions_flat);
77}
78
79mlir::LogicalResult ApplyOptimization(mlir::func::FuncOp function) {
80 std::vector<mlir::Operation*> ops_to_delete;
81 function.walk([&](mlir::TF::DTensorAllReduceOp all_reduce) {
82 if (all_reduce->hasOneUse()) {
83 if (auto all_scatter = mlir::dyn_cast<mlir::TF::DTensorAllScatterOp>(
84 *all_reduce->getUsers().begin())) {
85 VLOG(2) << "Found potential AllReduce+AllScatter to fuse.";
86 if (VLOG_IS_ON(2)) all_reduce.dump();
87 if (VLOG_IS_ON(2)) all_scatter.dump();
88
89 const Layout original_layout = all_scatter.input_layout();
90 const Layout desired_layout = all_scatter.output_layout();
91
92 // Find all potential scatter dimensions.
93 std::vector<int> scatter_dims;
94 for (int i = 0; i < original_layout.rank(); ++i) {
95 if (original_layout.sharding_spec(i) !=
96 desired_layout.sharding_spec(i)) {
97 scatter_dims.push_back(i);
98 }
99 }
100
101 if (scatter_dims.empty()) return mlir::WalkResult::advance();
102 if (scatter_dims.size() > 1) {
103 VLOG(2) << "Multiple dimensions are scatter. This is unsupported "
104 "for AllReduce+Scatter fusion.";
105 return mlir::WalkResult::advance();
106 }
107
108 int scatter_dim = scatter_dims[0];
109 VLOG(2) << "Scatter_dim: " << scatter_dim;
110
111 // Check that the all-reduce and all-scatter group assignments are the
112 // same.
113 mlir::DenseIntElementsAttr all_reduce_group_assignment_attr;
114 if (!matchPattern(all_reduce.group_assignment(),
115 m_Constant(&all_reduce_group_assignment_attr))) {
116 all_reduce.emitOpError("group_assignment should be a constant");
117 return mlir::WalkResult::interrupt();
118 }
119
120 mlir::DenseIntElementsAttr all_scatter_group_assignment_attr =
121 GetScatterGroupAssignment(all_scatter, scatter_dim);
122
123 VLOG(2) << "All scatter group assignment: ";
124 if (VLOG_IS_ON(2)) all_scatter_group_assignment_attr.dump();
125
126 bool same_group =
127 same_group_assignments(all_reduce_group_assignment_attr,
128 all_scatter_group_assignment_attr);
129
130 if (!same_group) return mlir::WalkResult::advance();
131 VLOG(2) << "Fuse reduce scatter with scatter_dim: " << scatter_dim;
132
133 mlir::OpBuilder builder(all_reduce);
134 auto scatter_dim_const_op = builder.create<mlir::TF::ConstOp>(
135 all_reduce.getLoc(),
136 mlir::DenseIntElementsAttr::get(
137 mlir::RankedTensorType::get({}, builder.getI32Type()),
138 {scatter_dim}));
139
140 auto reduce_scatter = builder.create<mlir::TF::DTensorReduceScatterOp>(
141 all_reduce.getLoc(), all_scatter->getResultTypes(),
142 all_reduce.getOperand(0), all_reduce.group_assignment(),
143 scatter_dim_const_op, all_reduce.reduce_op(),
144 all_reduce.device_type());
145 SetSingleLayoutOnOp(reduce_scatter, desired_layout);
146
147 all_scatter->replaceAllUsesWith(reduce_scatter);
148
149 ops_to_delete.push_back(all_scatter);
150 ops_to_delete.push_back(all_reduce);
151 }
152 }
153 return mlir::WalkResult::advance();
154 });
155
156 for (mlir::Operation* op : ops_to_delete) {
157 op->erase();
158 }
159 return mlir::success();
160}
161
162// MLIR pass that combines AllReduce and AllScatter to ReduceScatter.
163struct DTensorAllReduceScatterOptimization
164 : public impl::DTensorAllReduceScatterOptimizationBase<
165 DTensorAllReduceScatterOptimization> {
166 void runOnOperation() override {
167 mlir::func::FuncOp function = getOperation();
168
169 if (mlir::failed(ApplyOptimization(function))) return signalPassFailure();
170 }
171};
172
173} // namespace
174
175std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
176CreateDTensorAllReduceScatterOptimization() {
177 return std::make_unique<DTensorAllReduceScatterOptimization>();
178}
179
180} // namespace dtensor
181} // namespace tensorflow
182