1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file canonicalize_ops.cc |
22 | * \brief Canonicalize special operators to basic operators. |
23 | This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.) |
24 | */ |
25 | #include <tvm/relay/analysis.h> |
26 | #include <tvm/relay/attrs/nn.h> |
27 | #include <tvm/relay/expr_functor.h> |
28 | #include <tvm/relay/op.h> |
29 | #include <tvm/relay/transform.h> |
30 | |
31 | #include "pattern_utils.h" |
32 | |
33 | namespace tvm { |
34 | namespace relay { |
35 | |
36 | class BiasAddSimplifier : public ExprRewriter { |
37 | public: |
38 | BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add" )) {} |
39 | |
40 | Expr Rewrite_(const CallNode* n, const Expr& post) override { |
41 | auto new_n = post; |
42 | if (n->op == bias_add_op_) { |
43 | Call call = Downcast<Call>(new_n); |
44 | ICHECK_EQ(call->args.size(), 2); |
45 | const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>(); |
46 | |
47 | auto ttype = n->args[0]->type_as<TensorTypeNode>(); |
48 | size_t n_dim = ttype->shape.size(); |
49 | int axis = param->axis; |
50 | if (axis < 0) { |
51 | axis += n_dim; |
52 | } |
53 | Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {axis}); |
54 | Expr ret = Add(call->args[0], expanded_bias); |
55 | ret->checked_type_ = n->checked_type_; |
56 | return ret; |
57 | } |
58 | return new_n; |
59 | } |
60 | |
61 | private: |
62 | // Cache the bias_add for equivalence checking. |
63 | const Op& bias_add_op_; |
64 | }; |
65 | |
66 | Expr CanonicalizeOps(const Expr& e) { |
67 | auto rewriter = BiasAddSimplifier(); |
68 | return PostOrderRewrite(e, &rewriter); |
69 | } |
70 | |
71 | namespace transform { |
72 | |
73 | Pass CanonicalizeOps() { |
74 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
75 | [=](Function f, IRModule m, PassContext pc) { |
76 | return Downcast<Function>(CanonicalizeOps(f)); |
77 | }; |
78 | return CreateFunctionPass(pass_func, 3, "CanonicalizeOps" , {"InferType" }); |
79 | } |
80 | |
81 | TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps" ).set_body_typed(CanonicalizeOps); |
82 | |
83 | } // namespace transform |
84 | |
85 | } // namespace relay |
86 | } // namespace tvm |
87 | |