1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 *
22 * \file partition.cc
23 *
24 * \brief Partition a graph into sections for quantization.
25 */
26
27#include <tvm/relay/transform.h>
28
29#include "../op/annotation/annotation.h"
30#include "./quantize.h"
31
32namespace tvm {
33namespace relay {
34namespace quantize {
35
36using namespace relay::transform;
37
38class QPartitionExpr;
39class QPartitionExprNode : public TempExprNode {
40 public:
41 /*! \brief The original expression */
42 Expr expr;
43
44 void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr", &expr); }
45
46 Expr Realize() const final;
47
48 static constexpr const char* _type_key = "relay.QPartitionExpr";
49 TVM_DECLARE_FINAL_OBJECT_INFO(QPartitionExprNode, TempExprNode);
50};
51
52class QPartitionExpr : public TempExpr {
53 public:
54 /*!
55 * \brief The constructor
56 * \param expr The original relay expression.
57 */
58 TVM_DLL explicit QPartitionExpr(Expr expr);
59
60 TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode);
61};
62
63Expr QPartitionExprNode::Realize() const {
64 // insert cast hint and stop fusion
65 const QConfig& cfg = QConfig::Current();
66 Expr ret = CastHint(this->expr, cfg->dtype_input);
67 return StopFusion(ret);
68}
69
70QPartitionExpr::QPartitionExpr(Expr expr) {
71 auto rnode = make_object<QPartitionExprNode>();
72 rnode->expr = std::move(expr);
73 data_ = std::move(rnode);
74}
75
76TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr").set_body_typed([](Expr expr) {
77 return QPartitionExpr(expr);
78});
79
80Pass QuantizePartition() {
81 runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func =
82 [=](Function f, IRModule m, PassContext pc) {
83 auto ret = Downcast<Function>(ForwardRewrite(f, "FQPartitionRewrite", nullptr, nullptr));
84 return ret;
85 };
86 return CreateFunctionPass(pass_func, 1, "QuantizePartition", {});
87}
88
89TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition").set_body_typed(QuantizePartition);
90
91TVM_REGISTER_NODE_TYPE(QPartitionExprNode);
92
93} // namespace quantize
94} // namespace relay
95} // namespace tvm
96