1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file canonicalize_ops.cc
22 * \brief Canonicalize special operators to basic operators.
23 This can simplify latter analysis. (e.g. Expand bias_add to expand_dims and broadcast_add.)
24 */
25#include <tvm/relay/analysis.h>
26#include <tvm/relay/attrs/nn.h>
27#include <tvm/relay/expr_functor.h>
28#include <tvm/relay/op.h>
29#include <tvm/relay/transform.h>
30
31#include "pattern_utils.h"
32
33namespace tvm {
34namespace relay {
35
36class BiasAddSimplifier : public ExprRewriter {
37 public:
38 BiasAddSimplifier() : bias_add_op_(Op::Get("nn.bias_add")) {}
39
40 Expr Rewrite_(const CallNode* n, const Expr& post) override {
41 auto new_n = post;
42 if (n->op == bias_add_op_) {
43 Call call = Downcast<Call>(new_n);
44 ICHECK_EQ(call->args.size(), 2);
45 const BiasAddAttrs* param = call->attrs.as<BiasAddAttrs>();
46
47 auto ttype = n->args[0]->type_as<TensorTypeNode>();
48 size_t n_dim = ttype->shape.size();
49 int axis = param->axis;
50 if (axis < 0) {
51 axis += n_dim;
52 }
53 Expr expanded_bias = ExpandBiasToMatchAxis(call->args[1], n_dim, {axis});
54 Expr ret = Add(call->args[0], expanded_bias);
55 ret->checked_type_ = n->checked_type_;
56 return ret;
57 }
58 return new_n;
59 }
60
61 private:
62 // Cache the bias_add for equivalence checking.
63 const Op& bias_add_op_;
64};
65
66Expr CanonicalizeOps(const Expr& e) {
67 auto rewriter = BiasAddSimplifier();
68 return PostOrderRewrite(e, &rewriter);
69}
70
71namespace transform {
72
73Pass CanonicalizeOps() {
74 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
75 [=](Function f, IRModule m, PassContext pc) {
76 return Downcast<Function>(CanonicalizeOps(f));
77 };
78 return CreateFunctionPass(pass_func, 3, "CanonicalizeOps", {"InferType"});
79}
80
81TVM_REGISTER_GLOBAL("relay._transform.CanonicalizeOps").set_body_typed(CanonicalizeOps);
82
83} // namespace transform
84
85} // namespace relay
86} // namespace tvm
87