1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file intrin_rule_cuda.cc
22 * \brief CUDA intrinsic rules.
23 */
24#include <tvm/tir/builtin.h>
25#include <tvm/tir/op_attr_types.h>
26
27#include "../intrin_rule.h"
28
29namespace tvm {
30namespace codegen {
31namespace intrin {
32// Add float suffix to the intrinsics, CUDA fast math.
33using tir::FLowerIntrinsic;
34
35struct CUDAMath {
36 std::string operator()(DataType t, std::string name) const {
37 if (t.is_float()) {
38 switch (t.bits()) {
39 case 64:
40 return name;
41 case 32:
42 return name + 'f';
43 case 16: {
44 if (name == "fabs") {
45 return "__habs";
46 } else if (name == "round") {
47 return "hrint";
48 } else {
49 return "h" + name;
50 }
51 }
52 default:
53 return "";
54 }
55 } else if (t.is_bfloat16()) {
56 return 'h' + name;
57 }
58 return "";
59 }
60};
61
62struct CUDAFastMath : public CUDAMath {
63 std::string operator()(DataType t, std::string name) const {
64 if (t.is_float() && t.bits() == 32) {
65 return "__" + name + 'f';
66 } else {
67 return CUDAMath::operator()(t, name);
68 }
69 return "";
70 }
71};
72
73struct CUDAFastMathTan : public CUDAMath {
74 std::string operator()(DataType t, std::string name) const {
75 if (t.is_float()) {
76 switch (t.bits()) {
77 case 64:
78 return name;
79 // `__tanf` seems to produce some values too deviant from numpy tan version.
80 // So, let's use just `tanf` instead.
81 case 32:
82 return name + 'f';
83 case 16:
84 return 'h' + name;
85 default:
86 return "";
87 }
88 }
89 return "";
90 }
91};
92
93struct CUDAPopcount {
94 std::string operator()(DataType t, std::string name) const {
95 if (t.is_uint()) {
96 switch (t.bits()) {
97 case 32:
98 return "__popc";
99 case 64:
100 return "__popcll";
101 default:
102 return "";
103 }
104 }
105 return "";
106 }
107};
108
109struct CUDAWarpIntrinsic {
110 const Op operator()(DataType t, const Op& orig_op) const {
111 if (orig_op.same_as(builtin::tvm_warp_shuffle())) {
112 return Op::Get("tir.cuda.__shfl_sync");
113 } else if (orig_op.same_as(builtin::tvm_warp_shuffle_up())) {
114 return Op::Get("tir.cuda.__shfl_up_sync");
115 } else {
116 ICHECK(orig_op.same_as(builtin::tvm_warp_shuffle_down()));
117 return Op::Get("tir.cuda.__shfl_down_sync");
118 }
119 }
120};
121
122static PrimExpr DispatchCUDAWarpActiveMask(const PrimExpr& e) {
123 const CallNode* call = e.as<CallNode>();
124 return Call(call->dtype, Op::Get("tir.cuda.__activemask"), call->args);
125}
126
127template <typename T>
128static PrimExpr DispatchCUDAShuffle(const PrimExpr& e) {
129 const CallNode* call = e.as<CallNode>();
130 ICHECK(call != nullptr);
131 ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
132 Array<PrimExpr> cuda_args{{call->args[0], call->args[1], call->args[2], call->args[3]}};
133 return Call(call->dtype, T()(call->dtype, Downcast<Op>(call->op)), cuda_args);
134}
135
136TVM_REGISTER_OP("tir.floor")
137 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
138
139TVM_REGISTER_OP("tir.ceil")
140 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
141
142TVM_REGISTER_OP("tir.trunc")
143 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
144
145TVM_REGISTER_OP("tir.fabs")
146 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
147
148TVM_REGISTER_OP("tir.round")
149 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
150
151TVM_REGISTER_OP("tir.nearbyint")
152 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
153
154TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
155 DispatchPureExtern<CUDAFastMath>);
156
157TVM_REGISTER_OP("tir.exp2")
158 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
159
160TVM_REGISTER_OP("tir.exp10")
161 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
162
163TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
164 DispatchPureExtern<CUDAMath>);
165
166TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
167 DispatchPureExtern<CUDAFastMath>);
168
169TVM_REGISTER_OP("tir.log2")
170 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
171
172TVM_REGISTER_OP("tir.log10")
173 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAFastMath>);
174
175TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
176 DispatchPureExtern<CUDAFastMathTan>);
177
178TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
179 DispatchPureExtern<CUDAFastMath>);
180
181TVM_REGISTER_OP("tir.cosh")
182 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
183
184TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
185 DispatchPureExtern<CUDAFastMath>);
186
187TVM_REGISTER_OP("tir.sinh")
188 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
189
190TVM_REGISTER_OP("tir.atan")
191 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
192
193TVM_REGISTER_OP("tir.tanh")
194 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
195
196TVM_REGISTER_OP("tir.sqrt")
197 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
198
199TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic",
200 DispatchPureExtern<CUDAMath>);
201
202TVM_REGISTER_OP("tir.popcount")
203 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAPopcount>);
204
205TVM_REGISTER_OP("tir.tvm_warp_shuffle")
206 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchCUDAShuffle<CUDAWarpIntrinsic>);
207
208TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
209 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchCUDAShuffle<CUDAWarpIntrinsic>);
210
211TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
212 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchCUDAShuffle<CUDAWarpIntrinsic>);
213
214TVM_REGISTER_OP("tir.tvm_warp_activemask")
215 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchCUDAWarpActiveMask);
216
217TVM_REGISTER_OP("tir.fmod")
218 .set_attr<FLowerIntrinsic>("cuda.FLowerIntrinsic", DispatchPureExtern<CUDAMath>);
219
220// Register low-level builtin ops.
221// TODO(tvm-team): consider make CUDA its own subfolder and create a file for low-level builtins.
222TVM_REGISTER_OP("tir.cuda.__shfl_sync")
223 .set_num_inputs(4)
224 .add_argument("mask", "Expr", "The thread mask.")
225 .add_argument("var", "Expr", "The variable to sync.")
226 .add_argument("lane", "Expr", "The source thread id.")
227 .add_argument("width", "Expr", "The warp thread width, must be a power of 2.")
228 .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_sync")
229 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
230 .set_attr<bool>("cuda.need_warp_shuffle", true);
231
232TVM_REGISTER_OP("tir.cuda.__shfl_up_sync")
233 .set_num_inputs(4)
234 .add_argument("mask", "Expr", "The thread mask.")
235 .add_argument("var", "Expr", "The variable to sync.")
236 .add_argument("delta", "Expr", "The source lane id offset to be added.")
237 .add_argument("width", "Expr", "The warp thread width, must be a power of 2.")
238 .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_up_sync")
239 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
240 .set_attr<bool>("cuda.need_warp_shuffle", true);
241
242TVM_REGISTER_OP("tir.cuda.__shfl_down_sync")
243 .set_num_inputs(4)
244 .add_argument("mask", "Expr", "The thread mask.")
245 .add_argument("var", "Expr", "The variable to sync.")
246 .add_argument("delta", "Expr", "The source lane id offset to be subtracted.")
247 .add_argument("width", "Expr", "The warp thread width, must be a power of 2.")
248 .set_attr<TGlobalSymbol>("TGlobalSymbol", "__shfl_down_sync")
249 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kOpaque))
250 .set_attr<bool>("cuda.need_warp_shuffle", true);
251
252TVM_REGISTER_OP("tir.cuda.__activemask")
253 .set_num_inputs(0)
254 .set_attr<TGlobalSymbol>("TGlobalSymbol", "__activemask")
255 .set_attr<TCallEffectKind>("TCallEffectKind", Integer(CallEffectKind::kPure))
256 .set_attr<bool>("cuda.need_warp_shuffle", true);
257
258} // namespace intrin
259} // namespace codegen
260} // namespace tvm
261