1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <string>
17
18#include "llvm/ADT/DenseMap.h"
19#include "llvm/ADT/SmallPtrSet.h"
20#include "llvm/Support/FormatVariadic.h"
21#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
22#include "mlir/IR/Builders.h" // from @llvm-project
23#include "mlir/IR/BuiltinAttributes.h" // from @llvm-project
24#include "mlir/IR/Operation.h" // from @llvm-project
25#include "mlir/IR/UseDefLists.h" // from @llvm-project
26#include "mlir/IR/Value.h" // from @llvm-project
27#include "mlir/Support/LogicalResult.h" // from @llvm-project
28#include "mlir/Transforms/Passes.h" // from @llvm-project
29#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
30#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
31#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
32#include "tensorflow/dtensor/mlir/layout_parsing.h"
33#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
34
35namespace tensorflow {
36namespace dtensor {
37
38namespace {
39#define GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION
40#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
41
42constexpr int kMaxIteration = 10;
43
44mlir::Value GetIdentitySkippedInputs(mlir::Value val) {
45 mlir::Value input = val;
46 while (auto identity = llvm::dyn_cast_or_null<mlir::TF::IdentityOp>(
47 input.getDefiningOp())) {
48 input = identity.input();
49 }
50 return input;
51}
52
53bool IsZeroConstant(mlir::Value val) {
54 auto const_input = llvm::dyn_cast_or_null<mlir::TF::ConstOp>(
55 GetIdentitySkippedInputs(val).getDefiningOp());
56 if (!const_input) return false;
57 mlir::DenseFPElementsAttr attr =
58 const_input.value().dyn_cast<mlir::DenseFPElementsAttr>();
59 // This uses the fact that constant Attrs becomes splats, so we only need to
60 // check one value.
61 if (!attr || !attr.isSplat()) return false;
62 return attr.getSplatValue<mlir::FloatAttr>().getValue().isZero();
63}
64
65// Extracts inputs/ops required for optimization and checks whether graph
66// meets the criteria for reduction + sum optimization. The criterion are:
67// a) All DTensorAllReduce operations must be sum operations.
68// b) Group assignment of DTensorAllReduceOp must be the same
69// c) All operands of Add op must be DTensorAllReduce operations.
70mlir::LogicalResult CheckReduceAndSumOptimizationCriteria(
71 mlir::Operation* add_op,
72 llvm::SmallVectorImpl<mlir::Value>* reduction_inputs,
73 llvm::SmallVectorImpl<mlir::TF::DTensorAllReduceOp>* reduction_ops,
74 bool* can_be_reordered) {
75 for (mlir::Value operand : add_op->getOperands()) {
76 if (IsZeroConstant(operand)) {
77 reduction_inputs->emplace_back(operand);
78 continue;
79 }
80
81 auto reduction_op = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
82 operand.getDefiningOp());
83 if (!reduction_op) {
84 *can_be_reordered = false;
85 return mlir::success();
86 }
87
88 reduction_ops->emplace_back(reduction_op);
89 }
90
91 llvm::SmallDenseSet<mlir::Attribute> reduction_group_assignments;
92 for (mlir::TF::DTensorAllReduceOp reduction : *reduction_ops) {
93 if (reduction.reduce_op().str() != kReduceOpAdd) {
94 *can_be_reordered = false;
95 return mlir::success();
96 }
97
98 mlir::DenseIntElementsAttr group_assignment;
99 if (!matchPattern(reduction.group_assignment(),
100 m_Constant(&group_assignment))) {
101 *can_be_reordered = false;
102 return mlir::success();
103 }
104
105 reduction_group_assignments.insert(group_assignment);
106 reduction_inputs->emplace_back(reduction.input());
107 }
108
109 *can_be_reordered = (reduction_group_assignments.size() == 1);
110 return mlir::success();
111}
112
113// Applies optimization that reorders AllReduce + Add operations.
114// For example:
115// %3 = DTensorAllReduce(%0)
116// %4 = DTensorAllReduce(%1)
117// %5 = Add(%3, %4)
118//
119// Is transformed to:
120// %2 = Add(%0, %1)
121// %3 = DTensorAllReduce(%2)
122//
123// Therefore reducing the number of Reduction/cross device communication.
124mlir::LogicalResult OptimizeAllReduceAndSum(mlir::Operation* op,
125 bool* changed) {
126 bool can_be_reordered;
127 llvm::SmallVector<mlir::TF::DTensorAllReduceOp, 4> reduction_ops;
128 llvm::SmallVector<mlir::Value, 4> reduction_op_inputs;
129 if (mlir::failed(CheckReduceAndSumOptimizationCriteria(
130 op, &reduction_op_inputs, &reduction_ops, &can_be_reordered)))
131 return mlir::failure();
132
133 if (!can_be_reordered || reduction_ops.empty()) return mlir::success();
134
135 // Forward the inputs from the DTensorAllReduce to the add op. Calling
136 // getOperand(i).getDefiningOp() since CheckReduceAndSumOptimizationCriteria
137 // checks that each input is fed by a DTensorAllReduce or a Zero constant.
138 for (int i = 0; i < op->getNumOperands(); ++i) {
139 if (mlir::isa<mlir::TF::DTensorAllReduceOp>(
140 op->getOperand(i).getDefiningOp()))
141 op->setOperand(i, op->getOperand(i).getDefiningOp()->getOperand(0));
142 }
143
144 mlir::TF::DTensorAllReduceOp first_reduction_op = reduction_ops.front();
145
146 // Invoke reduction operation on locally added tensor once.
147 // From above check `CheckOptimizationCriteria()`, we know that all reduction
148 // operations that are fused reused the same group assignment value.
149 // 1) Get mlir::Value that represents group assignment used for reduction.
150 mlir::Value group_assignment = first_reduction_op.group_assignment();
151
152 // Create a singe reduction operation that reduces the result of the locally
153 // added tensor.
154 mlir::OpBuilder builder(op);
155 builder.setInsertionPointAfterValue(op->getResult(0));
156 mlir::TF::DTensorAllReduceOp all_reduce =
157 builder.create<mlir::TF::DTensorAllReduceOp>(
158 op->getLoc(), op->getResult(0).getType(), op->getResult(0),
159 group_assignment, builder.getStringAttr(std::string(kReduceOpAdd)),
160 builder.getStringAttr(first_reduction_op.device_type()));
161
162 const auto layout_or_status = ExtractSingleLayoutFromOp(first_reduction_op);
163 if (!layout_or_status.ok())
164 return first_reduction_op->emitOpError(llvm::formatv(
165 "Malformed layout specification for DTensorAllReduce op found: {0}",
166 layout_or_status.status().error_message()));
167
168 if (!layout_or_status->has_value())
169 return first_reduction_op->emitOpError(
170 "DTensorAllReduce op must have layout specification.");
171
172 // Set target layout that is equivalent to original DTensorReduction op in
173 // the graph. This is used during later optimization passes.
174 SetSingleLayoutOnOp(all_reduce, layout_or_status->value());
175
176 // Replace usages of original tf.Add op with newly created output of
177 // `all_reduce`.
178 op->getResult(0).replaceAllUsesExcept(
179 all_reduce.output(),
180 llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce.getOperation()});
181
182 // TODO(hongjunchoi, bfontain): Consider adding optimization for the case when
183 // `tree` of Add operations with DTensorAllReduce op as inputs exists.
184 // Remove original tf.Add `op` and if reduction operation inputs to original
185 // `op` is only used by the `op`, then remove the DTensorAllReduce op as well.
186 for (mlir::Operation* original_reduction_op : reduction_ops) {
187 if (original_reduction_op->use_empty()) original_reduction_op->erase();
188 }
189
190 *changed = true;
191 return mlir::success();
192}
193
194mlir::Value SkipIdentityLikeOpsOutputs(mlir::Value val) {
195 while (val.hasOneUse() &&
196 llvm::isa<mlir::TF::CastOp, mlir::TF::ReshapeOp, mlir::TF::IdentityOp>(
197 *val.user_begin())) {
198 val = val.user_begin()->getResult(0);
199 }
200 return val;
201}
202
203// TODO(hongjunchoi): Consider using tracing algorithm to virtually transform
204// the IR and only apply optimizations when total number of DTensorAllReduce in
205// the graph is reduced.
206bool MayRemoveAllReduce(mlir::Operation* op) {
207 mlir::Value op_output = op->getResult(0);
208 mlir::Value value_after_identity_like_ops =
209 SkipIdentityLikeOpsOutputs(op_output);
210 if (value_after_identity_like_ops.hasOneUse() &&
211 llvm::isa<mlir::TF::AddNOp, mlir::TF::AddV2Op, mlir::TF::AddOp>(
212 *value_after_identity_like_ops.user_begin()))
213
214 return true;
215
216 return false;
217}
218
219// Moves DTensorAllReduce ops after IdentityLike Operations if the operation is
220// connected to Add operation which may led to optimization.
221// For example:
222//
223// %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
224// %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
225// %4= "tf.DTensorAllReduce"(%2, %0) {reduce_op = "Add"}
226// %5 = "tf.Cast"(%4){Truncate = false, device = ""}
227// %6 = "tf.Identity"(%5){Truncate = false, device = ""}
228// %7 = "tf.Const"() {value = dense<[916,8192]> : tensor<2xi32>}
229// %8 = "tf.Reshape"(%6, %7)
230//
231// Becomes :
232//
233// %0 = "tf.Const"()
234// %2 = "tf.Const"()
235// %3 = "tf.Cast"(%2)
236// %4 = "tf.Identity"(%3)
237// %7 = "tf.Const"()
238// %8 = "tf.Reshape"(%4, %7)
239// %9 = "tf.DTensorAllReduce"(%8, %0) {reduce_op = "Add"}
240void OptimizeIdentityLikeOps(mlir::Operation* op, bool* changed) {
241 auto dtensor_all_reduce =
242 llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
243 op->getOperand(0).getDefiningOp());
244 if (!dtensor_all_reduce) return;
245 // TODO(hongjunchoi, bfontain): Consider allowing pushing DTensorAllReduce op
246 // with multiple usages if it can lead to performance optimization.
247 if (!dtensor_all_reduce->hasOneUse()) return;
248 if (!MayRemoveAllReduce(op)) return;
249
250 dtensor_all_reduce->moveAfter(op);
251 mlir::Value input = dtensor_all_reduce.input();
252 op->setOperand(0, input);
253
254 mlir::Value op_output = op->getResult(0);
255 dtensor_all_reduce.setOperand(0, op_output);
256 dtensor_all_reduce.input().setType(op_output.getType());
257 dtensor_all_reduce.output().setType(op_output.getType());
258
259 llvm::SmallPtrSet<mlir::Operation*, 4> exceptions{dtensor_all_reduce};
260 op_output.replaceAllUsesExcept(dtensor_all_reduce.output(), exceptions);
261 *changed = true;
262}
263
264bool CheckWhileLoopOptimizationCriteria(
265 const int index, mlir::TF::WhileRegionOp while_op, mlir::Value while_output,
266 mlir::Operation** add_op, mlir::TF::DTensorAllReduceOp* all_reduce_op,
267 mlir::OpOperand** add_input) {
268 // Loop variant input that is being optimized should not be used in loop
269 // condition.
270 mlir::Value loop_condition_input = while_op.cond().getArgument(index);
271 if (!loop_condition_input.use_empty()) return false;
272
273 // While loop output should be connected to add op.
274 // If operand to while loop body terminator if from Identity op,
275 // skip through the input identity operations.
276 mlir::Value output_value = GetIdentitySkippedInputs(while_output);
277 mlir::Operation* output_defining_op = output_value.getDefiningOp();
278 if (!output_defining_op) return false;
279
280 // TODO(hongjunchoi): Handle AddN op as well.
281 if (!output_defining_op ||
282 !llvm::isa<mlir::TF::AddV2Op, mlir::TF::AddOp>(output_defining_op)) {
283 return false;
284 }
285
286 // Input operand of add operation should be
287 // 1) DTensorAllReduce
288 // 2) from block argument of while loop
289 mlir::OpOperand& first_operand = output_defining_op->getOpOperand(0);
290 mlir::OpOperand& second_operand = output_defining_op->getOpOperand(1);
291 mlir::BlockArgument block_arg;
292 mlir::TF::DTensorAllReduceOp all_reduce =
293 llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
294 first_operand.get().getDefiningOp());
295 if (all_reduce) {
296 block_arg = second_operand.get().dyn_cast<mlir::BlockArgument>();
297 *add_input = &second_operand;
298 } else {
299 all_reduce = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>(
300 second_operand.get().getDefiningOp());
301 block_arg = first_operand.get().dyn_cast<mlir::BlockArgument>();
302 *add_input = &first_operand;
303 }
304 if (!block_arg || !all_reduce) return false;
305
306 // DTensorAllReduce should calculate sum across devices and group assignment
307 // must be statically known.
308 mlir::Operation* group_assignment =
309 all_reduce.group_assignment().getDefiningOp();
310 if (!group_assignment || !llvm::isa<mlir::TF::ConstOp>(group_assignment))
311 return false;
312
313 if (all_reduce.reduce_op().str() != kReduceOpAdd) return false;
314
315 // While loop block argument input connected to Add op should be
316 // connected to constant operations with zero value.
317 const int block_arg_index = block_arg.getArgNumber();
318 mlir::OpOperand& while_input = while_op->getOpOperand(block_arg_index);
319 if (!IsZeroConstant(while_input.get())) return false;
320
321 // TODO(hongjunchoi): Handle the case when input is from DTensorAllReduce op.
322 // If group assignment is the same, then the input DTensorAllReduce op can
323 // also be optimized away.
324
325 *add_op = output_defining_op;
326 *all_reduce_op = all_reduce;
327 return true;
328}
329
330// Extracts out DTensorAllReduce operation from while op if
331// a) While op contains DTensorAllReduce op followed by an Add Operation
332// b) Remaining operand of Add operation is a loop variant input of the while
333// operation with zero initial value.
334//
335// For example:
336//
337// %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
338// %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
339// WhileRegionOp(%2) {
340// %0 = "tf.A"(%2)
341// "tf.Yield"(%0)
342// }, {
343// ^bb0(%barg0: tensor<8192x916xbf16>):
344// ...
345// %0 = "tf.Const"()
346// %1 = "tf.Const"()
347// %2 = "tf.DTensorAllReduce"(%1, %0) {reduce_op = "Add"}
348// %3 = "tf.Add"(%2, %barg0)
349// "tf.Yield"(%3)
350// })
351//
352// Becomes :
353//
354// %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>}
355// %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>}
356// %4 = WhileRegionOp(%2) {
357// %0 = "tf.A"(%2)
358// "tf.Yield"(%0)
359// }, {
360// ^bb0(%barg0: tensor<8192x916xbf16>):
361// ...
362// %0 = "tf.Const"()
363// %1 = "tf.Const"()
364// %3 = "tf.Add"(%1, %barg0)
365// "tf.Yield"(%3)
366// })
367// "tf.DTensorAllReduce"(%4, %0) {reduce_op = "Add"}
368mlir::LogicalResult ExtractAllReduceFromWhileOp(
369 const int output_index, mlir::TF::DTensorAllReduceOp all_reduce,
370 mlir::TF::WhileRegionOp while_op, mlir::OpOperand& add_input,
371 mlir::Operation* add_op, bool* changed) {
372 // Set add input to input of all reduce.
373 mlir::Value all_reduce_input = all_reduce.input();
374 const int replacement_add_input_index =
375 add_input.getOperandNumber() == 0 ? 1 : 0;
376 add_op->setOperand(replacement_add_input_index, all_reduce_input);
377
378 mlir::OpBuilder builder(while_op);
379 builder.setInsertionPointAfter(while_op);
380
381 mlir::Value while_output = while_op.getResult(output_index);
382 mlir::Operation* group_assignment_const =
383 all_reduce.group_assignment().getDefiningOp();
384 mlir::Operation* cloned_group_assignment =
385 builder.clone(*group_assignment_const);
386
387 // Create a singe reduction operation that reduces the result of the locally
388 // added tensor.
389 auto new_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>(
390 all_reduce.getLoc(), while_output.getType(), while_output,
391 cloned_group_assignment->getResult(0),
392 builder.getStringAttr(std::string(kReduceOpAdd)),
393 builder.getStringAttr(all_reduce.device_type()));
394
395 const auto layout_or_status = ExtractSingleLayoutFromOp(all_reduce);
396 if (!layout_or_status.ok())
397 return all_reduce->emitOpError(llvm::formatv(
398 "Malformed layout specification for DTensorAllReduce op found: {0}",
399 layout_or_status.status().error_message()));
400
401 if (!layout_or_status->has_value())
402 return all_reduce->emitOpError(
403 "DTensorAllReduce op must have layout specification.");
404
405 // Set target layout that is equivalent to original DTensorReduction op in
406 // the graph. This is used during later optimization passes.
407 SetSingleLayoutOnOp(new_all_reduce, layout_or_status->value());
408
409 llvm::SmallPtrSet<mlir::Operation*, 4> exceptions;
410 exceptions.insert(new_all_reduce.getOperation());
411 while_output.replaceAllUsesExcept(new_all_reduce.output(), exceptions);
412
413 if (all_reduce.use_empty()) all_reduce.erase();
414
415 *changed = true;
416 return mlir::success();
417}
418
419mlir::LogicalResult OptimizeWhileLoopLazyAllReduce(
420 mlir::TF::WhileRegionOp while_op, bool* changed) {
421 mlir::Operation* while_body_terminator =
422 while_op.body().front().getTerminator();
423 for (const auto& data :
424 llvm::enumerate(while_body_terminator->getOpOperands())) {
425 const int index = data.index();
426 mlir::OpOperand& operand = data.value();
427
428 mlir::Operation* add_op = nullptr;
429 mlir::TF::DTensorAllReduceOp all_reduce;
430 mlir::OpOperand* add_input = nullptr;
431 if (!CheckWhileLoopOptimizationCriteria(index, while_op, operand.get(),
432 &add_op, &all_reduce, &add_input))
433 continue;
434
435 // Perform while loop lazy all reduce optimization.
436 if (mlir::failed(ExtractAllReduceFromWhileOp(index, all_reduce, while_op,
437 *add_input, add_op, changed)))
438 return mlir::failure();
439 }
440
441 return mlir::success();
442}
443
444mlir::LogicalResult ApplyOptimization(
445 mlir::func::FuncOp function,
446 const llvm::SmallVectorImpl<mlir::Operation*>& identity_like_ops,
447 const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops,
448 const llvm::SmallVectorImpl<mlir::Operation*>& add_ops, bool* changed) {
449 // Collect and fold the reduction operations within the function.
450 for (mlir::Operation* add_op : add_ops)
451 if (mlir::failed(OptimizeAllReduceAndSum(add_op, changed)))
452 return mlir::failure();
453
454 for (mlir::Operation* op : identity_like_ops)
455 OptimizeIdentityLikeOps(op, changed);
456
457 for (mlir::TF::WhileRegionOp op : while_ops)
458 if (mlir::failed(OptimizeWhileLoopLazyAllReduce(op, changed)))
459 return mlir::failure();
460
461 return mlir::success();
462}
463
464// Finds all potential ops that could lead to all reduce optimizations. Those
465// are:
466// a) Identity like ops (e.g. Identity/Reshape/Cast) ops.
467// b) WhileRegion op
468// c) Add operations.
469void CollectOptimizationCandidates(
470 mlir::func::FuncOp func,
471 llvm::SmallVectorImpl<mlir::Operation*>* identity_like_ops,
472 llvm::SmallVectorImpl<mlir::Operation*>* add_ops,
473 llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>* while_ops) {
474 func.walk([&](mlir::Operation* op) {
475 if (llvm::isa<mlir::TF::IdentityOp, mlir::TF::CastOp, mlir::TF::ReshapeOp>(
476 op))
477 identity_like_ops->emplace_back(op);
478
479 if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op))
480 while_ops->emplace_back(while_op);
481
482 if (llvm::isa<mlir::TF::AddOp, mlir::TF::AddV2Op, mlir::TF::AddNOp>(op))
483 add_ops->emplace_back(op);
484 });
485}
486
487// MLIR pass that folds constants that can be removed or deduplicated away.
488struct DTensorAllReduceSumOptimization
489 : public impl::DTensorAllReduceSumOptimizationBase<
490 DTensorAllReduceSumOptimization> {
491 void runOnOperation() override {
492 mlir::func::FuncOp function = getOperation();
493 bool changed = true;
494 int iteration = 0;
495
496 llvm::SmallVector<mlir::Operation*, 4> identity_like_ops;
497 llvm::SmallVector<mlir::Operation*, 4> add_ops;
498 llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops;
499 CollectOptimizationCandidates(function, &identity_like_ops, &add_ops,
500 &while_ops);
501 bool is_optimized = false;
502 while (changed && iteration < kMaxIteration) {
503 changed = false;
504 if (mlir::failed(ApplyOptimization(function, identity_like_ops, while_ops,
505 add_ops, &changed)))
506 return signalPassFailure();
507 iteration++;
508 if (changed) is_optimized = true;
509 }
510 }
511};
512
513} // namespace
514
515std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
516CreateDTensorAllReduceSumOptimization() {
517 return std::make_unique<DTensorAllReduceSumOptimization>();
518}
519
520} // namespace dtensor
521} // namespace tensorflow
522