1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file annotate.cc
23 *
24 * \brief Annotating the graph with simulated quantize operators.
25 */
26
27#include <tvm/relay/analysis.h>
28#include <tvm/relay/transform.h>
29
30#include "./quantize.h"
31
32namespace tvm {
33namespace relay {
34namespace quantize {
35
36using namespace relay::transform;
37
38class QAnnotateExpr;
39class QAnnotateExprNode : public TempExprNode {
40 public:
41 Expr expr;
42 QAnnotateKind kind;
43
44 void VisitAttrs(tvm::AttrVisitor* v) {
45 v->Visit("expr", &expr);
46 v->Visit("kind", &kind);
47 }
48
49 Expr Realize() const final;
50
51 static constexpr const char* _type_key = "relay.QAnnotateExpr";
52 TVM_DECLARE_FINAL_OBJECT_INFO(QAnnotateExprNode, TempExprNode);
53};
54
55class QAnnotateExpr : public TempExpr {
56 public:
57 /*!
58 * \brief The constructor
59 * \param expr The original relay expression.
60 * \param kind The annotation kind.
61 */
62 TVM_DLL QAnnotateExpr(Expr expr, QAnnotateKind kind);
63
64 TVM_DEFINE_OBJECT_REF_METHODS(QAnnotateExpr, TempExpr, QAnnotateExprNode);
65};
66
67Expr QAnnotateExprNode::Realize() const { return expr; }
68
69QAnnotateExpr::QAnnotateExpr(Expr expr, QAnnotateKind kind) {
70 auto rnode = make_object<QAnnotateExprNode>();
71 rnode->expr = std::move(expr);
72 rnode->kind = kind;
73 data_ = std::move(rnode);
74}
75
76TVM_REGISTER_GLOBAL("relay._quantize.make_annotate_expr").set_body_typed([](Expr expr, int kind) {
77 return QAnnotateExpr(expr, static_cast<QAnnotateKind>(kind));
78});
79
80Pass QuantizeAnnotate() {
81 // TODO(tvm-teams): since partition has added cast_hint in different
82 // branches, try to remove this in the future.
83 std::function<Expr(const Expr&)> fmulti_ref = [](const Expr& e) {
84 if (e->IsInstance<TempExprNode>()) {
85 const auto* n = e.as<QAnnotateExprNode>();
86 ICHECK(n);
87 const PackedFunc* f = runtime::Registry::Get("relay.quantize.attach_simulated_quantize");
88 Expr ret = (*f)(n->expr, static_cast<int>(kQInput));
89 return static_cast<Expr>(QAnnotateExpr(ret, kQInput));
90 }
91 return e;
92 };
93
94 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
95 [=](Function f, IRModule m, PassContext pc) {
96 auto func = Downcast<Function>(ForwardRewrite(f, "FQAnnotateRewrite", nullptr, fmulti_ref));
97 auto new_params = func->params;
98 for (const auto& x : FreeVars(func)) {
99 new_params.push_back(x);
100 }
101 return WithFields(func, new_params);
102 };
103 return CreateFunctionPass(pass_func, 1, "QuantizeAnnotate", {});
104}
105
106TVM_REGISTER_GLOBAL("relay._quantize.QuantizeAnnotate").set_body_typed(QuantizeAnnotate);
107
108TVM_REGISTER_NODE_TYPE(QAnnotateExprNode);
109
110} // namespace quantize
111} // namespace relay
112} // namespace tvm
113