1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include <string> |
17 | |
18 | #include "llvm/ADT/DenseMap.h" |
19 | #include "llvm/ADT/SmallPtrSet.h" |
20 | #include "llvm/Support/FormatVariadic.h" |
21 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
22 | #include "mlir/IR/Builders.h" // from @llvm-project |
23 | #include "mlir/IR/BuiltinAttributes.h" // from @llvm-project |
24 | #include "mlir/IR/Operation.h" // from @llvm-project |
25 | #include "mlir/IR/UseDefLists.h" // from @llvm-project |
26 | #include "mlir/IR/Value.h" // from @llvm-project |
27 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
28 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
29 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
30 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
31 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
32 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
33 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
34 | |
35 | namespace tensorflow { |
36 | namespace dtensor { |
37 | |
38 | namespace { |
39 | #define GEN_PASS_DEF_DTENSORALLREDUCESUMOPTIMIZATION |
40 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
41 | |
42 | constexpr int kMaxIteration = 10; |
43 | |
44 | mlir::Value GetIdentitySkippedInputs(mlir::Value val) { |
45 | mlir::Value input = val; |
46 | while (auto identity = llvm::dyn_cast_or_null<mlir::TF::IdentityOp>( |
47 | input.getDefiningOp())) { |
48 | input = identity.input(); |
49 | } |
50 | return input; |
51 | } |
52 | |
53 | bool IsZeroConstant(mlir::Value val) { |
54 | auto const_input = llvm::dyn_cast_or_null<mlir::TF::ConstOp>( |
55 | GetIdentitySkippedInputs(val).getDefiningOp()); |
56 | if (!const_input) return false; |
57 | mlir::DenseFPElementsAttr attr = |
58 | const_input.value().dyn_cast<mlir::DenseFPElementsAttr>(); |
59 | // This uses the fact that constant Attrs becomes splats, so we only need to |
60 | // check one value. |
61 | if (!attr || !attr.isSplat()) return false; |
62 | return attr.getSplatValue<mlir::FloatAttr>().getValue().isZero(); |
63 | } |
64 | |
65 | // Extracts inputs/ops required for optimization and checks whether graph |
66 | // meets the criteria for reduction + sum optimization. The criterion are: |
67 | // a) All DTensorAllReduce operations must be sum operations. |
68 | // b) Group assignment of DTensorAllReduceOp must be the same |
69 | // c) All operands of Add op must be DTensorAllReduce operations. |
70 | mlir::LogicalResult CheckReduceAndSumOptimizationCriteria( |
71 | mlir::Operation* add_op, |
72 | llvm::SmallVectorImpl<mlir::Value>* reduction_inputs, |
73 | llvm::SmallVectorImpl<mlir::TF::DTensorAllReduceOp>* reduction_ops, |
74 | bool* can_be_reordered) { |
75 | for (mlir::Value operand : add_op->getOperands()) { |
76 | if (IsZeroConstant(operand)) { |
77 | reduction_inputs->emplace_back(operand); |
78 | continue; |
79 | } |
80 | |
81 | auto reduction_op = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>( |
82 | operand.getDefiningOp()); |
83 | if (!reduction_op) { |
84 | *can_be_reordered = false; |
85 | return mlir::success(); |
86 | } |
87 | |
88 | reduction_ops->emplace_back(reduction_op); |
89 | } |
90 | |
91 | llvm::SmallDenseSet<mlir::Attribute> reduction_group_assignments; |
92 | for (mlir::TF::DTensorAllReduceOp reduction : *reduction_ops) { |
93 | if (reduction.reduce_op().str() != kReduceOpAdd) { |
94 | *can_be_reordered = false; |
95 | return mlir::success(); |
96 | } |
97 | |
98 | mlir::DenseIntElementsAttr group_assignment; |
99 | if (!matchPattern(reduction.group_assignment(), |
100 | m_Constant(&group_assignment))) { |
101 | *can_be_reordered = false; |
102 | return mlir::success(); |
103 | } |
104 | |
105 | reduction_group_assignments.insert(group_assignment); |
106 | reduction_inputs->emplace_back(reduction.input()); |
107 | } |
108 | |
109 | *can_be_reordered = (reduction_group_assignments.size() == 1); |
110 | return mlir::success(); |
111 | } |
112 | |
113 | // Applies optimization that reorders AllReduce + Add operations. |
114 | // For example: |
115 | // %3 = DTensorAllReduce(%0) |
116 | // %4 = DTensorAllReduce(%1) |
117 | // %5 = Add(%3, %4) |
118 | // |
119 | // Is transformed to: |
120 | // %2 = Add(%0, %1) |
121 | // %3 = DTensorAllReduce(%2) |
122 | // |
123 | // Therefore reducing the number of Reduction/cross device communication. |
124 | mlir::LogicalResult OptimizeAllReduceAndSum(mlir::Operation* op, |
125 | bool* changed) { |
126 | bool can_be_reordered; |
127 | llvm::SmallVector<mlir::TF::DTensorAllReduceOp, 4> reduction_ops; |
128 | llvm::SmallVector<mlir::Value, 4> reduction_op_inputs; |
129 | if (mlir::failed(CheckReduceAndSumOptimizationCriteria( |
130 | op, &reduction_op_inputs, &reduction_ops, &can_be_reordered))) |
131 | return mlir::failure(); |
132 | |
133 | if (!can_be_reordered || reduction_ops.empty()) return mlir::success(); |
134 | |
135 | // Forward the inputs from the DTensorAllReduce to the add op. Calling |
136 | // getOperand(i).getDefiningOp() since CheckReduceAndSumOptimizationCriteria |
137 | // checks that each input is fed by a DTensorAllReduce or a Zero constant. |
138 | for (int i = 0; i < op->getNumOperands(); ++i) { |
139 | if (mlir::isa<mlir::TF::DTensorAllReduceOp>( |
140 | op->getOperand(i).getDefiningOp())) |
141 | op->setOperand(i, op->getOperand(i).getDefiningOp()->getOperand(0)); |
142 | } |
143 | |
144 | mlir::TF::DTensorAllReduceOp first_reduction_op = reduction_ops.front(); |
145 | |
146 | // Invoke reduction operation on locally added tensor once. |
147 | // From above check `CheckOptimizationCriteria()`, we know that all reduction |
148 | // operations that are fused reused the same group assignment value. |
149 | // 1) Get mlir::Value that represents group assignment used for reduction. |
150 | mlir::Value group_assignment = first_reduction_op.group_assignment(); |
151 | |
152 | // Create a singe reduction operation that reduces the result of the locally |
153 | // added tensor. |
154 | mlir::OpBuilder builder(op); |
155 | builder.setInsertionPointAfterValue(op->getResult(0)); |
156 | mlir::TF::DTensorAllReduceOp all_reduce = |
157 | builder.create<mlir::TF::DTensorAllReduceOp>( |
158 | op->getLoc(), op->getResult(0).getType(), op->getResult(0), |
159 | group_assignment, builder.getStringAttr(std::string(kReduceOpAdd)), |
160 | builder.getStringAttr(first_reduction_op.device_type())); |
161 | |
162 | const auto layout_or_status = ExtractSingleLayoutFromOp(first_reduction_op); |
163 | if (!layout_or_status.ok()) |
164 | return first_reduction_op->emitOpError(llvm::formatv( |
165 | "Malformed layout specification for DTensorAllReduce op found: {0}" , |
166 | layout_or_status.status().error_message())); |
167 | |
168 | if (!layout_or_status->has_value()) |
169 | return first_reduction_op->emitOpError( |
170 | "DTensorAllReduce op must have layout specification." ); |
171 | |
172 | // Set target layout that is equivalent to original DTensorReduction op in |
173 | // the graph. This is used during later optimization passes. |
174 | SetSingleLayoutOnOp(all_reduce, layout_or_status->value()); |
175 | |
176 | // Replace usages of original tf.Add op with newly created output of |
177 | // `all_reduce`. |
178 | op->getResult(0).replaceAllUsesExcept( |
179 | all_reduce.output(), |
180 | llvm::SmallPtrSet<mlir::Operation*, 1>{all_reduce.getOperation()}); |
181 | |
182 | // TODO(hongjunchoi, bfontain): Consider adding optimization for the case when |
183 | // `tree` of Add operations with DTensorAllReduce op as inputs exists. |
184 | // Remove original tf.Add `op` and if reduction operation inputs to original |
185 | // `op` is only used by the `op`, then remove the DTensorAllReduce op as well. |
186 | for (mlir::Operation* original_reduction_op : reduction_ops) { |
187 | if (original_reduction_op->use_empty()) original_reduction_op->erase(); |
188 | } |
189 | |
190 | *changed = true; |
191 | return mlir::success(); |
192 | } |
193 | |
194 | mlir::Value SkipIdentityLikeOpsOutputs(mlir::Value val) { |
195 | while (val.hasOneUse() && |
196 | llvm::isa<mlir::TF::CastOp, mlir::TF::ReshapeOp, mlir::TF::IdentityOp>( |
197 | *val.user_begin())) { |
198 | val = val.user_begin()->getResult(0); |
199 | } |
200 | return val; |
201 | } |
202 | |
203 | // TODO(hongjunchoi): Consider using tracing algorithm to virtually transform |
204 | // the IR and only apply optimizations when total number of DTensorAllReduce in |
205 | // the graph is reduced. |
206 | bool MayRemoveAllReduce(mlir::Operation* op) { |
207 | mlir::Value op_output = op->getResult(0); |
208 | mlir::Value value_after_identity_like_ops = |
209 | SkipIdentityLikeOpsOutputs(op_output); |
210 | if (value_after_identity_like_ops.hasOneUse() && |
211 | llvm::isa<mlir::TF::AddNOp, mlir::TF::AddV2Op, mlir::TF::AddOp>( |
212 | *value_after_identity_like_ops.user_begin())) |
213 | |
214 | return true; |
215 | |
216 | return false; |
217 | } |
218 | |
219 | // Moves DTensorAllReduce ops after IdentityLike Operations if the operation is |
220 | // connected to Add operation which may led to optimization. |
221 | // For example: |
222 | // |
223 | // %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>} |
224 | // %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>} |
225 | // %4= "tf.DTensorAllReduce"(%2, %0) {reduce_op = "Add"} |
226 | // %5 = "tf.Cast"(%4){Truncate = false, device = ""} |
227 | // %6 = "tf.Identity"(%5){Truncate = false, device = ""} |
228 | // %7 = "tf.Const"() {value = dense<[916,8192]> : tensor<2xi32>} |
229 | // %8 = "tf.Reshape"(%6, %7) |
230 | // |
231 | // Becomes : |
232 | // |
233 | // %0 = "tf.Const"() |
234 | // %2 = "tf.Const"() |
235 | // %3 = "tf.Cast"(%2) |
236 | // %4 = "tf.Identity"(%3) |
237 | // %7 = "tf.Const"() |
238 | // %8 = "tf.Reshape"(%4, %7) |
239 | // %9 = "tf.DTensorAllReduce"(%8, %0) {reduce_op = "Add"} |
240 | void OptimizeIdentityLikeOps(mlir::Operation* op, bool* changed) { |
241 | auto dtensor_all_reduce = |
242 | llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>( |
243 | op->getOperand(0).getDefiningOp()); |
244 | if (!dtensor_all_reduce) return; |
245 | // TODO(hongjunchoi, bfontain): Consider allowing pushing DTensorAllReduce op |
246 | // with multiple usages if it can lead to performance optimization. |
247 | if (!dtensor_all_reduce->hasOneUse()) return; |
248 | if (!MayRemoveAllReduce(op)) return; |
249 | |
250 | dtensor_all_reduce->moveAfter(op); |
251 | mlir::Value input = dtensor_all_reduce.input(); |
252 | op->setOperand(0, input); |
253 | |
254 | mlir::Value op_output = op->getResult(0); |
255 | dtensor_all_reduce.setOperand(0, op_output); |
256 | dtensor_all_reduce.input().setType(op_output.getType()); |
257 | dtensor_all_reduce.output().setType(op_output.getType()); |
258 | |
259 | llvm::SmallPtrSet<mlir::Operation*, 4> exceptions{dtensor_all_reduce}; |
260 | op_output.replaceAllUsesExcept(dtensor_all_reduce.output(), exceptions); |
261 | *changed = true; |
262 | } |
263 | |
264 | bool CheckWhileLoopOptimizationCriteria( |
265 | const int index, mlir::TF::WhileRegionOp while_op, mlir::Value while_output, |
266 | mlir::Operation** add_op, mlir::TF::DTensorAllReduceOp* all_reduce_op, |
267 | mlir::OpOperand** add_input) { |
268 | // Loop variant input that is being optimized should not be used in loop |
269 | // condition. |
270 | mlir::Value loop_condition_input = while_op.cond().getArgument(index); |
271 | if (!loop_condition_input.use_empty()) return false; |
272 | |
273 | // While loop output should be connected to add op. |
274 | // If operand to while loop body terminator if from Identity op, |
275 | // skip through the input identity operations. |
276 | mlir::Value output_value = GetIdentitySkippedInputs(while_output); |
277 | mlir::Operation* output_defining_op = output_value.getDefiningOp(); |
278 | if (!output_defining_op) return false; |
279 | |
280 | // TODO(hongjunchoi): Handle AddN op as well. |
281 | if (!output_defining_op || |
282 | !llvm::isa<mlir::TF::AddV2Op, mlir::TF::AddOp>(output_defining_op)) { |
283 | return false; |
284 | } |
285 | |
286 | // Input operand of add operation should be |
287 | // 1) DTensorAllReduce |
288 | // 2) from block argument of while loop |
289 | mlir::OpOperand& first_operand = output_defining_op->getOpOperand(0); |
290 | mlir::OpOperand& second_operand = output_defining_op->getOpOperand(1); |
291 | mlir::BlockArgument block_arg; |
292 | mlir::TF::DTensorAllReduceOp all_reduce = |
293 | llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>( |
294 | first_operand.get().getDefiningOp()); |
295 | if (all_reduce) { |
296 | block_arg = second_operand.get().dyn_cast<mlir::BlockArgument>(); |
297 | *add_input = &second_operand; |
298 | } else { |
299 | all_reduce = llvm::dyn_cast_or_null<mlir::TF::DTensorAllReduceOp>( |
300 | second_operand.get().getDefiningOp()); |
301 | block_arg = first_operand.get().dyn_cast<mlir::BlockArgument>(); |
302 | *add_input = &first_operand; |
303 | } |
304 | if (!block_arg || !all_reduce) return false; |
305 | |
306 | // DTensorAllReduce should calculate sum across devices and group assignment |
307 | // must be statically known. |
308 | mlir::Operation* group_assignment = |
309 | all_reduce.group_assignment().getDefiningOp(); |
310 | if (!group_assignment || !llvm::isa<mlir::TF::ConstOp>(group_assignment)) |
311 | return false; |
312 | |
313 | if (all_reduce.reduce_op().str() != kReduceOpAdd) return false; |
314 | |
315 | // While loop block argument input connected to Add op should be |
316 | // connected to constant operations with zero value. |
317 | const int block_arg_index = block_arg.getArgNumber(); |
318 | mlir::OpOperand& while_input = while_op->getOpOperand(block_arg_index); |
319 | if (!IsZeroConstant(while_input.get())) return false; |
320 | |
321 | // TODO(hongjunchoi): Handle the case when input is from DTensorAllReduce op. |
322 | // If group assignment is the same, then the input DTensorAllReduce op can |
323 | // also be optimized away. |
324 | |
325 | *add_op = output_defining_op; |
326 | *all_reduce_op = all_reduce; |
327 | return true; |
328 | } |
329 | |
330 | // Extracts out DTensorAllReduce operation from while op if |
331 | // a) While op contains DTensorAllReduce op followed by an Add Operation |
332 | // b) Remaining operand of Add operation is a loop variant input of the while |
333 | // operation with zero initial value. |
334 | // |
335 | // For example: |
336 | // |
337 | // %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>} |
338 | // %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>} |
339 | // WhileRegionOp(%2) { |
340 | // %0 = "tf.A"(%2) |
341 | // "tf.Yield"(%0) |
342 | // }, { |
343 | // ^bb0(%barg0: tensor<8192x916xbf16>): |
344 | // ... |
345 | // %0 = "tf.Const"() |
346 | // %1 = "tf.Const"() |
347 | // %2 = "tf.DTensorAllReduce"(%1, %0) {reduce_op = "Add"} |
348 | // %3 = "tf.Add"(%2, %barg0) |
349 | // "tf.Yield"(%3) |
350 | // }) |
351 | // |
352 | // Becomes : |
353 | // |
354 | // %0 = "tf.Const"() {value = dense<0> : tensor<2x64xi32>} |
355 | // %2 = "tf.Const"() {value = dense<0.0> : tensor<8192x916xbf16>} |
356 | // %4 = WhileRegionOp(%2) { |
357 | // %0 = "tf.A"(%2) |
358 | // "tf.Yield"(%0) |
359 | // }, { |
360 | // ^bb0(%barg0: tensor<8192x916xbf16>): |
361 | // ... |
362 | // %0 = "tf.Const"() |
363 | // %1 = "tf.Const"() |
364 | // %3 = "tf.Add"(%1, %barg0) |
365 | // "tf.Yield"(%3) |
366 | // }) |
367 | // "tf.DTensorAllReduce"(%4, %0) {reduce_op = "Add"} |
368 | mlir::LogicalResult ExtractAllReduceFromWhileOp( |
369 | const int output_index, mlir::TF::DTensorAllReduceOp all_reduce, |
370 | mlir::TF::WhileRegionOp while_op, mlir::OpOperand& add_input, |
371 | mlir::Operation* add_op, bool* changed) { |
372 | // Set add input to input of all reduce. |
373 | mlir::Value all_reduce_input = all_reduce.input(); |
374 | const int replacement_add_input_index = |
375 | add_input.getOperandNumber() == 0 ? 1 : 0; |
376 | add_op->setOperand(replacement_add_input_index, all_reduce_input); |
377 | |
378 | mlir::OpBuilder builder(while_op); |
379 | builder.setInsertionPointAfter(while_op); |
380 | |
381 | mlir::Value while_output = while_op.getResult(output_index); |
382 | mlir::Operation* group_assignment_const = |
383 | all_reduce.group_assignment().getDefiningOp(); |
384 | mlir::Operation* cloned_group_assignment = |
385 | builder.clone(*group_assignment_const); |
386 | |
387 | // Create a singe reduction operation that reduces the result of the locally |
388 | // added tensor. |
389 | auto new_all_reduce = builder.create<mlir::TF::DTensorAllReduceOp>( |
390 | all_reduce.getLoc(), while_output.getType(), while_output, |
391 | cloned_group_assignment->getResult(0), |
392 | builder.getStringAttr(std::string(kReduceOpAdd)), |
393 | builder.getStringAttr(all_reduce.device_type())); |
394 | |
395 | const auto layout_or_status = ExtractSingleLayoutFromOp(all_reduce); |
396 | if (!layout_or_status.ok()) |
397 | return all_reduce->emitOpError(llvm::formatv( |
398 | "Malformed layout specification for DTensorAllReduce op found: {0}" , |
399 | layout_or_status.status().error_message())); |
400 | |
401 | if (!layout_or_status->has_value()) |
402 | return all_reduce->emitOpError( |
403 | "DTensorAllReduce op must have layout specification." ); |
404 | |
405 | // Set target layout that is equivalent to original DTensorReduction op in |
406 | // the graph. This is used during later optimization passes. |
407 | SetSingleLayoutOnOp(new_all_reduce, layout_or_status->value()); |
408 | |
409 | llvm::SmallPtrSet<mlir::Operation*, 4> exceptions; |
410 | exceptions.insert(new_all_reduce.getOperation()); |
411 | while_output.replaceAllUsesExcept(new_all_reduce.output(), exceptions); |
412 | |
413 | if (all_reduce.use_empty()) all_reduce.erase(); |
414 | |
415 | *changed = true; |
416 | return mlir::success(); |
417 | } |
418 | |
419 | mlir::LogicalResult OptimizeWhileLoopLazyAllReduce( |
420 | mlir::TF::WhileRegionOp while_op, bool* changed) { |
421 | mlir::Operation* while_body_terminator = |
422 | while_op.body().front().getTerminator(); |
423 | for (const auto& data : |
424 | llvm::enumerate(while_body_terminator->getOpOperands())) { |
425 | const int index = data.index(); |
426 | mlir::OpOperand& operand = data.value(); |
427 | |
428 | mlir::Operation* add_op = nullptr; |
429 | mlir::TF::DTensorAllReduceOp all_reduce; |
430 | mlir::OpOperand* add_input = nullptr; |
431 | if (!CheckWhileLoopOptimizationCriteria(index, while_op, operand.get(), |
432 | &add_op, &all_reduce, &add_input)) |
433 | continue; |
434 | |
435 | // Perform while loop lazy all reduce optimization. |
436 | if (mlir::failed(ExtractAllReduceFromWhileOp(index, all_reduce, while_op, |
437 | *add_input, add_op, changed))) |
438 | return mlir::failure(); |
439 | } |
440 | |
441 | return mlir::success(); |
442 | } |
443 | |
444 | mlir::LogicalResult ApplyOptimization( |
445 | mlir::func::FuncOp function, |
446 | const llvm::SmallVectorImpl<mlir::Operation*>& identity_like_ops, |
447 | const llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>& while_ops, |
448 | const llvm::SmallVectorImpl<mlir::Operation*>& add_ops, bool* changed) { |
449 | // Collect and fold the reduction operations within the function. |
450 | for (mlir::Operation* add_op : add_ops) |
451 | if (mlir::failed(OptimizeAllReduceAndSum(add_op, changed))) |
452 | return mlir::failure(); |
453 | |
454 | for (mlir::Operation* op : identity_like_ops) |
455 | OptimizeIdentityLikeOps(op, changed); |
456 | |
457 | for (mlir::TF::WhileRegionOp op : while_ops) |
458 | if (mlir::failed(OptimizeWhileLoopLazyAllReduce(op, changed))) |
459 | return mlir::failure(); |
460 | |
461 | return mlir::success(); |
462 | } |
463 | |
464 | // Finds all potential ops that could lead to all reduce optimizations. Those |
465 | // are: |
466 | // a) Identity like ops (e.g. Identity/Reshape/Cast) ops. |
467 | // b) WhileRegion op |
468 | // c) Add operations. |
469 | void CollectOptimizationCandidates( |
470 | mlir::func::FuncOp func, |
471 | llvm::SmallVectorImpl<mlir::Operation*>* identity_like_ops, |
472 | llvm::SmallVectorImpl<mlir::Operation*>* add_ops, |
473 | llvm::SmallVectorImpl<mlir::TF::WhileRegionOp>* while_ops) { |
474 | func.walk([&](mlir::Operation* op) { |
475 | if (llvm::isa<mlir::TF::IdentityOp, mlir::TF::CastOp, mlir::TF::ReshapeOp>( |
476 | op)) |
477 | identity_like_ops->emplace_back(op); |
478 | |
479 | if (auto while_op = llvm::dyn_cast<mlir::TF::WhileRegionOp>(op)) |
480 | while_ops->emplace_back(while_op); |
481 | |
482 | if (llvm::isa<mlir::TF::AddOp, mlir::TF::AddV2Op, mlir::TF::AddNOp>(op)) |
483 | add_ops->emplace_back(op); |
484 | }); |
485 | } |
486 | |
487 | // MLIR pass that folds constants that can be removed or deduplicated away. |
488 | struct DTensorAllReduceSumOptimization |
489 | : public impl::DTensorAllReduceSumOptimizationBase< |
490 | DTensorAllReduceSumOptimization> { |
491 | void runOnOperation() override { |
492 | mlir::func::FuncOp function = getOperation(); |
493 | bool changed = true; |
494 | int iteration = 0; |
495 | |
496 | llvm::SmallVector<mlir::Operation*, 4> identity_like_ops; |
497 | llvm::SmallVector<mlir::Operation*, 4> add_ops; |
498 | llvm::SmallVector<mlir::TF::WhileRegionOp, 4> while_ops; |
499 | CollectOptimizationCandidates(function, &identity_like_ops, &add_ops, |
500 | &while_ops); |
501 | bool is_optimized = false; |
502 | while (changed && iteration < kMaxIteration) { |
503 | changed = false; |
504 | if (mlir::failed(ApplyOptimization(function, identity_like_ops, while_ops, |
505 | add_ops, &changed))) |
506 | return signalPassFailure(); |
507 | iteration++; |
508 | if (changed) is_optimized = true; |
509 | } |
510 | } |
511 | }; |
512 | |
513 | } // namespace |
514 | |
515 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
516 | CreateDTensorAllReduceSumOptimization() { |
517 | return std::make_unique<DTensorAllReduceSumOptimization>(); |
518 | } |
519 | |
520 | } // namespace dtensor |
521 | } // namespace tensorflow |
522 | |