1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "absl/strings/string_view.h" |
17 | #include "llvm/Support/FormatVariadic.h" |
18 | #include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project |
19 | #include "mlir/IR/BuiltinTypes.h" // from @llvm-project |
20 | #include "mlir/IR/Visitors.h" // from @llvm-project |
21 | #include "mlir/Support/LLVM.h" // from @llvm-project |
22 | #include "mlir/Support/LogicalResult.h" // from @llvm-project |
23 | #include "mlir/Transforms/Passes.h" // from @llvm-project |
24 | #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" |
25 | #include "tensorflow/dtensor/cc/dtensor_utils.h" |
26 | #include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h" |
27 | #include "tensorflow/dtensor/mlir/ir/tf_dtensor.h" |
28 | #include "tensorflow/dtensor/mlir/layout_parsing.h" |
29 | #include "tensorflow/dtensor/mlir/spmd_expander_common.h" |
30 | |
31 | namespace tensorflow { |
32 | namespace dtensor { |
33 | |
34 | namespace { |
35 | #define GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE |
36 | #include "tensorflow/dtensor/mlir/dtensor_passes.h.inc" |
37 | |
38 | // Extracts the reduction group size from the group_assignment operand of the |
39 | // reduce op. group_assignment is a 2-dimensional array where each element is |
40 | // the list of devices that are a part of the same reduction group. |
41 | template <class ReduceOpType> |
42 | mlir::LogicalResult GetAllReduceGroupSize(ReduceOpType reduce_op, |
43 | int32* group_size) { |
44 | mlir::DenseIntElementsAttr group_assignment_attr; |
45 | if (!matchPattern(reduce_op.group_assignment(), |
46 | m_Constant(&group_assignment_attr))) |
47 | return mlir::emitError(reduce_op.getLoc(), |
48 | "group_assigment must be a constant." ); |
49 | if (group_assignment_attr.getType().getRank() != 2) |
50 | return mlir::emitError(reduce_op.getLoc(), |
51 | "group_assignment should have two dimensions." ); |
52 | |
53 | *group_size = group_assignment_attr.getType().getShape()[1]; |
54 | return mlir::success(); |
55 | } |
56 | |
57 | // For large enough reduction groups, we compute reductions in a higher |
58 | // precision type to ensure accuracy is not lost with sequential addition |
59 | // of large numbers in a lower precision type. If the given reduce op meets the |
60 | // following criteria: |
61 | // - the tensors being reduced are of type bfloat16, |
62 | // - the reduction group is at least as large as the configurable env var |
63 | // DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE, |
64 | // then the tensors are upcasted to float32 for the reduction before being |
65 | // downcasted again. |
66 | template <class ReduceOpType> |
67 | mlir::LogicalResult MaybeUpcastForReduction(ReduceOpType reduce_op, |
68 | bool* changed) { |
69 | const mlir::RankedTensorType& input_type = |
70 | reduce_op.input().getType().template dyn_cast<mlir::RankedTensorType>(); |
71 | if (!input_type.getElementType().isBF16()) { |
72 | // Upcast only applies for bfloat16 input. |
73 | return mlir::success(); |
74 | } |
75 | |
76 | mlir::OpBuilder builder(reduce_op); |
77 | const mlir::Location loc = reduce_op.getLoc(); |
78 | |
79 | int32 group_size; |
80 | if (mlir::failed(GetAllReduceGroupSize(reduce_op, &group_size))) |
81 | return mlir::failure(); |
82 | if (group_size <= ReduceInBfloat16MaxGroupSize()) |
83 | // Reduce group size is not sufficient, so we do not modify the ops. |
84 | return mlir::success(); |
85 | |
86 | const auto reduce_layout = ExtractRequiredSingleLayoutFromOp(reduce_op); |
87 | if (!reduce_layout.ok()) |
88 | return reduce_op.emitOpError(llvm::formatv( |
89 | "Malformed layout specification for DTensor reduce op found: {0}" , |
90 | reduce_layout.status().error_message())); |
91 | |
92 | // The original output tensor type that would have been used by all users of |
93 | // the reduce op. |
94 | const mlir::RankedTensorType& output_type = |
95 | reduce_op.output().getType().template dyn_cast<mlir::RankedTensorType>(); |
96 | |
97 | mlir::TF::CastOp upcast = builder.create<mlir::TF::CastOp>( |
98 | loc, |
99 | mlir::RankedTensorType::get(input_type.getShape(), builder.getF32Type()), |
100 | reduce_op.input()); |
101 | reduce_op->setOperand(0, upcast.y()); |
102 | reduce_op.output().setType(upcast.y().getType()); |
103 | |
104 | builder.setInsertionPointAfter(reduce_op); |
105 | mlir::TF::CastOp downcast = builder.create<mlir::TF::CastOp>( |
106 | loc, |
107 | mlir::RankedTensorType::get(output_type.getShape(), |
108 | output_type.getElementType()), |
109 | reduce_op); |
110 | // Match the layout of the downcast with the reduce op, this is required for |
111 | // the later passes. |
112 | SetSingleLayoutOnOp(downcast, *reduce_layout); |
113 | reduce_op.output().replaceAllUsesExcept(downcast.y(), downcast); |
114 | |
115 | *changed = true; |
116 | return mlir::success(); |
117 | } |
118 | |
119 | template <class ReduceOpType> |
120 | mlir::LogicalResult TryMixedPrecisionReduce(mlir::func::FuncOp function, |
121 | absl::string_view opName) { |
122 | int32_t reduceOpsCounter = 0; |
123 | int32_t changedReduceOpsCounter = 0; |
124 | |
125 | mlir::WalkResult walk_result = function.walk([&](ReduceOpType reduce_op) { |
126 | if (reduce_op.reduce_op().str() == kReduceOpAdd) { |
127 | reduceOpsCounter += 1; |
128 | bool changed = false; |
129 | if (mlir::failed(MaybeUpcastForReduction(reduce_op, &changed))) |
130 | return mlir::WalkResult::interrupt(); |
131 | if (changed) changedReduceOpsCounter += 1; |
132 | } |
133 | return mlir::WalkResult::advance(); |
134 | }); |
135 | if (walk_result.wasInterrupted()) return mlir::failure(); |
136 | |
137 | VLOG(2) << "Applied mixed precision to " << changedReduceOpsCounter << " of " |
138 | << reduceOpsCounter << " Add " << opName << " ops." ; |
139 | |
140 | return mlir::success(); |
141 | } |
142 | |
143 | // MLIR pass that enables tensor upcasting within mixed-precision reduction. |
144 | struct DTensorMixedPrecisionReducePass |
145 | : public impl::DTensorMixedPrecisionReduceBase< |
146 | DTensorMixedPrecisionReducePass> { |
147 | void runOnOperation() override { |
148 | mlir::func::FuncOp function = getOperation(); |
149 | |
150 | if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorAllReduceOp>( |
151 | function, "DTensorAllReduce" ))) |
152 | return signalPassFailure(); |
153 | if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorReduceScatterOp>( |
154 | function, "DTensorReduceScatter" ))) |
155 | return signalPassFailure(); |
156 | } |
157 | }; |
158 | |
159 | } // namespace |
160 | |
161 | std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>> |
162 | CreateDTensorMixedPrecisionReducePass() { |
163 | return std::make_unique<DTensorMixedPrecisionReducePass>(); |
164 | } |
165 | |
166 | } // namespace dtensor |
167 | } // namespace tensorflow |
168 | |