1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_stackvm.cc
22 */
23#include "codegen_stackvm.h"
24
25#include <tvm/ir/module.h>
26#include <tvm/runtime/registry.h>
27#include <tvm/tir/builtin.h>
28#include <tvm/tir/function.h>
29#include <tvm/tir/op.h>
30
31#include <limits>
32#include <utility>
33
34#include "../../runtime/stackvm/stackvm_module.h"
35
36namespace tvm {
37namespace codegen {
38
39using namespace tir;
40
41// map struct field kind to runtime variants
42// We keep two separate enums to ensure runtime/compiler isolation.
43StackVM::StructFieldKind MapFieldKind(int64_t kind) {
44 auto val = static_cast<builtin::TVMStructFieldKind>(kind);
45 switch (val) {
46 case builtin::kArrData:
47 return StackVM::kArrData;
48 case builtin::kArrShape:
49 return StackVM::kArrShape;
50 case builtin::kArrAddr:
51 return StackVM::kArrAddr;
52 case builtin::kArrStrides:
53 return StackVM::kArrStrides;
54 case builtin::kArrNDim:
55 return StackVM::kArrNDim;
56 case builtin::kArrTypeCode:
57 return StackVM::kArrTypeCode;
58 case builtin::kArrTypeBits:
59 return StackVM::kArrTypeBits;
60 case builtin::kArrTypeLanes:
61 return StackVM::kArrTypeLanes;
62 case builtin::kArrByteOffset:
63 return StackVM::kArrByteOffset;
64 case builtin::kArrDeviceId:
65 return StackVM::kArrDeviceId;
66 case builtin::kArrDeviceType:
67 return StackVM::kArrDeviceType;
68 case builtin::kTVMValueContent:
69 return StackVM::kTVMValueContent;
70 default:
71 LOG(FATAL) << "Do not know how to map field " << kind;
72 }
73 return StackVM::kArrData;
74}
75
76StackVM CodeGenStackVM::Compile(const PrimFunc& f) {
77 ICHECK_EQ(f->buffer_map.size(), 0U)
78 << "Cannot codegen function with buffer_map, please lower them first";
79 for (size_t i = 0; i < f->params.size(); ++i) {
80 Var v = f->params[i];
81 int vid = AllocVarID(v.get());
82 ICHECK_EQ(static_cast<size_t>(vid), i);
83 }
84 this->Push(f->body);
85 vm_.InitCache();
86 return std::move(vm_);
87}
88
89void CodeGenStackVM::Push(const Stmt& n) {
90 VisitStmt(n);
91 if (debug_) {
92 this->PushOp(StackVM::ASSERT_SP, 0);
93 }
94}
95
96void CodeGenStackVM::PushOp(StackVM::OpCode opcode) {
97 StackVM::Code code;
98 code.op_code = opcode;
99 vm_.code.push_back(code);
100}
101
102void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) {
103 ICHECK(operand >= std::numeric_limits<int>::min() && operand <= std::numeric_limits<int>::max());
104 vm_.code.at(operand_index).v_int = static_cast<int>(operand);
105}
106
107int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) {
108 int64_t pc = static_cast<int64_t>(vm_.code.size());
109 StackVM::Code code;
110 code.op_code = opcode;
111 vm_.code.push_back(code);
112 code.v_int = operand;
113 vm_.code.push_back(code);
114 return pc + 1;
115}
116
117int CodeGenStackVM::GetStrID(const std::string& key) {
118 auto it = str_idmap_.find(key);
119 if (it != str_idmap_.end()) return it->second;
120 int sid = static_cast<int>(vm_.str_data.size());
121 vm_.str_data.push_back(key);
122 str_idmap_[key] = sid;
123 return sid;
124}
125
126int CodeGenStackVM::AllocVarID(const VarNode* v) {
127 ICHECK(!var_idmap_.count(v));
128 int vid = static_cast<int>(vm_.heap_size);
129 ICHECK_EQ(vm_.heap_size, var_idmap_.size());
130 vm_.heap_id_name.push_back(v->name_hint);
131 ++vm_.heap_size;
132 var_idmap_[v] = vid;
133 return vid;
134}
135
136int CodeGenStackVM::GetVarID(const VarNode* v) const {
137 auto it = var_idmap_.find(v);
138 ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint;
139 return it->second;
140}
141
142void CodeGenStackVM::VisitExpr_(const LoadNode* op) {
143 LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead.";
144}
145
146void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) {
147 ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. "
148 << "Has StorageFlatten (TE-based schedules) or "
149 << "FlattenBuffer (TIR-based schedules) been run?";
150 auto index = op->indices[0];
151
152 this->Push(op->buffer->data);
153 StackVM::OpCode code = StackVM::GetLoad(op->dtype);
154 if (const IntImmNode* int_index = index.as<IntImmNode>()) {
155 this->PushOp(code, int_index->value);
156 } else {
157 this->Push(index);
158 this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes());
159 this->PushOp(StackVM::MUL_I64);
160 this->PushOp(StackVM::ADDR_ADD);
161 this->PushOp(code, 0);
162 }
163}
164
165void CodeGenStackVM::VisitStmt_(const StoreNode* op) {
166 LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead.";
167}
168
169void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) {
170 ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. "
171 << "Has StorageFlatten (TE-based schedules) or "
172 << "FlattenBuffer (TIR-based schedules) been run?";
173 auto index = op->indices[0];
174
175 this->Push(op->buffer->data);
176 StackVM::OpCode code = StackVM::GetStore(op->value.dtype());
177 if (const IntImmNode* int_index = index.as<IntImmNode>()) {
178 this->Push(op->value);
179 this->PushOp(code, int_index->value);
180 } else {
181 this->Push(index);
182 this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes());
183 this->PushOp(StackVM::MUL_I64);
184 this->PushOp(StackVM::ADDR_ADD);
185 this->Push(op->value);
186 this->PushOp(code, 0);
187 }
188}
189
190void CodeGenStackVM::VisitStmt_(const AllocateNode* op) {
191 LOG(FATAL) << "Dynamic allocation not supported";
192}
193
194void CodeGenStackVM::VisitExpr_(const CallNode* op) {
195 if (op->op.same_as(builtin::address_of())) {
196 const BufferLoadNode* load = op->args[0].as<BufferLoadNode>();
197 ICHECK(op->args.size() == 1 && load);
198 ICHECK_EQ(load->indices.size(), 1) << "CodeGenStackVM only supports flat memory allocations.";
199
200 this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get()));
201 this->Push(load->indices[0]);
202 this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes());
203 this->PushOp(StackVM::MUL_I64);
204 this->PushOp(StackVM::ADDR_ADD);
205 } else if (op->op.same_as(builtin::reinterpret())) {
206 this->Push(op->args[0]);
207 } else if (op->op.same_as(builtin::tvm_struct_get())) {
208 ICHECK_EQ(op->args.size(), 3U);
209 int kind = op->args[2].as<IntImmNode>()->value;
210 this->Push(op->args[0]);
211 const IntImmNode* index = op->args[1].as<IntImmNode>();
212 ICHECK(index != nullptr);
213 StackVM::Code code;
214 code.op_code = StackVM::TVM_STRUCT_GET;
215 vm_.code.push_back(code);
216 code.v_int = index->value;
217 vm_.code.push_back(code);
218 code.v_int = MapFieldKind(kind);
219 vm_.code.push_back(code);
220 } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) {
221 ICHECK_GE(op->args.size(), 5U);
222 const StringImmNode* s = op->args[0].as<StringImmNode>();
223 ICHECK(s != nullptr) << "tvm_call_global expect first argument as function name";
224 this->Push(op->args[1]);
225 this->Push(op->args[2]);
226 int begin = op->args[3].as<IntImmNode>()->value;
227 int end = op->args[4].as<IntImmNode>()->value;
228 // find the fuction id.
229 const std::string& func_name = s->value;
230 auto it = extern_fun_idmap_.find(func_name);
231 int fid;
232 if (it != extern_fun_idmap_.end()) {
233 fid = it->second;
234 } else {
235 fid = static_cast<int>(vm_.extern_func_name.size());
236 vm_.extern_func_name.push_back(func_name);
237 extern_fun_idmap_[func_name] = fid;
238 }
239 // CALL_PACKED_FUNC
240 StackVM::Code code;
241 code.op_code = StackVM::CALL_PACKED_LOWERED;
242 vm_.code.push_back(code);
243 code.v_int = fid;
244 vm_.code.push_back(code);
245 code.v_int = begin;
246 vm_.code.push_back(code);
247 code.v_int = end;
248 vm_.code.push_back(code);
249 } else if (op->op.same_as(builtin::tvm_stack_alloca())) {
250 ICHECK_EQ(op->args.size(), 2U);
251 const std::string& type = op->args[0].as<StringImmNode>()->value;
252 const IntImmNode* num = op->args[1].as<IntImmNode>();
253 ICHECK(num != nullptr);
254 static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant");
255 // static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant");
256 size_t unit = sizeof(TVMValue);
257 size_t size = 0;
258 if (type == "shape") {
259 size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit;
260 } else if (type == "arg_value") {
261 size = (num->value * sizeof(TVMValue) + unit - 1) / unit;
262 } else if (type == "arg_tcode") {
263 size = (num->value * sizeof(int) + unit - 1) / unit;
264 } else if (type == "array") {
265 size = (num->value * sizeof(DLTensor) + unit - 1) / unit;
266 } else {
267 LOG(FATAL) << "Unknown stack alloca type " << type;
268 }
269 // add stack size to be safe.
270 vm_.stack_size += size;
271 this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size));
272 } else if (op->op.same_as(backend_alloc_workspace_op_)) {
273 ICHECK_EQ(op->args.size(), 5U);
274 this->Push(op->args[0]);
275 this->Push(op->args[1]);
276 this->Push(op->args[2]);
277 this->Push(op->args[3]);
278 this->Push(op->args[4]);
279 this->PushOp(StackVM::TVM_DEVICE_ALLOCA);
280 } else if (op->op.same_as(backend_free_workspace_op_)) {
281 ICHECK_EQ(op->args.size(), 3U);
282 this->Push(op->args[0]);
283 this->Push(op->args[1]);
284 this->Push(op->args[2]);
285 this->PushOp(StackVM::TVM_DEVICE_FREE);
286 } else if (op->op.same_as(builtin::tvm_throw_last_error())) {
287 this->PushOp(StackVM::TVM_THROW_LAST_ERROR);
288 } else if (op->op.same_as(builtin::isnullptr())) {
289 ICHECK_EQ(op->args.size(), 1U);
290 this->Push(op->args[0]);
291 this->PushOp(StackVM::PUSH_I64, 0);
292 this->PushOp(StackVM::EQ_HANDLE);
293 } else {
294 LOG(FATAL) << "unknown function call " << op->op;
295 }
296}
297
298void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b) {
299 this->Push(a);
300 this->Push(b);
301 DataType t = a.dtype();
302 if (t.is_int()) {
303 this->PushOp(op_int64);
304 } else if (t.is_uint()) {
305 this->PushOp(op_int64);
306 } else {
307 this->PushOp(StackVM::CodeI64ToF64(op_int64));
308 }
309}
310
311void CodeGenStackVM::PushCast(DataType dst, DataType src) {
312 if (dst.is_int()) {
313 if (src.is_int() || src.is_uint()) return;
314 } else if (dst.is_uint()) {
315 if (src.is_int() || src.is_uint()) return;
316 } else if (dst.is_float()) {
317 if (src.is_float()) return;
318 }
319}
320
321void CodeGenStackVM::VisitExpr_(const StringImmNode* op) {
322 int sid = this->GetStrID(op->value);
323 this->PushOp(StackVM::PUSH_I64, sid);
324}
325
326void CodeGenStackVM::VisitExpr_(const IntImmNode* op) {
327 ICHECK(op->value >= std::numeric_limits<int>::min() &&
328 op->value <= std::numeric_limits<int>::max())
329 << "Int constant exceed bound";
330 this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value));
331}
332
333void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) {
334 LOG(FATAL) << "Float Imm is not supported";
335}
336
337void CodeGenStackVM::VisitExpr_(const VarNode* op) {
338 int vid = this->GetVarID(op);
339 this->PushOp(StackVM::LOAD_HEAP, vid);
340}
341
342void CodeGenStackVM::VisitExpr_(const CastNode* op) {
343 this->Push(op->value);
344 PushCast(op->dtype, op->value.dtype());
345}
346
347void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); }
348
349void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); }
350
351void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); }
352
353void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); }
354
355void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); }
356
357void CodeGenStackVM::VisitExpr_(const MinNode* op) {
358 this->Push(op->a);
359 this->Push(op->b);
360 this->PushOp(StackVM::PUSH_VALUE, -1);
361 this->PushOp(StackVM::PUSH_VALUE, -1);
362 this->PushOp(StackVM::LT_I64);
363 this->PushOp(StackVM::SELECT);
364}
365
366void CodeGenStackVM::VisitExpr_(const MaxNode* op) {
367 this->Push(op->a);
368 this->Push(op->b);
369 this->PushOp(StackVM::PUSH_VALUE, 0);
370 this->PushOp(StackVM::PUSH_VALUE, -2);
371 this->PushOp(StackVM::LT_I64);
372 this->PushOp(StackVM::SELECT);
373}
374
375void CodeGenStackVM::VisitExpr_(const EQNode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); }
376
377void CodeGenStackVM::VisitExpr_(const LENode* op) { PushBinary(StackVM::LE_I64, op->a, op->b); }
378
379void CodeGenStackVM::VisitExpr_(const NENode* op) {
380 PushBinary(StackVM::EQ_I64, op->a, op->b);
381 this->PushOp(StackVM::NOT);
382}
383
384void CodeGenStackVM::VisitExpr_(const LTNode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); }
385
386void CodeGenStackVM::VisitExpr_(const GENode* op) {
387 PushBinary(StackVM::LT_I64, op->a, op->b);
388 this->PushOp(StackVM::NOT);
389}
390
391void CodeGenStackVM::VisitExpr_(const GTNode* op) {
392 PushBinary(StackVM::LE_I64, op->a, op->b);
393 this->PushOp(StackVM::NOT);
394}
395
396void CodeGenStackVM::VisitExpr_(const AndNode* op) {
397 this->Push(op->a);
398 int64_t pc_jump = this->GetPC();
399 int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
400 this->PushOp(StackVM::POP);
401 this->Push(op->b);
402 int64_t diff = this->GetPC() - pc_jump;
403 this->SetOperand(opr_index, diff);
404}
405
406void CodeGenStackVM::VisitExpr_(const OrNode* op) {
407 this->Push(op->a);
408 int64_t pc_jump = this->GetPC();
409 int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0);
410 this->Push(op->b);
411 int64_t diff = this->GetPC() - pc_jump;
412 this->SetOperand(opr_index, diff);
413}
414
415void CodeGenStackVM::VisitExpr_(const NotNode* op) {
416 this->Push(op->a);
417 this->PushOp(StackVM::NOT);
418}
419
420void CodeGenStackVM::VisitStmt_(const ForNode* op) {
421 ICHECK(is_zero(op->min));
422 int vid = this->AllocVarID(op->loop_var.get());
423 this->PushOp(StackVM::PUSH_I64, 0);
424 int64_t loop_head = this->GetPC();
425 this->PushOp(StackVM::STORE_HEAP, vid);
426 this->PushOp(StackVM::LOAD_HEAP, vid);
427 this->Push(op->extent);
428 this->PushOp(StackVM::LT_I64);
429 int64_t label_fjump = this->GetPC();
430 int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
431 this->PushOp(StackVM::POP);
432 this->Push(op->body);
433 this->PushOp(StackVM::LOAD_HEAP, vid);
434 this->PushOp(StackVM::PUSH_I64, 1);
435 this->PushOp(StackVM::ADD_I64);
436 int64_t label_bjump = this->GetPC();
437 int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0);
438 int64_t loop_end = this->GetPC();
439 this->PushOp(StackVM::POP);
440 this->SetOperand(foward_jump, loop_end - label_fjump);
441 this->SetOperand(backward_jump, loop_head - label_bjump);
442}
443
444void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) {
445 for (Stmt stmt : op->seq) {
446 this->Push(stmt);
447 }
448}
449
450void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) {
451 if (is_const_int(ev->value)) return;
452 const CallNode* op = ev->value.as<CallNode>();
453 if (op && op->op.same_as(builtin::tvm_struct_set())) {
454 ICHECK_EQ(op->args.size(), 4U);
455 this->Push(op->args[0]);
456 this->Push(op->args[3]);
457 const IntImmNode* index = op->args[1].as<IntImmNode>();
458 ICHECK(index != nullptr);
459 StackVM::Code code;
460 code.op_code = StackVM::TVM_STRUCT_SET;
461 vm_.code.push_back(code);
462 code.v_int = index->value;
463 vm_.code.push_back(code);
464 code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value);
465 vm_.code.push_back(code);
466 } else {
467 this->Push(ev->value);
468 this->PushOp(StackVM::POP);
469 }
470}
471
472void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) {
473 this->Push(op->condition);
474 int64_t label_ejump = this->GetPC();
475 int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0);
476 this->PushOp(StackVM::POP);
477 this->Push(op->then_case);
478 if (op->else_case) {
479 int64_t label_then_jump = this->GetPC();
480 int64_t then_jump = this->PushOp(StackVM::RJUMP, 0);
481 int64_t else_begin = this->GetPC();
482 this->SetOperand(else_jump, else_begin - label_ejump);
483 this->PushOp(StackVM::POP);
484 this->Push(op->else_case.value());
485 int64_t if_end = this->GetPC();
486 this->SetOperand(then_jump, if_end - label_then_jump);
487 } else {
488 int64_t if_end = this->GetPC();
489 this->SetOperand(else_jump, if_end - label_ejump);
490 this->PushOp(StackVM::POP);
491 }
492}
493
494void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) {
495 this->Push(op->value);
496 int64_t vid = this->AllocVarID(op->var.get());
497 this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
498 this->Push(op->body);
499}
500
501void CodeGenStackVM::VisitExpr_(const RampNode* op) { LOG(FATAL) << "Ramp is not supported"; }
502
503void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) {
504 LOG(FATAL) << "Broadcast is not supported";
505}
506
507void CodeGenStackVM::VisitExpr_(const SelectNode* op) {
508 this->Push(op->true_value);
509 this->Push(op->false_value);
510 this->Push(op->condition);
511 this->PushOp(StackVM::SELECT);
512}
513
514void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) {
515 if (const auto* str = op->message.as<StringImmNode>()) {
516 int sid = this->GetStrID(str->value);
517 this->Push(op->condition);
518 this->PushOp(StackVM::ASSERT, sid);
519 }
520 this->Push(op->body);
521}
522
523void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { this->Push(op->body); }
524
525void CodeGenStackVM::VisitExpr_(const LetNode* op) {
526 this->Push(op->value);
527 int64_t vid = this->AllocVarID(op->var.get());
528 this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid));
529 this->Push(op->body);
530}
531
532runtime::Module BuildStackVM(IRModule mod, Target target) {
533 std::unordered_map<std::string, StackVM> fmap;
534 std::string entry_func;
535
536 for (auto kv : mod->functions) {
537 ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenStackVM: Can only take PrimFunc";
538 auto f = Downcast<PrimFunc>(kv.second);
539 auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol);
540 ICHECK(global_symbol.defined())
541 << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute";
542 std::string f_name = global_symbol.value();
543 StackVM vm = codegen::CodeGenStackVM().Compile(f);
544 ICHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list";
545 fmap[f_name] = std::move(vm);
546
547 if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) {
548 entry_func = f_name;
549 }
550 }
551
552 return runtime::StackVMModuleCreate(fmap, entry_func);
553}
554
555TVM_REGISTER_GLOBAL("target.build.stackvm").set_body_typed(BuildStackVM);
556} // namespace codegen
557} // namespace tvm
558