1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file intrin_rule_nvptx.cc
22 */
23#ifdef TVM_LLVM_VERSION
24
25#include <tvm/runtime/registry.h>
26#include <tvm/tir/builtin.h>
27#include <tvm/tir/expr.h>
28#include <tvm/tir/op.h>
29#include <tvm/tir/op_attr_types.h>
30
31#include <sstream>
32
33namespace tvm {
34namespace codegen {
35
36inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) {
37 using namespace tir;
38 const CallNode* call = e.as<CallNode>();
39 ICHECK(call != nullptr);
40 ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64)
41 << "Only support float32 or float64.";
42
43 const OpNode* op = call->op.as<OpNode>();
44 ICHECK(op != nullptr);
45 std::string name = op->name;
46 ICHECK_EQ(name.substr(0, 4), "tir.");
47
48 std::ostringstream intrinsic_name;
49 intrinsic_name << "__nv_" << name.substr(4);
50 if (call->dtype.bits() == 32) intrinsic_name << "f";
51
52 Array<PrimExpr> new_args = {StringImm(intrinsic_name.str())};
53 for (auto arg : call->args) {
54 new_args.push_back(arg);
55 }
56 return Call(call->dtype, builtin::call_pure_extern(), new_args);
57}
58
59namespace llvm {
60using tir::FLowerIntrinsic;
61
62TVM_REGISTER_OP("tir.floor")
63 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
64
65TVM_REGISTER_OP("tir.ceil")
66 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
67
68TVM_REGISTER_OP("tir.round")
69 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
70
71TVM_REGISTER_OP("tir.nearbyint")
72 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
73
74TVM_REGISTER_OP("tir.trunc")
75 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
76
77TVM_REGISTER_OP("tir.fabs")
78 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
79
80TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
81 DispatchPureExternLibDevice);
82
83TVM_REGISTER_OP("tir.exp2")
84 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
85
86TVM_REGISTER_OP("tir.exp10")
87 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
88
89TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
90 DispatchPureExternLibDevice);
91
92TVM_REGISTER_OP("tir.fma").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
93 DispatchPureExternLibDevice);
94
95TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
96 DispatchPureExternLibDevice);
97
98TVM_REGISTER_OP("tir.log2")
99 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
100
101TVM_REGISTER_OP("tir.log10")
102 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
103
104TVM_REGISTER_OP("tir.sqrt")
105 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
106
107TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
108 DispatchPureExternLibDevice);
109
110TVM_REGISTER_OP("tir.tanh")
111 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
112
113TVM_REGISTER_OP("tir.tan").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
114 DispatchPureExternLibDevice);
115
116TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
117 DispatchPureExternLibDevice);
118
119TVM_REGISTER_OP("tir.cosh")
120 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
121
122TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic",
123 DispatchPureExternLibDevice);
124
125TVM_REGISTER_OP("tir.sinh")
126 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
127
128TVM_REGISTER_OP("tir.atan")
129 .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic", DispatchPureExternLibDevice);
130
131} // namespace llvm
132} // namespace codegen
133} // namespace tvm
134
135#endif // LLVM_VERSION
136