1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/analysis/get_calibration_data.cc
22 *
23 * \brief To get the calibration data, we need to perform two
24 * steps. First, we need to prepare the module that generates
25 * the tensor values (GetCalibrateModule). Second, we need to
26 * generate the mapping between the values and the functions
27 * (GetCalibrateOutputMap).
28 */
29
30#include <tvm/relay/analysis.h>
31#include <tvm/relay/expr.h>
32#include <tvm/relay/expr_functor.h>
33
34namespace tvm {
35namespace relay {
36
37/*!
38 * \brief This function returns a module that will be used by
39 * the relay graph executor for collecting the calibration data.
40 * To do that, we first make all inputs and outputs of each
41 * function into the final output (i.e., the final output is a
42 * tuple of tensors). Then, we change the compiler attribute of
43 * each function. Finally, we mark all function to be inlined.
44 */
45
46class Collector : public ExprRewriter {
47 public:
48 explicit Collector(const IRModule& module) : module_(module) {}
49
50 Expr Rewrite_(const CallNode* call, const Expr& post) final {
51 // check if the function implementation is available
52 // intrinsic functions are excluded for now
53 if (call->op->IsInstance<GlobalVarNode>()) {
54 auto var = Downcast<GlobalVar>(call->op);
55 ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
56 // we only handle functions with Compiler attribute set
57 auto func = Downcast<Function>(module_->Lookup(var));
58 if (func->GetAttr<String>(attr::kCompiler)) {
59 // collect all the inputs and outputs
60 for (const auto& it : call->args) new_outputs_.push_back(it);
61 new_outputs_.push_back(post);
62 }
63 }
64 return post;
65 }
66
67 Array<Expr> GetNewOutputs() { return new_outputs_; }
68
69 private:
70 const IRModule& module_;
71 Array<Expr> new_outputs_;
72};
73
74Expr FlattenOutputTuple(const Array<Expr>& exprs) {
75 Array<Expr> fields;
76 for (const auto& it : exprs) {
77 ICHECK(it->checked_type_.defined());
78 if (auto* tn = it->checked_type_.as<TupleTypeNode>()) {
79 // TODO(seanlatias): for now input argument cannot be a tuple
80 ICHECK(it->IsInstance<CallNode>());
81 for (size_t i = 0; i < tn->fields.size(); i++) {
82 fields.push_back(TupleGetItem(it, i));
83 }
84 } else {
85 fields.push_back(it);
86 }
87 }
88 return Tuple(fields);
89}
90
91IRModule GetCalibrateModule(IRModule module) {
92 auto glob_funcs = module->functions;
93 // module is mutable, hence, we make a copy of it.
94 module.CopyOnWrite();
95 for (const auto& pair : glob_funcs) {
96 if (auto* fn = pair.second.as<FunctionNode>()) {
97 auto func = GetRef<Function>(fn);
98 // we only collect the outputs for main function
99 if (pair.first->name_hint == "main") {
100 Collector collector(module);
101 PostOrderRewrite(func->body, &collector);
102 auto new_outputs = collector.GetNewOutputs();
103 Expr tuple = FlattenOutputTuple(new_outputs);
104 func = Function(func->params, tuple, tuple->checked_type_, func->type_params, func->attrs);
105 module->Update(pair.first, func);
106 }
107 }
108 }
109 // reset the attribute of functions for running graph executor
110 for (const auto& pair : glob_funcs) {
111 if (auto* fn = pair.second.as<FunctionNode>()) {
112 auto func = GetRef<Function>(fn);
113 if (func->GetAttr<String>(attr::kCompiler)) {
114 // we need to inline the functions in order to run grpah runtime
115 func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1));
116 // reset the compiler attribute to null for llvm execution
117 func = WithAttr(std::move(func), attr::kCompiler, NullValue<ObjectRef>());
118 module->Update(pair.first, func);
119 }
120 }
121 }
122 return module;
123}
124
125/*!
126 * \brief This function generates the output mapping between
127 * the calibration data and each function. The key is a
128 * GlobalVar that corresponds to each function and the value
129 * is an array of integers. The size of the array is always
130 * three. The first value is the offset the points to the start.
131 * The second value is the number of inputs. The third value
132 * is the number of outputs.
133 */
134
135class OutputMapper : public ExprRewriter {
136 public:
137 OutputMapper(Map<GlobalVar, Array<Integer>>* output_map, const IRModule& module, size_t* offset)
138 : output_map_(output_map), module_(module), offset_(offset) {}
139
140 Expr Rewrite_(const CallNode* call, const Expr& post) final {
141 if (call->op->IsInstance<GlobalVarNode>()) {
142 auto var = Downcast<GlobalVar>(call->op);
143 ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined";
144 ICHECK_EQ(output_map_->count(var), 0)
145 << "Repeated function call " << var << " is not supported.";
146 auto func = Downcast<Function>(module_->Lookup(var));
147 // we only handle functions with Compiler attribute set
148 if (func->GetAttr<String>(attr::kCompiler)) {
149 Array<Integer> info;
150 // the first value is the offset
151 info.push_back(Integer(*offset_));
152 // the second value is the number of inputs
153 info.push_back(Integer(call->args.size()));
154 // the third value is the number of outputs
155 // we need to check if the output is a tuple
156 size_t out_size = 1;
157 if (auto* tn = func->body.as<TupleNode>()) {
158 info.push_back(Integer(tn->fields.size()));
159 out_size = tn->fields.size();
160 } else {
161 info.push_back(Integer(1));
162 }
163 output_map_->Set(var, info);
164 // calculate the offset for the next function
165 *offset_ = *offset_ + call->args.size() + out_size;
166 }
167 }
168 return post;
169 }
170
171 private:
172 Map<GlobalVar, Array<Integer>>* output_map_;
173 const IRModule& module_;
174 size_t* offset_;
175};
176
177Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& module) {
178 Map<GlobalVar, Array<Integer>> output_map;
179 size_t offset = 0;
180 auto glob_funcs = module->functions;
181 for (const auto& pair : glob_funcs) {
182 if (auto* fn = pair.second.as<FunctionNode>()) {
183 if (pair.first->name_hint == "main") {
184 OutputMapper output_mapper(&output_map, module, &offset);
185 auto func = GetRef<Function>(fn);
186 PostOrderRewrite(func->body, &output_mapper);
187 }
188 }
189 }
190
191 return output_map;
192}
193
194TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_module").set_body_typed([](IRModule mod) {
195 return GetCalibrateModule(mod);
196});
197
198TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_output_map")
199 .set_body_typed([](const IRModule& mod) { return GetCalibrateOutputMap(mod); });
200
201} // namespace relay
202} // namespace tvm
203