1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file intrin_rule_metal.cc
22 * \brief Metal intrinsic rules.
23 */
24#include <tvm/tir/op_attr_types.h>
25#include <tvm/topi/elemwise.h>
26
27#include "../intrin_rule.h"
28
29namespace tvm {
30namespace codegen {
31namespace intrin {
32using tir::FLowerIntrinsic;
33
34TVM_REGISTER_OP("tir.floor")
35 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
36
37TVM_REGISTER_OP("tir.ceil")
38 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
39
40TVM_REGISTER_OP("tir.trunc")
41 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
42
43TVM_REGISTER_OP("tir.fabs")
44 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
45
46TVM_REGISTER_OP("tir.round")
47 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
48
49TVM_REGISTER_OP("tir.nearbyint")
50 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
51
52TVM_REGISTER_OP("tir.exp").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
53 DispatchPureExtern<Direct>);
54
55TVM_REGISTER_OP("tir.exp2")
56 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
57
58TVM_REGISTER_OP("tir.exp10")
59 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
60
61TVM_REGISTER_OP("tir.log").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
62 DispatchPureExtern<Direct>);
63
64TVM_REGISTER_OP("tir.log2")
65 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
66
67TVM_REGISTER_OP("tir.log10")
68 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
69
70TVM_REGISTER_OP("tir.tanh")
71 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
72
73TVM_REGISTER_OP("tir.sqrt")
74 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
75
76TVM_REGISTER_OP("tir.pow").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
77 DispatchPureExtern<Direct>);
78
79TVM_REGISTER_OP("tir.popcount")
80 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
81
82TVM_REGISTER_OP("tir.fmod")
83 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
84
85TVM_REGISTER_OP("tir.sin").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
86 DispatchPureExtern<Direct>);
87
88TVM_REGISTER_OP("tir.sinh")
89 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
90
91TVM_REGISTER_OP("tir.cos").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic",
92 DispatchPureExtern<Direct>);
93
94TVM_REGISTER_OP("tir.cosh")
95 .set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchPureExtern<Direct>);
96
97// There is no erf function in Metal. When erf is used, we use fast_erf instead
98static PrimExpr DispatchFastErf(const PrimExpr& e) {
99 LOG(WARNING) << " Metal doesn't have built-in erf function. fast_erf will be used instead.";
100 const CallNode* call = e.as<CallNode>();
101 ICHECK(call != nullptr);
102 ICHECK_EQ(call->args.size(), 1);
103 PrimExpr arg = call->args[0];
104 int bits = arg.dtype().bits();
105 bool isFloat = arg.dtype().is_float();
106 PrimExpr res;
107 if (isFloat && (bits == 16 || bits == 32))
108 res = topi::fast_erf_float_expr(arg, bits);
109 else
110 LOG(FATAL) << "Unsupported type in Metal fast_erf";
111 return res;
112}
113TVM_REGISTER_OP("tir.erf").set_attr<FLowerIntrinsic>("metal.FLowerIntrinsic", DispatchFastErf);
114
115} // namespace intrin
116} // namespace codegen
117} // namespace tvm
118