1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file src/relay/analysis/get_calibration_data.cc |
22 | * |
23 | * \brief To get the calibration data, we need to perform two |
24 | * steps. First, we need to prepare the module that generates |
25 | * the tensor values (GetCalibrateModule). Second, we need to |
26 | * generate the mapping between the values and the functions |
27 | * (GetCalibrateOutputMap). |
28 | */ |
29 | |
30 | #include <tvm/relay/analysis.h> |
31 | #include <tvm/relay/expr.h> |
32 | #include <tvm/relay/expr_functor.h> |
33 | |
34 | namespace tvm { |
35 | namespace relay { |
36 | |
37 | /*! |
38 | * \brief This function returns a module that will be used by |
39 | * the relay graph executor for collecting the calibration data. |
40 | * To do that, we first make all inputs and outputs of each |
41 | * function into the final output (i.e., the final output is a |
42 | * tuple of tensors). Then, we change the compiler attribute of |
43 | * each function. Finally, we mark all function to be inlined. |
44 | */ |
45 | |
46 | class Collector : public ExprRewriter { |
47 | public: |
48 | explicit Collector(const IRModule& module) : module_(module) {} |
49 | |
50 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
51 | // check if the function implementation is available |
52 | // intrinsic functions are excluded for now |
53 | if (call->op->IsInstance<GlobalVarNode>()) { |
54 | auto var = Downcast<GlobalVar>(call->op); |
55 | ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined" ; |
56 | // we only handle functions with Compiler attribute set |
57 | auto func = Downcast<Function>(module_->Lookup(var)); |
58 | if (func->GetAttr<String>(attr::kCompiler)) { |
59 | // collect all the inputs and outputs |
60 | for (const auto& it : call->args) new_outputs_.push_back(it); |
61 | new_outputs_.push_back(post); |
62 | } |
63 | } |
64 | return post; |
65 | } |
66 | |
67 | Array<Expr> GetNewOutputs() { return new_outputs_; } |
68 | |
69 | private: |
70 | const IRModule& module_; |
71 | Array<Expr> new_outputs_; |
72 | }; |
73 | |
74 | Expr FlattenOutputTuple(const Array<Expr>& exprs) { |
75 | Array<Expr> fields; |
76 | for (const auto& it : exprs) { |
77 | ICHECK(it->checked_type_.defined()); |
78 | if (auto* tn = it->checked_type_.as<TupleTypeNode>()) { |
79 | // TODO(seanlatias): for now input argument cannot be a tuple |
80 | ICHECK(it->IsInstance<CallNode>()); |
81 | for (size_t i = 0; i < tn->fields.size(); i++) { |
82 | fields.push_back(TupleGetItem(it, i)); |
83 | } |
84 | } else { |
85 | fields.push_back(it); |
86 | } |
87 | } |
88 | return Tuple(fields); |
89 | } |
90 | |
91 | IRModule GetCalibrateModule(IRModule module) { |
92 | auto glob_funcs = module->functions; |
93 | // module is mutable, hence, we make a copy of it. |
94 | module.CopyOnWrite(); |
95 | for (const auto& pair : glob_funcs) { |
96 | if (auto* fn = pair.second.as<FunctionNode>()) { |
97 | auto func = GetRef<Function>(fn); |
98 | // we only collect the outputs for main function |
99 | if (pair.first->name_hint == "main" ) { |
100 | Collector collector(module); |
101 | PostOrderRewrite(func->body, &collector); |
102 | auto new_outputs = collector.GetNewOutputs(); |
103 | Expr tuple = FlattenOutputTuple(new_outputs); |
104 | func = Function(func->params, tuple, tuple->checked_type_, func->type_params, func->attrs); |
105 | module->Update(pair.first, func); |
106 | } |
107 | } |
108 | } |
109 | // reset the attribute of functions for running graph executor |
110 | for (const auto& pair : glob_funcs) { |
111 | if (auto* fn = pair.second.as<FunctionNode>()) { |
112 | auto func = GetRef<Function>(fn); |
113 | if (func->GetAttr<String>(attr::kCompiler)) { |
114 | // we need to inline the functions in order to run grpah runtime |
115 | func = WithAttr(std::move(func), attr::kInline, tvm::Integer(1)); |
116 | // reset the compiler attribute to null for llvm execution |
117 | func = WithAttr(std::move(func), attr::kCompiler, NullValue<ObjectRef>()); |
118 | module->Update(pair.first, func); |
119 | } |
120 | } |
121 | } |
122 | return module; |
123 | } |
124 | |
125 | /*! |
126 | * \brief This function generates the output mapping between |
127 | * the calibration data and each function. The key is a |
128 | * GlobalVar that corresponds to each function and the value |
129 | * is an array of integers. The size of the array is always |
130 | * three. The first value is the offset the points to the start. |
131 | * The second value is the number of inputs. The third value |
132 | * is the number of outputs. |
133 | */ |
134 | |
135 | class OutputMapper : public ExprRewriter { |
136 | public: |
137 | OutputMapper(Map<GlobalVar, Array<Integer>>* output_map, const IRModule& module, size_t* offset) |
138 | : output_map_(output_map), module_(module), offset_(offset) {} |
139 | |
140 | Expr Rewrite_(const CallNode* call, const Expr& post) final { |
141 | if (call->op->IsInstance<GlobalVarNode>()) { |
142 | auto var = Downcast<GlobalVar>(call->op); |
143 | ICHECK(module_->ContainGlobalVar(var->name_hint)) << "Function " << var << " is not defined" ; |
144 | ICHECK_EQ(output_map_->count(var), 0) |
145 | << "Repeated function call " << var << " is not supported." ; |
146 | auto func = Downcast<Function>(module_->Lookup(var)); |
147 | // we only handle functions with Compiler attribute set |
148 | if (func->GetAttr<String>(attr::kCompiler)) { |
149 | Array<Integer> info; |
150 | // the first value is the offset |
151 | info.push_back(Integer(*offset_)); |
152 | // the second value is the number of inputs |
153 | info.push_back(Integer(call->args.size())); |
154 | // the third value is the number of outputs |
155 | // we need to check if the output is a tuple |
156 | size_t out_size = 1; |
157 | if (auto* tn = func->body.as<TupleNode>()) { |
158 | info.push_back(Integer(tn->fields.size())); |
159 | out_size = tn->fields.size(); |
160 | } else { |
161 | info.push_back(Integer(1)); |
162 | } |
163 | output_map_->Set(var, info); |
164 | // calculate the offset for the next function |
165 | *offset_ = *offset_ + call->args.size() + out_size; |
166 | } |
167 | } |
168 | return post; |
169 | } |
170 | |
171 | private: |
172 | Map<GlobalVar, Array<Integer>>* output_map_; |
173 | const IRModule& module_; |
174 | size_t* offset_; |
175 | }; |
176 | |
177 | Map<GlobalVar, Array<Integer>> GetCalibrateOutputMap(const IRModule& module) { |
178 | Map<GlobalVar, Array<Integer>> output_map; |
179 | size_t offset = 0; |
180 | auto glob_funcs = module->functions; |
181 | for (const auto& pair : glob_funcs) { |
182 | if (auto* fn = pair.second.as<FunctionNode>()) { |
183 | if (pair.first->name_hint == "main" ) { |
184 | OutputMapper output_mapper(&output_map, module, &offset); |
185 | auto func = GetRef<Function>(fn); |
186 | PostOrderRewrite(func->body, &output_mapper); |
187 | } |
188 | } |
189 | } |
190 | |
191 | return output_map; |
192 | } |
193 | |
194 | TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_module" ).set_body_typed([](IRModule mod) { |
195 | return GetCalibrateModule(mod); |
196 | }); |
197 | |
198 | TVM_REGISTER_GLOBAL("relay.analysis.get_calibrate_output_map" ) |
199 | .set_body_typed([](const IRModule& mod) { return GetCalibrateOutputMap(mod); }); |
200 | |
201 | } // namespace relay |
202 | } // namespace tvm |
203 | |