1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file intrin_rule_nvptx.cc |
22 | */ |
23 | #ifdef TVM_LLVM_VERSION |
24 | |
25 | #include <tvm/runtime/registry.h> |
26 | #include <tvm/tir/builtin.h> |
27 | #include <tvm/tir/expr.h> |
28 | #include <tvm/tir/op.h> |
29 | #include <tvm/tir/op_attr_types.h> |
30 | |
31 | #include <sstream> |
32 | |
33 | namespace tvm { |
34 | namespace codegen { |
35 | |
36 | inline PrimExpr DispatchPureExternLibDevice(const PrimExpr& e) { |
37 | using namespace tir; |
38 | const CallNode* call = e.as<CallNode>(); |
39 | ICHECK(call != nullptr); |
40 | ICHECK(call->dtype.bits() == 32 || call->dtype.bits() == 64) |
41 | << "Only support float32 or float64." ; |
42 | |
43 | const OpNode* op = call->op.as<OpNode>(); |
44 | ICHECK(op != nullptr); |
45 | std::string name = op->name; |
46 | ICHECK_EQ(name.substr(0, 4), "tir." ); |
47 | |
48 | std::ostringstream intrinsic_name; |
49 | intrinsic_name << "__nv_" << name.substr(4); |
50 | if (call->dtype.bits() == 32) intrinsic_name << "f" ; |
51 | |
52 | Array<PrimExpr> new_args = {StringImm(intrinsic_name.str())}; |
53 | for (auto arg : call->args) { |
54 | new_args.push_back(arg); |
55 | } |
56 | return Call(call->dtype, builtin::call_pure_extern(), new_args); |
57 | } |
58 | |
59 | namespace llvm { |
60 | using tir::FLowerIntrinsic; |
61 | |
62 | TVM_REGISTER_OP("tir.floor" ) |
63 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
64 | |
65 | TVM_REGISTER_OP("tir.ceil" ) |
66 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
67 | |
68 | TVM_REGISTER_OP("tir.round" ) |
69 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
70 | |
71 | TVM_REGISTER_OP("tir.nearbyint" ) |
72 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
73 | |
74 | TVM_REGISTER_OP("tir.trunc" ) |
75 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
76 | |
77 | TVM_REGISTER_OP("tir.fabs" ) |
78 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
79 | |
80 | TVM_REGISTER_OP("tir.exp" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
81 | DispatchPureExternLibDevice); |
82 | |
83 | TVM_REGISTER_OP("tir.exp2" ) |
84 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
85 | |
86 | TVM_REGISTER_OP("tir.exp10" ) |
87 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
88 | |
89 | TVM_REGISTER_OP("tir.erf" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
90 | DispatchPureExternLibDevice); |
91 | |
92 | TVM_REGISTER_OP("tir.fma" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
93 | DispatchPureExternLibDevice); |
94 | |
95 | TVM_REGISTER_OP("tir.log" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
96 | DispatchPureExternLibDevice); |
97 | |
98 | TVM_REGISTER_OP("tir.log2" ) |
99 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
100 | |
101 | TVM_REGISTER_OP("tir.log10" ) |
102 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
103 | |
104 | TVM_REGISTER_OP("tir.sqrt" ) |
105 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
106 | |
107 | TVM_REGISTER_OP("tir.pow" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
108 | DispatchPureExternLibDevice); |
109 | |
110 | TVM_REGISTER_OP("tir.tanh" ) |
111 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
112 | |
113 | TVM_REGISTER_OP("tir.tan" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
114 | DispatchPureExternLibDevice); |
115 | |
116 | TVM_REGISTER_OP("tir.cos" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
117 | DispatchPureExternLibDevice); |
118 | |
119 | TVM_REGISTER_OP("tir.cosh" ) |
120 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
121 | |
122 | TVM_REGISTER_OP("tir.sin" ).set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , |
123 | DispatchPureExternLibDevice); |
124 | |
125 | TVM_REGISTER_OP("tir.sinh" ) |
126 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
127 | |
128 | TVM_REGISTER_OP("tir.atan" ) |
129 | .set_attr<FLowerIntrinsic>("nvptx.FLowerIntrinsic" , DispatchPureExternLibDevice); |
130 | |
131 | } // namespace llvm |
132 | } // namespace codegen |
133 | } // namespace tvm |
134 | |
135 | #endif // LLVM_VERSION |
136 | |