1/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "absl/strings/string_view.h"
17#include "llvm/Support/FormatVariadic.h"
18#include "mlir/Dialect/Func/IR/FuncOps.h" // from @llvm-project
19#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
20#include "mlir/IR/Visitors.h" // from @llvm-project
21#include "mlir/Support/LLVM.h" // from @llvm-project
22#include "mlir/Support/LogicalResult.h" // from @llvm-project
23#include "mlir/Transforms/Passes.h" // from @llvm-project
24#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
25#include "tensorflow/dtensor/cc/dtensor_utils.h"
26#include "tensorflow/dtensor/mlir/dtensor_mlir_passes.h"
27#include "tensorflow/dtensor/mlir/ir/tf_dtensor.h"
28#include "tensorflow/dtensor/mlir/layout_parsing.h"
29#include "tensorflow/dtensor/mlir/spmd_expander_common.h"
30
31namespace tensorflow {
32namespace dtensor {
33
34namespace {
35#define GEN_PASS_DEF_DTENSORMIXEDPRECISIONREDUCE
36#include "tensorflow/dtensor/mlir/dtensor_passes.h.inc"
37
38// Extracts the reduction group size from the group_assignment operand of the
39// reduce op. group_assignment is a 2-dimensional array where each element is
40// the list of devices that are a part of the same reduction group.
41template <class ReduceOpType>
42mlir::LogicalResult GetAllReduceGroupSize(ReduceOpType reduce_op,
43 int32* group_size) {
44 mlir::DenseIntElementsAttr group_assignment_attr;
45 if (!matchPattern(reduce_op.group_assignment(),
46 m_Constant(&group_assignment_attr)))
47 return mlir::emitError(reduce_op.getLoc(),
48 "group_assigment must be a constant.");
49 if (group_assignment_attr.getType().getRank() != 2)
50 return mlir::emitError(reduce_op.getLoc(),
51 "group_assignment should have two dimensions.");
52
53 *group_size = group_assignment_attr.getType().getShape()[1];
54 return mlir::success();
55}
56
57// For large enough reduction groups, we compute reductions in a higher
58// precision type to ensure accuracy is not lost with sequential addition
59// of large numbers in a lower precision type. If the given reduce op meets the
60// following criteria:
61// - the tensors being reduced are of type bfloat16,
62// - the reduction group is at least as large as the configurable env var
63// DTENSOR_REDUCE_IN_BFLOAT16_MAX_GROUP_SIZE,
64// then the tensors are upcasted to float32 for the reduction before being
65// downcasted again.
66template <class ReduceOpType>
67mlir::LogicalResult MaybeUpcastForReduction(ReduceOpType reduce_op,
68 bool* changed) {
69 const mlir::RankedTensorType& input_type =
70 reduce_op.input().getType().template dyn_cast<mlir::RankedTensorType>();
71 if (!input_type.getElementType().isBF16()) {
72 // Upcast only applies for bfloat16 input.
73 return mlir::success();
74 }
75
76 mlir::OpBuilder builder(reduce_op);
77 const mlir::Location loc = reduce_op.getLoc();
78
79 int32 group_size;
80 if (mlir::failed(GetAllReduceGroupSize(reduce_op, &group_size)))
81 return mlir::failure();
82 if (group_size <= ReduceInBfloat16MaxGroupSize())
83 // Reduce group size is not sufficient, so we do not modify the ops.
84 return mlir::success();
85
86 const auto reduce_layout = ExtractRequiredSingleLayoutFromOp(reduce_op);
87 if (!reduce_layout.ok())
88 return reduce_op.emitOpError(llvm::formatv(
89 "Malformed layout specification for DTensor reduce op found: {0}",
90 reduce_layout.status().error_message()));
91
92 // The original output tensor type that would have been used by all users of
93 // the reduce op.
94 const mlir::RankedTensorType& output_type =
95 reduce_op.output().getType().template dyn_cast<mlir::RankedTensorType>();
96
97 mlir::TF::CastOp upcast = builder.create<mlir::TF::CastOp>(
98 loc,
99 mlir::RankedTensorType::get(input_type.getShape(), builder.getF32Type()),
100 reduce_op.input());
101 reduce_op->setOperand(0, upcast.y());
102 reduce_op.output().setType(upcast.y().getType());
103
104 builder.setInsertionPointAfter(reduce_op);
105 mlir::TF::CastOp downcast = builder.create<mlir::TF::CastOp>(
106 loc,
107 mlir::RankedTensorType::get(output_type.getShape(),
108 output_type.getElementType()),
109 reduce_op);
110 // Match the layout of the downcast with the reduce op, this is required for
111 // the later passes.
112 SetSingleLayoutOnOp(downcast, *reduce_layout);
113 reduce_op.output().replaceAllUsesExcept(downcast.y(), downcast);
114
115 *changed = true;
116 return mlir::success();
117}
118
119template <class ReduceOpType>
120mlir::LogicalResult TryMixedPrecisionReduce(mlir::func::FuncOp function,
121 absl::string_view opName) {
122 int32_t reduceOpsCounter = 0;
123 int32_t changedReduceOpsCounter = 0;
124
125 mlir::WalkResult walk_result = function.walk([&](ReduceOpType reduce_op) {
126 if (reduce_op.reduce_op().str() == kReduceOpAdd) {
127 reduceOpsCounter += 1;
128 bool changed = false;
129 if (mlir::failed(MaybeUpcastForReduction(reduce_op, &changed)))
130 return mlir::WalkResult::interrupt();
131 if (changed) changedReduceOpsCounter += 1;
132 }
133 return mlir::WalkResult::advance();
134 });
135 if (walk_result.wasInterrupted()) return mlir::failure();
136
137 VLOG(2) << "Applied mixed precision to " << changedReduceOpsCounter << " of "
138 << reduceOpsCounter << " Add " << opName << " ops.";
139
140 return mlir::success();
141}
142
143// MLIR pass that enables tensor upcasting within mixed-precision reduction.
144struct DTensorMixedPrecisionReducePass
145 : public impl::DTensorMixedPrecisionReduceBase<
146 DTensorMixedPrecisionReducePass> {
147 void runOnOperation() override {
148 mlir::func::FuncOp function = getOperation();
149
150 if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorAllReduceOp>(
151 function, "DTensorAllReduce")))
152 return signalPassFailure();
153 if (mlir::failed(TryMixedPrecisionReduce<mlir::TF::DTensorReduceScatterOp>(
154 function, "DTensorReduceScatter")))
155 return signalPassFailure();
156 }
157};
158
159} // namespace
160
161std::unique_ptr<mlir::OperationPass<mlir::func::FuncOp>>
162CreateDTensorMixedPrecisionReducePass() {
163 return std::make_unique<DTensorMixedPrecisionReducePass>();
164}
165
166} // namespace dtensor
167} // namespace tensorflow
168