1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file intrin_rule_cuda.cc |
22 | * \brief CUDA intrinsic rules. |
23 | */ |
24 | #include <tvm/tir/builtin.h> |
25 | #include <tvm/tir/op_attr_types.h> |
26 | |
27 | #include "../intrin_rule.h" |
28 | |
29 | namespace tvm { |
30 | namespace codegen { |
31 | namespace intrin { |
32 | // Add float suffix to the intrinsics, CUDA fast math. |
33 | using tir::FLowerIntrinsic; |
34 | |
35 | struct CUDAMath { |
36 | std::string operator()(DataType t, std::string name) const { |
37 | if (t.is_float()) { |
38 | switch (t.bits()) { |
39 | case 64: |
40 | return name; |
41 | case 32: |
42 | return name + 'f'; |
43 | case 16: { |
44 | if (name == "fabs" ) { |
45 | return "__habs" ; |
46 | } else if (name == "round" ) { |
47 | return "hrint" ; |
48 | } else { |
49 | return "h" + name; |
50 | } |
51 | } |
52 | default: |
53 | return "" ; |
54 | } |
55 | } else if (t.is_bfloat16()) { |
56 | return 'h' + name; |
57 | } |
58 | return "" ; |
59 | } |
60 | }; |
61 | |
62 | struct CUDAFastMath : public CUDAMath { |
63 | std::string operator()(DataType t, std::string name) const { |
64 | if (t.is_float() && t.bits() == 32) { |
65 | return "__" + name + 'f'; |
66 | } else { |
67 | return CUDAMath::operator()(t, name); |
68 | } |
69 | return "" ; |
70 | } |
71 | }; |
72 | |
73 | struct CUDAFastMathTan : public CUDAMath { |
74 | std::string operator()(DataType t, std::string name) const { |
75 | if (t.is_float()) { |
76 | switch (t.bits()) { |
77 | case 64: |
78 | return name; |
79 | // `__tanf` seems to produce some values too deviant from numpy tan version. |
80 | // So, let's use just `tanf` instead. |
81 | case 32: |
82 | return name + 'f'; |
83 | case 16: |
84 | return 'h' + name; |
85 | default: |
86 | return "" ; |
87 | } |
88 | } |
89 | return "" ; |
90 | } |
91 | }; |
92 | |
93 | struct CUDAPopcount { |
94 | std::string operator()(DataType t, std::string name) const { |
95 | if (t.is_uint()) { |
96 | switch (t.bits()) { |
97 | case 32: |
98 | return "__popc" ; |
99 | case 64: |
100 | return "__popcll" ; |
101 | default: |
102 | return "" ; |
103 | } |
104 | } |
105 | return "" ; |
106 | } |
107 | }; |
108 | |
109 | struct CUDAWarpIntrinsic { |
110 | const Op operator()(DataType t, const Op& orig_op) const { |
111 | if (orig_op.same_as(builtin::tvm_warp_shuffle())) { |
112 | return Op::Get("tir.cuda.__shfl_sync" ); |
113 | } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) { |
114 | return Op::Get("tir.cuda.__shfl_up_sync" ); |
115 | } else { |
116 | ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down())); |
117 | return Op::Get("tir.cuda.__shfl_down_sync" ); |
118 | } |
119 | } |
120 | }; |
121 | |
122 | static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) { |
123 | const CallNode* call = e.as<CallNode>(); |
124 | return Call(call->dtype, Op::Get("tir.cuda.__activemask" ), call->args); |
125 | } |
126 | |
127 | template <typename T> |
128 | static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) { |
129 | const CallNode* call = e.as<CallNode>(); |
130 | ICHECK(call != nullptr); |
131 | ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size |
132 | Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}}; |
133 | return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args); |
134 | } |
135 | |
136 | TVM_REGISTER_OP("tir.floor" ) |
137 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
138 | |
139 | TVM_REGISTER_OP("tir.ceil" ) |
140 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
141 | |
142 | TVM_REGISTER_OP("tir.trunc" ) |
143 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
144 | |
145 | TVM_REGISTER_OP("tir.fabs" ) |
146 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
147 | |
148 | TVM_REGISTER_OP("tir.round" ) |
149 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
150 | |
151 | TVM_REGISTER_OP("tir.nearbyint" ) |
152 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
153 | |
154 | TVM_REGISTER_OP("tir.exp" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
155 | DispatchPureExtern<CUDAFastMath>); |
156 | |
157 | TVM_REGISTER_OP("tir.exp2" ) |
158 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
159 | |
160 | TVM_REGISTER_OP("tir.exp10" ) |
161 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAFastMath>); |
162 | |
163 | TVM_REGISTER_OP("tir.erf" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
164 | DispatchPureExtern<CUDAMath>); |
165 | |
166 | TVM_REGISTER_OP("tir.log" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
167 | DispatchPureExtern<CUDAFastMath>); |
168 | |
169 | TVM_REGISTER_OP("tir.log2" ) |
170 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAFastMath>); |
171 | |
172 | TVM_REGISTER_OP("tir.log10" ) |
173 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAFastMath>); |
174 | |
175 | TVM_REGISTER_OP("tir.tan" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
176 | DispatchPureExtern<CUDAFastMathTan>); |
177 | |
178 | TVM_REGISTER_OP("tir.cos" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
179 | DispatchPureExtern<CUDAFastMath>); |
180 | |
181 | TVM_REGISTER_OP("tir.cosh" ) |
182 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
183 | |
184 | TVM_REGISTER_OP("tir.sin" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
185 | DispatchPureExtern<CUDAFastMath>); |
186 | |
187 | TVM_REGISTER_OP("tir.sinh" ) |
188 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
189 | |
190 | TVM_REGISTER_OP("tir.atan" ) |
191 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
192 | |
193 | TVM_REGISTER_OP("tir.tanh" ) |
194 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
195 | |
196 | TVM_REGISTER_OP("tir.sqrt" ) |
197 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
198 | |
199 | TVM_REGISTER_OP("tir.pow" ).set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , |
200 | DispatchPureExtern<CUDAMath>); |
201 | |
202 | TVM_REGISTER_OP("tir.popcount" ) |
203 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAPopcount>); |
204 | |
205 | TVM_REGISTER_OP("tir.tvm_warp_shuffle" ) |
206 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchCUDAShuffle<CUDAWarpIntrinsic>); |
207 | |
208 | TVM_REGISTER_OP("tir.tvm_warp_shuffle_up" ) |
209 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchCUDAShuffle<CUDAWarpIntrinsic>); |
210 | |
211 | TVM_REGISTER_OP("tir.tvm_warp_shuffle_down" ) |
212 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchCUDAShuffle<CUDAWarpIntrinsic>); |
213 | |
214 | TVM_REGISTER_OP("tir.tvm_warp_activemask" ) |
215 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchCUDAWarpActiveMask); |
216 | |
217 | TVM_REGISTER_OP("tir.fmod" ) |
218 | .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic" , DispatchPureExtern<CUDAMath>); |
219 | |
220 | // Register low-level builtin ops. |
221 | // TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins. |
222 | TVM_REGISTER_OP("tir.cuda.__shfl_sync" ) |
223 | .set_num_inputs(4) |
224 | .add_argument("mask" , "Expr" , "The thread mask." ) |
225 | .add_argument("var" , "Expr" , "The variable to sync." ) |
226 | .add_argument("lane" , "Expr" , "The source thread id." ) |
227 | .add_argument("width" , "Expr" , "The warp thread width, must be a power of 2." ) |
228 | .set_attr<TGlobalSymbol>("TGlobalSymbol" , "__shfl_sync" ) |
229 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kOpaque)) |
230 | .set_attr<bool>("cuda.need_warp_shuffle" , true); |
231 | |
232 | TVM_REGISTER_OP("tir.cuda.__shfl_up_sync" ) |
233 | .set_num_inputs(4) |
234 | .add_argument("mask" , "Expr" , "The thread mask." ) |
235 | .add_argument("var" , "Expr" , "The variable to sync." ) |
236 | .add_argument("delta" , "Expr" , "The source lane id offset to be added." ) |
237 | .add_argument("width" , "Expr" , "The warp thread width, must be a power of 2." ) |
238 | .set_attr<TGlobalSymbol>("TGlobalSymbol" , "__shfl_up_sync" ) |
239 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kOpaque)) |
240 | .set_attr<bool>("cuda.need_warp_shuffle" , true); |
241 | |
242 | TVM_REGISTER_OP("tir.cuda.__shfl_down_sync" ) |
243 | .set_num_inputs(4) |
244 | .add_argument("mask" , "Expr" , "The thread mask." ) |
245 | .add_argument("var" , "Expr" , "The variable to sync." ) |
246 | .add_argument("delta" , "Expr" , "The source lane id offset to be subtracted." ) |
247 | .add_argument("width" , "Expr" , "The warp thread width, must be a power of 2." ) |
248 | .set_attr<TGlobalSymbol>("TGlobalSymbol" , "__shfl_down_sync" ) |
249 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kOpaque)) |
250 | .set_attr<bool>("cuda.need_warp_shuffle" , true); |
251 | |
252 | TVM_REGISTER_OP("tir.cuda.__activemask" ) |
253 | .set_num_inputs(0) |
254 | .set_attr<TGlobalSymbol>("TGlobalSymbol" , "__activemask" ) |
255 | .set_attr<TCallEffectKind>("TCallEffectKind" , Integer(CallEffectKind::kPure)) |
256 | .set_attr<bool>("cuda.need_warp_shuffle" , true); |
257 | |
258 | } // namespace intrin |
259 | } // namespace codegen |
260 | } // namespace tvm |
261 | |