1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "./te_compiler_cache.h"
21
22#include <tvm/driver/driver_api.h>
23#include <tvm/ir/name_supply.h>
24#include <tvm/ir/type_functor.h>
25#include <tvm/meta_schedule/database.h>
26#include <tvm/relay/analysis.h>
27#include <tvm/relay/attrs/device_copy.h>
28#include <tvm/relay/expr.h>
29#include <tvm/relay/expr_functor.h>
30#include <tvm/relay/op.h>
31#include <tvm/relay/op_attr_types.h>
32#include <tvm/relay/op_strategy.h>
33#include <tvm/runtime/builtin_fp16.h>
34#include <tvm/runtime/device_api.h>
35#include <tvm/runtime/registry.h>
36#include <tvm/te/operation.h>
37#include <tvm/te/schedule.h>
38#include <tvm/te/schedule_pass.h>
39#include <tvm/tir/function.h>
40#include <tvm/tir/index_map.h>
41#include <tvm/tir/schedule/schedule.h>
42#include <tvm/tir/stmt_functor.h>
43#include <tvm/tir/transform.h>
44#include <tvm/topi/tags.h>
45
46#include <functional>
47#include <limits>
48#include <memory>
49#include <mutex>
50#include <unordered_map>
51#include <utility>
52#include <vector>
53
54#include "../../te/operation/create_primfunc.h"
55#include "../op/memory/memory.h"
56#include "../src/meta_schedule/module_equality.h"
57#include "../src/meta_schedule/trace_apply.h"
58#include "../transforms/meta_schedule_layout_rewrite.h"
59#include "utils.h"
60
61namespace tvm {
62namespace relay {
63namespace tec {
64
65TVM_REGISTER_NODE_TYPE(LoweredOutputNode);
66TVM_REGISTER_NODE_TYPE(CachedFuncNode);
67TVM_REGISTER_NODE_TYPE(CCacheKeyNode);
68TVM_REGISTER_NODE_TYPE(CCacheValueNode);
69
70LoweredOutput::LoweredOutput(tvm::Array<te::Tensor> outputs, OpImplementation impl) {
71 auto n = make_object<LoweredOutputNode>();
72 n->outputs = std::move(outputs);
73 n->implementation = std::move(impl);
74 data_ = std::move(n);
75}
76
77CCacheKey::CCacheKey(Function source_func, Target target, VirtualDevice vd) {
78 auto n = make_object<CCacheKeyNode>();
79 n->source_func = std::move(source_func);
80 n->target = std::move(target);
81 n->virtual_device = std::move(vd);
82 data_ = std::move(n);
83}
84
85CachedFunc::CachedFunc(tvm::Target target, GlobalVar prim_fn_var, tvm::Array<te::Tensor> inputs,
86 tvm::Array<te::Tensor> outputs, te::Schedule schedule,
87 tir::PrimFunc prim_func, tvm::Array<Integer> shape_func_param_states,
88 IRModule funcs,
89 std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors) {
90 auto n = make_object<CachedFuncNode>();
91 n->target = target;
92 n->prim_fn_var = prim_fn_var;
93 n->inputs = inputs;
94 n->outputs = outputs;
95 n->schedule = schedule;
96 n->prim_func = prim_func;
97 n->shape_func_param_states = shape_func_param_states;
98 n->funcs = funcs;
99 n->constant_tensors = constant_tensors;
100 data_ = std::move(n);
101}
102
103Array<IndexExpr> GetShape(const Array<IndexExpr>& shape) {
104 // for now, we always use int32 shape when possible
105 // even if the result of shape inference becomes int64.
106 Array<IndexExpr> res;
107 for (IndexExpr val : shape) {
108 const int64_t* pval = tir::as_const_int(val);
109 if (pval != nullptr) {
110#ifndef TVM_INDEX_DEFAULT_I64
111 ICHECK_LE(pval[0], std::numeric_limits<int32_t>::max())
112 << "dimension must be less then int32_t's max value";
113 ICHECK_GE(pval[0], std::numeric_limits<int32_t>::min())
114 << "dimension must be less then int32_t's max value";
115 res.push_back(IntImm(DataType::Int(32), *pval));
116#else
117 res.push_back(val);
118#endif // TVM_INDEX_DEFAULT_I64
119 } else if (val->IsInstance<tir::AnyNode>()) {
120 // currently all 'any' we meet in shape function are non-negative.
121 res.push_back(val.as<tir::AnyNode>()->ToSizeVar());
122 } else {
123 res.push_back(val);
124 }
125 }
126 return res;
127}
128
129// Helper class that is used during lowering to TE.
130// It matches sequence of Ops and lower them into single TOPI operation. All supported patterns are
131// enumerated in "supported_patterns_".
132class QnnPatternMatcher {
133 public:
134 QnnPatternMatcher()
135 : qnn_conv2d_op_(Op::Get("qnn.conv2d")),
136 qnn_dense_op_(Op::Get("qnn.dense")),
137 qnn_dense_pack_op_(Op::Get("qnn.contrib_dense_pack")),
138 qnn_requantize_op_(Op::Get("qnn.requantize")),
139 bias_add_op_(Op::Get("add")) {}
140
141 // Memoize visited operations
142 void Register(const CallNode* call_node) {
143 ICHECK(call_node->op.as<OpNode>());
144 Op op = Downcast<Op>(call_node->op);
145 if (op == qnn_conv2d_op_) {
146 registered_ops_.push_front(P_QConv2d);
147 ICHECK(anchor_op_ == nullptr);
148 anchor_op_ = call_node;
149 } else if (op == qnn_requantize_op_) {
150 registered_ops_.push_front(P_QRequantize);
151 } else if (op == bias_add_op_) {
152 registered_ops_.push_front(P_BiasAdd);
153 } else if (op == qnn_dense_op_) {
154 registered_ops_.push_front(P_QDense);
155 ICHECK(anchor_op_ == nullptr);
156 anchor_op_ = call_node;
157 } else if (op == qnn_dense_pack_op_) {
158 registered_ops_.push_front(P_QDensePack);
159 ICHECK(anchor_op_ == nullptr);
160 anchor_op_ = call_node;
161 } else {
162 registered_ops_.push_front(P_Opaque);
163 }
164 }
165
166 // Check whether given Op is a part of matched pattern.
167 bool find(const Op& op) {
168 if (registered_ops_.empty()) return false;
169
170 if (op == qnn_conv2d_op_ || op == qnn_requantize_op_ || op == bias_add_op_ ||
171 op == qnn_dense_op_ || op == qnn_dense_pack_op_) {
172 for (const auto& pat : supported_patterns_) {
173 auto it =
174 std::search(registered_ops_.begin(), registered_ops_.end(), pat.begin(), pat.end());
175 if (it != registered_ops_.end()) return true;
176 }
177 }
178 return false;
179 }
180
181 // returns whether given Op is last in the pattern sequence.
182 bool IsLeafOp(const Op& op) { return op == qnn_requantize_op_; }
183
184 const CallNode* GetAnchorOp() { return anchor_op_; }
185
186 void Clear() { registered_ops_.clear(); }
187
188 private:
189 const Op& qnn_conv2d_op_;
190 const Op& qnn_dense_op_;
191 const Op& qnn_dense_pack_op_;
192 const Op& qnn_requantize_op_;
193 const Op& bias_add_op_;
194
195 // Main (complicated) operation in the primitive (for example qnn.conv2d, qnn.dense etc.).
196 const CallNode* anchor_op_ = nullptr;
197
198 enum POper { P_QConv2d, P_QDense, P_QDensePack, P_BiasAdd, P_QRequantize, P_Opaque };
199
200 std::deque<POper> registered_ops_;
201
202 const std::vector<std::deque<POper>> supported_patterns_ = {
203 {P_QDense, P_BiasAdd, P_QRequantize}, // qnn.dense -> bias_add -> qnn.requantize
204 {P_QDense, P_QRequantize}, // qnn.dense -> qnn.requantize
205 {P_QDensePack, P_BiasAdd, P_QRequantize}, // qnn.contrib_dense_pack -> bias -> qnn.requantize
206 {P_QDensePack, P_QRequantize}, // qnn.contrib_dense_pack -> qnn.requantize
207 {P_QConv2d, P_BiasAdd, P_QRequantize}, // qnn.conv2d -> bias_add -> qnn.requantize
208 {P_QConv2d, P_QRequantize} // qnn.conv2d -> qnn.requantize
209 };
210};
211
212// Lowers Relay primitive Function to TE Compute
213class LowerToTECompute : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
214 public:
215 LowerToTECompute(Target target, NameSupply constants_name_supply)
216 : target_(target),
217 device_copy_op_(Op::Get("device_copy")),
218 constants_name_supply_(constants_name_supply) {}
219
220 Array<te::Tensor> Lower(const Function& relay_func) {
221 for (Var param : relay_func->params) {
222 Array<tvm::te::Tensor> inputs;
223 for (const auto& ttype : FlattenTupleType(param->checked_type())) {
224 auto name_hint = param->vid->name_hint;
225 tvm::te::Tensor tensor = tvm::te::placeholder(
226 GetShape(ttype->shape), ttype->dtype, (name_hint == "") ? "placeholder" : name_hint);
227 inputs.push_back(tensor);
228 fn_inputs_.push_back(tensor);
229 }
230 memo_[param] = inputs;
231 }
232 readable_name_stream_ << "fused";
233
234 Array<te::Tensor> outputs = this->VisitExpr(relay_func->body);
235
236 candidate_name_ = readable_name_stream_.str();
237
238 constexpr static size_t kMaxFuncNameLength = 80;
239 // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
240 // whenever the value of kMaxFuncNameLength changes
241 if (candidate_name_.size() > kMaxFuncNameLength) {
242 std::stringstream truncated_name;
243 truncated_name << candidate_name_.substr(0, kMaxFuncNameLength);
244 truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name_) << "_";
245 candidate_name_ = truncated_name.str();
246 }
247
248 return outputs;
249 }
250
251 Array<te::Tensor> VisitExpr_(const VarNode* op) final {
252 LOG(FATAL) << "Unexpected free variable " << PrettyPrint(GetRef<Var>(op));
253 }
254
255 Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
256 using tir::make_const;
257 void* data = op->data->data;
258 DataType dtype = DataType(op->data->dtype);
259 if (op->is_scalar()) {
260 auto value = te::compute(
261 {},
262 [&](const Array<tvm::tir::Var>&) {
263 if (dtype == DataType::Int(16)) {
264 return make_const(dtype, static_cast<const int16_t*>(data)[0]);
265 } else if (dtype == DataType::Int(8)) {
266 return make_const(dtype, static_cast<const int8_t*>(data)[0]);
267 } else if (dtype == DataType::UInt(8) || dtype == DataType::Bool()) {
268 return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
269 } else if (dtype == DataType::Int(32)) {
270 return make_const(dtype, static_cast<const int32_t*>(data)[0]);
271 } else if (dtype == DataType::Int(64)) {
272 return make_const(dtype, static_cast<const int64_t*>(data)[0]);
273 } else if (dtype == DataType::Float(16)) {
274 return make_const(dtype, __gnu_h2f_ieee(static_cast<const uint16_t*>(data)[0]));
275 } else if (dtype == DataType::Float(32)) {
276 return make_const(dtype, static_cast<const float*>(data)[0]);
277 } else if (dtype == DataType::Float(64)) {
278 return make_const(dtype, static_cast<const double*>(data)[0]);
279 } else {
280 LOG(FATAL) << dtype << " not handled";
281 }
282 },
283 "compile_engine_const", topi::kBroadcast);
284 scalars_.push_back(value->op);
285 return {value};
286 } else {
287 const auto* ttype = op->checked_type().as<TensorTypeNode>();
288 std::stringstream ss;
289 std::string s = readable_name_stream_.str();
290 std::replace(s.begin(), s.end(), '.', '_');
291 ss << constants_name_supply_->FreshName(s + "_constant");
292 tvm::te::Tensor tensor = tvm::te::placeholder(GetShape(ttype->shape), ttype->dtype, ss.str());
293 constant_tensors_[op] = tensor;
294 return {tensor};
295 }
296 }
297
298 Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
299 static auto flower_call = tvm::runtime::Registry::Get("relay.backend.lower_call");
300 ICHECK(flower_call) << "relay.backend.lower_call is not registered.";
301
302 pattern_matcher_.Register(call_node);
303
304 Array<te::Tensor> inputs;
305 // int count_tuple = 0;
306 for (Expr arg : call_node->args) {
307 if (arg->checked_type().as<TupleTypeNode>()) {
308 // ++count_tuple;
309 }
310 for (te::Tensor tensor : VisitExpr(arg)) {
311 inputs.push_back(tensor);
312 }
313 }
314
315 ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
316 Op op = Downcast<Op>(call_node->op);
317
318 // TODO(mbs): device_copy cleanup
319 ICHECK_NE(op, device_copy_op_) << "device_copy cannot be lowered";
320
321 Array<te::Tensor> outputs;
322
323 if (pattern_matcher_.find(op)) {
324 if (pattern_matcher_.IsLeafOp(op)) {
325 // Lower anchor op when pattern leaf op was reached
326 auto anchor_op = pattern_matcher_.GetAnchorOp();
327 LoweredOutput lowered_out =
328 (*flower_call)(GetRef<Call>(anchor_op), inputs, target_, call_node->checked_type());
329 outputs = lowered_out->outputs;
330 Op a_op = Downcast<Op>(anchor_op->op);
331 op_implementations_[a_op.operator->()] = lowered_out->implementation;
332
333 pattern_matcher_.Clear();
334 } else {
335 // Forward inputs as "outputs" for successor.
336 readable_name_stream_ << '_' << op->name;
337 return inputs;
338 }
339 } else {
340 LoweredOutput lowered_out = (*flower_call)(GetRef<Call>(call_node), inputs, target_);
341 outputs = lowered_out->outputs;
342 op_implementations_[op.operator->()] = lowered_out->implementation;
343 }
344
345 if (outputs.size() != 1) {
346 const auto* tuple_type = call_node->checked_type().as<TupleTypeNode>();
347 ICHECK(tuple_type) << "Expected output to be a tuple type "
348 << PrettyPrint(call_node->checked_type());
349
350 ICHECK_EQ(tuple_type->fields.size(), outputs.size());
351 }
352
353 readable_name_stream_ << '_' << op->name;
354 return outputs;
355 }
356
357 Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
358 LOG(FATAL) << "Primitive Functions can not contain nested functions.";
359 }
360
361 Array<te::Tensor> VisitExpr_(const LetNode* op) final {
362 Array<te::Tensor> val = VisitExpr(op->value);
363 ICHECK(!memo_.count(op->var));
364 memo_[op->var] = val;
365 return VisitExpr(op->body);
366 }
367
368 Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
369 Array<te::Tensor> fields;
370 for (Expr field : op->fields) {
371 // TODO(mbs): Generalize to be equivalent to FlattenTupleType.
372 ICHECK(field->checked_type().as<TensorTypeNode>()) << "Only allow Tuple of Tensor";
373 Array<te::Tensor> res = VisitExpr(field);
374 ICHECK_EQ(res.size(), 1);
375 fields.push_back(res[0]);
376 }
377 return fields;
378 }
379
380 Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
381 const auto* tuple_type = op->tuple->type_as<TupleTypeNode>();
382 Array<te::Tensor> tuple = VisitExpr(op->tuple);
383 ICHECK_EQ(tuple_type->fields.size(), tuple.size());
384 ICHECK_GE(op->index, 0);
385 ICHECK_LT(static_cast<size_t>(op->index), tuple.size());
386 return {tuple[op->index]};
387 }
388
389 public:
390 // Additional outputs
391 Array<tvm::te::Tensor> fn_inputs_;
392 Array<te::Operation> scalars_;
393 std::unordered_map<const ConstantNode*, te::Tensor> constant_tensors_;
394 std::unordered_map<const OpNode*, OpImplementation> op_implementations_;
395 std::string candidate_name_;
396
397 private:
398 QnnPatternMatcher pattern_matcher_;
399
400 tvm::Target target_;
401 std::ostringstream readable_name_stream_;
402 // Cache device copy op for equivalence checking to reduce registry lookup
403 // overhead for each invocation of call node when retrieving schedules.
404 const Op& device_copy_op_;
405 // A NameSupply object passed from a caller, used to assign unique names to constants
406 // across different invocations of LowerToTECompute.
407 NameSupply constants_name_supply_;
408};
409
410using namespace tvm::tir;
411
412class LayoutFreeConstantCollector : public StmtVisitor {
413 public:
414 Array<runtime::NDArray> constants;
415
416 private:
417 void VisitStmt_(const BlockNode* op) final {
418 StmtVisitor::VisitStmt_(op);
419 if (Optional<ObjectRef> ann = op->annotations.Get("layout_free_placeholders")) {
420 for (Buffer buffer : Downcast<Array<Buffer>>(ann)) {
421 layout_free_buffer_vars_.insert(buffer->data.get());
422 }
423 }
424 }
425
426 void VisitStmt_(const AllocateConstNode* op) final {
427 StmtVisitor::VisitStmt_(op);
428 if (auto it = layout_free_buffer_vars_.find(op->buffer_var.get());
429 it != layout_free_buffer_vars_.end()) {
430 constants.push_back(op->data.value());
431 }
432 }
433
434 std::unordered_set<const tir::VarNode*> layout_free_buffer_vars_;
435};
436
437using NDArrayMap =
438 std::unordered_map<runtime::NDArray, runtime::NDArray, ObjectPtrHash, ObjectPtrEqual>;
439
440// Replace constants in AllocateConst nodes according to the given mapping
441class AllocateConstReplaceConstant : public StmtExprMutator {
442 public:
443 explicit AllocateConstReplaceConstant(const NDArrayMap& constant_map)
444 : constant_map_(constant_map) {}
445
446 static PrimFunc Rewrite(PrimFunc f, const NDArrayMap& constant_map) {
447 AllocateConstReplaceConstant rewriter(constant_map);
448 PrimFuncNode* n = f.CopyOnWrite();
449 n->body = rewriter(std::move(n->body));
450 return f;
451 }
452
453 private:
454 Stmt VisitStmt_(const AllocateConstNode* op) final {
455 if (auto it = constant_map_.find(op->data.value()); it != constant_map_.end()) {
456 auto rewriten_constant = it->second;
457 Array<PrimExpr> rewritten_extents;
458 for (auto s : rewriten_constant.Shape()) {
459 rewritten_extents.push_back(PrimExpr(static_cast<int>(s)));
460 }
461 return AllocateConst(op->buffer_var, op->dtype, rewritten_extents, rewriten_constant,
462 op->body, op->annotations, op->span);
463 }
464 return StmtExprMutator::VisitStmt_(op);
465 }
466
467 NDArrayMap constant_map_;
468};
469
470// Construct a schedule for a given Relay primitive function and target.
471class ScheduleBuilder : public ExprVisitor {
472 public:
473 explicit ScheduleBuilder(Target target)
474 : target_(target),
475 mod_eq_structural_(meta_schedule::ModuleEquality::Create("ignore-ndarray")) {
476 // Whether to use auto_scheduler schedule.
477 use_auto_scheduler_ = backend::IsAutoSchedulerEnabled();
478 if (backend::IsMetaScheduleEnabled()) {
479 database_ = meta_schedule::Database::Current();
480 CHECK(database_.defined()) << "ValueError: `use_meta_schedule` is enabled in Relay "
481 "build, but no `meta_schedule.Database` context is provided. ";
482 } else {
483 database_ = NullOpt;
484 }
485 }
486
487 CachedFunc Create(const Function& relay_func, GlobalVarSupply global_var_supply,
488 NameSupply constant_name_supply) {
489 LowerToTECompute lower_te_compute(target_, constant_name_supply);
490 Array<te::Tensor> tensor_outs = lower_te_compute.Lower(relay_func);
491 Array<te::Tensor> fn_inputs = lower_te_compute.fn_inputs_;
492 VisitExpr(relay_func->body);
493
494 // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
495 // no other GlobalVar ctors should appear inside the lowering machinery.
496 auto prim_fn_var = global_var_supply->FreshGlobal(lower_te_compute.candidate_name_);
497 prim_fn_var->checked_type_ = relay_func->checked_type();
498
499 // Fusion over tupled results may leave identity relationships
500 // between inputs and outputs, copy identity output tensors,
501 // since tir lowering do not support aliasing output to input buffer.
502 for (size_t i = 0; i < tensor_outs.size(); ++i) {
503 if (tensor_outs[i]->op.as<te::PlaceholderOpNode>()) {
504 tensor_outs.Set(i, topi::identity(tensor_outs[i]));
505 }
506 }
507
508 te::Schedule schedule{nullptr};
509 tir::PrimFunc prim_func{nullptr};
510 // No need to register schedule for device copy op.
511 if (anchor_attrs_.as<DeviceCopyAttrs>() == nullptr) {
512 if (use_auto_scheduler_) {
513 const auto* fauto_schedule =
514 runtime::Registry::Get("auto_scheduler.relay_integration.auto_schedule_topi_compute");
515 ICHECK(fauto_schedule != nullptr)
516 << "auto_scheduler.relay_integration.auto_schedule_topi_compute is not registered";
517 ObjectRef obj = (*fauto_schedule)(prim_fn_var->name_hint, tensor_outs);
518 if (obj.defined()) {
519 schedule = Downcast<te::Schedule>(obj);
520 }
521 }
522 if (database_) {
523 using tvm::meta_schedule::TuningRecord;
524 using tvm::tir::IndexMap;
525 using tvm::tir::Instruction;
526 using tvm::tir::InstructionKind;
527 using tvm::tir::PrimFunc;
528 using tvm::tir::Schedule;
529 backend::FTECompilerTIRConverter tir_converter = backend::GetTIRConverter();
530 Array<te::Tensor> te_args = Concat(fn_inputs, tensor_outs);
531 Array<runtime::NDArray> constants;
532 for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
533 te_args.push_back(te_tensor);
534 constants.push_back(const_node->data);
535 }
536 if (Optional<PrimFunc> f = tir_converter(te_args, constants)) {
537 IRModule query_mod = backend::PrimFuncToIRModule(f.value());
538 if (Optional<TuningRecord> opt_record = database_.value()->QueryTuningRecord(
539 /*mod=*/query_mod,
540 /*target=*/target_,
541 /*workload_name=*/prim_fn_var->name_hint)) {
542 LayoutFreeConstantCollector const_collector;
543 const_collector(f.value()->body);
544
545 static InstructionKind kind_transform_layout = InstructionKind::Get("TransformLayout");
546 TuningRecord record = opt_record.value();
547 for (const Instruction& inst : record->trace->insts) {
548 if (inst->kind.same_as(kind_transform_layout)) {
549 ICHECK_EQ(inst->inputs.size(), 2);
550 auto index_map = Downcast<IndexMap>(inst->inputs[1]);
551
552 if (!const_collector.constants.empty()) {
553 // In this case, RewriteLayout is acting on an AllocateConst node.
554 // After tuning, we reach this code path twice: First by
555 // the Relay MetaScheduleLayoutRewrite pass, and next by the final
556 // compilation (Relay to TE schedule lowering).
557 //
558 // Due to Relay MetaScheduleLayoutRewrite and FoldConstant passes,
559 // the Relay subgraph for which we query the database during the
560 // final compilation has its weight tensor transformed according to
561 // the index map, determined during tuning. For example,
562 //
563 // fn (%p0: Tensor[(1, 56, 56, 64), float32]) {
564 // %0 = nn.conv2d(%p0, meta[relay.Constant][0],
565 // /*ty=Tensor[(4, 2, 2, 3, 3, 32, 8), float32]*/, ...);
566 // add(%0, meta[relay.Constant][1])
567 // }
568 //
569 // Note that the database does not have an entry corresponding to such subgraphs,
570 // since an input subgraph to the tuning system always has its weight tensor in
571 // the original layout, e.g.
572 //
573 // fn (%p0: Tensor[(1, 56, 56, 64), float32]) {
574 // %0 = nn.conv2d(%p0, meta[relay.Constant][0],
575 // /*ty=Tensor[(3, 3, 64, 64), float32]*/, ...);
576 // add(%0, meta[relay.Constant][1])
577 // }
578 //
579 // Thus, in both of the two cases where we reach this code path, we need careful
580 // logic to make sure that (1) the database lookup during the final compilation
581 // succeeds and (2) the application of a schedule trace is well defined.
582
583 ICHECK(const_collector.constants.size() == 1)
584 << "Only one layout-free constant is supported by RewriteLayout for now";
585 auto constant = const_collector.constants[0];
586
587 auto is_constant_transformed = [index_map](runtime::NDArray c) {
588 if (c.Shape().size() != index_map->initial_indices.size()) {
589 return true;
590 }
591 size_t src_size_1d = 1;
592 Array<PrimExpr> orig_shape;
593 for (size_t i = 0; i < c.Shape().size(); ++i) {
594 src_size_1d *= c->shape[i];
595 orig_shape.push_back(PrimExpr(static_cast<int>((c->shape[i]))));
596 }
597 auto dst_shape = index_map->MapShape(orig_shape);
598 std::vector<int64_t> dst_shape_int;
599 size_t dst_size_1d = 1;
600 for (size_t i = 0; i < dst_shape.size(); ++i) {
601 dst_size_1d *= dst_shape[i].as<IntImmNode>()->value;
602 }
603 return src_size_1d != dst_size_1d;
604 };
605
606 if (!is_constant_transformed(constant)) {
607 // This is the first case, reached during the MetaScheduleLayoutRewrite pass.
608 //
609 // A layout-free constant having the same rank as an input to the index map
610 // is assumed to be transformed by this index map.
611 // TODO(masahi): If there are multiple layout-free constants in one
612 // TIR mod (e.g. conv2d -> conv2d fusion), this assumption does not hold.
613 // We need to determine which constant the given index map acts on.
614 //
615 // We know that, during the final compilation, we will query the database
616 // for a subgraph that the tuner has never seen. We workaround this problem
617 // by adding a dummy entry to the database. The dummy entry is carefully
618 // constructed so that the lookup during the final compilation would succeed.
619 runtime::NDArray rewritten_constant = index_map->MapNDArray(constant);
620 auto f_dummy = AllocateConstReplaceConstant::Rewrite(
621 f.value(), {{constant, rewritten_constant}});
622 auto workload_dummy =
623 database_.value()->CommitWorkload(backend::PrimFuncToIRModule(f_dummy));
624 TuningRecord rec_dummy(record->trace, workload_dummy, record->run_secs,
625 record->target, record->args_info);
626 database_.value()->CommitTuningRecord(rec_dummy);
627 } else {
628 // The constant is already transformed, so this is the second case, reached
629 // during the final compilation.
630 //
631 // The schedule trace is supposed to be applied to the weight in its original
632 // layout. But as explained above, the Relay subgraph we get in this case
633 // has its weight tensor transformed according to the corresponding index map.
634 // So effectively, we undo the layout transformation on the weight to restore
635 // the original PrimFunc that the schedule trace is supposed to act on.
636 ICHECK(index_map->inverse_index_map);
637 auto inverse_map = Downcast<IndexMap>(index_map->inverse_index_map.value());
638 ICHECK(constant.Shape().size() == inverse_map->initial_indices.size());
639 runtime::NDArray orig_constant = inverse_map->MapNDArray(constant);
640 auto f_ = AllocateConstReplaceConstant::Rewrite(f.value(),
641 {{constant, orig_constant}});
642 query_mod = backend::PrimFuncToIRModule(f_);
643 }
644 }
645 MetaScheduleLayoutRewriter::LayoutQueuePush(index_map);
646 }
647 }
648
649 Schedule sch = Schedule::Traced(query_mod, /*seed=*/-1, /*debug_mask=*/0,
650 tir::ScheduleErrorRenderLevel::kDetail);
651
652 if (!mod_eq_structural_->Equal(query_mod, opt_record.value()->workload->mod)) {
653 // When the database lookup succeeds while structural equality check fails,
654 // it implies that the anchor block based equality has been used during tuning.
655 // The trace in the record cannot directly be applied to this query module.
656 meta_schedule::ScheduleUsingAnchorTrace(sch, record->trace, target_);
657 } else {
658 record->trace->ApplyToSchedule(sch, /*remove_postproc=*/false);
659 }
660
661 IRModule mod = sch->mod();
662 ICHECK_EQ(mod->functions.size(), 1);
663 mod = tir::transform::RemoveWeightLayoutRewriteBlock(/*skip_ndarray_rewrite*/ false)(
664 std::move(mod));
665 prim_func = Downcast<PrimFunc>(mod->Lookup("main"));
666 // Need to copy attrs from relay function over to prim func. Most notably the structural
667 // hash.
668 prim_func = WithAttrs(prim_func, relay_func->attrs->dict);
669 } else {
670 int dispatch = backend::UseMetaScheduleDispatch();
671 // (dispatch & 2): controls whether to print TVMScript for missing TIR
672 // (dispatch & 4): controls whether to raise fatal errors for missing TIR
673 if (dispatch & 2) {
674 LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint << "\n"
675 << f.value();
676 } else {
677 LOG(WARNING) << "Cannot find workload: " << prim_fn_var->name_hint;
678 }
679 if (dispatch & 4) {
680 LOG(FATAL);
681 }
682 }
683 }
684 }
685 // Use TOPI schedule if user specified, or the function has no auto_scheduler schedule.
686 if (!schedule.defined() && !prim_func.defined()) {
687 if (anchor_op_.defined()) {
688 auto anchor_impl = lower_te_compute.op_implementations_.find(anchor_op_.operator->());
689 ICHECK(anchor_impl != lower_te_compute.op_implementations_.end());
690 schedule = anchor_impl->second.Schedule(anchor_attrs_, tensor_outs, target_);
691 } else {
692 auto default_sched = GenericFunc::Get("schedule_injective");
693 ICHECK(default_sched.defined()) << "schedule_injective not registered for " << target_;
694 With<Target> tctx(target_);
695 schedule = default_sched(tensor_outs);
696 }
697 }
698 if (schedule.defined()) {
699 for (const auto& scalar : lower_te_compute.scalars_) {
700 if (schedule->Contain(scalar)) {
701 schedule[scalar].compute_inline();
702 }
703 }
704 }
705 }
706
707 IRModule funcs = IRModule(Map<GlobalVar, BaseFunc>({}));
708 return CachedFunc(target_, prim_fn_var, fn_inputs, tensor_outs, schedule, prim_func, {}, funcs,
709 lower_te_compute.constant_tensors_);
710 }
711
712 void VisitExpr_(const CallNode* call_node) final {
713 static auto fpattern = Op::GetAttrMap<TOpPattern>("TOpPattern");
714
715 ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
716 Op op = Downcast<Op>(call_node->op);
717
718 for (Expr arg : call_node->args) {
719 VisitExpr(arg);
720 }
721
722 int op_pattern = fpattern[op];
723 if (!use_auto_scheduler_ && !database_.defined() && op_pattern >= kCommReduce) {
724 ICHECK(!anchor_op_.defined() || anchor_op_pattern_ < kCommReduce)
725 << "Cannot apply TOPI schedule to a primitive function with two complicated ops"
726 << " anchor=" << anchor_op_ << " current=" << op;
727 }
728 if (op_pattern >= anchor_op_pattern_) {
729 anchor_op_ = op;
730 anchor_attrs_ = call_node->attrs;
731 anchor_op_pattern_ = op_pattern;
732 }
733 }
734
735 private:
736 tvm::Target target_;
737 Op anchor_op_;
738 Attrs anchor_attrs_;
739 int anchor_op_pattern_{0};
740 bool use_auto_scheduler_;
741 Optional<meta_schedule::Database> database_;
742 std::unique_ptr<meta_schedule::ModuleEquality> mod_eq_structural_;
743};
744
745/*!
746 * \brief Create schedule for target.
747 * \param source_func The primitive function to be lowered.
748 * \param target The target we want to create schedule for.
749 * \return Pair of schedule and cache.
750 * The funcs field in cache is not yet populated.
751 */
752CachedFunc PrimFuncFor(const Function& source_func, const Target& target,
753 GlobalVarSupply global_var_supply, NameSupply constant_name_supply) {
754 return ScheduleBuilder(target).Create(source_func, global_var_supply, constant_name_supply);
755}
756
757// Creates shape function from functor.
758class MakeShapeFunc : public backend::MemoizedExprTranslator<Array<te::Tensor>> {
759 public:
760 MakeShapeFunc() {}
761
762 CachedFunc Create(const Function& prim_func, const Target& target,
763 GlobalVarSupply global_var_supply) {
764 VLOG_CONTEXT << "MakeShapeFunc";
765 TShapeDataDependent shape_func_param_states;
766
767 for (auto param : prim_func->params) {
768 param_states_[param] = kNoNeed;
769 Array<tvm::te::Tensor> data_inputs;
770 Array<tvm::te::Tensor> shape_inputs;
771
772 for (const auto& ttype : FlattenTupleType(param->checked_type())) {
773 // Add data placeholder (in case we discover we need it below)
774 Shape shape = GetShape(ttype->shape);
775 tvm::te::Tensor data_tensor =
776 tvm::te::placeholder(shape, ttype->dtype, "data_" + param->vid->name_hint);
777 data_inputs.push_back(data_tensor);
778 // Add shape placeholder (in case we discover we need it below)
779 int64_t ndim = shape.size();
780 Shape sshape;
781 if (ndim > 0) {
782 sshape.push_back(tvm::Integer(ndim));
783 }
784 tvm::te::Tensor shape_tensor =
785 tvm::te::placeholder(sshape, DataType::Int(64), "shape_" + param->vid->name_hint);
786 shape_inputs.push_back(shape_tensor);
787 }
788 param_data_[param] = data_inputs;
789 param_shapes_[param] = shape_inputs;
790 }
791
792 // Setup the name;
793 readable_name_stream_ << "shape_func";
794
795 // Create the tensor expressions representing the output shapes.
796 Array<te::Tensor> outputs = VisitExpr(prim_func->body);
797
798 // Generate a name.
799 auto candidate_name = readable_name_stream_.str();
800
801 constexpr static size_t kMaxFuncNameLength = 80;
802 // WARNING: Please make sure to also update TVM_CRT_MAX_STRLEN_FUNCTION_NAME
803 // whenever the value of kMaxFuncNameLength changes
804 if (candidate_name.size() > kMaxFuncNameLength) {
805 std::stringstream truncated_name;
806 truncated_name << candidate_name.substr(0, kMaxFuncNameLength);
807 truncated_name << "_" << std::hex << std::hash<std::string>{}(candidate_name) << "_";
808 candidate_name = truncated_name.str();
809 }
810
811 // Set all the inputs correctly, and accumulate their types from the p.o.v. of the
812 // shape function rather than the primitive it is derived for.
813 Array<te::Tensor> inputs;
814 Array<Type> shape_function_arg_types;
815 for (auto param : prim_func->params) {
816 int state = param_states_[param];
817 shape_func_param_states.push_back(IntImm(DataType::Int(32), state));
818 if (state & kNeedInputData) {
819 // Pass the primitive arguments directly (though in flattened form and on the host)
820 for (auto t : param_data_[param]) {
821 inputs.push_back(t);
822 shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
823 }
824 }
825 if (state & kNeedInputShape) {
826 // Pass the shapes of the primitive arguments (also on the host)
827 for (auto t : param_shapes_[param]) {
828 inputs.push_back(t);
829 shape_function_arg_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
830 }
831 }
832 }
833
834 // TODO(mbs): This should be the definitive global by which the PrimFunc is known and
835 // no other GlobalVar ctors should appear inside the lowering machinery.
836 auto prim_fn_gvar = global_var_supply->FreshGlobal(candidate_name);
837
838 // Gather the result types, again from the p.o.v. of the shape function rather than
839 // the primitive it is derived for.
840 Array<Type> shape_function_res_types;
841 for (const auto& t : outputs) {
842 shape_function_res_types.push_back(TensorType(t->GetShape(), t->GetDataType()));
843 }
844
845 // Assign the shape function its true type.
846 FuncType type(shape_function_arg_types, TupleType(shape_function_res_types),
847 /*type_params=*/{}, /*type_constraints=*/{});
848 VLOG(1) << "shape function '" << prim_fn_gvar->name_hint << "' has type:" << std::endl
849 << PrettyPrint(type) << std::endl
850 << "corresponding to primitive of type:" << std::endl
851 << PrettyPrint(prim_func->checked_type());
852 prim_fn_gvar->checked_type_ = std::move(type);
853
854 // generate schedule for shape func
855 Array<te::Operation> out_ops;
856 for (auto t : outputs) {
857 out_ops.push_back(t->op);
858 }
859 te::Schedule schedule = te::create_schedule(out_ops);
860 tvm::te::AutoInlineInjective(schedule);
861 for (const auto& scalar : scalars_) {
862 auto scalar_op = scalar->op;
863 if (schedule->Contain(scalar_op)) {
864 schedule[scalar_op].compute_inline();
865 }
866 }
867
868 Array<te::Tensor> all_args = Array<te::Tensor>(inputs);
869 for (te::Tensor arg : outputs) {
870 all_args.push_back(arg);
871 }
872
873 using tvm::transform::PassContext;
874 With<PassContext> fresh_pass_ctx_scope(PassContext::Create());
875
876 std::unordered_map<te::Tensor, tir::Buffer> binds;
877 IRModule lowered_module =
878 tvm::LowerSchedule(schedule, all_args, prim_fn_gvar->name_hint, binds, global_var_supply);
879 return CachedFunc(target, prim_fn_gvar, inputs, outputs, schedule, tir::PrimFunc{nullptr},
880 shape_func_param_states, lowered_module);
881 }
882
883 Array<te::Tensor> VisitExpr(const Expr& expr) final {
884 if (expr.as<VarNode>()) {
885 // Do not memoize vars because shape functions could use either the data
886 // or the shape of a var each time.
887 return ExprFunctor::VisitExpr(expr);
888 }
889 // For other case, do memoized visit
890 return backend::MemoizedExprTranslator<Array<te::Tensor>>::VisitExpr(expr);
891 }
892
893 Array<te::Tensor> VisitExpr_(const VarNode* var_node) final {
894 auto var = GetRef<Var>(var_node);
895 auto it = param_arg_map_.find(var);
896 if (it != param_arg_map_.end()) {
897 // This var is a parameter of a nested function. Visit the corresponding argument in the
898 // function call site.
899 return VisitExpr(it->second);
900 }
901 if (param_states_.find(var) == param_states_.end()) {
902 LOG(FATAL) << "Unexpected free variable " << PrettyPrint(var);
903 } else {
904 ICHECK(data_dependents_per_input_.size());
905 auto data_dependent = data_dependents_per_input_.back();
906 if (data_dependent) {
907 param_states_[var] |= kNeedInputData;
908 return param_data_[var];
909 } else {
910 param_states_[var] |= kNeedInputShape;
911 return param_shapes_[var];
912 }
913 }
914 }
915
916 Array<te::Tensor> VisitExpr_(const ConstantNode* op) final {
917 using tir::make_const;
918 ICHECK(data_dependents_per_input_.size());
919 bool data_dependent = data_dependents_per_input_.back();
920 if (!op->is_scalar()) {
921 // This is a constant weight, extract the shape of the weight tensor.
922 // This can not be data dependent.
923 CHECK(!data_dependent);
924 auto ttype = op->checked_type().as<TensorTypeNode>();
925 int ndim = static_cast<int>(ttype->shape.size());
926 Array<PrimExpr> out_shape{ndim};
927 te::Tensor value = tvm::te::compute(
928 out_shape,
929 [&](const Array<tvm::tir::Var>& indices) {
930 auto idx = indices[0];
931 PrimExpr ret = make_const(DataType::Int(64), 0);
932 for (int i = 0; i < ndim; i++) {
933 ret = tvm::if_then_else(idx == i, ttype->shape[i], ret);
934 }
935 return ret;
936 },
937 "shape_const", topi::kBroadcast);
938 scalars_.push_back(value);
939 return {value};
940 }
941 if (data_dependent) {
942 void* data = op->data->data;
943 DataType dtype = DataType(op->data->dtype);
944 auto value = tvm::te::compute(
945 {},
946 [&](const Array<tvm::tir::Var>&) {
947 if (dtype == DataType::Int(32)) {
948 return make_const(dtype, static_cast<const int32_t*>(data)[0]);
949 } else if (dtype == DataType::Int(64)) {
950 return make_const(dtype, static_cast<const int64_t*>(data)[0]);
951 } else if (dtype == DataType::Float(32)) {
952 return make_const(dtype, static_cast<const float*>(data)[0]);
953 } else if (dtype == DataType::Float(64)) {
954 return make_const(dtype, static_cast<const double*>(data)[0]);
955 } else if (dtype == DataType::Bool()) {
956 return make_const(dtype, static_cast<const uint8_t*>(data)[0]);
957 } else {
958 LOG(FATAL) << "not handled";
959 }
960 },
961 "data_const", topi::kBroadcast);
962 scalars_.push_back(value);
963 return {value};
964 } else {
965 auto value = tvm::te::compute(
966 {}, [&](const Array<tvm::tir::Var>&) { return tir::make_const(DataType::Int(64), 0); },
967 "shape_const", topi::kBroadcast);
968 scalars_.push_back(value);
969 return {value};
970 }
971 }
972
973 Array<te::Tensor> VisitExpr_(const CallNode* call_node) final {
974 VLOG(1) << "considering call:" << std::endl << PrettyPrint(GetRef<Call>(call_node));
975 if (auto* func = call_node->op.as<FunctionNode>()) {
976 VLOG(1) << "user function";
977 for (size_t i = 0; i < func->params.size(); ++i) {
978 param_arg_map_[func->params[i]] = call_node->args[i];
979 }
980 return VisitExpr(func->body);
981 }
982
983 static auto fshape_func = Op::GetAttrMap<FShapeFunc>("FShapeFunc");
984 static auto tshape_data_dependent = Op::GetAttrMap<TShapeDataDependent>("TShapeDataDependent");
985 ICHECK(call_node->op.as<OpNode>()) << "Primitive function only allows call into primitive ops";
986 Op op = Downcast<Op>(call_node->op);
987 ICHECK(data_dependents_per_input_.empty() || !data_dependents_per_input_.back())
988 << "Error in op fusion: output of the shape func is fed to a "
989 << "data-dependent shape func";
990 ICHECK_GT(fshape_func.count(op), 0) << "Internal error, cannot find ShapeFunc for " << op->name;
991 ICHECK_GT(tshape_data_dependent.count(op), 0)
992 << "Internal error, cannot find TShapeDataDependent for " << op->name;
993
994 Array<Integer> dep_spec = tshape_data_dependent[op];
995 if (dep_spec.size() == 1) {
996 // This is for cases when data dependence is specified per op
997 // Replicate 0 or 1 flag to all arguments
998 for (size_t i = 1; i < call_node->args.size(); ++i) {
999 dep_spec.push_back(dep_spec[0]);
1000 }
1001 }
1002
1003 // Visit all inputs
1004 Array<te::Tensor> inputs;
1005 int count_tuple = 0;
1006 for (size_t i = 0; i < call_node->args.size(); ++i) {
1007 Expr arg = call_node->args[i];
1008 if (arg->checked_type().as<TupleTypeNode>()) {
1009 ++count_tuple;
1010 }
1011 data_dependents_per_input_.push_back(dep_spec[i]->value != 0);
1012 for (te::Tensor tensor : VisitExpr(arg)) {
1013 inputs.push_back(tensor);
1014 }
1015 data_dependents_per_input_.pop_back();
1016 }
1017 if (count_tuple) {
1018 ICHECK_EQ(call_node->args.size(), 1U) << "Only allow function with a single tuple input";
1019 }
1020 // Get output ndims
1021 auto ret_type = call_node->checked_type();
1022 Array<IndexExpr> out_ndims;
1023 for (const auto& ttype : FlattenTupleType(ret_type)) {
1024 out_ndims.push_back(IntImm(DataType::Int(32), ttype->shape.size()));
1025 }
1026
1027 // Call shape function
1028 Array<te::Tensor> outputs = fshape_func[op](call_node->attrs, inputs, out_ndims);
1029 VLOG(1) << "shape function for '" << op->name << "' with inputs:" << std::endl
1030 << inputs << std::endl
1031 << "yielded outputs:" << std::endl
1032 << outputs;
1033 readable_name_stream_ << "_" << op->name;
1034 return outputs;
1035 }
1036
1037 Array<te::Tensor> VisitExpr_(const FunctionNode* op) final {
1038 LOG(FATAL) << "Nested functions are not allowed to be visited.";
1039 }
1040
1041 Array<te::Tensor> VisitExpr_(const LetNode* op) final {
1042 Array<te::Tensor> val = VisitExpr(op->value);
1043 ICHECK(!memo_.count(op->var));
1044 memo_[op->var] = val;
1045 return VisitExpr(op->body);
1046 }
1047
1048 Array<te::Tensor> VisitExpr_(const TupleNode* op) final {
1049 Array<te::Tensor> fields;
1050 for (Expr field : op->fields) {
1051 ICHECK(field->checked_type().as<TensorTypeNode>())
1052 << "Expected a Tuple of Tensor, but got " << PrettyPrint(field->checked_type());
1053 Array<te::Tensor> res = VisitExpr(field);
1054 ICHECK_EQ(res.size(), 1);
1055 fields.push_back(res[0]);
1056 }
1057 return fields;
1058 }
1059
1060 Array<te::Tensor> VisitExpr_(const TupleGetItemNode* op) final {
1061 Array<te::Tensor> input_shapes = VisitExpr(op->tuple);
1062 Array<te::Tensor> out;
1063 out.push_back(input_shapes[op->index]);
1064 return out;
1065 }
1066
1067 private:
1068 /*! \brief String stream for function name */
1069 std::ostringstream readable_name_stream_;
1070 /*! \brief Map from parameter to its shape function usage state */
1071 std::unordered_map<Expr, int, ObjectPtrHash, ObjectPtrEqual> param_states_;
1072 /*! \brief Map from parameter to list of data placeholder */
1073 std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_data_;
1074 /*! \brief Map from parameter to list of shape placeholder */
1075 std::unordered_map<Expr, Array<te::Tensor>, ObjectPtrHash, ObjectPtrEqual> param_shapes_;
1076 /*! \brief Stack of data dependencies for shape function, specified per each op input */
1077 std::vector<bool> data_dependents_per_input_;
1078 /*! \brief Scalars used in the shape function */
1079 Array<te::Tensor> scalars_;
1080 /*! \brief Map from parameters of a nested function to corresponding arguments in a function
1081 * call site.
1082 */
1083 std::unordered_map<Var, Expr, ObjectPtrHash, ObjectPtrEqual> param_arg_map_;
1084};
1085
1086CachedFunc ShapeFuncFor(const Function& prim_func, const Target& target,
1087 GlobalVarSupply global_var_supply) {
1088 return MakeShapeFunc().Create(prim_func, target, global_var_supply);
1089}
1090
1091std::tuple<Array<te::Tensor>, Array<runtime::NDArray>, std::string> LowerTECompute(
1092 const Function& source_func, Target target, NameSupply constant_name_supply,
1093 bool return_inputs) {
1094 LowerToTECompute lower_te_compute(target, constant_name_supply);
1095 Array<te::Tensor> outputs = lower_te_compute.Lower(source_func);
1096 // Following ScheduleBuilder, remove placeholder ops from outputs.
1097 tvm::Array<te::Tensor> tensor_outs;
1098 for (const auto& tensor : outputs) {
1099 if (!tensor->op.as<te::PlaceholderOpNode>()) {
1100 tensor_outs.push_back(tensor);
1101 }
1102 }
1103
1104 tvm::Array<runtime::NDArray> constants;
1105 for (auto [const_node, te_tensor] : lower_te_compute.constant_tensors_) {
1106 tensor_outs.push_back(te_tensor);
1107 constants.push_back(const_node->data);
1108 }
1109
1110 if (return_inputs) {
1111 return std::make_tuple(Concat(lower_te_compute.fn_inputs_, tensor_outs), constants,
1112 lower_te_compute.candidate_name_);
1113 }
1114 return std::make_tuple(tensor_outs, constants, lower_te_compute.candidate_name_);
1115}
1116
1117std::pair<Optional<tir::PrimFunc>, std::string> LowerToPrimFunc(const Function& relay_func,
1118 Target target,
1119 NameSupply constant_name_supply) {
1120 ICHECK(relay_func->HasNonzeroAttr(attr::kPrimitive))
1121 << "The input must be a Relay primitive function.";
1122
1123 auto [inputs_outputs, constants, fused_name] =
1124 tec::LowerTECompute(relay_func, target, constant_name_supply, /*return_inputs=*/true);
1125 auto tir_converter = backend::GetTIRConverter();
1126 return std::make_pair(tir_converter(inputs_outputs, constants), fused_name);
1127}
1128
1129tir::PrimFunc LowerToPrimFunc(const Function& relay_func, Target target) {
1130 auto [f_opt, _] = LowerToPrimFunc(relay_func, target, NameSupply(""));
1131 (void)_; // to suppress -Werror=unused-variable warning
1132 if (f_opt) {
1133 return f_opt.value();
1134 }
1135 LOG(FATAL) << "Failed to convert the Relay function: " << AsText(relay_func, false);
1136 return PrimFunc();
1137}
1138
1139TVM_REGISTER_GLOBAL("relay.backend.LowerToPrimFunc")
1140 .set_body_typed([](Function relay_func, Target target) {
1141 return LowerToPrimFunc(relay_func, target);
1142 });
1143
1144TVM_REGISTER_GLOBAL("relay.backend.LowerToTE").set_body_typed([](Function prim_func) {
1145 auto tgt = tvm::Target("ext_dev");
1146 LowerToTECompute lower_te_compute(tgt, NameSupply(""));
1147 auto outputs = lower_te_compute.Lower(prim_func);
1148 return CachedFunc(tgt, GlobalVar(lower_te_compute.candidate_name_), lower_te_compute.fn_inputs_,
1149 outputs, te::Schedule(), tir::PrimFunc(), {},
1150 IRModule(Map<GlobalVar, BaseFunc>({})), lower_te_compute.constant_tensors_);
1151});
1152
1153} // namespace tec
1154} // namespace relay
1155} // namespace tvm
1156