1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "llvm/ADT/DenseMap.h" |
17 | #include "llvm/ADT/SmallPtrSet.h" |
18 | #include "llvm/Support/FormatVariadic.h" |
19 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
20 | #include "mlir/IR/Builders.h" // from @llvm-project |
21 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
22 | #include "mlir/IR/Operation.h" // from @llvm-project |
23 | #include "mlir/IR/UseDefLists.h" // from @llvm-project |
24 | #include "mlir/IR/Value.h" // from @llvm-project |
25 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
26 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
27 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
28 | #include "tensorflow/dtensor/mlir/collectives_common.h" |
29 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
30 | #include "tensorflow/dtensor/mlir/group_assignment.h" |
31 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
32 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
33 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
34 | |
35 | namespace tensorflow { |
36 | namespace dtensor { |
37 | |
38 | namespace { |
39 | #define GEN_PASS_DEF_DTENSORALLREDUCESCATTEROPTIMIZATION |
40 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
41 | |
42 | // Returns true if both group assignments are constant and equal. |
43 | bool same_group_assignments(mlir::DenseIntElementsAttr attr_a, |
44 | mlir::DenseIntElementsAttr attr_b) { |
45 | if (attr_a.getType().getShape() != attr_b.getType().getShape()) { |
46 | return false; |
47 | } |
48 | return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end()); |
49 | } |
50 | |
51 | mlir::DenseIntElementsAttr GetScatterGroupAssignment( |
52 | mlir::TF::DTensorAllScatterOp all_scatter, int scatter_dim) { |
53 | const Layout original_layout = all_scatter.input_layout(); |
54 | const Layout desired_layout = all_scatter.output_layout(); |
55 | absl::flat_hash_set<std::string> scattered_dims; |
56 | scattered_dims.insert(desired_layout.sharding_spec(scatter_dim)); |
57 | |
58 | auto partitions = |
59 | GetAllReducePartitionsFromReducedDims(original_layout, scattered_dims) |
60 | .value(); |
61 | const int32 num_partitions = partitions.size(); |
62 | |
63 | // Construct a flattened list of scatter partitions. |
64 | std::vector<int32> partitions_flat; |
65 | for (auto& p : partitions) { |
66 | partitions_flat.insert(partitions_flat.end(), p.second.begin(), |
67 | p.second.end()); |
68 | } |
69 | |
70 | int32 partition_size = partitions.begin()->second.size(); |
71 | mlir::OpBuilder builder(all_scatter); |
72 | auto group_shaped_type = mlir::RankedTensorType::get( |
73 | {num_partitions, partition_size}, |
74 | mlir::IntegerType::get(builder.getContext(), 32)); |
75 | |
76 | return mlir::DenseIntElementsAttr::get(group_shaped_type, partitions_flat); |
77 | } |
78 | |
79 | mlir::LogicalResult ApplyOptimization(mlir::func::FuncOp function) { |
80 | std::vector<mlir::Operation*> ops_to_delete; |
81 | function.walk([&](mlir::TF::DTensorAllReduceOp all_reduce) { |
82 | if (all_reduce->hasOneUse()) { |
83 | if (auto all_scatter = mlir::dyn_cast<mlir::TF::DTensorAllScatterOp>( |
84 | *all_reduce->getUsers().begin())) { |
85 | VLOG(2) << "Found potential AllReduce+AllScatter to fuse." ; |
86 | if (VLOG_IS_ON(2)) all_reduce.dump(); |
87 | if (VLOG_IS_ON(2)) all_scatter.dump(); |
88 | |
89 | const Layout original_layout = all_scatter.input_layout(); |
90 | const Layout desired_layout = all_scatter.output_layout(); |
91 | |
92 | // Find all potential scatter dimensions. |
93 | std::vector<int> scatter_dims; |
94 | for (int i = 0; i < original_layout.rank(); ++i) { |
95 | if (original_layout.sharding_spec(i) != |
96 | desired_layout.sharding_spec(i)) { |
97 | scatter_dims.push_back(i); |
98 | } |
99 | } |
100 | |
101 | if (scatter_dims.empty()) return mlir::WalkResult::advance(); |
102 | if (scatter_dims.size() > 1) { |
103 | VLOG(2) << "Multiple dimensions are scatter. This is unsupported " |
104 | "for AllReduce+Scatter fusion." ; |
105 | return mlir::WalkResult::advance(); |
106 | } |
107 | |
108 | int scatter_dim = scatter_dims[0]; |
109 | VLOG(2) << "Scatter_dim: " << scatter_dim; |
110 | |
111 | // Check that the all-reduce and all-scatter group assignments are the |
112 | // same. |
113 | mlir::DenseIntElementsAttr all_reduce_group_assignment_attr; |
114 | if (!matchPattern(all_reduce.group_assignment(), |
115 | m_Constant(&all_reduce_group_assignment_attr))) { |
116 | all_reduce.emitOpError("group_assignment should be a constant" ); |
117 | return mlir::WalkResult::interrupt(); |
118 | } |
119 | |
120 | mlir::DenseIntElementsAttr all_scatter_group_assignment_attr = |
121 | GetScatterGroupAssignment(all_scatter, scatter_dim); |
122 | |
123 | VLOG(2) << "All scatter group assignment: " ; |
124 | if (VLOG_IS_ON(2)) all_scatter_group_assignment_attr.dump(); |
125 | |
126 | bool same_group = |
127 | same_group_assignments(all_reduce_group_assignment_attr, |
128 | all_scatter_group_assignment_attr); |
129 | |
130 | if (!same_group) return mlir::WalkResult::advance(); |
131 | VLOG(2) << "Fuse reduce scatter with scatter_dim: " << scatter_dim; |
132 | |
133 | mlir::OpBuilder builder(all_reduce); |
134 | auto scatter_dim_const_op = builder.create<mlir::TF::ConstOp>( |
135 | all_reduce.getLoc(), |
136 | mlir::DenseIntElementsAttr::get( |
137 | mlir::RankedTensorType::get({}, builder.getI32Type()), |
138 | {scatter_dim})); |
139 | |
140 | auto reduce_scatter = builder.create<mlir::TF::DTensorReduceScatterOp>( |
141 | all_reduce.getLoc(), all_scatter->getResultTypes(), |
142 | all_reduce.getOperand(0), all_reduce.group_assignment(), |
143 | scatter_dim_const_op, all_reduce.reduce_op(), |
144 | all_reduce.device_type()); |
145 | SetSingleLayoutOnOp(reduce_scatter, desired_layout); |
146 | |
147 | all_scatter->replaceAllUsesWith(reduce_scatter); |
148 | |
149 | ops_to_delete.push_back(all_scatter); |
150 | ops_to_delete.push_back(all_reduce); |
151 | } |
152 | } |
153 | return mlir::WalkResult::advance(); |
154 | }); |
155 | |
156 | for (mlir::Operation* op : ops_to_delete) { |
157 | op->erase(); |
158 | } |
159 | return mlir::success(); |
160 | } |
161 | |
162 | // MLIR pass that combines AllReduce and AllScatter to ReduceScatter. |
163 | struct DTensorAllReduceScatterOptimization |
164 | : public impl::DTensorAllReduceScatterOptimizationBase< |
165 | DTensorAllReduceScatterOptimization> { |
166 | void runOnOperation() override { |
167 | mlir::func::FuncOp function = getOperation(); |
168 | |
169 | if (mlir::failed(ApplyOptimization(function))) return signalPassFailure(); |
170 | } |
171 | }; |
172 | |
173 | } // namespace |
174 | |
175 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
176 | CreateDTensorAllReduceScatterOptimization() { |
177 | return std::make_unique<DTensorAllReduceScatterOptimization>(); |
178 | } |
179 | |
180 | } // namespace dtensor |
181 | } // namespace tensorflow |
182 | |