1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file fast_math.cc |
22 | * \brief Replaces non linear activation functions with their fast but approximate counterparts. |
23 | */ |
24 | #include <tvm/relay/analysis.h> |
25 | #include <tvm/relay/attrs/nn.h> |
26 | #include <tvm/relay/expr_functor.h> |
27 | #include <tvm/relay/op.h> |
28 | #include <tvm/relay/transform.h> |
29 | |
30 | #include "pattern_utils.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | |
35 | class FastMathMutator : public ExprRewriter { |
36 | public: |
37 | FastMathMutator() |
38 | : exp_op_(Op::Get("exp" )), |
39 | erf_op_(Op::Get("erf" )), |
40 | tanh_op_(Op::Get("tanh" )), |
41 | softmax_op_(Op::Get("nn.softmax" )) {} |
42 | |
43 | Expr Rewrite_(const CallNode* pre, const Expr& post) override { |
44 | if (pre->op == exp_op_) { |
45 | return FastExp(post.as<CallNode>()->args[0]); |
46 | } else if (pre->op == erf_op_) { |
47 | return FastErf(post.as<CallNode>()->args[0]); |
48 | } else if (pre->op == tanh_op_) { |
49 | return FastTanh(post.as<CallNode>()->args[0]); |
50 | } else if (pre->op == softmax_op_) { |
51 | return FastSoftmax(post.as<CallNode>()->args[0], post.as<CallNode>()->attrs); |
52 | } |
53 | return post; |
54 | } |
55 | |
56 | private: |
57 | // Cache the following ops. They will be used in the passes repeatedly for |
58 | // operator equivalence checking so that the registry lookup overhead can be |
59 | // reduced. |
60 | const Op& exp_op_; |
61 | const Op& erf_op_; |
62 | const Op& tanh_op_; |
63 | const Op& softmax_op_; |
64 | }; |
65 | |
66 | Expr FastMath(const Expr& e) { |
67 | auto rewriter = FastMathMutator(); |
68 | return PostOrderRewrite(e, &rewriter); |
69 | } |
70 | |
71 | namespace transform { |
72 | |
73 | Pass FastMath() { |
74 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
75 | [=](Function f, IRModule m, PassContext pc) { return Downcast<Function>(FastMath(f)); }; |
76 | return CreateFunctionPass(pass_func, 4, "FastMath" , {"InferType" }); |
77 | } |
78 | |
79 | TVM_REGISTER_GLOBAL("relay._transform.FastMath" ).set_body_typed(FastMath); |
80 | |
81 | } // namespace transform |
82 | |
83 | } // namespace relay |
84 | } // namespace tvm |
85 | |