1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file intrin_rule_rocm.cc |
22 | */ |
23 | #ifdef TVM_LLVM_VERSION |
24 | |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/op_attr_types.h> |
30 | |
31 | #include <sstream> |
32 | |
33 | namespace tvm { |
34 | namespace codegen { |
35 | |
36 | inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) { |
37 | using namespace tir; |
38 | const CallNode* call = e.as<CallNode>(); |
39 | ICHECK(call != nullptr); |
40 | |
41 | const OpNode* op = call->op.as<OpNode>(); |
42 | ICHECK(op != nullptr); |
43 | std::string name = op->name; |
44 | ICHECK_EQ(name.substr(0, 4), "tir." ); |
45 | |
46 | std::ostringstream intrinsic_name; |
47 | intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits(); |
48 | |
49 | Array<PrimExpr> new_args = {StringImm(intrinsic_name.str())}; |
50 | for (auto arg : call->args) { |
51 | new_args.push_back(arg); |
52 | } |
53 | |
54 | return Call(call->dtype, builtin::call_pure_extern(), new_args); |
55 | } |
56 | |
57 | inline PrimExpr DispatchShuffle(const PrimExpr& e) { |
58 | using namespace tir; |
59 | const CallNode* call = e.as<CallNode>(); |
60 | ICHECK(call != nullptr); |
61 | ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size |
62 | PrimExpr var = call->args[1]; |
63 | ICHECK_EQ(var.dtype().bits(), 32); |
64 | |
65 | // get own lane in self (__lane_id) |
66 | PrimExpr minus_one = tir::make_const(DataType::Int(32), -1); |
67 | PrimExpr zero = tir::make_zero(DataType::Int(32)); |
68 | PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(), |
69 | {StringImm("llvm.amdgcn.mbcnt.lo" ), minus_one, zero}); |
70 | PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(), |
71 | {StringImm("llvm.amdgcn.mbcnt.hi" ), minus_one, lo}); |
72 | |
73 | // compute lane to get from |
74 | PrimExpr width = call->args[3]; |
75 | PrimExpr index; |
76 | if (call->op.same_as(builtin::tvm_warp_shuffle())) { |
77 | PrimExpr src_lane = call->args[2]; |
78 | index = src_lane + (self & ~(width - 1)); |
79 | } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) { |
80 | PrimExpr delta = call->args[2]; |
81 | index = self - delta; |
82 | index = Select(index < (self & ~(width - 1)), self, index); |
83 | } else { |
84 | ICHECK(call->op.same_as(builtin::tvm_warp_shuffle_down())); |
85 | PrimExpr delta = call->args[2]; |
86 | index = self + delta; |
87 | index = Select((self & (width - 1)) + delta >= width, self, index); |
88 | } |
89 | PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(), |
90 | {StringImm("llvm.amdgcn.ds.bpermute" ), index << 2, var}); |
91 | return res; |
92 | } |
93 | |
94 | namespace llvm { |
95 | using tir::FLowerIntrinsic; |
96 | |
97 | // dummy because we don't have the activemask |
98 | TVM_REGISTER_OP("tir.tvm_warp_activemask" ) |
99 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , [](const PrimExpr& e) -> PrimExpr { |
100 | PrimExpr zero = tir::make_zero(DataType::Int(32)); |
101 | return zero; |
102 | }); |
103 | |
104 | TVM_REGISTER_OP("tir.tvm_warp_shuffle" ) |
105 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchShuffle); |
106 | |
107 | TVM_REGISTER_OP("tir.tvm_warp_shuffle_up" ) |
108 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchShuffle); |
109 | |
110 | TVM_REGISTER_OP("tir.tvm_warp_shuffle_down" ) |
111 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchShuffle); |
112 | |
113 | TVM_REGISTER_OP("tir.floor" ) |
114 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
115 | |
116 | TVM_REGISTER_OP("tir.ceil" ) |
117 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
118 | |
119 | TVM_REGISTER_OP("tir.round" ) |
120 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
121 | |
122 | TVM_REGISTER_OP("tir.nearbyint" ) |
123 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
124 | |
125 | TVM_REGISTER_OP("tir.trunc" ) |
126 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
127 | |
128 | TVM_REGISTER_OP("tir.fabs" ) |
129 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
130 | |
131 | TVM_REGISTER_OP("tir.exp" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
132 | DispatchPureExternOCML); |
133 | |
134 | TVM_REGISTER_OP("tir.exp2" ) |
135 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
136 | |
137 | TVM_REGISTER_OP("tir.exp10" ) |
138 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
139 | |
140 | TVM_REGISTER_OP("tir.erf" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
141 | DispatchPureExternOCML); |
142 | |
143 | TVM_REGISTER_OP("tir.fma" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
144 | DispatchPureExternOCML); |
145 | |
146 | TVM_REGISTER_OP("tir.log" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
147 | DispatchPureExternOCML); |
148 | |
149 | TVM_REGISTER_OP("tir.log2" ) |
150 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
151 | |
152 | TVM_REGISTER_OP("tir.log10" ) |
153 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
154 | |
155 | TVM_REGISTER_OP("tir.sqrt" ) |
156 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
157 | |
158 | TVM_REGISTER_OP("tir.pow" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
159 | DispatchPureExternOCML); |
160 | |
161 | TVM_REGISTER_OP("tir.tanh" ) |
162 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
163 | |
164 | TVM_REGISTER_OP("tir.tan" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
165 | DispatchPureExternOCML); |
166 | |
167 | TVM_REGISTER_OP("tir.cos" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
168 | DispatchPureExternOCML); |
169 | |
170 | TVM_REGISTER_OP("tir.cosh" ) |
171 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
172 | |
173 | TVM_REGISTER_OP("tir.sin" ).set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , |
174 | DispatchPureExternOCML); |
175 | |
176 | TVM_REGISTER_OP("tir.sinh" ) |
177 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
178 | |
179 | TVM_REGISTER_OP("tir.atan" ) |
180 | .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic" , DispatchPureExternOCML); |
181 | |
182 | } // namespace llvm |
183 | } // namespace codegen |
184 | } // namespace tvm |
185 | |
186 | #endif // LLVM_VERSION |
187 | |