1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * Implementation stack VM.
22 * \file stackvm.cc
23 */
24#include "stackvm.h"
25
26#include <dmlc/thread_local.h>
27#include <tvm/runtime/c_backend_api.h>
28
29#include <algorithm>
30
31namespace tvm {
32namespace runtime {
33
34typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore;
35
36StackVM::State* StackVM::ThreadLocalState() { return StackVMStateStore::Get(); }
37
38#define STACK_VM_BINOP(OP, FIELD) \
39 { \
40 stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \
41 sp -= 1; \
42 pc += 1; \
43 }
44
45#define STACK_VM_CMPOP(OP, FIELD) \
46 { \
47 stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \
48 sp -= 1; \
49 pc += 1; \
50 }
51
52#define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \
53 { \
54 int index = code[pc + 1].v_int; \
55 stack[sp] FIELD = static_cast<DST_TYPE>(static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]); \
56 pc += 2; \
57 }
58
59#define STACK_VM_STORE(FIELD, DST_TYPE) \
60 { \
61 int index = code[pc + 1].v_int; \
62 static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] = \
63 static_cast<DST_TYPE>(stack[sp] FIELD); \
64 sp -= 2; \
65 pc += 2; \
66 }
67
68#define STACK_VM_PRINT_CODE0(CODE) \
69 case CODE: { \
70 os << "[" << pc << "]\t" << #CODE << std::endl; \
71 return pc + 1; \
72 }
73
74#define STACK_VM_PRINT_CODE1(CODE) \
75 case CODE: { \
76 os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \
77 << "[" << pc + 1 << "]" << std::endl; \
78 return pc + 2; \
79 }
80
81#define STACK_VM_PRINT_CODE2(CODE) \
82 case CODE: { \
83 os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " << code[pc + 2].v_int \
84 << "\n" \
85 << "[" << pc + 1 << "]" << std::endl \
86 << "[" << pc + 2 << "]" << std::endl; \
87 return pc + 3; \
88 }
89
90#define STACK_VM_PRINT_HEAP_ACCESS(CODE) \
91 case CODE: { \
92 os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " \
93 << heap_id_name[code[pc + 1].v_int] << "\n" \
94 << "[" << pc + 1 << "]" << std::endl; \
95 return pc + 2; \
96 }
97
98#define STACK_VM_PRINT_JUMP(CODE) \
99 case CODE: { \
100 os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int << " to " \
101 << pc + code[pc + 1].v_int << '\n' \
102 << "[" << pc + 1 << "]" << std::endl; \
103 return pc + 2; \
104 }
105
106int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const {
107 switch (code[pc].op_code) {
108 // int
109 STACK_VM_PRINT_CODE0(ADD_I64);
110 STACK_VM_PRINT_CODE0(SUB_I64);
111 STACK_VM_PRINT_CODE0(MUL_I64);
112 STACK_VM_PRINT_CODE0(MOD_I64);
113 STACK_VM_PRINT_CODE0(DIV_I64);
114 STACK_VM_PRINT_CODE0(EQ_I64);
115 STACK_VM_PRINT_CODE0(LT_I64);
116 STACK_VM_PRINT_CODE0(LE_I64);
117 // floats
118 STACK_VM_PRINT_CODE0(ADD_F64);
119 STACK_VM_PRINT_CODE0(SUB_F64);
120 STACK_VM_PRINT_CODE0(MUL_F64);
121 STACK_VM_PRINT_CODE0(DIV_F64);
122 STACK_VM_PRINT_CODE0(EQ_F64);
123 STACK_VM_PRINT_CODE0(LT_F64);
124 STACK_VM_PRINT_CODE0(LE_F64);
125 // handle.
126 STACK_VM_PRINT_CODE0(EQ_HANDLE);
127 // addressing load
128 STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32);
129 STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT32);
130 STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT64);
131 STACK_VM_PRINT_CODE1(ARRAY_LOAD_FP64);
132 STACK_VM_PRINT_CODE1(ARRAY_LOAD_HANDLE);
133 STACK_VM_PRINT_CODE1(ARRAY_LOAD_TVMVALUE);
134 STACK_VM_PRINT_CODE1(ARRAY_STORE_UINT32);
135 STACK_VM_PRINT_CODE1(ARRAY_STORE_INT32);
136 STACK_VM_PRINT_CODE1(ARRAY_STORE_INT64);
137 STACK_VM_PRINT_CODE1(ARRAY_STORE_FP64);
138 STACK_VM_PRINT_CODE1(ARRAY_STORE_HANDLE);
139 STACK_VM_PRINT_CODE1(ARRAY_STORE_TVMVALUE);
140 STACK_VM_PRINT_CODE0(NOT);
141 STACK_VM_PRINT_CODE0(ADDR_ADD);
142 // stack ops
143 STACK_VM_PRINT_CODE1(PUSH_I64);
144 STACK_VM_PRINT_CODE1(PUSH_VALUE);
145 STACK_VM_PRINT_CODE0(POP);
146 STACK_VM_PRINT_CODE0(SELECT);
147 STACK_VM_PRINT_HEAP_ACCESS(STORE_HEAP);
148 STACK_VM_PRINT_HEAP_ACCESS(LOAD_HEAP);
149 STACK_VM_PRINT_CODE1(ASSERT);
150 STACK_VM_PRINT_JUMP(RJUMP_IF_TRUE);
151 STACK_VM_PRINT_JUMP(RJUMP_IF_FALSE);
152 STACK_VM_PRINT_JUMP(RJUMP);
153 STACK_VM_PRINT_CODE1(ASSERT_SP);
154 // Intrinsics
155 STACK_VM_PRINT_CODE2(TVM_STRUCT_GET);
156 STACK_VM_PRINT_CODE2(TVM_STRUCT_SET);
157 // Allocate data by 8 bytes.
158 STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE);
159 STACK_VM_PRINT_CODE0(TVM_DEVICE_ALLOCA);
160 STACK_VM_PRINT_CODE0(TVM_DEVICE_FREE);
161 STACK_VM_PRINT_CODE0(TVM_THROW_LAST_ERROR);
162 // packed function.
163 case CALL_PACKED_LOWERED: {
164 int call_fid = code[pc + 1].v_int;
165 int begin = code[pc + 2].v_int;
166 int end = code[pc + 3].v_int;
167 os << "[" << pc << "]\tCALL_PACKED_FUNC "
168 << " fid=" << call_fid << " begin=" << begin << " end=" << end;
169 os << '\n';
170 for (int i = 0; i < 3; ++i) {
171 os << "[" << pc + 1 + i << "]" << std::endl;
172 }
173 return pc + 4;
174 }
175 }
176 LOG(FATAL) << "unknown op code " << code[pc].op_code;
177}
178
179std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*)
180 int64_t pc = 0;
181 const int64_t code_size = static_cast<int64_t>(vm.code.size());
182 os << "Program dump: code-size=" << code_size << '\n' << "----------begin-----------------\n";
183 while (pc < code_size) {
184 pc = vm.PrintCode(os, pc);
185 }
186 os << "----------end--------------------\n";
187 return os;
188}
189
190void StackVM::Run(const runtime::TVMArgs& args, runtime::ModuleNode* mod_ctx) const {
191 StackVM::State* s = StackVM::ThreadLocalState();
192 if (s->heap.size() < heap_size) {
193 s->heap.resize(heap_size);
194 }
195 s->sp = 0;
196 s->pc = 0;
197 s->mod_ctx = mod_ctx;
198 s->heap[0].v_handle = (void*)args.values; // NOLINT(*)
199 s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*)
200 s->heap[2].v_int64 = args.num_args;
201 this->Run(s);
202}
203
204void StackVM::InitCache() {
205 extern_func_cache_.clear();
206 extern_func_cache_.resize(extern_func_name.size(), PackedFunc(nullptr));
207}
208
209void StackVM::Save(dmlc::Stream* strm) const {
210 // to be endian invariant.
211 std::vector<int32_t> code_copy(code.size());
212 std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { return c.v_int; });
213 strm->Write(code_copy);
214 strm->Write(str_data);
215 strm->Write(extern_func_name);
216 strm->Write(heap_id_name);
217 strm->Write(heap_size);
218 strm->Write(stack_size);
219}
220
221bool StackVM::Load(dmlc::Stream* strm) {
222 // to be endian invariant.
223 std::vector<int32_t> code_copy;
224 if (!strm->Read(&code_copy)) return false;
225 code.resize(code_copy.size());
226 std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) {
227 Code code;
228 code.v_int = v;
229 return code;
230 });
231 if (!strm->Read(&str_data)) return false;
232 if (!strm->Read(&extern_func_name)) return false;
233 if (!strm->Read(&heap_id_name)) return false;
234 if (!strm->Read(&heap_size)) return false;
235 if (!strm->Read(&stack_size)) return false;
236 this->InitCache();
237 return true;
238}
239
240void StackVM::Run(State* s) const {
241 int64_t sp = s->sp;
242 int64_t pc = s->pc;
243 int64_t alloca_sp = s->sp;
244 std::vector<TVMValue>& stack = s->stack;
245 std::vector<TVMValue>& heap = s->heap;
246 if (stack.size() < stack_size) {
247 stack.resize(stack_size);
248 }
249 int64_t stack_cap = static_cast<int64_t>(stack_size - 4);
250 if (heap.size() < heap_size) {
251 heap.resize(heap_size);
252 }
253 const int64_t code_size = static_cast<int64_t>(code.size());
254 while (pc < code_size) {
255 switch (code[pc].op_code) {
256 case ADD_I64:
257 STACK_VM_BINOP(+, v_int64);
258 break;
259 case SUB_I64:
260 STACK_VM_BINOP(-, v_int64);
261 break;
262 case MUL_I64:
263 STACK_VM_BINOP(*, v_int64);
264 break;
265 case DIV_I64:
266 STACK_VM_BINOP(/, v_int64);
267 break;
268 case MOD_I64:
269 STACK_VM_BINOP(%, v_int64);
270 break;
271 case EQ_I64:
272 STACK_VM_CMPOP(==, v_int64);
273 break;
274 case LT_I64:
275 STACK_VM_CMPOP(<, v_int64);
276 break;
277 case LE_I64:
278 STACK_VM_CMPOP(<=, v_int64);
279 break;
280 case ADD_F64:
281 STACK_VM_BINOP(+, v_float64);
282 break;
283 case SUB_F64:
284 STACK_VM_BINOP(-, v_float64);
285 break;
286 case MUL_F64:
287 STACK_VM_BINOP(*, v_float64);
288 break;
289 case DIV_F64:
290 STACK_VM_BINOP(/, v_float64);
291 break;
292 case EQ_F64:
293 STACK_VM_CMPOP(==, v_float64);
294 break;
295 case LT_F64:
296 STACK_VM_CMPOP(<, v_float64);
297 break;
298 case LE_F64:
299 STACK_VM_CMPOP(<=, v_float64);
300 break;
301 case EQ_HANDLE:
302 STACK_VM_CMPOP(==, v_handle);
303 break;
304 // addressing
305 case ARRAY_LOAD_UINT32:
306 STACK_VM_LOAD(.v_int64, int64_t, uint32_t);
307 break;
308 case ARRAY_LOAD_INT32:
309 STACK_VM_LOAD(.v_int64, int64_t, int32_t);
310 break;
311 case ARRAY_LOAD_INT64:
312 STACK_VM_LOAD(.v_int64, int64_t, int64_t);
313 break;
314 case ARRAY_LOAD_FP64:
315 STACK_VM_LOAD(.v_float64, double, double);
316 break;
317 case ARRAY_LOAD_HANDLE:
318 STACK_VM_LOAD(.v_handle, void*, void*);
319 break;
320 case ARRAY_LOAD_TVMVALUE:
321 STACK_VM_LOAD(, TVMValue, TVMValue);
322 break;
323 // store
324 case ARRAY_STORE_UINT32:
325 STACK_VM_STORE(.v_int64, uint32_t);
326 break;
327 case ARRAY_STORE_INT32:
328 STACK_VM_STORE(.v_int64, int32_t);
329 break;
330 case ARRAY_STORE_INT64:
331 STACK_VM_STORE(.v_int64, int64_t);
332 break;
333 case ARRAY_STORE_FP64:
334 STACK_VM_STORE(.v_float64, double);
335 break;
336 case ARRAY_STORE_HANDLE:
337 STACK_VM_STORE(.v_handle, void*);
338 break;
339 case ARRAY_STORE_TVMVALUE:
340 STACK_VM_STORE(, TVMValue);
341 break;
342 // add
343 case ADDR_ADD: {
344 stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*)
345 sp = sp - 1;
346 pc = pc + 1;
347 break;
348 }
349 case NOT: {
350 stack[sp].v_int64 = !stack[sp].v_int64;
351 pc += 1;
352 break;
353 }
354 case PUSH_I64: {
355 stack[sp + 1].v_int64 = code[pc + 1].v_int;
356 sp += 1;
357 pc += 2;
358 break;
359 }
360 case PUSH_VALUE: {
361 int relpos = code[pc + 1].v_int;
362 ICHECK_LE(relpos, 0);
363 stack[sp + 1] = stack[sp + relpos];
364 sp += 1;
365 pc += 2;
366 break;
367 }
368 case POP: {
369 sp -= 1;
370 pc += 1;
371 break;
372 }
373 case SELECT: {
374 stack[sp - 2] = (stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]);
375 sp -= 2;
376 pc += 1;
377 break;
378 }
379 case LOAD_HEAP: {
380 stack[sp + 1] = heap[code[pc + 1].v_int];
381 sp += 1;
382 pc += 2;
383 break;
384 }
385 case STORE_HEAP: {
386 heap[code[pc + 1].v_int] = stack[sp];
387 sp -= 1;
388 pc += 2;
389 break;
390 }
391 case ASSERT: {
392 ICHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int];
393 sp -= 1;
394 pc += 2;
395 break;
396 }
397 case RJUMP_IF_TRUE: {
398 if (stack[sp].v_int64) {
399 pc += code[pc + 1].v_int;
400 } else {
401 pc += 2;
402 }
403 break;
404 }
405 case RJUMP_IF_FALSE: {
406 if (!stack[sp].v_int64) {
407 pc += code[pc + 1].v_int;
408 } else {
409 pc += 2;
410 }
411 break;
412 }
413 case RJUMP: {
414 pc += code[pc + 1].v_int;
415 break;
416 }
417 case ASSERT_SP: {
418 int64_t expected = code[pc + 1].v_int;
419 ICHECK_EQ(sp, expected) << "sp assertion failed, expected=" << expected << " now=" << sp
420 << ", pc=" << pc;
421 pc += 2;
422 break;
423 }
424 case CALL_PACKED_LOWERED: {
425 // call packed function.
426 TVMValue* value_stack = static_cast<TVMValue*>(stack[sp - 1].v_handle);
427 int* type_stack = static_cast<int*>(stack[sp].v_handle);
428 int call_fid = code[pc + 1].v_int;
429 int begin = code[pc + 2].v_int;
430 int end = code[pc + 3].v_int;
431 int num_args = end - begin;
432 static_assert(sizeof(Code) == sizeof(int) && alignof(Code) == alignof(int), "asusmption");
433 runtime::TVMRetValue rv;
434 GetExtern(s, call_fid)
435 .CallPacked(runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv);
436 sp = sp - 1;
437 stack[sp] = rv.value();
438 pc += 4;
439 break;
440 }
441 // intrinsics
442 case TVM_STRUCT_GET: {
443 int index = code[pc + 1].v_int;
444 int kind = code[pc + 2].v_int;
445 DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle);
446 switch (kind) {
447 case StackVM::kArrData: {
448 stack[sp].v_handle = arr[index].data;
449 break;
450 }
451 case StackVM::kArrShape: {
452 stack[sp].v_handle = arr[index].shape;
453 break;
454 }
455 case StackVM::kArrStrides: {
456 stack[sp].v_handle = arr[index].strides;
457 break;
458 }
459 case StackVM::kArrNDim: {
460 stack[sp].v_int64 = arr[index].ndim;
461 break;
462 }
463 case StackVM::kArrTypeCode: {
464 stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.code);
465 break;
466 }
467 case StackVM::kArrTypeBits: {
468 stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.bits);
469 break;
470 }
471 case StackVM::kArrTypeLanes: {
472 stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.lanes);
473 break;
474 }
475 case StackVM::kArrByteOffset: {
476 stack[sp].v_int64 = static_cast<int64_t>(arr[index].byte_offset);
477 break;
478 }
479 case StackVM::kArrDeviceId: {
480 stack[sp].v_int64 = arr[index].device.device_id;
481 break;
482 }
483 case StackVM::kArrDeviceType: {
484 stack[sp].v_int64 = static_cast<int64_t>(arr[index].device.device_type);
485 break;
486 }
487 case StackVM::kArrAddr: {
488 stack[sp].v_handle = arr + index;
489 break;
490 }
491 case StackVM::kTVMValueContent: {
492 stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index];
493 break;
494 }
495 default:
496 LOG(FATAL) << "unhandled get " << kind;
497 }
498 pc = pc + 3;
499 break;
500 }
501 case TVM_STRUCT_SET: {
502 int index = code[pc + 1].v_int;
503 int kind = code[pc + 2].v_int;
504 DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle);
505 switch (kind) {
506 case StackVM::kArrData: {
507 arr[index].data = stack[sp].v_handle;
508 break;
509 }
510 case StackVM::kArrShape: {
511 arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle);
512 break;
513 }
514 case StackVM::kArrStrides: {
515 arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle);
516 break;
517 }
518 case StackVM::kArrNDim: {
519 arr[index].ndim = static_cast<int>(stack[sp].v_int64);
520 break;
521 }
522 case StackVM::kArrTypeCode: {
523 arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64);
524 break;
525 }
526 case StackVM::kArrTypeBits: {
527 arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64);
528 break;
529 }
530 case StackVM::kArrTypeLanes: {
531 arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64);
532 break;
533 }
534 case StackVM::kArrByteOffset: {
535 arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64);
536 break;
537 }
538 case StackVM::kArrDeviceId: {
539 arr[index].device.device_id = static_cast<int>(stack[sp].v_int64);
540 break;
541 }
542 case StackVM::kArrDeviceType: {
543 arr[index].device.device_type = static_cast<DLDeviceType>(stack[sp].v_int64);
544 break;
545 }
546 case StackVM::kTVMValueContent: {
547 static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp];
548 break;
549 }
550 default:
551 LOG(FATAL) << "unhandled tvm_struct_set " << kind;
552 }
553 sp -= 2;
554 pc += 3;
555 break;
556 }
557 // alloca
558 case TVM_STACK_ALLOCA_BY_8BYTE: {
559 static_assert(sizeof(TVMValue) == 8, "invariance");
560 int num = code[pc + 1].v_int;
561 void* addr = &stack[sp] + 1;
562 sp = sp + num + 1;
563 alloca_sp = sp - 1;
564 stack[sp].v_handle = addr;
565 pc = pc + 2;
566 break;
567 }
568 case TVM_DEVICE_ALLOCA: {
569 int device_type = static_cast<int>(stack[sp - 4].v_int64);
570 int device_id = static_cast<int>(stack[sp - 3].v_int64);
571 size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64);
572 int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64);
573 int dtype_bits_hint = static_cast<int>(stack[sp].v_int64);
574 void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint,
575 dtype_bits_hint);
576 stack[sp - 4].v_handle = ptr;
577 sp = sp - 4;
578 pc = pc + 1;
579 break;
580 }
581 case TVM_DEVICE_FREE: {
582 int device_type = static_cast<int>(stack[sp - 2].v_int64);
583 int device_id = static_cast<int>(stack[sp - 1].v_int64);
584 void* ptr = stack[sp].v_handle;
585 int ret = TVMBackendFreeWorkspace(device_type, device_id, ptr);
586 stack[sp - 2].v_int64 = ret;
587 sp = sp - 2;
588 pc = pc + 1;
589 break;
590 }
591 case TVM_THROW_LAST_ERROR: {
592 LOG(FATAL) << TVMGetLastError();
593 break;
594 }
595 }
596 ICHECK_GE(sp, alloca_sp) << "touch allocated space";
597 ICHECK_LT(sp, stack_cap) << "Stack overflow";
598 }
599}
600
601const PackedFunc& StackVM::GetExtern(State* s, int fid) const {
602 ICHECK_LT(static_cast<size_t>(fid), extern_func_cache_.size());
603 // allow race write in this, since write is idempotent
604 PackedFunc& f = extern_func_cache_[fid];
605 if (f == nullptr) {
606 ICHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm";
607 const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]);
608 ICHECK(pf != nullptr);
609 f = *pf;
610 }
611 return f;
612}
613
614} // namespace runtime
615} // namespace tvm
616