1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * |
22 | * \file partition.cc |
23 | * |
24 | * \brief Partition a graph into sections for quantization. |
25 | */ |
26 | |
27 | #include <tvm/relay/transform.h> |
28 | |
29 | #include "../op/annotation/annotation.h" |
30 | #include "./quantize.h" |
31 | |
32 | namespace tvm { |
33 | namespace relay { |
34 | namespace quantize { |
35 | |
36 | using namespace relay::transform; |
37 | |
38 | class QPartitionExpr; |
39 | class QPartitionExprNode : public TempExprNode { |
40 | public: |
41 | /*! \brief The original expression */ |
42 | Expr expr; |
43 | |
44 | void VisitAttrs(tvm::AttrVisitor* v) { v->Visit("expr" , &expr); } |
45 | |
46 | Expr Realize() const final; |
47 | |
48 | static constexpr const char* _type_key = "relay.QPartitionExpr" ; |
49 | TVM_DECLARE_FINAL_OBJECT_INFO(QPartitionExprNode, TempExprNode); |
50 | }; |
51 | |
52 | class QPartitionExpr : public TempExpr { |
53 | public: |
54 | /*! |
55 | * \brief The constructor |
56 | * \param expr The original relay expression. |
57 | */ |
58 | TVM_DLL explicit QPartitionExpr(Expr expr); |
59 | |
60 | TVM_DEFINE_OBJECT_REF_METHODS(QPartitionExpr, TempExpr, QPartitionExprNode); |
61 | }; |
62 | |
63 | Expr QPartitionExprNode::Realize() const { |
64 | // insert cast hint and stop fusion |
65 | const QConfig& cfg = QConfig::Current(); |
66 | Expr ret = CastHint(this->expr, cfg->dtype_input); |
67 | return StopFusion(ret); |
68 | } |
69 | |
70 | QPartitionExpr::QPartitionExpr(Expr expr) { |
71 | auto rnode = make_object<QPartitionExprNode>(); |
72 | rnode->expr = std::move(expr); |
73 | data_ = std::move(rnode); |
74 | } |
75 | |
76 | TVM_REGISTER_GLOBAL("relay._quantize.make_partition_expr" ).set_body_typed([](Expr expr) { |
77 | return QPartitionExpr(expr); |
78 | }); |
79 | |
80 | Pass QuantizePartition() { |
81 | runtime::TypedPackedFunc<Function(Function, IRModule, PassContext)> pass_func = |
82 | [=](Function f, IRModule m, PassContext pc) { |
83 | auto ret = Downcast<Function>(ForwardRewrite(f, "FQPartitionRewrite" , nullptr, nullptr)); |
84 | return ret; |
85 | }; |
86 | return CreateFunctionPass(pass_func, 1, "QuantizePartition" , {}); |
87 | } |
88 | |
89 | TVM_REGISTER_GLOBAL("relay._quantize.QuantizePartition" ).set_body_typed(QuantizePartition); |
90 | |
91 | TVM_REGISTER_NODE_TYPE(QPartitionExprNode); |
92 | |
93 | } // namespace quantize |
94 | } // namespace relay |
95 | } // namespace tvm |
96 | |