1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file quantize.cc
22 *
23 * \brief transform a graph to a low-bit graph
24 * for compression and acceleration.
25 */
26#include "./quantize.h"
27
28#include <dmlc/thread_local.h>
29#include <tvm/relay/op_attr_types.h>
30#include <tvm/relay/transform.h>
31
32#include <stack>
33
34namespace tvm {
35namespace relay {
36namespace quantize {
37
38TVM_REGISTER_NODE_TYPE(SimulatedQuantizeAttrs);
39
40bool SimulatedQuantizeRel(const Array<Type>& types, int num_inputs, const Attrs& attrs,
41 const TypeReporter& reporter) {
42 ICHECK_EQ(types.size(), 5);
43 const auto param = attrs.as<SimulatedQuantizeAttrs>();
44 ICHECK(param != nullptr);
45
46 const auto* data = types[0].as<TensorTypeNode>();
47
48 if (data == nullptr) {
49 return false;
50 }
51
52 ICHECK_NE(data->shape.size(), 0) << "Input shape cannot be empty";
53
54 reporter->Assign(types[1], TensorType({}, DataType::Float(32))); // dom_scale
55 reporter->Assign(types[2], TensorType({}, DataType::Float(32))); // clip_min
56 reporter->Assign(types[3], TensorType({}, DataType::Float(32))); // clip_max
57 reporter->Assign(types[4], types[0]); // output
58 return true;
59}
60
61RELAY_REGISTER_OP("relay.op.annotation.simulated_quantize")
62 .describe(R"code(simulated quantize op)code" TVM_ADD_FILELINE)
63 .set_num_inputs(4)
64 .add_argument("data", "Tensor", "The input data.")
65 .add_argument("dom_scale", "Tensor", "The domain scale of input data. It should be a scalar")
66 .add_argument("clip_min", "Tensor", "lower bound. It should be a scalar")
67 .add_argument("clip_max", "Tensor", "upper bound. It should be a scalar")
68 .set_attrs_type<SimulatedQuantizeAttrs>()
69 .set_support_level(11)
70 .add_type_rel("SimulatedQuantize", SimulatedQuantizeRel);
71
72TVM_REGISTER_GLOBAL("relay._quantize.simulated_quantize")
73 .set_body_typed([](Expr data, Expr dom_scale, Expr clip_min, Expr clip_max, int kind, bool sign,
74 String rounding) {
75 auto attrs = make_object<SimulatedQuantizeAttrs>();
76 attrs->kind = kind;
77 attrs->sign = sign;
78 attrs->rounding = rounding;
79 static const Op& op = Op::Get("relay.op.annotation.simulated_quantize");
80 return Call(op, {data, dom_scale, clip_min, clip_max}, Attrs(attrs), {});
81 });
82
83/*! \brief Entry to hold the BuildConfig context stack. */
84struct TVMQConfigThreadLocalEntry {
85 /*! \brief The default build config if the stack is empty */
86 QConfig default_config;
87
88 /*! \brief The current build config context */
89 std::stack<QConfig> context_stack;
90
91 TVMQConfigThreadLocalEntry() : default_config(make_object<QConfigNode>()) {}
92};
93
94/*! \brief Thread local store to hold the BuildConfig context stack. */
95typedef dmlc::ThreadLocalStore<TVMQConfigThreadLocalEntry> TVMQConfigThreadLocalStore;
96
97void QConfig::EnterQConfigScope(const QConfig& build_config) {
98 TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get();
99 entry->context_stack.push(build_config);
100}
101
102void QConfig::ExitQConfigScope() {
103 TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get();
104 entry->context_stack.pop();
105}
106
107QConfig& QConfig::Current() {
108 TVMQConfigThreadLocalEntry* entry = TVMQConfigThreadLocalStore::Get();
109 if (entry->context_stack.size() > 0) {
110 return entry->context_stack.top();
111 }
112
113 return entry->default_config;
114}
115
116TVM_REGISTER_NODE_TYPE(QConfigNode);
117
118TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
119 .set_dispatch<QConfigNode>([](const ObjectRef& ref, ReprPrinter* p) {
120 auto* op = static_cast<const QConfigNode*>(ref.get());
121 p->stream << "qconfig(";
122 p->stream << "nbit_input=" << op->nbit_input << ", ";
123 p->stream << "nbit_weight=" << op->nbit_weight << ", ";
124 p->stream << "nbit_activation=" << op->nbit_activation << ", ";
125 p->stream << "calibrate_mode=" << op->calibrate_mode << ", ";
126 p->stream << "global_scale=" << op->global_scale << ", ";
127 p->stream << "weight_scale=" << op->weight_scale << ", ";
128 p->stream << "skip_conv_layers==" << op->skip_conv_layers << ", ";
129 p->stream << "skip_dense_layer==" << op->skip_dense_layer << ", ";
130 p->stream << "do_simulation==" << op->do_simulation << ", ";
131 p->stream << "round_for_shift==" << op->round_for_shift << ", ";
132 p->stream << "debug_enabled_ops==" << op->debug_enabled_ops << ", ";
133 p->stream << "rounding==" << op->rounding << ", ";
134 p->stream << "partition_conversions==" << op->partition_conversions;
135 p->stream << ")";
136 });
137
138TVM_REGISTER_GLOBAL("relay._quantize._GetCurrentQConfig").set_body_typed([]() -> QConfig {
139 return QConfig::Current();
140});
141
142TVM_REGISTER_GLOBAL("relay._quantize._EnterQConfigScope")
143 .set_body_typed(QConfig::EnterQConfigScope);
144
145TVM_REGISTER_GLOBAL("relay._quantize._ExitQConfigScope").set_body_typed(QConfig::ExitQConfigScope);
146
147} // namespace quantize
148} // namespace relay
149} // namespace tvm
150