1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_stackvm.cc |
22 | */ |
23 | #include "codegen_stackvm.h" |
24 | |
25 | #include <tvm/ir/module.h> |
26 | #include <tvm/runtime/registry.h> |
27 | #include <tvm/tir/builtin.h> |
28 | #include <tvm/tir/function.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <limits> |
32 | #include <utility> |
33 | |
34 | #include "../../runtime/stackvm/stackvm_module.h" |
35 | |
36 | namespace tvm { |
37 | namespace codegen { |
38 | |
39 | using namespace tir; |
40 | |
41 | // map struct field kind to runtime variants |
42 | // We keep two separate enums to ensure runtime/compiler isolation. |
43 | StackVM::StructFieldKind MapFieldKind(int64_t kind) { |
44 | auto val = static_cast<builtin::TVMStructFieldKind>(kind); |
45 | switch (val) { |
46 | case builtin::kArrData: |
47 | return StackVM::kArrData; |
48 | case builtin::kArrShape: |
49 | return StackVM::kArrShape; |
50 | case builtin::kArrAddr: |
51 | return StackVM::kArrAddr; |
52 | case builtin::kArrStrides: |
53 | return StackVM::kArrStrides; |
54 | case builtin::kArrNDim: |
55 | return StackVM::kArrNDim; |
56 | case builtin::kArrTypeCode: |
57 | return StackVM::kArrTypeCode; |
58 | case builtin::kArrTypeBits: |
59 | return StackVM::kArrTypeBits; |
60 | case builtin::kArrTypeLanes: |
61 | return StackVM::kArrTypeLanes; |
62 | case builtin::kArrByteOffset: |
63 | return StackVM::kArrByteOffset; |
64 | case builtin::kArrDeviceId: |
65 | return StackVM::kArrDeviceId; |
66 | case builtin::kArrDeviceType: |
67 | return StackVM::kArrDeviceType; |
68 | case builtin::kTVMValueContent: |
69 | return StackVM::kTVMValueContent; |
70 | default: |
71 | LOG(FATAL) << "Do not know how to map field " << kind; |
72 | } |
73 | return StackVM::kArrData; |
74 | } |
75 | |
76 | StackVM CodeGenStackVM::Compile(const PrimFunc& f) { |
77 | ICHECK_EQ(f->buffer_map.size(), 0U) |
78 | << "Cannot codegen function with buffer_map, please lower them first" ; |
79 | for (size_t i = 0; i < f->params.size(); ++i) { |
80 | Var v = f->params[i]; |
81 | int vid = AllocVarID(v.get()); |
82 | ICHECK_EQ(static_cast<size_t>(vid), i); |
83 | } |
84 | this->Push(f->body); |
85 | vm_.InitCache(); |
86 | return std::move(vm_); |
87 | } |
88 | |
89 | void CodeGenStackVM::Push(const Stmt& n) { |
90 | VisitStmt(n); |
91 | if (debug_) { |
92 | this->PushOp(StackVM::ASSERT_SP, 0); |
93 | } |
94 | } |
95 | |
96 | void CodeGenStackVM::PushOp(StackVM::OpCode opcode) { |
97 | StackVM::Code code; |
98 | code.op_code = opcode; |
99 | vm_.code.push_back(code); |
100 | } |
101 | |
102 | void CodeGenStackVM::SetOperand(int64_t operand_index, int64_t operand) { |
103 | ICHECK(operand >= std::numeric_limits<int>::min() && operand <= std::numeric_limits<int>::max()); |
104 | vm_.code.at(operand_index).v_int = static_cast<int>(operand); |
105 | } |
106 | |
107 | int64_t CodeGenStackVM::PushOp(StackVM::OpCode opcode, int operand) { |
108 | int64_t pc = static_cast<int64_t>(vm_.code.size()); |
109 | StackVM::Code code; |
110 | code.op_code = opcode; |
111 | vm_.code.push_back(code); |
112 | code.v_int = operand; |
113 | vm_.code.push_back(code); |
114 | return pc + 1; |
115 | } |
116 | |
117 | int CodeGenStackVM::GetStrID(const std::string& key) { |
118 | auto it = str_idmap_.find(key); |
119 | if (it != str_idmap_.end()) return it->second; |
120 | int sid = static_cast<int>(vm_.str_data.size()); |
121 | vm_.str_data.push_back(key); |
122 | str_idmap_[key] = sid; |
123 | return sid; |
124 | } |
125 | |
126 | int CodeGenStackVM::AllocVarID(const VarNode* v) { |
127 | ICHECK(!var_idmap_.count(v)); |
128 | int vid = static_cast<int>(vm_.heap_size); |
129 | ICHECK_EQ(vm_.heap_size, var_idmap_.size()); |
130 | vm_.heap_id_name.push_back(v->name_hint); |
131 | ++vm_.heap_size; |
132 | var_idmap_[v] = vid; |
133 | return vid; |
134 | } |
135 | |
136 | int CodeGenStackVM::GetVarID(const VarNode* v) const { |
137 | auto it = var_idmap_.find(v); |
138 | ICHECK(it != var_idmap_.end()) << "Find undefined Variable " << v->name_hint; |
139 | return it->second; |
140 | } |
141 | |
142 | void CodeGenStackVM::VisitExpr_(const LoadNode* op) { |
143 | LOG(FATAL) << "Unexpected use of deprecated LoadNode. Please use BufferLoadNode instead." ; |
144 | } |
145 | |
146 | void CodeGenStackVM::VisitExpr_(const BufferLoadNode* op) { |
147 | ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " |
148 | << "Has StorageFlatten (TE-based schedules) or " |
149 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
150 | auto index = op->indices[0]; |
151 | |
152 | this->Push(op->buffer->data); |
153 | StackVM::OpCode code = StackVM::GetLoad(op->dtype); |
154 | if (const IntImmNode* int_index = index.as<IntImmNode>()) { |
155 | this->PushOp(code, int_index->value); |
156 | } else { |
157 | this->Push(index); |
158 | this->PushOp(StackVM::PUSH_I64, op->dtype.element_of().bytes()); |
159 | this->PushOp(StackVM::MUL_I64); |
160 | this->PushOp(StackVM::ADDR_ADD); |
161 | this->PushOp(code, 0); |
162 | } |
163 | } |
164 | |
165 | void CodeGenStackVM::VisitStmt_(const StoreNode* op) { |
166 | LOG(FATAL) << "Unexpected use of deprecated StoreNode. Please use BufferStoreNode instead." ; |
167 | } |
168 | |
169 | void CodeGenStackVM::VisitStmt_(const BufferStoreNode* op) { |
170 | ICHECK_EQ(op->indices.size(), 1) << "StackVM expects flat 1-d buffers. " |
171 | << "Has StorageFlatten (TE-based schedules) or " |
172 | << "FlattenBuffer (TIR-based schedules) been run?" ; |
173 | auto index = op->indices[0]; |
174 | |
175 | this->Push(op->buffer->data); |
176 | StackVM::OpCode code = StackVM::GetStore(op->value.dtype()); |
177 | if (const IntImmNode* int_index = index.as<IntImmNode>()) { |
178 | this->Push(op->value); |
179 | this->PushOp(code, int_index->value); |
180 | } else { |
181 | this->Push(index); |
182 | this->PushOp(StackVM::PUSH_I64, op->value.dtype().element_of().bytes()); |
183 | this->PushOp(StackVM::MUL_I64); |
184 | this->PushOp(StackVM::ADDR_ADD); |
185 | this->Push(op->value); |
186 | this->PushOp(code, 0); |
187 | } |
188 | } |
189 | |
190 | void CodeGenStackVM::VisitStmt_(const AllocateNode* op) { |
191 | LOG(FATAL) << "Dynamic allocation not supported" ; |
192 | } |
193 | |
194 | void CodeGenStackVM::VisitExpr_(const CallNode* op) { |
195 | if (op->op.same_as(builtin::address_of())) { |
196 | const BufferLoadNode* load = op->args[0].as<BufferLoadNode>(); |
197 | ICHECK(op->args.size() == 1 && load); |
198 | ICHECK_EQ(load->indices.size(), 1) << "CodeGenStackVM only supports flat memory allocations." ; |
199 | |
200 | this->PushOp(StackVM::LOAD_HEAP, GetVarID(load->buffer->data.get())); |
201 | this->Push(load->indices[0]); |
202 | this->PushOp(StackVM::PUSH_I64, load->dtype.element_of().bytes()); |
203 | this->PushOp(StackVM::MUL_I64); |
204 | this->PushOp(StackVM::ADDR_ADD); |
205 | } else if (op->op.same_as(builtin::reinterpret())) { |
206 | this->Push(op->args[0]); |
207 | } else if (op->op.same_as(builtin::tvm_struct_get())) { |
208 | ICHECK_EQ(op->args.size(), 3U); |
209 | int kind = op->args[2].as<IntImmNode>()->value; |
210 | this->Push(op->args[0]); |
211 | const IntImmNode* index = op->args[1].as<IntImmNode>(); |
212 | ICHECK(index != nullptr); |
213 | StackVM::Code code; |
214 | code.op_code = StackVM::TVM_STRUCT_GET; |
215 | vm_.code.push_back(code); |
216 | code.v_int = index->value; |
217 | vm_.code.push_back(code); |
218 | code.v_int = MapFieldKind(kind); |
219 | vm_.code.push_back(code); |
220 | } else if (op->op.same_as(builtin::tvm_call_packed_lowered())) { |
221 | ICHECK_GE(op->args.size(), 5U); |
222 | const StringImmNode* s = op->args[0].as<StringImmNode>(); |
223 | ICHECK(s != nullptr) << "tvm_call_global expect first argument as function name" ; |
224 | this->Push(op->args[1]); |
225 | this->Push(op->args[2]); |
226 | int begin = op->args[3].as<IntImmNode>()->value; |
227 | int end = op->args[4].as<IntImmNode>()->value; |
228 | // find the fuction id. |
229 | const std::string& func_name = s->value; |
230 | auto it = extern_fun_idmap_.find(func_name); |
231 | int fid; |
232 | if (it != extern_fun_idmap_.end()) { |
233 | fid = it->second; |
234 | } else { |
235 | fid = static_cast<int>(vm_.extern_func_name.size()); |
236 | vm_.extern_func_name.push_back(func_name); |
237 | extern_fun_idmap_[func_name] = fid; |
238 | } |
239 | // CALL_PACKED_FUNC |
240 | StackVM::Code code; |
241 | code.op_code = StackVM::CALL_PACKED_LOWERED; |
242 | vm_.code.push_back(code); |
243 | code.v_int = fid; |
244 | vm_.code.push_back(code); |
245 | code.v_int = begin; |
246 | vm_.code.push_back(code); |
247 | code.v_int = end; |
248 | vm_.code.push_back(code); |
249 | } else if (op->op.same_as(builtin::tvm_stack_alloca())) { |
250 | ICHECK_EQ(op->args.size(), 2U); |
251 | const std::string& type = op->args[0].as<StringImmNode>()->value; |
252 | const IntImmNode* num = op->args[1].as<IntImmNode>(); |
253 | ICHECK(num != nullptr); |
254 | static_assert(alignof(TVMValue) % alignof(DLTensor) == 0, "invariant" ); |
255 | // static_assert(alignof(TVMValue) % alignof(tvm_index_t) == 0, "invariant"); |
256 | size_t unit = sizeof(TVMValue); |
257 | size_t size = 0; |
258 | if (type == "shape" ) { |
259 | size = (num->value * sizeof(tvm_index_t) + unit - 1) / unit; |
260 | } else if (type == "arg_value" ) { |
261 | size = (num->value * sizeof(TVMValue) + unit - 1) / unit; |
262 | } else if (type == "arg_tcode" ) { |
263 | size = (num->value * sizeof(int) + unit - 1) / unit; |
264 | } else if (type == "array" ) { |
265 | size = (num->value * sizeof(DLTensor) + unit - 1) / unit; |
266 | } else { |
267 | LOG(FATAL) << "Unknown stack alloca type " << type; |
268 | } |
269 | // add stack size to be safe. |
270 | vm_.stack_size += size; |
271 | this->PushOp(StackVM::TVM_STACK_ALLOCA_BY_8BYTE, static_cast<int>(size)); |
272 | } else if (op->op.same_as(backend_alloc_workspace_op_)) { |
273 | ICHECK_EQ(op->args.size(), 5U); |
274 | this->Push(op->args[0]); |
275 | this->Push(op->args[1]); |
276 | this->Push(op->args[2]); |
277 | this->Push(op->args[3]); |
278 | this->Push(op->args[4]); |
279 | this->PushOp(StackVM::TVM_DEVICE_ALLOCA); |
280 | } else if (op->op.same_as(backend_free_workspace_op_)) { |
281 | ICHECK_EQ(op->args.size(), 3U); |
282 | this->Push(op->args[0]); |
283 | this->Push(op->args[1]); |
284 | this->Push(op->args[2]); |
285 | this->PushOp(StackVM::TVM_DEVICE_FREE); |
286 | } else if (op->op.same_as(builtin::tvm_throw_last_error())) { |
287 | this->PushOp(StackVM::TVM_THROW_LAST_ERROR); |
288 | } else if (op->op.same_as(builtin::isnullptr())) { |
289 | ICHECK_EQ(op->args.size(), 1U); |
290 | this->Push(op->args[0]); |
291 | this->PushOp(StackVM::PUSH_I64, 0); |
292 | this->PushOp(StackVM::EQ_HANDLE); |
293 | } else { |
294 | LOG(FATAL) << "unknown function call " << op->op; |
295 | } |
296 | } |
297 | |
298 | void CodeGenStackVM::PushBinary(StackVM::OpCode op_int64, const PrimExpr& a, const PrimExpr& b) { |
299 | this->Push(a); |
300 | this->Push(b); |
301 | DataType t = a.dtype(); |
302 | if (t.is_int()) { |
303 | this->PushOp(op_int64); |
304 | } else if (t.is_uint()) { |
305 | this->PushOp(op_int64); |
306 | } else { |
307 | this->PushOp(StackVM::CodeI64ToF64(op_int64)); |
308 | } |
309 | } |
310 | |
311 | void CodeGenStackVM::PushCast(DataType dst, DataType src) { |
312 | if (dst.is_int()) { |
313 | if (src.is_int() || src.is_uint()) return; |
314 | } else if (dst.is_uint()) { |
315 | if (src.is_int() || src.is_uint()) return; |
316 | } else if (dst.is_float()) { |
317 | if (src.is_float()) return; |
318 | } |
319 | } |
320 | |
321 | void CodeGenStackVM::VisitExpr_(const StringImmNode* op) { |
322 | int sid = this->GetStrID(op->value); |
323 | this->PushOp(StackVM::PUSH_I64, sid); |
324 | } |
325 | |
326 | void CodeGenStackVM::VisitExpr_(const IntImmNode* op) { |
327 | ICHECK(op->value >= std::numeric_limits<int>::min() && |
328 | op->value <= std::numeric_limits<int>::max()) |
329 | << "Int constant exceed bound" ; |
330 | this->PushOp(StackVM::PUSH_I64, static_cast<int>(op->value)); |
331 | } |
332 | |
333 | void CodeGenStackVM::VisitExpr_(const FloatImmNode* op) { |
334 | LOG(FATAL) << "Float Imm is not supported" ; |
335 | } |
336 | |
337 | void CodeGenStackVM::VisitExpr_(const VarNode* op) { |
338 | int vid = this->GetVarID(op); |
339 | this->PushOp(StackVM::LOAD_HEAP, vid); |
340 | } |
341 | |
342 | void CodeGenStackVM::VisitExpr_(const CastNode* op) { |
343 | this->Push(op->value); |
344 | PushCast(op->dtype, op->value.dtype()); |
345 | } |
346 | |
347 | void CodeGenStackVM::VisitExpr_(const AddNode* op) { PushBinary(StackVM::ADD_I64, op->a, op->b); } |
348 | |
349 | void CodeGenStackVM::VisitExpr_(const SubNode* op) { PushBinary(StackVM::SUB_I64, op->a, op->b); } |
350 | |
351 | void CodeGenStackVM::VisitExpr_(const MulNode* op) { PushBinary(StackVM::MUL_I64, op->a, op->b); } |
352 | |
353 | void CodeGenStackVM::VisitExpr_(const DivNode* op) { PushBinary(StackVM::DIV_I64, op->a, op->b); } |
354 | |
355 | void CodeGenStackVM::VisitExpr_(const ModNode* op) { PushBinary(StackVM::MOD_I64, op->a, op->b); } |
356 | |
357 | void CodeGenStackVM::VisitExpr_(const MinNode* op) { |
358 | this->Push(op->a); |
359 | this->Push(op->b); |
360 | this->PushOp(StackVM::PUSH_VALUE, -1); |
361 | this->PushOp(StackVM::PUSH_VALUE, -1); |
362 | this->PushOp(StackVM::LT_I64); |
363 | this->PushOp(StackVM::SELECT); |
364 | } |
365 | |
366 | void CodeGenStackVM::VisitExpr_(const MaxNode* op) { |
367 | this->Push(op->a); |
368 | this->Push(op->b); |
369 | this->PushOp(StackVM::PUSH_VALUE, 0); |
370 | this->PushOp(StackVM::PUSH_VALUE, -2); |
371 | this->PushOp(StackVM::LT_I64); |
372 | this->PushOp(StackVM::SELECT); |
373 | } |
374 | |
375 | void CodeGenStackVM::VisitExpr_(const EQNode* op) { PushBinary(StackVM::EQ_I64, op->a, op->b); } |
376 | |
377 | void CodeGenStackVM::VisitExpr_(const LENode* op) { PushBinary(StackVM::LE_I64, op->a, op->b); } |
378 | |
379 | void CodeGenStackVM::VisitExpr_(const NENode* op) { |
380 | PushBinary(StackVM::EQ_I64, op->a, op->b); |
381 | this->PushOp(StackVM::NOT); |
382 | } |
383 | |
384 | void CodeGenStackVM::VisitExpr_(const LTNode* op) { PushBinary(StackVM::LT_I64, op->a, op->b); } |
385 | |
386 | void CodeGenStackVM::VisitExpr_(const GENode* op) { |
387 | PushBinary(StackVM::LT_I64, op->a, op->b); |
388 | this->PushOp(StackVM::NOT); |
389 | } |
390 | |
391 | void CodeGenStackVM::VisitExpr_(const GTNode* op) { |
392 | PushBinary(StackVM::LE_I64, op->a, op->b); |
393 | this->PushOp(StackVM::NOT); |
394 | } |
395 | |
396 | void CodeGenStackVM::VisitExpr_(const AndNode* op) { |
397 | this->Push(op->a); |
398 | int64_t pc_jump = this->GetPC(); |
399 | int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); |
400 | this->PushOp(StackVM::POP); |
401 | this->Push(op->b); |
402 | int64_t diff = this->GetPC() - pc_jump; |
403 | this->SetOperand(opr_index, diff); |
404 | } |
405 | |
406 | void CodeGenStackVM::VisitExpr_(const OrNode* op) { |
407 | this->Push(op->a); |
408 | int64_t pc_jump = this->GetPC(); |
409 | int64_t opr_index = this->PushOp(StackVM::RJUMP_IF_TRUE, 0); |
410 | this->Push(op->b); |
411 | int64_t diff = this->GetPC() - pc_jump; |
412 | this->SetOperand(opr_index, diff); |
413 | } |
414 | |
415 | void CodeGenStackVM::VisitExpr_(const NotNode* op) { |
416 | this->Push(op->a); |
417 | this->PushOp(StackVM::NOT); |
418 | } |
419 | |
420 | void CodeGenStackVM::VisitStmt_(const ForNode* op) { |
421 | ICHECK(is_zero(op->min)); |
422 | int vid = this->AllocVarID(op->loop_var.get()); |
423 | this->PushOp(StackVM::PUSH_I64, 0); |
424 | int64_t loop_head = this->GetPC(); |
425 | this->PushOp(StackVM::STORE_HEAP, vid); |
426 | this->PushOp(StackVM::LOAD_HEAP, vid); |
427 | this->Push(op->extent); |
428 | this->PushOp(StackVM::LT_I64); |
429 | int64_t label_fjump = this->GetPC(); |
430 | int64_t foward_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); |
431 | this->PushOp(StackVM::POP); |
432 | this->Push(op->body); |
433 | this->PushOp(StackVM::LOAD_HEAP, vid); |
434 | this->PushOp(StackVM::PUSH_I64, 1); |
435 | this->PushOp(StackVM::ADD_I64); |
436 | int64_t label_bjump = this->GetPC(); |
437 | int64_t backward_jump = this->PushOp(StackVM::RJUMP, 0); |
438 | int64_t loop_end = this->GetPC(); |
439 | this->PushOp(StackVM::POP); |
440 | this->SetOperand(foward_jump, loop_end - label_fjump); |
441 | this->SetOperand(backward_jump, loop_head - label_bjump); |
442 | } |
443 | |
444 | void CodeGenStackVM::VisitStmt_(const SeqStmtNode* op) { |
445 | for (Stmt stmt : op->seq) { |
446 | this->Push(stmt); |
447 | } |
448 | } |
449 | |
450 | void CodeGenStackVM::VisitStmt_(const EvaluateNode* ev) { |
451 | if (is_const_int(ev->value)) return; |
452 | const CallNode* op = ev->value.as<CallNode>(); |
453 | if (op && op->op.same_as(builtin::tvm_struct_set())) { |
454 | ICHECK_EQ(op->args.size(), 4U); |
455 | this->Push(op->args[0]); |
456 | this->Push(op->args[3]); |
457 | const IntImmNode* index = op->args[1].as<IntImmNode>(); |
458 | ICHECK(index != nullptr); |
459 | StackVM::Code code; |
460 | code.op_code = StackVM::TVM_STRUCT_SET; |
461 | vm_.code.push_back(code); |
462 | code.v_int = index->value; |
463 | vm_.code.push_back(code); |
464 | code.v_int = MapFieldKind(op->args[2].as<IntImmNode>()->value); |
465 | vm_.code.push_back(code); |
466 | } else { |
467 | this->Push(ev->value); |
468 | this->PushOp(StackVM::POP); |
469 | } |
470 | } |
471 | |
472 | void CodeGenStackVM::VisitStmt_(const IfThenElseNode* op) { |
473 | this->Push(op->condition); |
474 | int64_t label_ejump = this->GetPC(); |
475 | int64_t else_jump = this->PushOp(StackVM::RJUMP_IF_FALSE, 0); |
476 | this->PushOp(StackVM::POP); |
477 | this->Push(op->then_case); |
478 | if (op->else_case) { |
479 | int64_t label_then_jump = this->GetPC(); |
480 | int64_t then_jump = this->PushOp(StackVM::RJUMP, 0); |
481 | int64_t else_begin = this->GetPC(); |
482 | this->SetOperand(else_jump, else_begin - label_ejump); |
483 | this->PushOp(StackVM::POP); |
484 | this->Push(op->else_case.value()); |
485 | int64_t if_end = this->GetPC(); |
486 | this->SetOperand(then_jump, if_end - label_then_jump); |
487 | } else { |
488 | int64_t if_end = this->GetPC(); |
489 | this->SetOperand(else_jump, if_end - label_ejump); |
490 | this->PushOp(StackVM::POP); |
491 | } |
492 | } |
493 | |
494 | void CodeGenStackVM::VisitStmt_(const LetStmtNode* op) { |
495 | this->Push(op->value); |
496 | int64_t vid = this->AllocVarID(op->var.get()); |
497 | this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid)); |
498 | this->Push(op->body); |
499 | } |
500 | |
501 | void CodeGenStackVM::VisitExpr_(const RampNode* op) { LOG(FATAL) << "Ramp is not supported" ; } |
502 | |
503 | void CodeGenStackVM::VisitExpr_(const BroadcastNode* op) { |
504 | LOG(FATAL) << "Broadcast is not supported" ; |
505 | } |
506 | |
507 | void CodeGenStackVM::VisitExpr_(const SelectNode* op) { |
508 | this->Push(op->true_value); |
509 | this->Push(op->false_value); |
510 | this->Push(op->condition); |
511 | this->PushOp(StackVM::SELECT); |
512 | } |
513 | |
514 | void CodeGenStackVM::VisitStmt_(const AssertStmtNode* op) { |
515 | if (const auto* str = op->message.as<StringImmNode>()) { |
516 | int sid = this->GetStrID(str->value); |
517 | this->Push(op->condition); |
518 | this->PushOp(StackVM::ASSERT, sid); |
519 | } |
520 | this->Push(op->body); |
521 | } |
522 | |
523 | void CodeGenStackVM::VisitStmt_(const AttrStmtNode* op) { this->Push(op->body); } |
524 | |
525 | void CodeGenStackVM::VisitExpr_(const LetNode* op) { |
526 | this->Push(op->value); |
527 | int64_t vid = this->AllocVarID(op->var.get()); |
528 | this->PushOp(StackVM::STORE_HEAP, static_cast<int>(vid)); |
529 | this->Push(op->body); |
530 | } |
531 | |
532 | runtime::Module BuildStackVM(IRModule mod, Target target) { |
533 | std::unordered_map<std::string, StackVM> fmap; |
534 | std::string entry_func; |
535 | |
536 | for (auto kv : mod->functions) { |
537 | ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenStackVM: Can only take PrimFunc" ; |
538 | auto f = Downcast<PrimFunc>(kv.second); |
539 | auto global_symbol = f->GetAttr<String>(tvm::attr::kGlobalSymbol); |
540 | ICHECK(global_symbol.defined()) |
541 | << "CodeGenStackVM: Expect PrimFunc to have the global_symbol attribute" ; |
542 | std::string f_name = global_symbol.value(); |
543 | StackVM vm = codegen::CodeGenStackVM().Compile(f); |
544 | ICHECK(!fmap.count(f_name)) << "Function name " << f_name << "already exist in list" ; |
545 | fmap[f_name] = std::move(vm); |
546 | |
547 | if (f->HasNonzeroAttr(tir::attr::kIsEntryFunc)) { |
548 | entry_func = f_name; |
549 | } |
550 | } |
551 | |
552 | return runtime::StackVMModuleCreate(fmap, entry_func); |
553 | } |
554 | |
555 | TVM_REGISTER_GLOBAL("target.build.stackvm" ).set_body_typed(BuildStackVM); |
556 | } // namespace codegen |
557 | } // namespace tvm |
558 | |