1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file intrin_rule_rocm.cc
22 */
23#ifdef TVM_LLVM_VERSION
24
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/op_attr_types.h>
30
31#include <sstream>
32
33namespace tvm {
34namespace codegen {
35
36inline PrimExpr DispatchPureExternOCML(const PrimExpr& e) {
37 using namespace tir;
38 const CallNode* call = e.as<CallNode>();
39 ICHECK(call != nullptr);
40
41 const OpNode* op = call->op.as<OpNode>();
42 ICHECK(op != nullptr);
43 std::string name = op->name;
44 ICHECK_EQ(name.substr(0, 4), "tir.");
45
46 std::ostringstream intrinsic_name;
47 intrinsic_name << "__ocml_" << name.substr(4) << "_f" << call->dtype.bits();
48
49 Array<PrimExpr> new_args = {StringImm(intrinsic_name.str())};
50 for (auto arg : call->args) {
51 new_args.push_back(arg);
52 }
53
54 return Call(call->dtype, builtin::call_pure_extern(), new_args);
55}
56
57inline PrimExpr DispatchShuffle(const PrimExpr& e) {
58 using namespace tir;
59 const CallNode* call = e.as<CallNode>();
60 ICHECK(call != nullptr);
61 ICHECK_EQ(call->args.size(), 5); // mask, value, warp_id, width, warp_size
62 PrimExpr var = call->args[1];
63 ICHECK_EQ(var.dtype().bits(), 32);
64
65 // get own lane in self (__lane_id)
66 PrimExpr minus_one = tir::make_const(DataType::Int(32), -1);
67 PrimExpr zero = tir::make_zero(DataType::Int(32));
68 PrimExpr lo = Call(DataType::Int(32), builtin::call_pure_extern(),
69 {StringImm("llvm.amdgcn.mbcnt.lo"), minus_one, zero});
70 PrimExpr self = Call(DataType::Int(32), builtin::call_pure_extern(),
71 {StringImm("llvm.amdgcn.mbcnt.hi"), minus_one, lo});
72
73 // compute lane to get from
74 PrimExpr width = call->args[3];
75 PrimExpr index;
76 if (call->op.same_as(builtin::tvm_warp_shuffle())) {
77 PrimExpr src_lane = call->args[2];
78 index = src_lane + (self & ~(width - 1));
79 } else if (call->op.same_as(builtin::tvm_warp_shuffle_up())) {
80 PrimExpr delta = call->args[2];
81 index = self - delta;
82 index = Select(index < (self & ~(width - 1)), self, index);
83 } else {
84 ICHECK(call->op.same_as(builtin::tvm_warp_shuffle_down()));
85 PrimExpr delta = call->args[2];
86 index = self + delta;
87 index = Select((self & (width - 1)) + delta >= width, self, index);
88 }
89 PrimExpr res = Call(var.dtype(), builtin::call_pure_extern(),
90 {StringImm("llvm.amdgcn.ds.bpermute"), index << 2, var});
91 return res;
92}
93
94namespace llvm {
95using tir::FLowerIntrinsic;
96
97// dummy because we don't have the activemask
98TVM_REGISTER_OP("tir.tvm_warp_activemask")
99 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", [](const PrimExpr& e) -> PrimExpr {
100 PrimExpr zero = tir::make_zero(DataType::Int(32));
101 return zero;
102 });
103
104TVM_REGISTER_OP("tir.tvm_warp_shuffle")
105 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);
106
107TVM_REGISTER_OP("tir.tvm_warp_shuffle_up")
108 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);
109
110TVM_REGISTER_OP("tir.tvm_warp_shuffle_down")
111 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchShuffle);
112
113TVM_REGISTER_OP("tir.floor")
114 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
115
116TVM_REGISTER_OP("tir.ceil")
117 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
118
119TVM_REGISTER_OP("tir.round")
120 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
121
122TVM_REGISTER_OP("tir.nearbyint")
123 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
124
125TVM_REGISTER_OP("tir.trunc")
126 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
127
128TVM_REGISTER_OP("tir.fabs")
129 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
130
131TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
132 DispatchPureExternOCML);
133
134TVM_REGISTER_OP("tir.exp2")
135 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
136
137TVM_REGISTER_OP("tir.exp10")
138 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
139
140TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
141 DispatchPureExternOCML);
142
143TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
144 DispatchPureExternOCML);
145
146TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
147 DispatchPureExternOCML);
148
149TVM_REGISTER_OP("tir.log2")
150 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
151
152TVM_REGISTER_OP("tir.log10")
153 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
154
155TVM_REGISTER_OP("tir.sqrt")
156 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
157
158TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
159 DispatchPureExternOCML);
160
161TVM_REGISTER_OP("tir.tanh")
162 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
163
164TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
165 DispatchPureExternOCML);
166
167TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
168 DispatchPureExternOCML);
169
170TVM_REGISTER_OP("tir.cosh")
171 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
172
173TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic",
174 DispatchPureExternOCML);
175
176TVM_REGISTER_OP("tir.sinh")
177 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
178
179TVM_REGISTER_OP("tir.atan")
180 .set_attr<FLowerIntrinsic>("rocm.FLowerIntrinsic", DispatchPureExternOCML);
181
182} // namespace llvm
183} // namespace codegen
184} // namespace tvm
185
186#endif // LLVM_VERSION
187