1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <algorithm> |
17 | #include <memory> |
18 | #include <string> |
19 | #include <vector> |
20 | |
21 | #include "llvm/ADT/DenseMap.h" |
22 | #include "llvm/ADT/StringRef.h" |
23 | #include "llvm/Support/FormatVariadic.h" |
24 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
25 | #include "mlir/IR/BuiltinOps.h" // from @llvm-project |
26 | #include "mlir/IR/Types.h" // from @llvm-project |
27 | #include "mlir/IR/Visitors.h" // from @llvm-project |
28 | #include "mlir/Pass/Pass.h" // from @llvm-project |
29 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
30 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" |
31 | #include "tensorflow/compiler/mlir/tensorflow/transforms/collection_ops_util.h" |
32 | #include "tensorflow/core/platform/str_util.h" |
33 | #include "tensorflow/dtensor/cc/constants.h" |
34 | #include "tensorflow/dtensor/cc/dtensor_utils.h" |
35 | #include "tensorflow/dtensor/mlir/dtensor_location.h" |
36 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
37 | #include "tensorflow/dtensor/mlir/group_assignment.h" |
38 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
39 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
40 | |
41 | namespace tensorflow { |
42 | namespace dtensor { |
43 | |
44 | namespace { |
45 | #define GEN_PASS_DEF_DTENSORALLREDUCECOMBINEOPTIMIZATION |
46 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
47 | |
48 | namespace ops_util = ::mlir::TF::collection_ops_util; |
49 | |
50 | // Pad the merged tensor shape to multiples of 1024B, so delinearization |
51 | // skipping optimization in XLA can get activated. |
52 | constexpr int32 kAllReducePadding = 1024; |
53 | |
54 | // Returns true if `successor` depends on `predecessor`. |
55 | // TODO(jiawenhao): Repeatedly computing dependency sets for a large cluster can |
56 | // get expensive when the number of all-reduces is high. Consider building a |
57 | // cluster-scope op dependency graph ahead of time to amortize the cost. |
58 | bool DependsOn(mlir::Operation* successor, mlir::Operation* predecessor) { |
59 | llvm::SmallVector<mlir::Operation*, 4> to_visit; |
60 | llvm::SmallPtrSet<mlir::Operation*, 4> visited; |
61 | to_visit.push_back(predecessor); |
62 | while (!to_visit.empty()) { |
63 | mlir::Operation* producer = to_visit.pop_back_val(); |
64 | if (visited.contains(producer)) continue; |
65 | visited.insert(producer); |
66 | if (successor == producer) return true; |
67 | for (mlir::Operation* user : producer->getUsers()) { |
68 | if (visited.contains(user)) continue; |
69 | to_visit.push_back(user); |
70 | } |
71 | } |
72 | return false; |
73 | } |
74 | |
75 | // Moves all usages of `a` (direct and transitive) to right after `b` in |
76 | // `cluster`, preserving the original order of moved ops. |
77 | // `a` and `b` must be in `cluster`. `a` must appear before `b` originally. |
78 | // `a` itself is not moved. |
79 | // |
80 | // For example, this program: |
81 | // |
82 | // tf_device.cluster() ({ |
83 | // %a = tf.A() |
84 | // %1 = tf.C(%a) |
85 | // %2 = tf.D(%a) |
86 | // %3 = tf.E(%1, %2) |
87 | // %b = tf.B() |
88 | // %4 = tf.F(%3) |
89 | // %5 = tf.G(%b) |
90 | // tf_device.return() |
91 | // }) |
92 | // |
93 | // will become this: |
94 | // |
95 | // tf_device.cluster() ({ |
96 | // %a = tf.A() |
97 | // %b = tf.B() |
98 | // %1 = tf.C(%a) |
99 | // %2 = tf.D(%a) |
100 | // %3 = tf.E(%1, %2) |
101 | // %4 = tf.F(%3) |
102 | // %5 = tf.G(%b) |
103 | // tf_device.return() |
104 | // }) |
105 | void MoveUsagesAfter(mlir::tf_device::ClusterOp cluster, mlir::Operation* a, |
106 | mlir::Operation* b) { |
107 | llvm::SmallVector<mlir::Operation*, 4> to_visit; |
108 | llvm::SmallPtrSet<mlir::Operation*, 4> visited; |
109 | to_visit.push_back(a); |
110 | while (!to_visit.empty()) { |
111 | mlir::Operation* producer = to_visit.pop_back_val(); |
112 | if (visited.contains(producer)) continue; |
113 | visited.insert(producer); |
114 | for (mlir::Operation* user : producer->getUsers()) { |
115 | if (visited.contains(user)) continue; |
116 | to_visit.push_back(user); |
117 | } |
118 | } |
119 | |
120 | llvm::SmallVector<mlir::Operation*, 4> to_move; |
121 | cluster.GetBody().walk([&](mlir::Operation* op) { |
122 | if (op != a && visited.contains(op) && op->isBeforeInBlock(b)) { |
123 | to_move.push_back(op); |
124 | } |
125 | }); |
126 | |
127 | mlir::Operation* last = b; |
128 | for (mlir::Operation* op : to_move) { |
129 | if (mlir::dyn_cast<mlir::TF::YieldOp>(op)) { |
130 | LOG(FATAL) << "Should never move YieldOp" ; // Crash OK |
131 | } |
132 | op->moveAfter(last); |
133 | last = op; |
134 | } |
135 | } |
136 | |
137 | // Merge all-reduces in the group into one all-reduce. |
138 | // |
139 | // Requirements: |
140 | // - The group should have at least two all-reduces. |
141 | // - They should be located next to each other in the parent block. |
142 | // - They should all have the same element type. |
143 | // - They should all have the same group assignment. |
144 | // |
145 | // The merged all-reduce operates on a 1D tensor, whose size is the sum of all |
146 | // merged all-reduce tensors padded to 1024B. (The padding is necessary for the |
147 | // XLA delinearization skipping logic.) Each to-be-merged all-reduce flattens |
148 | // its input tensor and writes the resulting 1D tensor into the corresponding |
149 | // offset in the merged 1D tensor. After the merged all-reduce is done, the |
150 | // reverse happens: results are sliced out and reshaped to the original shape. |
151 | mlir::LogicalResult MergeAllReduceGroup( |
152 | std::vector<mlir::TF::DTensorAllReduceOp>& all_reduce_group) { |
153 | // Create the initial all-zero merged tensor. |
154 | // The merged tensor's size is the sum of all individual all-reduces' sizes. |
155 | int num_all_reduces = all_reduce_group.size(); |
156 | DCHECK(num_all_reduces > 1) |
157 | << "All reduce group size expected to be greater than 1." ; |
158 | int total_num_elements = 0; |
159 | std::vector<llvm::ArrayRef<int64_t>> all_reduce_shapes; |
160 | all_reduce_shapes.reserve(num_all_reduces); |
161 | for (mlir::TF::DTensorAllReduceOp& all_reduce : all_reduce_group) { |
162 | auto all_reduce_ranked_type = |
163 | all_reduce.getType().dyn_cast<mlir::RankedTensorType>(); |
164 | if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) { |
165 | return all_reduce.emitOpError(llvm::formatv( |
166 | "requires static shape for DTensorAllReduceOp, but got : {0}" , |
167 | all_reduce_ranked_type)); |
168 | } |
169 | int num_elements = all_reduce_ranked_type.getNumElements(); |
170 | total_num_elements += num_elements; |
171 | all_reduce_shapes.push_back(all_reduce_ranked_type.getShape()); |
172 | } |
173 | |
174 | // Pad the merged tensor shape to multiples of 1024B, so delinearization |
175 | // skipping optimization in XLA can get activated. |
176 | if (total_num_elements % kAllReducePadding != 0) { |
177 | total_num_elements = |
178 | total_num_elements / kAllReducePadding * kAllReducePadding + |
179 | kAllReducePadding; |
180 | } |
181 | |
182 | // Fill the merged tensor with 0 initially. |
183 | mlir::OpBuilder builder(all_reduce_group[0]); |
184 | mlir::Location loc = all_reduce_group[0].getLoc(); |
185 | mlir::Type elem_type = all_reduce_group[0].getType().getElementType(); |
186 | auto zero_scalar = ops_util::CreateScalarConst(0, builder, loc); |
187 | auto zero_scalar_elem_type = builder.create<mlir::TF::CastOp>( |
188 | loc, mlir::RankedTensorType::get({}, elem_type), zero_scalar); |
189 | auto merged = builder.create<mlir::TF::FillOp>( |
190 | loc, ops_util::GetR1Const({total_num_elements}, builder, loc), |
191 | zero_scalar_elem_type); |
192 | |
193 | // Store every all-reduce's input at an offset location in the merged tensor, |
194 | // as a 1D tensor. |
195 | int offset_num_elements = 0; |
196 | std::vector<mlir::Type> flattened_types; |
197 | flattened_types.reserve(num_all_reduces); |
198 | mlir::Value updated; |
199 | for (int i = 0; i < all_reduce_group.size(); ++i) { |
200 | mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i]; |
201 | mlir::Location loc = all_reduce.getLoc(); |
202 | auto all_reduce_ranked_type = |
203 | all_reduce.getType().dyn_cast<mlir::RankedTensorType>(); |
204 | if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) { |
205 | return all_reduce.emitOpError(llvm::formatv( |
206 | "requires static shape for DTensorAllReduceOp, but got : {0}" , |
207 | all_reduce_ranked_type)); |
208 | } |
209 | |
210 | int num_elements = all_reduce_ranked_type.getNumElements(); |
211 | auto flattened = builder.create<mlir::TF::ReshapeOp>( |
212 | DT_LOC2(loc, "CombinedReduceFlatten" ), all_reduce.input(), |
213 | ops_util::GetR1Const({num_elements}, builder, loc)); |
214 | flattened_types.push_back(flattened.getType()); |
215 | auto indices = ops_util::GetR1Const({offset_num_elements}, builder, loc); |
216 | |
217 | if (all_reduce.device_type().contains("TPU" )) { |
218 | updated = builder.create<mlir::TF::XlaDynamicUpdateSliceOp>( |
219 | DT_LOC2(loc, "CombinedReduceUpdateSlice" ), merged.getType(), |
220 | /*input=*/i == 0 ? merged.getResult() : updated, |
221 | /*update=*/flattened, indices); |
222 | } else { |
223 | auto end = ops_util::GetR1Const({offset_num_elements + num_elements}, |
224 | builder, loc); |
225 | auto strides = ops_util::GetR1Const({1}, builder, loc); |
226 | updated = builder.create<mlir::TF::TensorStridedSliceUpdateOp>( |
227 | DT_LOC2(loc, "CombinedReduceUpdateSlice" ), merged.getType(), |
228 | /*input=*/i == 0 ? merged.getResult() : updated, indices, end, |
229 | strides, |
230 | /*value=*/flattened); |
231 | } |
232 | offset_num_elements += num_elements; |
233 | } |
234 | |
235 | // All-reduce the updated merged tensor. |
236 | auto merged_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>( |
237 | all_reduce_group[0].getLoc(), updated.getType(), updated, |
238 | all_reduce_group[0].group_assignment(), all_reduce_group[0].reduce_op(), |
239 | all_reduce_group[0].device_type()); |
240 | SetSingleLayoutOnOp( |
241 | merged_all_reduce, |
242 | ExtractSingleLayoutFromOp(all_reduce_group[0]).value().value()); |
243 | |
244 | // Slice out the original all-reduces, and reshape back to the original shape. |
245 | offset_num_elements = 0; |
246 | std::vector<mlir::TF::ReshapeOp> replacements; |
247 | replacements.reserve(num_all_reduces); |
248 | for (int i = 0; i < all_reduce_group.size(); ++i) { |
249 | mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i]; |
250 | mlir::Location loc = all_reduce.getLoc(); |
251 | auto all_reduce_ranked_type = |
252 | all_reduce.getType().dyn_cast<mlir::RankedTensorType>(); |
253 | if (!all_reduce_ranked_type || !all_reduce_ranked_type.hasStaticShape()) { |
254 | return all_reduce.emitOpError(llvm::formatv( |
255 | "requires static shape for DTensorAllReduceOp, but got : {0}" , |
256 | all_reduce_ranked_type)); |
257 | } |
258 | int num_elements = all_reduce_ranked_type.getNumElements(); |
259 | auto slice = builder.create<mlir::TF::SliceOp>( |
260 | DT_LOC2(loc, "PostCombinedReduceSlice" ), flattened_types[i], |
261 | /*input=*/merged_all_reduce, |
262 | /*begin=*/ops_util::GetR1Const({offset_num_elements}, builder, loc), |
263 | /*size=*/ops_util::GetR1Const({num_elements}, builder, loc)); |
264 | auto replacement = builder.create<mlir::TF::ReshapeOp>( |
265 | DT_LOC2(loc, "PostCombinedReduceReshape" ), slice.getResult(), |
266 | ops_util::GetR1Const(all_reduce_shapes[i], builder, loc)); |
267 | replacements.push_back(replacement); |
268 | offset_num_elements += num_elements; |
269 | } |
270 | |
271 | // Replace usages and clean up. |
272 | for (int i = 0; i < all_reduce_group.size(); ++i) { |
273 | mlir::TF::DTensorAllReduceOp& all_reduce = all_reduce_group[i]; |
274 | mlir::TF::ReshapeOp& replacement = replacements[i]; |
275 | all_reduce.replaceAllUsesWith(replacement.getResult()); |
276 | all_reduce.erase(); |
277 | } |
278 | return mlir::success(); |
279 | } |
280 | |
281 | // Dump the dependencies between AllReduce ops as a DOT graph. |
282 | std::string DrawAllReduceDependencies( |
283 | std::vector<mlir::TF::DTensorAllReduceOp> all_reduces) { |
284 | std::vector<std::vector<int>> dependents(all_reduces.size(), |
285 | std::vector<int>()); |
286 | for (int j = 0; j < all_reduces.size(); ++j) { |
287 | mlir::TF::DTensorAllReduceOp later = all_reduces[j]; |
288 | for (int i = 0; i < j; ++i) { |
289 | mlir::TF::DTensorAllReduceOp earlier = all_reduces[i]; |
290 | DCHECK(!DependsOn(earlier, later)); |
291 | if (earlier->getBlock() != later->getBlock() || |
292 | DependsOn(later, earlier)) { |
293 | dependents[i].push_back(j); |
294 | } |
295 | } |
296 | } |
297 | std::string output = "digraph all_reduces {\n" ; |
298 | for (int i = 0; i < dependents.size(); i++) { |
299 | strings::StrAppend(&output, i); |
300 | strings::StrAppend(&output, "\n" ); |
301 | } |
302 | for (int i = 0; i < dependents.size(); i++) { |
303 | for (int j : dependents[i]) { |
304 | strings::StrAppend(&output, i, " -> " , j, "\n" ); |
305 | } |
306 | } |
307 | output += "}" ; |
308 | return output; |
309 | } |
310 | |
311 | // Combine cross-slice DTensorAllReduce ops of the same element type and group |
312 | // assignment into as few groups as possible. Only independent ops can be |
313 | // combined together. |
314 | // |
315 | // For example, this program: |
316 | // |
317 | // clang-format off |
318 | // NOLINTBEGIN(whitespace/line_length) |
319 | // %0 = "tf_device.cluster"() ({ |
320 | // %1 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> |
321 | // %2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> |
322 | // %3 = "tf.DTensorAllReduce"(%1, %2) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> |
323 | // %4 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> |
324 | // %5 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> |
325 | // %6 = "tf.DTensorAllReduce"(%4, %5) {reduce_op = "Add"} : (tensor<4x4xf32>, tensor<2x2xi32>) -> tensor<4x4xf32> |
326 | // %7 = "tf.Add"(%3, %6) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> |
327 | // "tf_device.return"(%7) : (tensor<4x4xf32>) -> () |
328 | // }) : () -> tensor<4x4xf32> |
329 | // NOLINTEND |
330 | // clang-format on |
331 | // |
332 | // will become this: |
333 | // |
334 | // clang-format off |
335 | // NOLINTBEGIN(whitespace/line_length) |
336 | // %0 = "tf_device.cluster"() ( { |
337 | // %cst = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> |
338 | // %cst_0 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> |
339 | // %cst_1 = "tf.Const"() {value = dense<0.000000e+00> : tensor<4x4xf32>} : () -> tensor<4x4xf32> |
340 | // %cst_2 = "tf.Const"() {value = dense<[[0, 1], [2, 3]]> : tensor<2x2xi32>} : () -> tensor<2x2xi32> |
341 | // %cst_3 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32> |
342 | // %1 = "tf.Cast"(%cst_3) {Truncate = false} : (tensor<i32>) -> tensor<f32> |
343 | // %cst_4 = "tf.Const"() {value = dense<1024> : tensor<1xi32>} : () -> tensor<1xi32> |
344 | // %2 = "tf.Fill"(%cst_4, %1) : (tensor<1xi32>, tensor<f32>) -> tensor<1024xf32> |
345 | // %cst_5 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32> |
346 | // %3 = "tf.Reshape"(%cst, %cst_5) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32> |
347 | // %cst_6 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> |
348 | // %4 = "tf.XlaDynamicUpdateSlice"(%2, %3, %cst_6) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32> |
349 | // %cst_7 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32> |
350 | // %5 = "tf.Reshape"(%cst_1, %cst_7) : (tensor<4x4xf32>, tensor<1xi32>) -> tensor<16xf32> |
351 | // %cst_8 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32> |
352 | // %6 = "tf.XlaDynamicUpdateSlice"(%4, %5, %cst_8) : (tensor<1024xf32>, tensor<16xf32>, tensor<1xi32>) -> tensor<1024xf32> |
353 | // %7 = "tf.DTensorAllReduce"(%6, %cst_0) {reduce_op = "Add"} : (tensor<1024xf32>, tensor<2x2xi32>) -> tensor<1024xf32> |
354 | // %cst_9 = "tf.Const"() {value = dense<0> : tensor<1xi32>} : () -> tensor<1xi32> |
355 | // %cst_10 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32> |
356 | // %8 = "tf.Slice"(%7, %cst_9, %cst_10) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32> |
357 | // %cst_11 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32> |
358 | // %9 = "tf.Reshape"(%8, %cst_11) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32> |
359 | // %cst_12 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32> |
360 | // %cst_13 = "tf.Const"() {value = dense<16> : tensor<1xi32>} : () -> tensor<1xi32> |
361 | // %10 = "tf.Slice"(%7, %cst_12, %cst_13) : (tensor<1024xf32>, tensor<1xi32>, tensor<1xi32>) -> tensor<16xf32> |
362 | // %cst_14 = "tf.Const"() {value = dense<4> : tensor<2xi32>} : () -> tensor<2xi32> |
363 | // %11 = "tf.Reshape"(%10, %cst_14) : (tensor<16xf32>, tensor<2xi32>) -> tensor<4x4xf32> |
364 | // %12 = "tf.Add"(%9, %11) : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32> |
365 | // tf_device.return %12 : tensor<4x4xf32> |
366 | // }) : () -> tensor<4x4xf32> |
367 | // NOLINTEND |
368 | // clang-format on |
369 | mlir::LogicalResult CombineAllReduceOpsOfSameTypeAndGroupAssignment( |
370 | mlir::tf_device::ClusterOp cluster, |
371 | const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) { |
372 | // Drop within-slice all-reduces. |
373 | std::vector<mlir::TF::DTensorAllReduceOp> cross_slice_all_reduces; |
374 | for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) { |
375 | mlir::DenseIntElementsAttr group_assignment_attr; |
376 | if (!matchPattern(all_reduce.group_assignment(), |
377 | m_Constant(&group_assignment_attr))) { |
378 | return all_reduce.emitOpError("group_assignment should be a constant" ); |
379 | } |
380 | // LINT.IfChange |
381 | // TODO(ishark): Confirm the right check for GPUs. |
382 | int num_slices = NumClients(); |
383 | int slice_size = kTpuDonutSize; |
384 | if (group_assignment_attr.getNumElements() < kTpuDonutSize) { |
385 | DCHECK_EQ(num_slices, 1) << "Num slices expected to be equal to 1." ; |
386 | slice_size = group_assignment_attr.getNumElements(); |
387 | } |
388 | StatusOr<GroupAssignment> group_assignment = GroupAssignment::FromMLIR( |
389 | group_assignment_attr, |
390 | GroupAssignment::ReplicaToDeviceMap::DefaultReplicaToDeviceMap( |
391 | num_slices, slice_size)); |
392 | // LINT.ThenChange(//tensorflow/dtensor/mlir/utils/collective_lowering_google.inc) |
393 | if (!group_assignment.ok()) { |
394 | return all_reduce.emitOpError( |
395 | llvm::formatv("Failed to create a GroupAssignment due to {0}" , |
396 | group_assignment.status().error_message())); |
397 | } |
398 | // Unit tests have only one slice. Always combine all all-reduces in them. |
399 | if (group_assignment->num_slices() == 1 || |
400 | !group_assignment->IsWithinSlices()) { |
401 | cross_slice_all_reduces.push_back(all_reduce); |
402 | } |
403 | } |
404 | |
405 | // A single op has nothing to combine with. |
406 | int num_all_reduces = cross_slice_all_reduces.size(); |
407 | if (num_all_reduces <= 1) return mlir::success(); |
408 | |
409 | // Export the all reduces as a DOT graph. |
410 | VLOG(4) << "Visualizing AllReduce dependencies:\n" |
411 | << DrawAllReduceDependencies(cross_slice_all_reduces); |
412 | |
413 | // Build a reverse adjacency matrix from dependents to requirements. |
414 | std::vector<std::vector<int>> requirements(num_all_reduces, |
415 | std::vector<int>()); |
416 | for (int i = 0; i < num_all_reduces - 1; ++i) { |
417 | mlir::TF::DTensorAllReduceOp requirement = cross_slice_all_reduces[i]; |
418 | for (int j = i + 1; j < num_all_reduces; ++j) { |
419 | mlir::TF::DTensorAllReduceOp dependent = cross_slice_all_reduces[j]; |
420 | DCHECK( |
421 | !DependsOn(requirement, dependent)); // guaranteed by program order |
422 | // In this example, all three DTensorAllReduce ops are independent from |
423 | // each other according to MLIR value use-def chains considered by |
424 | // DependsOn. However, moving all three to after the WhileRegion and |
425 | // combine them would break the program. |
426 | // |
427 | // %3 = tf.DTensorAllReduce(%1, %2) |
428 | // %4 = tf.WhileRegion(%1) ({ |
429 | // ^bb0(%arg): |
430 | // %5 = tf.TooBool(%arg) |
431 | // tf.Yield(%5) |
432 | // }, { |
433 | // %6 = tf.DTensorAllReduce(%1, %2) |
434 | // tf.Yield(%5) |
435 | // }) |
436 | // %7 = tf.DTensorAllReduce(%1, %2) |
437 | // |
438 | // Therefore, in addition to DependsOn, we also check if two |
439 | // DTensorAllReduceOps belong to different blocks. If they do, since they |
440 | // exist in the same ClusterOp, one or both of them must be inside a |
441 | // control flow region block. We treat them as if there is a dependency |
442 | // between them. |
443 | // |
444 | // In the example above, the second DTensorAllReduceOp would "depend on" |
445 | // the first one, and the third on the second. This effectively prevents |
446 | // any two DTensorAllReduce from merging together. |
447 | if (requirement->getBlock() != dependent->getBlock() || |
448 | DependsOn(dependent, requirement)) { |
449 | requirements[j].push_back(i); |
450 | } |
451 | } |
452 | } |
453 | |
454 | // Traverse the adjacency matrix layer by layer to find combination groups. |
455 | std::vector<std::vector<mlir::TF::DTensorAllReduceOp>> all_reduce_groups; |
456 | std::set<int> fulfilled; |
457 | while (fulfilled.size() < cross_slice_all_reduces.size()) { |
458 | std::vector<int> fulfilled_this_layer; |
459 | for (int j = 0; j < requirements.size(); ++j) { |
460 | if (fulfilled.count(j) > 0) continue; |
461 | bool requirements_met = true; |
462 | for (int i : requirements[j]) { |
463 | if (fulfilled.count(i) == 0) { |
464 | requirements_met = false; |
465 | break; |
466 | } |
467 | } |
468 | if (requirements_met) { |
469 | fulfilled_this_layer.push_back(j); |
470 | } |
471 | } |
472 | VLOG(4) << "Fulfilled: " << str_util::Join(fulfilled_this_layer, ", " ); |
473 | all_reduce_groups.push_back({}); |
474 | for (int i : fulfilled_this_layer) { |
475 | fulfilled.insert(i); |
476 | all_reduce_groups.back().push_back(cross_slice_all_reduces[i]); |
477 | } |
478 | } |
479 | VLOG(2) << num_all_reduces << " all-reduce ops in " |
480 | << all_reduce_groups.size() << " groups" ; |
481 | |
482 | // Move all-reduces in the same group together and combine them. |
483 | for (auto& all_reduce_group : all_reduce_groups) { |
484 | int num_all_reduces = all_reduce_group.size(); |
485 | if (num_all_reduces <= 1) continue; |
486 | mlir::TF::DTensorAllReduceOp final_all_reduce = |
487 | all_reduce_group[num_all_reduces - 1]; |
488 | for (int i = num_all_reduces - 2; i >= 0; --i) { |
489 | mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i]; |
490 | MoveUsagesAfter(cluster, all_reduce, final_all_reduce); |
491 | } |
492 | for (int i = 0; i < num_all_reduces - 1; ++i) { |
493 | mlir::TF::DTensorAllReduceOp all_reduce = all_reduce_group[i]; |
494 | all_reduce->moveBefore(final_all_reduce); |
495 | } |
496 | auto merge_result = MergeAllReduceGroup(all_reduce_group); |
497 | if (merge_result.failed()) return merge_result; |
498 | } |
499 | |
500 | return mlir::success(); |
501 | } |
502 | |
503 | // Returns true if both group assignments are constant and equal. |
504 | bool same_group_assignments(mlir::Value group_assignment_a, |
505 | mlir::Value group_assignment_b) { |
506 | if (group_assignment_a == group_assignment_b) { |
507 | return true; |
508 | } |
509 | mlir::DenseIntElementsAttr attr_a; |
510 | if (!matchPattern(group_assignment_a, m_Constant(&attr_a))) { |
511 | return false; |
512 | } |
513 | mlir::DenseIntElementsAttr attr_b; |
514 | if (!matchPattern(group_assignment_b, m_Constant(&attr_b))) { |
515 | return false; |
516 | } |
517 | if (attr_a.getType().getShape() != attr_b.getType().getShape()) { |
518 | return false; |
519 | } |
520 | return std::equal(attr_a.begin(), attr_a.end(), attr_b.begin(), attr_b.end()); |
521 | } |
522 | |
523 | // Combines DTensorAllReduce ops of the same element type into as few groups as |
524 | // possible. Only ops with the same group assignment can be combined together. |
525 | mlir::LogicalResult CombineAllReduceOpsOfSameType( |
526 | mlir::tf_device::ClusterOp cluster, |
527 | const std::vector<mlir::TF::DTensorAllReduceOp>& all_reduces) { |
528 | // Maintain a list of seen group assignments, sorted by first appearance. |
529 | // Also find and store all-reduces by group assignment. Use the first |
530 | // mlir::Value that contains a certain group assignment to represent all the |
531 | // same group assignments. |
532 | std::vector<mlir::Value> group_assignments; |
533 | llvm::DenseMap<mlir::Value, std::vector<mlir::TF::DTensorAllReduceOp>> |
534 | all_reduces_by_group_assignment; |
535 | for (mlir::TF::DTensorAllReduceOp all_reduce : all_reduces) { |
536 | mlir::Value group_assignment = all_reduce.group_assignment(); |
537 | bool seen = false; |
538 | for (mlir::Value seen_group_assignment : group_assignments) { |
539 | if (same_group_assignments(group_assignment, seen_group_assignment)) { |
540 | group_assignment = seen_group_assignment; |
541 | seen = true; |
542 | break; |
543 | } |
544 | } |
545 | if (!seen) group_assignments.push_back(group_assignment); |
546 | all_reduces_by_group_assignment[group_assignment].push_back(all_reduce); |
547 | } |
548 | |
549 | // Combine all-reduces of the same group assignment in first-appearance order. |
550 | for (mlir::Value group_assignment : group_assignments) { |
551 | mlir::LogicalResult result = |
552 | CombineAllReduceOpsOfSameTypeAndGroupAssignment( |
553 | cluster, all_reduces_by_group_assignment[group_assignment]); |
554 | if (mlir::failed(result)) return result; |
555 | } |
556 | |
557 | return mlir::success(); |
558 | } |
559 | |
560 | struct DTensorAllReduceCombineOptimization |
561 | : public impl::DTensorAllReduceCombineOptimizationBase< |
562 | DTensorAllReduceCombineOptimization> { |
563 | void runOnOperation() override { |
564 | mlir::func::FuncOp function = getOperation(); |
565 | function.walk([&](mlir::tf_device::ClusterOp cluster) { |
566 | // Maintain a list of seen element types, sorted by first appearance. |
567 | // Also find and store all-reduces by element type. |
568 | std::vector<mlir::Type> elem_types; |
569 | llvm::DenseMap<mlir::Type, std::vector<mlir::TF::DTensorAllReduceOp>> |
570 | all_reduces_by_elem_type; |
571 | cluster.GetBody().walk([&](mlir::TF::DTensorAllReduceOp all_reduce) { |
572 | mlir::Type elem_type = all_reduce.getType().getElementType(); |
573 | if (std::find(elem_types.begin(), elem_types.end(), elem_type) == |
574 | elem_types.end()) { |
575 | elem_types.push_back(elem_type); |
576 | } |
577 | all_reduces_by_elem_type[elem_type].push_back(all_reduce); |
578 | }); |
579 | |
580 | // Combine all-reduces of the same element type in first-appearance order. |
581 | for (mlir::Type elem_type : elem_types) { |
582 | // Combine all-reduces for the same attribute reduce_op for the element |
583 | // type. |
584 | auto& all_reduces_for_elem_type = all_reduces_by_elem_type[elem_type]; |
585 | llvm::DenseMap<llvm::StringRef, |
586 | std::vector<mlir::TF::DTensorAllReduceOp>> |
587 | all_reduces_by_attr_reduce_op; |
588 | for (mlir::TF::DTensorAllReduceOp all_reduce : |
589 | all_reduces_for_elem_type) { |
590 | llvm::StringRef attr_reduce_op = all_reduce.reduce_op(); |
591 | all_reduces_by_attr_reduce_op[attr_reduce_op].push_back(all_reduce); |
592 | } |
593 | for (auto& all_reduces_to_merge : all_reduces_by_attr_reduce_op) { |
594 | if (mlir::failed(CombineAllReduceOpsOfSameType( |
595 | cluster, all_reduces_to_merge.second))) { |
596 | return signalPassFailure(); |
597 | } |
598 | } |
599 | } |
600 | }); |
601 | } |
602 | }; |
603 | |
604 | } // namespace |
605 | |
606 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
607 | CreateDTensorAllReduceCombineOptimization() { |
608 | return std::make_unique<DTensorAllReduceCombineOptimization>(); |
609 | } |
610 | |
611 | } // namespace dtensor |
612 | } // namespace tensorflow |
613 | |