1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | #include "./te_compiler_cache.h" |
21 | |
22 | #include <tvm/driver/driver_api.h> |
23 | #include <tvm/ir/name_supply.h> |
24 | #include <tvm/ir/type_functor.h> |
25 | #include <tvm/meta_schedule/database.h> |
26 | #include <tvm/relay/analysis.h> |
27 | #include <tvm/relay/attrs/device_copy.h> |
28 | #include <tvm/relay/expr.h> |
29 | #include <tvm/relay/expr_functor.h> |
30 | #include <tvm/relay/op.h> |
31 | #include <tvm/relay/op_attr_types.h> |
32 | #include <tvm/relay/op_strategy.h> |
33 | #include <tvm/runtime/builtin_fp16.h> |
34 | #include <tvm/runtime/device_api.h> |
35 | #include <tvm/runtime/registry.h> |
36 | #include <tvm/te/operation.h> |
37 | #include <tvm/te/schedule.h> |
38 | #include <tvm/te/schedule_pass.h> |
39 | #include <tvm/tir/function.h> |
40 | #include <tvm/tir/index_map.h> |
41 | #include <tvm/tir/schedule/schedule.h> |
42 | #include <tvm/tir/stmt_functor.h> |
43 | #include <tvm/tir/transform.h> |
44 | #include <tvm/topi/tags.h> |
45 | |
46 | #include <functional> |
47 | #include <limits> |
48 | #include <memory> |
49 | #include <mutex> |
50 | #include <unordered_map> |
51 | #include <utility> |
52 | #include <vector> |
53 | |
54 | #include "../../te/operation/create_primfunc.h" |
55 | #include "../op/memory/memory.h" |
56 | #include "../src/meta_schedule/module_equality.h" |
57 | #include "../src/meta_schedule/trace_apply.h" |
58 | #include "../transforms/meta_schedule_layout_rewrite.h" |
59 | #include "utils.h" |
60 | |
61 | namespace tvm { |
62 | namespace relay { |
63 | namespace tec { |
64 | |
65 | TVM_REGISTER_NODE_TYPE(LoweredOutputNode); |
66 | TVM_REGISTER_NODE_TYPE(CachedFuncNode); |
67 | TVM_REGISTER_NODE_TYPE(CCacheKeyNode); |
68 | TVM_REGISTER_NODE_TYPE(CCacheValueNode); |
69 | |
70 | LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) { |
71 | auto n = make_object<LoweredOutputNode>(); |
72 | n->outputs = std::move(outputs); |
73 | n->implementation = std::move(impl); |
74 | data_ = std::move(n); |
75 | } |
76 | |
77 | CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) { |
78 | auto n = make_object<CCacheKeyNode>(); |
79 | n->source_func = std::move(source_func); |
80 | n->target = std::move(target); |
81 | n->virtual_device = std::move(vd); |
82 | data_ = std::move(n); |
83 | } |
84 | |
85 | CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs, |
86 | tvm::Array<te::Tensor> outputs, te::Schedule schedule, |
87 | tir::PrimFunc prim_func, tvm::Array<Integer> shape_func_param_states, |
88 | IRModule funcs, |
89 | std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors) { |
90 | auto n = make_object<CachedFuncNode>(); |
91 | n->target = target; |
92 | n->prim_fn_var = prim_fn_var; |
93 | n->inputs = inputs; |
94 | n->outputs = outputs; |
95 | n->schedule = schedule; |
96 | n->prim_func = prim_func; |
97 | n->shape_func_param_states = shape_func_param_states; |
98 | n->funcs = funcs; |
99 | n->constant_tensors = constant_tensors; |
100 | data_ = std::move(n); |
101 | } |
102 | |
103 | Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) { |
104 | // for now, we always use int32 shape when possible |
105 | // even if the result of shape inference becomes int64. |
106 | Array<IndexExpr> res; |
107 | for (IndexExpr val : shape) { |
108 | const int64_t* pval = tir::as_const_int(val); |
109 | if (pval != nullptr) { |
110 | #ifndef TVM_INDEX_DEFAULT_I64 |
111 | ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max()) |
112 | << "dimension must be less then int32_t's max value" ; |
113 | ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min()) |
114 | << "dimension must be less then int32_t's max value" ; |
115 | res.push_back(IntImm(DataType::Int(32), *pval)); |
116 | #else |
117 | res.push_back(val); |
118 | #endif // TVM_INDEX_DEFAULT_I64 |
119 | } else if (val->IsInstance<tir::AnyNode>()) { |
120 | // currently all 'any' we meet in shape function are non-negative. |
121 | res.push_back(val.as<tir::AnyNode>()->ToSizeVar()); |
122 | } else { |
123 | res.push_back(val); |
124 | } |
125 | } |
126 | return res; |
127 | } |
128 | |
129 | // Helper class that is used during lowering to TE. |
130 | // It matches sequence of Ops and lower them into single TOPI operation. All supported patterns are |
131 | // enumerated in "supported_patterns_". |
132 | class QnnPatternMatcher { |
133 | public: |
134 | QnnPatternMatcher() |
135 | : qnn_conv2d_op_(Op::Get("qnn.conv2d" )), |
136 | qnn_dense_op_(Op::Get("qnn.dense" )), |
137 | qnn_dense_pack_op_(Op::Get("qnn.contrib_dense_pack" )), |
138 | qnn_requantize_op_(Op::Get("qnn.requantize" )), |
139 | bias_add_op_(Op::Get("add" )) {} |
140 | |
141 | // Memoize visited operations |
142 | void Register(const CallNode* call_node) { |
143 | ICHECK(call_node->op.as<OpNode>()); |
144 | Op op = Downcast<Op>(call_node->op); |
145 | if (op == qnn_conv2d_op_) { |
146 | registered_ops_.push_front(P_QConv2d); |
147 | ICHECK(anchor_op_ == nullptr); |
148 | anchor_op_ = call_node; |
149 | } else if (op == qnn_requantize_op_) { |
150 | registered_ops_.push_front(P_QRequantize); |
151 | } else if (op == bias_add_op_) { |
152 | registered_ops_.push_front(P_BiasAdd); |
153 | } else if (op == qnn_dense_op_) { |
154 | registered_ops_.push_front(P_QDense); |
155 | ICHECK(anchor_op_ == nullptr); |
156 | anchor_op_ = call_node; |
157 | } else if (op == qnn_dense_pack_op_) { |
158 | registered_ops_.push_front(P_QDensePack); |
159 | ICHECK(anchor_op_ == nullptr); |
160 | anchor_op_ = call_node; |
161 | } else { |
162 | registered_ops_.push_front(P_Opaque); |
163 | } |
164 | } |
165 | |
166 | // Check whether given Op is a part of matched pattern. |
167 | bool find(const Op& op) { |
168 | if (registered_ops_.empty()) return false; |
169 | |
170 | if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_ || |
171 | op == qnn_dense_op_ || op == qnn_dense_pack_op_) { |
172 | for (const auto& pat : supported_patterns_) { |
173 | auto it = |
174 | std::search(registered_ops_.begin(), registered_ops_.end(), pat.begin(), pat.end()); |
175 | if (it != registered_ops_.end()) return true; |
176 | } |
177 | } |
178 | return false; |
179 | } |
180 | |
181 | // returns whether given Op is last in the pattern sequence. |
182 | bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; } |
183 | |
184 | const CallNode* GetAnchorOp() { return anchor_op_; } |
185 | |
186 | void Clear() { registered_ops_.clear(); } |
187 | |
188 | private: |
189 | const Op& qnn_conv2d_op_; |
190 | const Op& qnn_dense_op_; |
191 | const Op& qnn_dense_pack_op_; |
192 | const Op& qnn_requantize_op_; |
193 | const Op& bias_add_op_; |
194 | |
195 | // Main (complicated) operation in the primitive (for example qnn.conv2d, qnn.dense etc.). |
196 | const CallNode* anchor_op_ = nullptr; |
197 | |
198 | enum POper { P_QConv2d, P_QDense, P_QDensePack, P_BiasAdd, P_QRequantize, P_Opaque }; |
199 | |
200 | std::deque<POper> registered_ops_; |
201 | |
202 | const std::vector<std::deque<POper>> supported_patterns_ = { |
203 | {P_QDense, P_BiasAdd, P_QRequantize}, // qnn.dense -> bias_add -> qnn.requantize |
204 | {P_QDense, P_QRequantize}, // qnn.dense -> qnn.requantize |
205 | {P_QDensePack, P_BiasAdd, P_QRequantize}, // qnn.contrib_dense_pack -> bias -> qnn.requantize |
206 | {P_QDensePack, P_QRequantize}, // qnn.contrib_dense_pack -> qnn.requantize |
207 | {P_QConv2d, P_BiasAdd, P_QRequantize}, // qnn.conv2d -> bias_add -> qnn.requantize |
208 | {P_QConv2d, P_QRequantize} // qnn.conv2d -> qnn.requantize |
209 | }; |
210 | }; |
211 | |
212 | // Lowers Relay primitive Function to TE Compute |
213 | class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> { |
214 | public: |
215 | LowerToTECompute(Target target, NameSupply constants_name_supply) |
216 | : target_(target), |
217 | device_copy_op_(Op::Get("device_copy" )), |
218 | constants_name_supply_(constants_name_supply) {} |
219 | |
220 | Array<te::Tensor> Lower(const Function& relay_func) { |
221 | for (Var param : relay_func->params) { |
222 | Array<tvm::te::Tensor> inputs; |
223 | for (const auto& ttype : FlattenTupleType(param->checked_type())) { |
224 | auto name_hint = param->vid->name_hint; |
225 | tvm::te::Tensor tensor = tvm::te::placeholder( |
226 | GetShape(ttype->shape), ttype->dtype, (name_hint == "" ) ? "placeholder" : name_hint); |
227 | inputs.push_back(tensor); |
228 | fn_inputs_.push_back(tensor); |
229 | } |
230 | memo_[param] = inputs; |
231 | } |
232 | readable_name_stream_ << "fused" ; |
233 | |
234 | Array<te::Tensor> outputs = this->VisitExpr(relay_func->body); |
235 | |
236 | candidate_name_ = readable_name_stream_.str(); |
237 | |
238 | constexpr static size_t kMaxFuncNameLength = 80; |
239 | // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME |
240 | // whenever the value of kMaxFuncNameLength changes |
241 | if (candidate_name_.size() > kMaxFuncNameLength) { |
242 | std::stringstream truncated_name; |
243 | truncated_name << candidate_name_.substr(0, kMaxFuncNameLength); |
244 | truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_" ; |
245 | candidate_name_ = truncated_name.str(); |
246 | } |
247 | |
248 | return outputs; |
249 | } |
250 | |
251 | Array<te::Tensor> VisitExpr_(const VarNode* op) final { |
252 | LOG(FATAL) << "Unexpected free variable " << PrettyPrint(GetRef<Var>(op)); |
253 | } |
254 | |
255 | Array<te::Tensor> VisitExpr_(const ConstantNode* op) final { |
256 | using tir::make_const; |
257 | void* data = op->data->data; |
258 | DataType dtype = DataType(op->data->dtype); |
259 | if (op->is_scalar()) { |
260 | auto value = te::compute( |
261 | {}, |
262 | [&](const Array<tvm::tir::Var>&) { |
263 | if (dtype == DataType::Int(16)) { |
264 | return make_const(dtype, static_cast<const int16_t*>(data)[0]); |
265 | } else if (dtype == DataType::Int(8)) { |
266 | return make_const(dtype, static_cast<const int8_t*>(data)[0]); |
267 | } else if (dtype == DataType::UInt(8) || dtype == DataType::Bool()) { |
268 | return make_const(dtype, static_cast<const uint8_t*>(data)[0]); |
269 | } else if (dtype == DataType::Int(32)) { |
270 | return make_const(dtype, static_cast<const int32_t*>(data)[0]); |
271 | } else if (dtype == DataType::Int(64)) { |
272 | return make_const(dtype, static_cast<const int64_t*>(data)[0]); |
273 | } else if (dtype == DataType::Float(16)) { |
274 | return make_const(dtype, __gnu_h2f_ieee(static_cast<const uint16_t*>(data)[0])); |
275 | } else if (dtype == DataType::Float(32)) { |
276 | return make_const(dtype, static_cast<const float*>(data)[0]); |
277 | } else if (dtype == DataType::Float(64)) { |
278 | return make_const(dtype, static_cast<const double*>(data)[0]); |
279 | } else { |
280 | LOG(FATAL) << dtype << " not handled" ; |
281 | } |
282 | }, |
283 | "compile_engine_const" , topi::kBroadcast); |
284 | scalars_.push_back(value->op); |
285 | return {value}; |
286 | } else { |
287 | const auto* ttype = op->checked_type().as<TensorTypeNode>(); |
288 | std::stringstream ss; |
289 | std::string s = readable_name_stream_.str(); |
290 | std::replace(s.begin(), s.end(), '.', '_'); |
291 | ss << constants_name_supply_->FreshName(s + "_constant" ); |
292 | tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype, ss.str()); |
293 | constant_tensors_[op] = tensor; |
294 | return {tensor}; |
295 | } |
296 | } |
297 | |
298 | Array<te::Tensor> VisitExpr_(const CallNode* call_node) final { |
299 | static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call" ); |
300 | ICHECK(flower_call) << "relay.backend.lower_call is not registered." ; |
301 | |
302 | pattern_matcher_.Register(call_node); |
303 | |
304 | Array<te::Tensor> inputs; |
305 | // int count_tuple = 0; |
306 | for (Expr arg : call_node->args) { |
307 | if (arg->checked_type().as<TupleTypeNode>()) { |
308 | // ++count_tuple; |
309 | } |
310 | for (te::Tensor tensor : VisitExpr(arg)) { |
311 | inputs.push_back(tensor); |
312 | } |
313 | } |
314 | |
315 | ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops" ; |
316 | Op op = Downcast<Op>(call_node->op); |
317 | |
318 | // TODO(mbs): device_copy cleanup |
319 | ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered" ; |
320 | |
321 | Array<te::Tensor> outputs; |
322 | |
323 | if (pattern_matcher_.find(op)) { |
324 | if (pattern_matcher_.IsLeafOp(op)) { |
325 | // Lower anchor op when pattern leaf op was reached |
326 | auto anchor_op = pattern_matcher_.GetAnchorOp(); |
327 | LoweredOutput lowered_out = |
328 | (*flower_call)(GetRef<Call>(anchor_op), inputs, target_, call_node->checked_type()); |
329 | outputs = lowered_out->outputs; |
330 | Op a_op = Downcast<Op>(anchor_op->op); |
331 | op_implementations_[a_op.operator->()] = lowered_out->implementation; |
332 | |
333 | pattern_matcher_.Clear(); |
334 | } else { |
335 | // Forward inputs as "outputs" for successor. |
336 | readable_name_stream_ << '_' << op->name; |
337 | return inputs; |
338 | } |
339 | } else { |
340 | LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_); |
341 | outputs = lowered_out->outputs; |
342 | op_implementations_[op.operator->()] = lowered_out->implementation; |
343 | } |
344 | |
345 | if (outputs.size() != 1) { |
346 | const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>(); |
347 | ICHECK(tuple_type) << "Expected output to be a tuple type " |
348 | << PrettyPrint(call_node->checked_type()); |
349 | |
350 | ICHECK_EQ(tuple_type->fields.size(), outputs.size()); |
351 | } |
352 | |
353 | readable_name_stream_ << '_' << op->name; |
354 | return outputs; |
355 | } |
356 | |
357 | Array<te::Tensor> VisitExpr_(const FunctionNode* op) final { |
358 | LOG(FATAL) << "Primitive Functions can not contain nested functions." ; |
359 | } |
360 | |
361 | Array<te::Tensor> VisitExpr_(const LetNode* op) final { |
362 | Array<te::Tensor> val = VisitExpr(op->value); |
363 | ICHECK(!memo_.count(op->var)); |
364 | memo_[op->var] = val; |
365 | return VisitExpr(op->body); |
366 | } |
367 | |
368 | Array<te::Tensor> VisitExpr_(const TupleNode* op) final { |
369 | Array<te::Tensor> fields; |
370 | for (Expr field : op->fields) { |
371 | // TODO(mbs): Generalize to be equivalent to FlattenTupleType. |
372 | ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor" ; |
373 | Array<te::Tensor> res = VisitExpr(field); |
374 | ICHECK_EQ(res.size(), 1); |
375 | fields.push_back(res[0]); |
376 | } |
377 | return fields; |
378 | } |
379 | |
380 | Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final { |
381 | const auto* tuple_type = op->tuple->type_as<TupleTypeNode>(); |
382 | Array<te::Tensor> tuple = VisitExpr(op->tuple); |
383 | ICHECK_EQ(tuple_type->fields.size(), tuple.size()); |
384 | ICHECK_GE(op->index, 0); |
385 | ICHECK_LT(static_cast<size_t>(op->index), tuple.size()); |
386 | return {tuple[op->index]}; |
387 | } |
388 | |
389 | public: |
390 | // Additional outputs |
391 | Array<tvm::te::Tensor> fn_inputs_; |
392 | Array<te::Operation> scalars_; |
393 | std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_; |
394 | std::unordered_map<const OpNode*, OpImplementation> op_implementations_; |
395 | std::string candidate_name_; |
396 | |
397 | private: |
398 | QnnPatternMatcher pattern_matcher_; |
399 | |
400 | tvm::Target target_; |
401 | std::ostringstream readable_name_stream_; |
402 | // Cache device copy op for equivalence checking to reduce registry lookup |
403 | // overhead for each invocation of call node when retrieving schedules. |
404 | const Op& device_copy_op_; |
405 | // A NameSupply object passed from a caller, used to assign unique names to constants |
406 | // across different invocations of LowerToTECompute. |
407 | NameSupply constants_name_supply_; |
408 | }; |
409 | |
410 | using namespace tvm::tir; |
411 | |
412 | class LayoutFreeConstantCollector : public StmtVisitor { |
413 | public: |
414 | Array<runtime::NDArray> constants; |
415 | |
416 | private: |
417 | void VisitStmt_(const BlockNode* op) final { |
418 | StmtVisitor::VisitStmt_(op); |
419 | if (Optional<ObjectRef> ann = op->annotations.Get("layout_free_placeholders" )) { |
420 | for (Buffer buffer : Downcast<Array<Buffer>>(ann)) { |
421 | layout_free_buffer_vars_.insert(buffer->data.get()); |
422 | } |
423 | } |
424 | } |
425 | |
426 | void VisitStmt_(const AllocateConstNode* op) final { |
427 | StmtVisitor::VisitStmt_(op); |
428 | if (auto it = layout_free_buffer_vars_.find(op->buffer_var.get()); |
429 | it != layout_free_buffer_vars_.end()) { |
430 | constants.push_back(op->data.value()); |
431 | } |
432 | } |
433 | |
434 | std::unordered_set<const tir::VarNode*> layout_free_buffer_vars_; |
435 | }; |
436 | |
437 | using NDArrayMap = |
438 | std::unordered_map<runtime::NDArray, runtime::NDArray, ObjectPtrHash, ObjectPtrEqual>; |
439 | |
440 | // Replace constants in AllocateConst nodes according to the given mapping |
441 | class AllocateConstReplaceConstant : public StmtExprMutator { |
442 | public: |
443 | explicit AllocateConstReplaceConstant(const NDArrayMap& constant_map) |
444 | : constant_map_(constant_map) {} |
445 | |
446 | static PrimFunc Rewrite(PrimFunc f, const NDArrayMap& constant_map) { |
447 | AllocateConstReplaceConstant rewriter(constant_map); |
448 | PrimFuncNode* n = f.CopyOnWrite(); |
449 | n->body = rewriter(std::move(n->body)); |
450 | return f; |
451 | } |
452 | |
453 | private: |
454 | Stmt VisitStmt_(const AllocateConstNode* op) final { |
455 | if (auto it = constant_map_.find(op->data.value()); it != constant_map_.end()) { |
456 | auto rewriten_constant = it->second; |
457 | Array<PrimExpr> rewritten_extents; |
458 | for (auto s : rewriten_constant.Shape()) { |
459 | rewritten_extents.push_back(PrimExpr(static_cast<int>(s))); |
460 | } |
461 | return AllocateConst(op->buffer_var, op->dtype, rewritten_extents, rewriten_constant, |
462 | op->body, op->annotations, op->span); |
463 | } |
464 | return StmtExprMutator::VisitStmt_(op); |
465 | } |
466 | |
467 | NDArrayMap constant_map_; |
468 | }; |
469 | |
470 | // Construct a schedule for a given Relay primitive function and target. |
471 | class ScheduleBuilder : public ExprVisitor { |
472 | public: |
473 | explicit ScheduleBuilder(Target target) |
474 | : target_(target), |
475 | mod_eq_structural_(meta_schedule::ModuleEquality::Create("ignore-ndarray" )) { |
476 | // Whether to use auto_scheduler schedule. |
477 | use_auto_scheduler_ = backend::IsAutoSchedulerEnabled(); |
478 | if (backend::IsMetaScheduleEnabled()) { |
479 | database_ = meta_schedule::Database::Current(); |
480 | CHECK(database_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay " |
481 | "build, but no `meta_schedule.Database` context is provided. " ; |
482 | } else { |
483 | database_ = NullOpt; |
484 | } |
485 | } |
486 | |
487 | CachedFunc Create(const Function& relay_func, GlobalVarSupply global_var_supply, |
488 | NameSupply constant_name_supply) { |
489 | LowerToTECompute lower_te_compute(target_, constant_name_supply); |
490 | Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func); |
491 | Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_; |
492 | VisitExpr(relay_func->body); |
493 | |
494 | // TODO(mbs): This should be the definitive global by which the PrimFunc is known and |
495 | // no other GlobalVar ctors should appear inside the lowering machinery. |
496 | auto prim_fn_var = global_var_supply->FreshGlobal(lower_te_compute.candidate_name_); |
497 | prim_fn_var->checked_type_ = relay_func->checked_type(); |
498 | |
499 | // Fusion over tupled results may leave identity relationships |
500 | // between inputs and outputs, copy identity output tensors, |
501 | // since tir lowering do not support aliasing output to input buffer. |
502 | for (size_t i = 0; i < tensor_outs.size(); ++i) { |
503 | if (tensor_outs[i]->op.as<te::PlaceholderOpNode>()) { |
504 | tensor_outs.Set(i, topi::identity(tensor_outs[i])); |
505 | } |
506 | } |
507 | |
508 | te::Schedule schedule{nullptr}; |
509 | tir::PrimFunc prim_func{nullptr}; |
510 | // No need to register schedule for device copy op. |
511 | if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) { |
512 | if (use_auto_scheduler_) { |
513 | const auto* fauto_schedule = |
514 | runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute" ); |
515 | ICHECK(fauto_schedule != nullptr) |
516 | << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered" ; |
517 | ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs); |
518 | if (obj.defined()) { |
519 | schedule = Downcast<te::Schedule>(obj); |
520 | } |
521 | } |
522 | if (database_) { |
523 | using tvm::meta_schedule::TuningRecord; |
524 | using tvm::tir::IndexMap; |
525 | using tvm::tir::Instruction; |
526 | using tvm::tir::InstructionKind; |
527 | using tvm::tir::PrimFunc; |
528 | using tvm::tir::Schedule; |
529 | backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter(); |
530 | Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs); |
531 | Array<runtime::NDArray> constants; |
532 | for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { |
533 | te_args.push_back(te_tensor); |
534 | constants.push_back(const_node->data); |
535 | } |
536 | if (Optional<PrimFunc> f = tir_converter(te_args, constants)) { |
537 | IRModule query_mod = backend::PrimFuncToIRModule(f.value()); |
538 | if (Optional<TuningRecord> opt_record = database_.value()->QueryTuningRecord( |
539 | /*mod=*/query_mod, |
540 | /*target=*/target_, |
541 | /*workload_name=*/prim_fn_var->name_hint)) { |
542 | LayoutFreeConstantCollector const_collector; |
543 | const_collector(f.value()->body); |
544 | |
545 | static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout" ); |
546 | TuningRecord record = opt_record.value(); |
547 | for (const Instruction& inst : record->trace->insts) { |
548 | if (inst->kind.same_as(kind_transform_layout)) { |
549 | ICHECK_EQ(inst->inputs.size(), 2); |
550 | auto index_map = Downcast<IndexMap>(inst->inputs[1]); |
551 | |
552 | if (!const_collector.constants.empty()) { |
553 | // In this case, RewriteLayout is acting on an AllocateConst node. |
554 | // After tuning, we reach this code path twice: First by |
555 | // the Relay MetaScheduleLayoutRewrite pass, and next by the final |
556 | // compilation (Relay to TE schedule lowering). |
557 | // |
558 | // Due to Relay MetaScheduleLayoutRewrite and FoldConstant passes, |
559 | // the Relay subgraph for which we query the database during the |
560 | // final compilation has its weight tensor transformed according to |
561 | // the index map, determined during tuning. For example, |
562 | // |
563 | // fn (%p0: Tensor[(1, 56, 56, 64), float32]) { |
564 | // %0 = nn.conv2d(%p0, meta[relay.Constant][0], |
565 | // /*ty=Tensor[(4, 2, 2, 3, 3, 32, 8), float32]*/, ...); |
566 | // add(%0, meta[relay.Constant][1]) |
567 | // } |
568 | // |
569 | // Note that the database does not have an entry corresponding to such subgraphs, |
570 | // since an input subgraph to the tuning system always has its weight tensor in |
571 | // the original layout, e.g. |
572 | // |
573 | // fn (%p0: Tensor[(1, 56, 56, 64), float32]) { |
574 | // %0 = nn.conv2d(%p0, meta[relay.Constant][0], |
575 | // /*ty=Tensor[(3, 3, 64, 64), float32]*/, ...); |
576 | // add(%0, meta[relay.Constant][1]) |
577 | // } |
578 | // |
579 | // Thus, in both of the two cases where we reach this code path, we need careful |
580 | // logic to make sure that (1) the database lookup during the final compilation |
581 | // succeeds and (2) the application of a schedule trace is well defined. |
582 | |
583 | ICHECK(const_collector.constants.size() == 1) |
584 | << "Only one layout-free constant is supported by RewriteLayout for now" ; |
585 | auto constant = const_collector.constants[0]; |
586 | |
587 | auto is_constant_transformed = [index_map](runtime::NDArray c) { |
588 | if (c.Shape().size() != index_map->initial_indices.size()) { |
589 | return true; |
590 | } |
591 | size_t src_size_1d = 1; |
592 | Array<PrimExpr> orig_shape; |
593 | for (size_t i = 0; i < c.Shape().size(); ++i) { |
594 | src_size_1d *= c->shape[i]; |
595 | orig_shape.push_back(PrimExpr(static_cast<int>((c->shape[i])))); |
596 | } |
597 | auto dst_shape = index_map->MapShape(orig_shape); |
598 | std::vector<int64_t> dst_shape_int; |
599 | size_t dst_size_1d = 1; |
600 | for (size_t i = 0; i < dst_shape.size(); ++i) { |
601 | dst_size_1d *= dst_shape[i].as<IntImmNode>()->value; |
602 | } |
603 | return src_size_1d != dst_size_1d; |
604 | }; |
605 | |
606 | if (!is_constant_transformed(constant)) { |
607 | // This is the first case, reached during the MetaScheduleLayoutRewrite pass. |
608 | // |
609 | // A layout-free constant having the same rank as an input to the index map |
610 | // is assumed to be transformed by this index map. |
611 | // TODO(masahi): If there are multiple layout-free constants in one |
612 | // TIR mod (e.g. conv2d -> conv2d fusion), this assumption does not hold. |
613 | // We need to determine which constant the given index map acts on. |
614 | // |
615 | // We know that, during the final compilation, we will query the database |
616 | // for a subgraph that the tuner has never seen. We workaround this problem |
617 | // by adding a dummy entry to the database. The dummy entry is carefully |
618 | // constructed so that the lookup during the final compilation would succeed. |
619 | runtime::NDArray rewritten_constant = index_map->MapNDArray(constant); |
620 | auto f_dummy = AllocateConstReplaceConstant::Rewrite( |
621 | f.value(), {{constant, rewritten_constant}}); |
622 | auto workload_dummy = |
623 | database_.value()->CommitWorkload(backend::PrimFuncToIRModule(f_dummy)); |
624 | TuningRecord rec_dummy(record->trace, workload_dummy, record->run_secs, |
625 | record->target, record->args_info); |
626 | database_.value()->CommitTuningRecord(rec_dummy); |
627 | } else { |
628 | // The constant is already transformed, so this is the second case, reached |
629 | // during the final compilation. |
630 | // |
631 | // The schedule trace is supposed to be applied to the weight in its original |
632 | // layout. But as explained above, the Relay subgraph we get in this case |
633 | // has its weight tensor transformed according to the corresponding index map. |
634 | // So effectively, we undo the layout transformation on the weight to restore |
635 | // the original PrimFunc that the schedule trace is supposed to act on. |
636 | ICHECK(index_map->inverse_index_map); |
637 | auto inverse_map = Downcast<IndexMap>(index_map->inverse_index_map.value()); |
638 | ICHECK(constant.Shape().size() == inverse_map->initial_indices.size()); |
639 | runtime::NDArray orig_constant = inverse_map->MapNDArray(constant); |
640 | auto f_ = AllocateConstReplaceConstant::Rewrite(f.value(), |
641 | {{constant, orig_constant}}); |
642 | query_mod = backend::PrimFuncToIRModule(f_); |
643 | } |
644 | } |
645 | MetaScheduleLayoutRewriter::LayoutQueuePush(index_map); |
646 | } |
647 | } |
648 | |
649 | Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0, |
650 | tir::ScheduleErrorRenderLevel::kDetail); |
651 | |
652 | if (!mod_eq_structural_->Equal(query_mod, opt_record.value()->workload->mod)) { |
653 | // When the database lookup succeeds while structural equality check fails, |
654 | // it implies that the anchor block based equality has been used during tuning. |
655 | // The trace in the record cannot directly be applied to this query module. |
656 | meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, target_); |
657 | } else { |
658 | record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false); |
659 | } |
660 | |
661 | IRModule mod = sch->mod(); |
662 | ICHECK_EQ(mod->functions.size(), 1); |
663 | mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ false)( |
664 | std::move(mod)); |
665 | prim_func = Downcast<PrimFunc>(mod->Lookup("main" )); |
666 | // Need to copy attrs from relay function over to prim func. Most notably the structural |
667 | // hash. |
668 | prim_func = WithAttrs(prim_func, relay_func->attrs->dict); |
669 | } else { |
670 | int dispatch = backend::UseMetaScheduleDispatch(); |
671 | // (dispatch & 2): controls whether to print TVMScript for missing TIR |
672 | // (dispatch & 4): controls whether to raise fatal errors for missing TIR |
673 | if (dispatch & 2) { |
674 | LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint << "\n" |
675 | << f.value(); |
676 | } else { |
677 | LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint; |
678 | } |
679 | if (dispatch & 4) { |
680 | LOG(FATAL); |
681 | } |
682 | } |
683 | } |
684 | } |
685 | // Use TOPI schedule if user specified, or the function has no auto_scheduler schedule. |
686 | if (!schedule.defined() && !prim_func.defined()) { |
687 | if (anchor_op_.defined()) { |
688 | auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->()); |
689 | ICHECK(anchor_impl != lower_te_compute.op_implementations_.end()); |
690 | schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_); |
691 | } else { |
692 | auto default_sched = GenericFunc::Get("schedule_injective" ); |
693 | ICHECK(default_sched.defined()) << "schedule_injective not registered for " << target_; |
694 | With<Target> tctx(target_); |
695 | schedule = default_sched(tensor_outs); |
696 | } |
697 | } |
698 | if (schedule.defined()) { |
699 | for (const auto& scalar : lower_te_compute.scalars_) { |
700 | if (schedule->Contain(scalar)) { |
701 | schedule[scalar].compute_inline(); |
702 | } |
703 | } |
704 | } |
705 | } |
706 | |
707 | IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({})); |
708 | return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule, prim_func, {}, funcs, |
709 | lower_te_compute.constant_tensors_); |
710 | } |
711 | |
712 | void VisitExpr_(const CallNode* call_node) final { |
713 | static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern" ); |
714 | |
715 | ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops" ; |
716 | Op op = Downcast<Op>(call_node->op); |
717 | |
718 | for (Expr arg : call_node->args) { |
719 | VisitExpr(arg); |
720 | } |
721 | |
722 | int op_pattern = fpattern[op]; |
723 | if (!use_auto_scheduler_ && !database_.defined() && op_pattern >= kCommReduce) { |
724 | ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce) |
725 | << "Cannot apply TOPI schedule to a primitive function with two complicated ops" |
726 | << " anchor=" << anchor_op_ << " current=" << op; |
727 | } |
728 | if (op_pattern >= anchor_op_pattern_) { |
729 | anchor_op_ = op; |
730 | anchor_attrs_ = call_node->attrs; |
731 | anchor_op_pattern_ = op_pattern; |
732 | } |
733 | } |
734 | |
735 | private: |
736 | tvm::Target target_; |
737 | Op anchor_op_; |
738 | Attrs anchor_attrs_; |
739 | int anchor_op_pattern_{0}; |
740 | bool use_auto_scheduler_; |
741 | Optional<meta_schedule::Database> database_; |
742 | std::unique_ptr<meta_schedule::ModuleEquality> mod_eq_structural_; |
743 | }; |
744 | |
745 | /*! |
746 | * \brief Create schedule for target. |
747 | * \param source_func The primitive function to be lowered. |
748 | * \param target The target we want to create schedule for. |
749 | * \return Pair of schedule and cache. |
750 | * The funcs field in cache is not yet populated. |
751 | */ |
752 | CachedFunc PrimFuncFor(const Function& source_func, const Target& target, |
753 | GlobalVarSupply global_var_supply, NameSupply constant_name_supply) { |
754 | return ScheduleBuilder(target).Create(source_func, global_var_supply, constant_name_supply); |
755 | } |
756 | |
757 | // Creates shape function from functor. |
758 | class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> { |
759 | public: |
760 | MakeShapeFunc() {} |
761 | |
762 | CachedFunc Create(const Function& prim_func, const Target& target, |
763 | GlobalVarSupply global_var_supply) { |
764 | VLOG_CONTEXT << "MakeShapeFunc" ; |
765 | TShapeDataDependent shape_func_param_states; |
766 | |
767 | for (auto param : prim_func->params) { |
768 | param_states_[param] = kNoNeed; |
769 | Array<tvm::te::Tensor> data_inputs; |
770 | Array<tvm::te::Tensor> shape_inputs; |
771 | |
772 | for (const auto& ttype : FlattenTupleType(param->checked_type())) { |
773 | // Add data placeholder (in case we discover we need it below) |
774 | Shape shape = GetShape(ttype->shape); |
775 | tvm::te::Tensor data_tensor = |
776 | tvm::te::placeholder(shape, ttype->dtype, "data_" + param->vid->name_hint); |
777 | data_inputs.push_back(data_tensor); |
778 | // Add shape placeholder (in case we discover we need it below) |
779 | int64_t ndim = shape.size(); |
780 | Shape sshape; |
781 | if (ndim > 0) { |
782 | sshape.push_back(tvm::Integer(ndim)); |
783 | } |
784 | tvm::te::Tensor shape_tensor = |
785 | tvm::te::placeholder(sshape, DataType::Int(64), "shape_" + param->vid->name_hint); |
786 | shape_inputs.push_back(shape_tensor); |
787 | } |
788 | param_data_[param] = data_inputs; |
789 | param_shapes_[param] = shape_inputs; |
790 | } |
791 | |
792 | // Setup the name; |
793 | readable_name_stream_ << "shape_func" ; |
794 | |
795 | // Create the tensor expressions representing the output shapes. |
796 | Array<te::Tensor> outputs = VisitExpr(prim_func->body); |
797 | |
798 | // Generate a name. |
799 | auto candidate_name = readable_name_stream_.str(); |
800 | |
801 | constexpr static size_t kMaxFuncNameLength = 80; |
802 | // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME |
803 | // whenever the value of kMaxFuncNameLength changes |
804 | if (candidate_name.size() > kMaxFuncNameLength) { |
805 | std::stringstream truncated_name; |
806 | truncated_name << candidate_name.substr(0, kMaxFuncNameLength); |
807 | truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_" ; |
808 | candidate_name = truncated_name.str(); |
809 | } |
810 | |
811 | // Set all the inputs correctly, and accumulate their types from the p.o.v. of the |
812 | // shape function rather than the primitive it is derived for. |
813 | Array<te::Tensor> inputs; |
814 | Array<Type> shape_function_arg_types; |
815 | for (auto param : prim_func->params) { |
816 | int state = param_states_[param]; |
817 | shape_func_param_states.push_back(IntImm(DataType::Int(32), state)); |
818 | if (state & kNeedInputData) { |
819 | // Pass the primitive arguments directly (though in flattened form and on the host) |
820 | for (auto t : param_data_[param]) { |
821 | inputs.push_back(t); |
822 | shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType())); |
823 | } |
824 | } |
825 | if (state & kNeedInputShape) { |
826 | // Pass the shapes of the primitive arguments (also on the host) |
827 | for (auto t : param_shapes_[param]) { |
828 | inputs.push_back(t); |
829 | shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType())); |
830 | } |
831 | } |
832 | } |
833 | |
834 | // TODO(mbs): This should be the definitive global by which the PrimFunc is known and |
835 | // no other GlobalVar ctors should appear inside the lowering machinery. |
836 | auto prim_fn_gvar = global_var_supply->FreshGlobal(candidate_name); |
837 | |
838 | // Gather the result types, again from the p.o.v. of the shape function rather than |
839 | // the primitive it is derived for. |
840 | Array<Type> shape_function_res_types; |
841 | for (const auto& t : outputs) { |
842 | shape_function_res_types.push_back(TensorType(t->GetShape(), t->GetDataType())); |
843 | } |
844 | |
845 | // Assign the shape function its true type. |
846 | FuncType type(shape_function_arg_types, TupleType(shape_function_res_types), |
847 | /*type_params=*/{}, /*type_constraints=*/{}); |
848 | VLOG(1) << "shape function '" << prim_fn_gvar->name_hint << "' has type:" << std::endl |
849 | << PrettyPrint(type) << std::endl |
850 | << "corresponding to primitive of type:" << std::endl |
851 | << PrettyPrint(prim_func->checked_type()); |
852 | prim_fn_gvar->checked_type_ = std::move(type); |
853 | |
854 | // generate schedule for shape func |
855 | Array<te::Operation> out_ops; |
856 | for (auto t : outputs) { |
857 | out_ops.push_back(t->op); |
858 | } |
859 | te::Schedule schedule = te::create_schedule(out_ops); |
860 | tvm::te::AutoInlineInjective(schedule); |
861 | for (const auto& scalar : scalars_) { |
862 | auto scalar_op = scalar->op; |
863 | if (schedule->Contain(scalar_op)) { |
864 | schedule[scalar_op].compute_inline(); |
865 | } |
866 | } |
867 | |
868 | Array<te::Tensor> all_args = Array<te::Tensor>(inputs); |
869 | for (te::Tensor arg : outputs) { |
870 | all_args.push_back(arg); |
871 | } |
872 | |
873 | using tvm::transform::PassContext; |
874 | With<PassContext> fresh_pass_ctx_scope(PassContext::Create()); |
875 | |
876 | std::unordered_map<te::Tensor, tir::Buffer> binds; |
877 | IRModule lowered_module = |
878 | tvm::LowerSchedule(schedule, all_args, prim_fn_gvar->name_hint, binds, global_var_supply); |
879 | return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr}, |
880 | shape_func_param_states, lowered_module); |
881 | } |
882 | |
883 | Array<te::Tensor> VisitExpr(const Expr& expr) final { |
884 | if (expr.as<VarNode>()) { |
885 | // Do not memoize vars because shape functions could use either the data |
886 | // or the shape of a var each time. |
887 | return ExprFunctor::VisitExpr(expr); |
888 | } |
889 | // For other case, do memoized visit |
890 | return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr); |
891 | } |
892 | |
893 | Array<te::Tensor> VisitExpr_(const VarNode* var_node) final { |
894 | auto var = GetRef<Var>(var_node); |
895 | auto it = param_arg_map_.find(var); |
896 | if (it != param_arg_map_.end()) { |
897 | // This var is a parameter of a nested function. Visit the corresponding argument in the |
898 | // function call site. |
899 | return VisitExpr(it->second); |
900 | } |
901 | if (param_states_.find(var) == param_states_.end()) { |
902 | LOG(FATAL) << "Unexpected free variable " << PrettyPrint(var); |
903 | } else { |
904 | ICHECK(data_dependents_per_input_.size()); |
905 | auto data_dependent = data_dependents_per_input_.back(); |
906 | if (data_dependent) { |
907 | param_states_[var] |= kNeedInputData; |
908 | return param_data_[var]; |
909 | } else { |
910 | param_states_[var] |= kNeedInputShape; |
911 | return param_shapes_[var]; |
912 | } |
913 | } |
914 | } |
915 | |
916 | Array<te::Tensor> VisitExpr_(const ConstantNode* op) final { |
917 | using tir::make_const; |
918 | ICHECK(data_dependents_per_input_.size()); |
919 | bool data_dependent = data_dependents_per_input_.back(); |
920 | if (!op->is_scalar()) { |
921 | // This is a constant weight, extract the shape of the weight tensor. |
922 | // This can not be data dependent. |
923 | CHECK(!data_dependent); |
924 | auto ttype = op->checked_type().as<TensorTypeNode>(); |
925 | int ndim = static_cast<int>(ttype->shape.size()); |
926 | Array<PrimExpr> out_shape{ndim}; |
927 | te::Tensor value = tvm::te::compute( |
928 | out_shape, |
929 | [&](const Array<tvm::tir::Var>& indices) { |
930 | auto idx = indices[0]; |
931 | PrimExpr ret = make_const(DataType::Int(64), 0); |
932 | for (int i = 0; i < ndim; i++) { |
933 | ret = tvm::if_then_else(idx == i, ttype->shape[i], ret); |
934 | } |
935 | return ret; |
936 | }, |
937 | "shape_const" , topi::kBroadcast); |
938 | scalars_.push_back(value); |
939 | return {value}; |
940 | } |
941 | if (data_dependent) { |
942 | void* data = op->data->data; |
943 | DataType dtype = DataType(op->data->dtype); |
944 | auto value = tvm::te::compute( |
945 | {}, |
946 | [&](const Array<tvm::tir::Var>&) { |
947 | if (dtype == DataType::Int(32)) { |
948 | return make_const(dtype, static_cast<const int32_t*>(data)[0]); |
949 | } else if (dtype == DataType::Int(64)) { |
950 | return make_const(dtype, static_cast<const int64_t*>(data)[0]); |
951 | } else if (dtype == DataType::Float(32)) { |
952 | return make_const(dtype, static_cast<const float*>(data)[0]); |
953 | } else if (dtype == DataType::Float(64)) { |
954 | return make_const(dtype, static_cast<const double*>(data)[0]); |
955 | } else if (dtype == DataType::Bool()) { |
956 | return make_const(dtype, static_cast<const uint8_t*>(data)[0]); |
957 | } else { |
958 | LOG(FATAL) << "not handled" ; |
959 | } |
960 | }, |
961 | "data_const" , topi::kBroadcast); |
962 | scalars_.push_back(value); |
963 | return {value}; |
964 | } else { |
965 | auto value = tvm::te::compute( |
966 | {}, [&](const Array<tvm::tir::Var>&) { return tir::make_const(DataType::Int(64), 0); }, |
967 | "shape_const" , topi::kBroadcast); |
968 | scalars_.push_back(value); |
969 | return {value}; |
970 | } |
971 | } |
972 | |
973 | Array<te::Tensor> VisitExpr_(const CallNode* call_node) final { |
974 | VLOG(1) << "considering call:" << std::endl << PrettyPrint(GetRef<Call>(call_node)); |
975 | if (auto* func = call_node->op.as<FunctionNode>()) { |
976 | VLOG(1) << "user function" ; |
977 | for (size_t i = 0; i < func->params.size(); ++i) { |
978 | param_arg_map_[func->params[i]] = call_node->args[i]; |
979 | } |
980 | return VisitExpr(func->body); |
981 | } |
982 | |
983 | static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc" ); |
984 | static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent" ); |
985 | ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops" ; |
986 | Op op = Downcast<Op>(call_node->op); |
987 | ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back()) |
988 | << "Error in op fusion: output of the shape func is fed to a " |
989 | << "data-dependent shape func" ; |
990 | ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name; |
991 | ICHECK_GT(tshape_data_dependent.count(op), 0) |
992 | << "Internal error, cannot find TShapeDataDependent for " << op->name; |
993 | |
994 | Array<Integer> dep_spec = tshape_data_dependent[op]; |
995 | if (dep_spec.size() == 1) { |
996 | // This is for cases when data dependence is specified per op |
997 | // Replicate 0 or 1 flag to all arguments |
998 | for (size_t i = 1; i < call_node->args.size(); ++i) { |
999 | dep_spec.push_back(dep_spec[0]); |
1000 | } |
1001 | } |
1002 | |
1003 | // Visit all inputs |
1004 | Array<te::Tensor> inputs; |
1005 | int count_tuple = 0; |
1006 | for (size_t i = 0; i < call_node->args.size(); ++i) { |
1007 | Expr arg = call_node->args[i]; |
1008 | if (arg->checked_type().as<TupleTypeNode>()) { |
1009 | ++count_tuple; |
1010 | } |
1011 | data_dependents_per_input_.push_back(dep_spec[i]->value != 0); |
1012 | for (te::Tensor tensor : VisitExpr(arg)) { |
1013 | inputs.push_back(tensor); |
1014 | } |
1015 | data_dependents_per_input_.pop_back(); |
1016 | } |
1017 | if (count_tuple) { |
1018 | ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input" ; |
1019 | } |
1020 | // Get output ndims |
1021 | auto ret_type = call_node->checked_type(); |
1022 | Array<IndexExpr> out_ndims; |
1023 | for (const auto& ttype : FlattenTupleType(ret_type)) { |
1024 | out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size())); |
1025 | } |
1026 | |
1027 | // Call shape function |
1028 | Array<te::Tensor> outputs = fshape_func[op](call_node->attrs, inputs, out_ndims); |
1029 | VLOG(1) << "shape function for '" << op->name << "' with inputs:" << std::endl |
1030 | << inputs << std::endl |
1031 | << "yielded outputs:" << std::endl |
1032 | << outputs; |
1033 | readable_name_stream_ << "_" << op->name; |
1034 | return outputs; |
1035 | } |
1036 | |
1037 | Array<te::Tensor> VisitExpr_(const FunctionNode* op) final { |
1038 | LOG(FATAL) << "Nested functions are not allowed to be visited." ; |
1039 | } |
1040 | |
1041 | Array<te::Tensor> VisitExpr_(const LetNode* op) final { |
1042 | Array<te::Tensor> val = VisitExpr(op->value); |
1043 | ICHECK(!memo_.count(op->var)); |
1044 | memo_[op->var] = val; |
1045 | return VisitExpr(op->body); |
1046 | } |
1047 | |
1048 | Array<te::Tensor> VisitExpr_(const TupleNode* op) final { |
1049 | Array<te::Tensor> fields; |
1050 | for (Expr field : op->fields) { |
1051 | ICHECK(field->checked_type().as<TensorTypeNode>()) |
1052 | << "Expected a Tuple of Tensor, but got " << PrettyPrint(field->checked_type()); |
1053 | Array<te::Tensor> res = VisitExpr(field); |
1054 | ICHECK_EQ(res.size(), 1); |
1055 | fields.push_back(res[0]); |
1056 | } |
1057 | return fields; |
1058 | } |
1059 | |
1060 | Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final { |
1061 | Array<te::Tensor> input_shapes = VisitExpr(op->tuple); |
1062 | Array<te::Tensor> out; |
1063 | out.push_back(input_shapes[op->index]); |
1064 | return out; |
1065 | } |
1066 | |
1067 | private: |
1068 | /*! \brief String stream for function name */ |
1069 | std::ostringstream readable_name_stream_; |
1070 | /*! \brief Map from parameter to its shape function usage state */ |
1071 | std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> param_states_; |
1072 | /*! \brief Map from parameter to list of data placeholder */ |
1073 | std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_data_; |
1074 | /*! \brief Map from parameter to list of shape placeholder */ |
1075 | std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_shapes_; |
1076 | /*! \brief Stack of data dependencies for shape function, specified per each op input */ |
1077 | std::vector<bool> data_dependents_per_input_; |
1078 | /*! \brief Scalars used in the shape function */ |
1079 | Array<te::Tensor> scalars_; |
1080 | /*! \brief Map from parameters of a nested function to corresponding arguments in a function |
1081 | * call site. |
1082 | */ |
1083 | std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_arg_map_; |
1084 | }; |
1085 | |
1086 | CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target, |
1087 | GlobalVarSupply global_var_supply) { |
1088 | return MakeShapeFunc().Create(prim_func, target, global_var_supply); |
1089 | } |
1090 | |
1091 | std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompute( |
1092 | const Function& source_func, Target target, NameSupply constant_name_supply, |
1093 | bool return_inputs) { |
1094 | LowerToTECompute lower_te_compute(target, constant_name_supply); |
1095 | Array<te::Tensor> outputs = lower_te_compute.Lower(source_func); |
1096 | // Following ScheduleBuilder, remove placeholder ops from outputs. |
1097 | tvm::Array<te::Tensor> tensor_outs; |
1098 | for (const auto& tensor : outputs) { |
1099 | if (!tensor->op.as<te::PlaceholderOpNode>()) { |
1100 | tensor_outs.push_back(tensor); |
1101 | } |
1102 | } |
1103 | |
1104 | tvm::Array<runtime::NDArray> constants; |
1105 | for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) { |
1106 | tensor_outs.push_back(te_tensor); |
1107 | constants.push_back(const_node->data); |
1108 | } |
1109 | |
1110 | if (return_inputs) { |
1111 | return std::make_tuple(Concat(lower_te_compute.fn_inputs_, tensor_outs), constants, |
1112 | lower_te_compute.candidate_name_); |
1113 | } |
1114 | return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_); |
1115 | } |
1116 | |
1117 | std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func, |
1118 | Target target, |
1119 | NameSupply constant_name_supply) { |
1120 | ICHECK(relay_func->HasNonzeroAttr(attr::kPrimitive)) |
1121 | << "The input must be a Relay primitive function." ; |
1122 | |
1123 | auto [inputs_outputs, constants, fused_name] = |
1124 | tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true); |
1125 | auto tir_converter = backend::GetTIRConverter(); |
1126 | return std::make_pair(tir_converter(inputs_outputs, constants), fused_name); |
1127 | } |
1128 | |
1129 | tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) { |
1130 | auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply("" )); |
1131 | (void)_; // to suppress -Werror=unused-variable warning |
1132 | if (f_opt) { |
1133 | return f_opt.value(); |
1134 | } |
1135 | LOG(FATAL) << "Failed to convert the Relay function: " << AsText(relay_func, false); |
1136 | return PrimFunc(); |
1137 | } |
1138 | |
1139 | TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc" ) |
1140 | .set_body_typed([](Function relay_func, Target target) { |
1141 | return LowerToPrimFunc(relay_func, target); |
1142 | }); |
1143 | |
1144 | TVM_REGISTER_GLOBAL("relay.backend.LowerToTE" ).set_body_typed([](Function prim_func) { |
1145 | auto tgt = tvm::Target("ext_dev" ); |
1146 | LowerToTECompute lower_te_compute(tgt, NameSupply("" )); |
1147 | auto outputs = lower_te_compute.Lower(prim_func); |
1148 | return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_, |
1149 | outputs, te::Schedule(), tir::PrimFunc(), {}, |
1150 | IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_); |
1151 | }); |
1152 | |
1153 | } // namespace tec |
1154 | } // namespace relay |
1155 | } // namespace tvm |
1156 | |