1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file bf16_legalize.cc
22 * \brief legalize bf16 type by adding cast_to_fp32
23 */
24
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/op.h>
28#include <tvm/tir/transform.h>
29
30#include <cmath>
31#include <tuple>
32
33#include "../../arith/ir_mutator_with_analyzer.h"
34#include "../../arith/ir_visitor_with_analyzer.h"
35
36namespace tvm {
37namespace tir {
38
39using arith::Analyzer;
40using arith::IRMutatorWithAnalyzer;
41
42class BF16PromoteRewriter : public StmtExprMutator {
43 public:
44 BF16PromoteRewriter() {}
45
46 Stmt operator()(Stmt s) { return VisitStmt(s); }
47
48 PrimExpr VisitExpr_(const AddNode* op) final;
49 PrimExpr VisitExpr_(const SubNode* op) final;
50 PrimExpr VisitExpr_(const MulNode* op) final;
51 PrimExpr VisitExpr_(const DivNode* op) final;
52 PrimExpr VisitExpr_(const MinNode* op) final;
53 PrimExpr VisitExpr_(const MaxNode* op) final;
54 PrimExpr VisitExpr_(const LTNode* op) final;
55 PrimExpr VisitExpr_(const LENode* op) final;
56 PrimExpr VisitExpr_(const GTNode* op) final;
57 PrimExpr VisitExpr_(const GENode* op) final;
58 PrimExpr VisitExpr_(const CallNode* op) final;
59};
60
61#define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST) \
62 PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \
63 PrimExpr origin_a = this->VisitExpr(op->a); \
64 PrimExpr origin_b = this->VisitExpr(op->b); \
65 bool a_is_bfloat16 = origin_a->dtype.is_bfloat16(); \
66 bool b_is_bfloat16 = origin_b->dtype.is_bfloat16(); \
67 bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16; \
68 bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16); \
69 if (none_bfloat16) { \
70 return GetRef<PrimExpr>(op); \
71 } \
72 DataType float32_dtype(kDLFloat, 32, 1); \
73 PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) : origin_a; \
74 PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) : origin_b; \
75 PrimExpr result = FUNC(float32_a, float32_b); \
76 DataType bfloat16_dtype(kDLBfloat, 16, 1); \
77 bool do_cast = both_bfloat16 && NEEDCAST; \
78 return do_cast ? Cast(bfloat16_dtype, result) : result; \
79 }
80
81DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true)
82DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true)
83DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true)
84DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true)
85DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true)
86DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true)
87DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false)
88DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false)
89DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false)
90DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false)
91
92PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) {
93 Array<PrimExpr> args;
94 for (auto& arg : op->args) {
95 PrimExpr x = this->VisitExpr(arg);
96 if (x.dtype().is_bfloat16()) {
97 DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes());
98 args.push_back(Cast(fp32_dtype, {x}, op->span));
99 } else {
100 args.push_back(x);
101 }
102 }
103 if (op->dtype.is_bfloat16()) {
104 DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes());
105 PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span);
106 return Cast(op->dtype, {result_fp32}, op->span);
107 } else {
108 return Call(op->dtype, op->op, args, op->span);
109 }
110}
111
112/*
113 * Eliminate verbose casting between fp32 and bf16
114 * Checks if the AST has the pattern:
115 * castto32(castto16(some_fp32_op(...)))
116 * The verbose casting is generated by BF16Promote for multiple
117 * bf16 Ops in a row. e.g.:
118 * X[i] + Y[i] + T[i] =>
119 * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i])))
120 * After this pass:
121 * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i]))
122 */
123class BF16CastEliminationRewriter : public StmtExprMutator {
124 public:
125 BF16CastEliminationRewriter() {}
126
127 Stmt operator()(Stmt s) { return VisitStmt(s); }
128
129 PrimExpr VisitExpr_(const CastNode* op) final {
130 auto op_val = StmtExprMutator::VisitExpr(op->value);
131 if (op->dtype.is_float() && op->dtype.bits() == 32) {
132 // if is cast_to_fp32, check if op->value is cast_to_fp16
133 // and op->value->value is a float32
134 if (auto innercast = op_val.as<CastNode>()) {
135 if (innercast->dtype.is_bfloat16() && innercast->value->dtype.is_float() &&
136 innercast->value->dtype.bits() == 32) {
137 return innercast->value;
138 }
139 }
140 }
141 if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
142 return Cast(op->dtype, op_val);
143 }
144};
145
146union FloatCaster {
147 uint32_t u32;
148 float f32;
149};
150
151uint16_t RoundToNearestEven(float src) {
152 if (std::isnan(src)) {
153 return UINT16_C(0x7FC0);
154 } else {
155 FloatCaster caster;
156 caster.f32 = src;
157 uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF);
158 return static_cast<uint16_t>((caster.u32 + rounding_bias) >> 16);
159 }
160}
161
162/*
163 * Lower the bf16 type to int16
164 * Lower cast between bf16 and fp32
165 * Lower bf16 FloatImm to int16
166 */
167class BF16LowerRewriter : public StmtExprMutator {
168 public:
169 BF16LowerRewriter() {}
170
171 using StmtExprMutator::operator();
172
173 PrimExpr VisitExpr_(const CastNode* op) final {
174 PrimExpr op_val = StmtExprMutator::VisitExpr(op->value);
175 DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes());
176 DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes());
177 if (op->value->dtype.is_bfloat16()) { // cast from bf16
178 PrimExpr uint32_v = Cast(uint32_dtype, op_val);
179 PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(), {uint32_v << 16});
180 bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32;
181 return is_to_float32 ? float32_v : Cast(op->dtype, float32_v);
182 } else if (op->dtype.is_bfloat16()) { // cast to bf16
183 bool is_from_float32 = op->value->dtype.is_float() && op->value->dtype.bits() == 32;
184 PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype, op_val);
185 PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(), {float32_v});
186 DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes());
187 /* the following TIR is equivalent to the C++ code below:
188 uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF);
189 return static_cast<uint16_t>((U32 + rounding_bias) >> 16);*/
190 PrimExpr rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF);
191 return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16});
192 }
193 if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op);
194 return Cast(op->dtype, op_val);
195 }
196
197 PrimExpr VisitExpr_(const VarNode* op) final {
198 Var var = GetRef<Var>(op);
199
200 auto itr = var_remap_.find(var);
201 if (itr != var_remap_.end()) {
202 return itr->second;
203 } else {
204 return std::move(var);
205 }
206 }
207
208 Stmt VisitStmt_(const AllocateNode* op) final {
209 if (op->dtype.is_bfloat16()) {
210 DataType dtype = DataType::UInt(16, op->dtype.lanes());
211 Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype)));
212 var_remap_[op->buffer_var] = buffer_var;
213 return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body));
214 } else {
215 return StmtExprMutator::VisitStmt_(op);
216 }
217 }
218
219 Stmt VisitStmt_(const BufferStoreNode* op) final {
220 Stmt ret = StmtExprMutator::VisitStmt_(op);
221 op = ret.as<BufferStoreNode>();
222
223 Buffer new_buf = GetRemappedBuffer(op->buffer);
224 if (new_buf.same_as(op->buffer)) {
225 return ret;
226 } else {
227 return BufferStore(new_buf, op->value, op->indices);
228 }
229 }
230
231 Stmt VisitStmt_(const AttrStmtNode* op) final {
232 Stmt ret = StmtExprMutator::VisitStmt_(op);
233 op = ret.as<AttrStmtNode>();
234
235 if (auto* buffer = op->node.as<BufferNode>()) {
236 auto it = buffer_remap_.find(GetRef<Buffer>(buffer));
237 if (it != buffer_remap_.end()) {
238 return AttrStmt(it->second, op->attr_key, op->value, op->body);
239 }
240 } else if (auto* var = op->node.as<VarNode>()) {
241 auto it = var_remap_.find(GetRef<Var>(var));
242 if (it != var_remap_.end()) {
243 return AttrStmt(it->second, op->attr_key, op->value, op->body);
244 }
245 }
246 return ret;
247 }
248
249 Stmt VisitStmt_(const BufferRealizeNode* op) final {
250 Stmt ret = StmtExprMutator::VisitStmt_(op);
251 op = ret.as<BufferRealizeNode>();
252
253 Buffer new_buf = GetRemappedBuffer(op->buffer);
254 if (new_buf.same_as(op->buffer)) {
255 return ret;
256 } else {
257 return BufferRealize(new_buf, op->bounds, op->condition, op->body);
258 }
259 }
260
261 Stmt VisitStmt_(const StoreNode* op) final {
262 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
263 }
264
265 PrimExpr VisitExpr_(const BufferLoadNode* op) final {
266 PrimExpr ret = StmtExprMutator::VisitExpr_(op);
267 op = ret.as<BufferLoadNode>();
268
269 Buffer new_buf = GetRemappedBuffer(op->buffer);
270 if (new_buf.same_as(op->buffer)) {
271 return ret;
272 } else {
273 return BufferLoad(new_buf, op->indices);
274 }
275 }
276
277 PrimExpr VisitExpr_(const LoadNode* op) final {
278 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
279 }
280
281 PrimExpr VisitExpr_(const FloatImmNode* op) final {
282 if (op->dtype.is_bfloat16()) {
283 return IntImm(DataType::UInt(16, op->dtype.lanes()),
284 RoundToNearestEven(static_cast<float>(op->value)));
285 }
286 return StmtExprMutator::VisitExpr_(op);
287 }
288
289 void AlterBuffers(PrimFuncNode* op) {
290 Map<Var, Buffer> new_buffer_map;
291
292 for (auto& itr : op->buffer_map) {
293 auto param_var = itr.first;
294 auto oldbuf = itr.second;
295 if (oldbuf->dtype.is_bfloat16()) {
296 DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes());
297 Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype)));
298 auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset,
299 oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor,
300 oldbuf->buffer_type);
301 buffer_remap_[oldbuf] = newbuf;
302 var_remap_[oldbuf->data] = buffer_var;
303 new_buffer_map.Set(param_var, newbuf);
304 } else {
305 new_buffer_map.Set(param_var, oldbuf);
306 }
307 }
308
309 if (buffer_remap_.size() != 0) {
310 op->buffer_map = new_buffer_map;
311 }
312 }
313
314 private:
315 Buffer GetRemappedBuffer(Buffer buf) {
316 auto buf_it = buffer_remap_.find(buf);
317 if (buf_it != buffer_remap_.end()) {
318 return buf_it->second;
319 }
320
321 Buffer new_buf = buf;
322
323 auto var_it = var_remap_.find(buf->data);
324 if (var_it != var_remap_.end()) {
325 DataType dtype =
326 buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : buf->dtype;
327 new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name,
328 buf->data_alignment, buf->offset_factor, buf->buffer_type,
329 buf->axis_separators, buf->span);
330 }
331
332 buffer_remap_[buf] = new_buf;
333
334 return new_buf;
335 }
336
337 std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_;
338 std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_;
339};
340
341namespace transform {
342
343Pass BF16Promote() {
344 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
345 auto* n = f.CopyOnWrite();
346 n->body = BF16PromoteRewriter()(std::move(n->body));
347 return f;
348 };
349 return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote", {});
350}
351
352TVM_REGISTER_GLOBAL("tir.transform.BF16Promote").set_body_typed(BF16Promote);
353
354Pass BF16CastElimination() {
355 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
356 auto* n = f.CopyOnWrite();
357 n->body = BF16CastEliminationRewriter()(std::move(n->body));
358 return f;
359 };
360 return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination", {});
361}
362
363TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination").set_body_typed(BF16CastElimination);
364
365Pass BF16TypeLowering() {
366 auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
367 auto* n = f.CopyOnWrite();
368 BF16LowerRewriter lowerer;
369 lowerer.AlterBuffers(n);
370 n->body = lowerer(std::move(n->body));
371 return f;
372 };
373 return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering", {});
374}
375
376TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering").set_body_typed(BF16TypeLowering);
377
378Pass BF16Legalize() {
379 return Sequential({BF16Promote(), BF16CastElimination(), BF16TypeLowering()}, "tir.BF16Legalize");
380}
381
382TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize").set_body_typed(BF16Legalize);
383
384} // namespace transform
385} // namespace tir
386} // namespace tvm
387