1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file bf16_legalize.cc |
22 | * \brief legalize bf16 type by adding cast_to_fp32 |
23 | */ |
24 | |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/op.h> |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <cmath> |
31 | #include <tuple> |
32 | |
33 | #include "../../arith/ir_mutator_with_analyzer.h" |
34 | #include "../../arith/ir_visitor_with_analyzer.h" |
35 | |
36 | namespace tvm { |
37 | namespace tir { |
38 | |
39 | using arith::Analyzer; |
40 | using arith::IRMutatorWithAnalyzer; |
41 | |
42 | class BF16PromoteRewriter : public StmtExprMutator { |
43 | public: |
44 | BF16PromoteRewriter() {} |
45 | |
46 | Stmt operator()(Stmt s) { return VisitStmt(s); } |
47 | |
48 | PrimExpr VisitExpr_(const AddNode* op) final; |
49 | PrimExpr VisitExpr_(const SubNode* op) final; |
50 | PrimExpr VisitExpr_(const MulNode* op) final; |
51 | PrimExpr VisitExpr_(const DivNode* op) final; |
52 | PrimExpr VisitExpr_(const MinNode* op) final; |
53 | PrimExpr VisitExpr_(const MaxNode* op) final; |
54 | PrimExpr VisitExpr_(const LTNode* op) final; |
55 | PrimExpr VisitExpr_(const LENode* op) final; |
56 | PrimExpr VisitExpr_(const GTNode* op) final; |
57 | PrimExpr VisitExpr_(const GENode* op) final; |
58 | PrimExpr VisitExpr_(const CallNode* op) final; |
59 | }; |
60 | |
61 | #define DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(OP, FUNC, NEEDCAST) \ |
62 | PrimExpr BF16PromoteRewriter::VisitExpr_(const OP* op) { \ |
63 | PrimExpr origin_a = this->VisitExpr(op->a); \ |
64 | PrimExpr origin_b = this->VisitExpr(op->b); \ |
65 | bool a_is_bfloat16 = origin_a->dtype.is_bfloat16(); \ |
66 | bool b_is_bfloat16 = origin_b->dtype.is_bfloat16(); \ |
67 | bool both_bfloat16 = a_is_bfloat16 && b_is_bfloat16; \ |
68 | bool none_bfloat16 = !(a_is_bfloat16 || b_is_bfloat16); \ |
69 | if (none_bfloat16) { \ |
70 | return GetRef<PrimExpr>(op); \ |
71 | } \ |
72 | DataType float32_dtype(kDLFloat, 32, 1); \ |
73 | PrimExpr float32_a = a_is_bfloat16 ? Cast(float32_dtype, origin_a) : origin_a; \ |
74 | PrimExpr float32_b = b_is_bfloat16 ? Cast(float32_dtype, origin_b) : origin_b; \ |
75 | PrimExpr result = FUNC(float32_a, float32_b); \ |
76 | DataType bfloat16_dtype(kDLBfloat, 16, 1); \ |
77 | bool do_cast = both_bfloat16 && NEEDCAST; \ |
78 | return do_cast ? Cast(bfloat16_dtype, result) : result; \ |
79 | } |
80 | |
81 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(AddNode, operator+, true) |
82 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(SubNode, operator-, true) |
83 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MulNode, operator*, true) |
84 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(DivNode, div, true) |
85 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MinNode, min, true) |
86 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(MaxNode, max, true) |
87 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LTNode, operator<, false) |
88 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(LENode, operator<=, false) |
89 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GTNode, operator>, false) |
90 | DEFINE_BIOP_EXPR_MUTATE_WITH_TYPE_MATCH(GENode, operator>=, false) |
91 | |
92 | PrimExpr BF16PromoteRewriter::VisitExpr_(const CallNode* op) { |
93 | Array<PrimExpr> args; |
94 | for (auto& arg : op->args) { |
95 | PrimExpr x = this->VisitExpr(arg); |
96 | if (x.dtype().is_bfloat16()) { |
97 | DataType fp32_dtype(kDLFloat, 32, x.dtype().lanes()); |
98 | args.push_back(Cast(fp32_dtype, {x}, op->span)); |
99 | } else { |
100 | args.push_back(x); |
101 | } |
102 | } |
103 | if (op->dtype.is_bfloat16()) { |
104 | DataType fp32_dtype(kDLFloat, 32, op->dtype.lanes()); |
105 | PrimExpr result_fp32 = Call(fp32_dtype, op->op, args, op->span); |
106 | return Cast(op->dtype, {result_fp32}, op->span); |
107 | } else { |
108 | return Call(op->dtype, op->op, args, op->span); |
109 | } |
110 | } |
111 | |
112 | /* |
113 | * Eliminate verbose casting between fp32 and bf16 |
114 | * Checks if the AST has the pattern: |
115 | * castto32(castto16(some_fp32_op(...))) |
116 | * The verbose casting is generated by BF16Promote for multiple |
117 | * bf16 Ops in a row. e.g.: |
118 | * X[i] + Y[i] + T[i] => |
119 | * bf16((float32(bf16((float32(X[i]) + float32(Y[i])))) + float32(T[i]))) |
120 | * After this pass: |
121 | * bf16(float32(X[i]) + float32(Y[i]) + float32(T[i])) |
122 | */ |
123 | class BF16CastEliminationRewriter : public StmtExprMutator { |
124 | public: |
125 | BF16CastEliminationRewriter() {} |
126 | |
127 | Stmt operator()(Stmt s) { return VisitStmt(s); } |
128 | |
129 | PrimExpr VisitExpr_(const CastNode* op) final { |
130 | auto op_val = StmtExprMutator::VisitExpr(op->value); |
131 | if (op->dtype.is_float() && op->dtype.bits() == 32) { |
132 | // if is cast_to_fp32, check if op->value is cast_to_fp16 |
133 | // and op->value->value is a float32 |
134 | if (auto innercast = op_val.as<CastNode>()) { |
135 | if (innercast->dtype.is_bfloat16() && innercast->value->dtype.is_float() && |
136 | innercast->value->dtype.bits() == 32) { |
137 | return innercast->value; |
138 | } |
139 | } |
140 | } |
141 | if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op); |
142 | return Cast(op->dtype, op_val); |
143 | } |
144 | }; |
145 | |
146 | union FloatCaster { |
147 | uint32_t u32; |
148 | float f32; |
149 | }; |
150 | |
151 | uint16_t RoundToNearestEven(float src) { |
152 | if (std::isnan(src)) { |
153 | return UINT16_C(0x7FC0); |
154 | } else { |
155 | FloatCaster caster; |
156 | caster.f32 = src; |
157 | uint32_t rounding_bias = ((caster.u32 >> 16) & 1) + UINT32_C(0x7FFF); |
158 | return static_cast<uint16_t>((caster.u32 + rounding_bias) >> 16); |
159 | } |
160 | } |
161 | |
162 | /* |
163 | * Lower the bf16 type to int16 |
164 | * Lower cast between bf16 and fp32 |
165 | * Lower bf16 FloatImm to int16 |
166 | */ |
167 | class BF16LowerRewriter : public StmtExprMutator { |
168 | public: |
169 | BF16LowerRewriter() {} |
170 | |
171 | using StmtExprMutator::operator(); |
172 | |
173 | PrimExpr VisitExpr_(const CastNode* op) final { |
174 | PrimExpr op_val = StmtExprMutator::VisitExpr(op->value); |
175 | DataType uint32_dtype(kDLUInt, 32, op_val->dtype.lanes()); |
176 | DataType float32_dtype(kDLFloat, 32, op_val->dtype.lanes()); |
177 | if (op->value->dtype.is_bfloat16()) { // cast from bf16 |
178 | PrimExpr uint32_v = Cast(uint32_dtype, op_val); |
179 | PrimExpr float32_v = Call(float32_dtype, builtin::reinterpret(), {uint32_v << 16}); |
180 | bool is_to_float32 = op->dtype.is_float() && op->dtype.bits() == 32; |
181 | return is_to_float32 ? float32_v : Cast(op->dtype, float32_v); |
182 | } else if (op->dtype.is_bfloat16()) { // cast to bf16 |
183 | bool is_from_float32 = op->value->dtype.is_float() && op->value->dtype.bits() == 32; |
184 | PrimExpr float32_v = is_from_float32 ? op_val : Cast(float32_dtype, op_val); |
185 | PrimExpr uint32_v = Call(uint32_dtype, builtin::reinterpret(), {float32_v}); |
186 | DataType uint16_dtype(kDLUInt, 16, op_val->dtype.lanes()); |
187 | /* the following TIR is equivalent to the C++ code below: |
188 | uint32_t rounding_bias = ((U32 >> 16) & 1) + UINT32_C(0x7FFF); |
189 | return static_cast<uint16_t>((U32 + rounding_bias) >> 16);*/ |
190 | PrimExpr rounding_bias = ((uint32_v >> 16) & 1) + make_const(uint16_dtype, 0x7FFF); |
191 | return Cast(uint16_dtype, {(uint32_v + rounding_bias) >> 16}); |
192 | } |
193 | if (op->value.same_as(op_val)) return GetRef<PrimExpr>(op); |
194 | return Cast(op->dtype, op_val); |
195 | } |
196 | |
197 | PrimExpr VisitExpr_(const VarNode* op) final { |
198 | Var var = GetRef<Var>(op); |
199 | |
200 | auto itr = var_remap_.find(var); |
201 | if (itr != var_remap_.end()) { |
202 | return itr->second; |
203 | } else { |
204 | return std::move(var); |
205 | } |
206 | } |
207 | |
208 | Stmt VisitStmt_(const AllocateNode* op) final { |
209 | if (op->dtype.is_bfloat16()) { |
210 | DataType dtype = DataType::UInt(16, op->dtype.lanes()); |
211 | Var buffer_var = Var(op->buffer_var->name_hint, PointerType(PrimType(dtype))); |
212 | var_remap_[op->buffer_var] = buffer_var; |
213 | return VisitStmt(Allocate(buffer_var, dtype, op->extents, op->condition, op->body)); |
214 | } else { |
215 | return StmtExprMutator::VisitStmt_(op); |
216 | } |
217 | } |
218 | |
219 | Stmt VisitStmt_(const BufferStoreNode* op) final { |
220 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
221 | op = ret.as<BufferStoreNode>(); |
222 | |
223 | Buffer new_buf = GetRemappedBuffer(op->buffer); |
224 | if (new_buf.same_as(op->buffer)) { |
225 | return ret; |
226 | } else { |
227 | return BufferStore(new_buf, op->value, op->indices); |
228 | } |
229 | } |
230 | |
231 | Stmt VisitStmt_(const AttrStmtNode* op) final { |
232 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
233 | op = ret.as<AttrStmtNode>(); |
234 | |
235 | if (auto* buffer = op->node.as<BufferNode>()) { |
236 | auto it = buffer_remap_.find(GetRef<Buffer>(buffer)); |
237 | if (it != buffer_remap_.end()) { |
238 | return AttrStmt(it->second, op->attr_key, op->value, op->body); |
239 | } |
240 | } else if (auto* var = op->node.as<VarNode>()) { |
241 | auto it = var_remap_.find(GetRef<Var>(var)); |
242 | if (it != var_remap_.end()) { |
243 | return AttrStmt(it->second, op->attr_key, op->value, op->body); |
244 | } |
245 | } |
246 | return ret; |
247 | } |
248 | |
249 | Stmt VisitStmt_(const BufferRealizeNode* op) final { |
250 | Stmt ret = StmtExprMutator::VisitStmt_(op); |
251 | op = ret.as<BufferRealizeNode>(); |
252 | |
253 | Buffer new_buf = GetRemappedBuffer(op->buffer); |
254 | if (new_buf.same_as(op->buffer)) { |
255 | return ret; |
256 | } else { |
257 | return BufferRealize(new_buf, op->bounds, op->condition, op->body); |
258 | } |
259 | } |
260 | |
261 | Stmt VisitStmt_(const StoreNode* op) final { |
262 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
263 | } |
264 | |
265 | PrimExpr VisitExpr_(const BufferLoadNode* op) final { |
266 | PrimExpr ret = StmtExprMutator::VisitExpr_(op); |
267 | op = ret.as<BufferLoadNode>(); |
268 | |
269 | Buffer new_buf = GetRemappedBuffer(op->buffer); |
270 | if (new_buf.same_as(op->buffer)) { |
271 | return ret; |
272 | } else { |
273 | return BufferLoad(new_buf, op->indices); |
274 | } |
275 | } |
276 | |
277 | PrimExpr VisitExpr_(const LoadNode* op) final { |
278 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
279 | } |
280 | |
281 | PrimExpr VisitExpr_(const FloatImmNode* op) final { |
282 | if (op->dtype.is_bfloat16()) { |
283 | return IntImm(DataType::UInt(16, op->dtype.lanes()), |
284 | RoundToNearestEven(static_cast<float>(op->value))); |
285 | } |
286 | return StmtExprMutator::VisitExpr_(op); |
287 | } |
288 | |
289 | void AlterBuffers(PrimFuncNode* op) { |
290 | Map<Var, Buffer> new_buffer_map; |
291 | |
292 | for (auto& itr : op->buffer_map) { |
293 | auto param_var = itr.first; |
294 | auto oldbuf = itr.second; |
295 | if (oldbuf->dtype.is_bfloat16()) { |
296 | DataType dtype = DataType::UInt(16, oldbuf->dtype.lanes()); |
297 | Var buffer_var = Var(oldbuf->data->name_hint, PointerType(PrimType(dtype))); |
298 | auto newbuf = Buffer(buffer_var, dtype, oldbuf->shape, oldbuf->strides, oldbuf->elem_offset, |
299 | oldbuf->name, oldbuf->data_alignment, oldbuf->offset_factor, |
300 | oldbuf->buffer_type); |
301 | buffer_remap_[oldbuf] = newbuf; |
302 | var_remap_[oldbuf->data] = buffer_var; |
303 | new_buffer_map.Set(param_var, newbuf); |
304 | } else { |
305 | new_buffer_map.Set(param_var, oldbuf); |
306 | } |
307 | } |
308 | |
309 | if (buffer_remap_.size() != 0) { |
310 | op->buffer_map = new_buffer_map; |
311 | } |
312 | } |
313 | |
314 | private: |
315 | Buffer GetRemappedBuffer(Buffer buf) { |
316 | auto buf_it = buffer_remap_.find(buf); |
317 | if (buf_it != buffer_remap_.end()) { |
318 | return buf_it->second; |
319 | } |
320 | |
321 | Buffer new_buf = buf; |
322 | |
323 | auto var_it = var_remap_.find(buf->data); |
324 | if (var_it != var_remap_.end()) { |
325 | DataType dtype = |
326 | buf->dtype.is_bfloat16() ? DataType::UInt(16, buf->dtype.lanes()) : buf->dtype; |
327 | new_buf = Buffer(var_it->second, dtype, buf->shape, buf->strides, buf->elem_offset, buf->name, |
328 | buf->data_alignment, buf->offset_factor, buf->buffer_type, |
329 | buf->axis_separators, buf->span); |
330 | } |
331 | |
332 | buffer_remap_[buf] = new_buf; |
333 | |
334 | return new_buf; |
335 | } |
336 | |
337 | std::unordered_map<Buffer, Buffer, ObjectPtrHash, ObjectPtrEqual> buffer_remap_; |
338 | std::unordered_map<Var, Var, ObjectPtrHash, ObjectPtrEqual> var_remap_; |
339 | }; |
340 | |
341 | namespace transform { |
342 | |
343 | Pass BF16Promote() { |
344 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
345 | auto* n = f.CopyOnWrite(); |
346 | n->body = BF16PromoteRewriter()(std::move(n->body)); |
347 | return f; |
348 | }; |
349 | return CreatePrimFuncPass(pass_func, 0, "tir.BF16Promote" , {}); |
350 | } |
351 | |
352 | TVM_REGISTER_GLOBAL("tir.transform.BF16Promote" ).set_body_typed(BF16Promote); |
353 | |
354 | Pass BF16CastElimination() { |
355 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
356 | auto* n = f.CopyOnWrite(); |
357 | n->body = BF16CastEliminationRewriter()(std::move(n->body)); |
358 | return f; |
359 | }; |
360 | return CreatePrimFuncPass(pass_func, 0, "tir.BF16CastElimination" , {}); |
361 | } |
362 | |
363 | TVM_REGISTER_GLOBAL("tir.transform.BF16CastElimination" ).set_body_typed(BF16CastElimination); |
364 | |
365 | Pass BF16TypeLowering() { |
366 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
367 | auto* n = f.CopyOnWrite(); |
368 | BF16LowerRewriter lowerer; |
369 | lowerer.AlterBuffers(n); |
370 | n->body = lowerer(std::move(n->body)); |
371 | return f; |
372 | }; |
373 | return CreatePrimFuncPass(pass_func, 0, "tir.BF16TypeLowering" , {}); |
374 | } |
375 | |
376 | TVM_REGISTER_GLOBAL("tir.transform.BF16TypeLowering" ).set_body_typed(BF16TypeLowering); |
377 | |
378 | Pass BF16Legalize() { |
379 | return Sequential({BF16Promote(), BF16CastElimination(), BF16TypeLowering()}, "tir.BF16Legalize" ); |
380 | } |
381 | |
382 | TVM_REGISTER_GLOBAL("tir.transform.BF16Legalize" ).set_body_typed(BF16Legalize); |
383 | |
384 | } // namespace transform |
385 | } // namespace tir |
386 | } // namespace tvm |
387 | |