1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * Implementation stack VM. |
22 | * \file stackvm.cc |
23 | */ |
24 | #include "stackvm.h" |
25 | |
26 | #include <dmlc/thread_local.h> |
27 | #include <tvm/runtime/c_backend_api.h> |
28 | |
29 | #include <algorithm> |
30 | |
31 | namespace tvm { |
32 | namespace runtime { |
33 | |
34 | typedef dmlc::ThreadLocalStore<StackVM::State> StackVMStateStore; |
35 | |
36 | StackVM::State* StackVM::ThreadLocalState() { return StackVMStateStore::Get(); } |
37 | |
38 | #define STACK_VM_BINOP(OP, FIELD) \ |
39 | { \ |
40 | stack[sp - 1].FIELD = stack[sp - 1].FIELD OP stack[sp].FIELD; \ |
41 | sp -= 1; \ |
42 | pc += 1; \ |
43 | } |
44 | |
45 | #define STACK_VM_CMPOP(OP, FIELD) \ |
46 | { \ |
47 | stack[sp - 1].v_int64 = stack[sp - 1].FIELD OP stack[sp].FIELD; \ |
48 | sp -= 1; \ |
49 | pc += 1; \ |
50 | } |
51 | |
52 | #define STACK_VM_LOAD(FIELD, DST_TYPE, SRC_TYPE) \ |
53 | { \ |
54 | int index = code[pc + 1].v_int; \ |
55 | stack[sp] FIELD = static_cast<DST_TYPE>(static_cast<SRC_TYPE*>(stack[sp].v_handle)[index]); \ |
56 | pc += 2; \ |
57 | } |
58 | |
59 | #define STACK_VM_STORE(FIELD, DST_TYPE) \ |
60 | { \ |
61 | int index = code[pc + 1].v_int; \ |
62 | static_cast<DST_TYPE*>(stack[sp - 1].v_handle)[index] = \ |
63 | static_cast<DST_TYPE>(stack[sp] FIELD); \ |
64 | sp -= 2; \ |
65 | pc += 2; \ |
66 | } |
67 | |
68 | #define STACK_VM_PRINT_CODE0(CODE) \ |
69 | case CODE: { \ |
70 | os << "[" << pc << "]\t" << #CODE << std::endl; \ |
71 | return pc + 1; \ |
72 | } |
73 | |
74 | #define STACK_VM_PRINT_CODE1(CODE) \ |
75 | case CODE: { \ |
76 | os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << "\n" \ |
77 | << "[" << pc + 1 << "]" << std::endl; \ |
78 | return pc + 2; \ |
79 | } |
80 | |
81 | #define STACK_VM_PRINT_CODE2(CODE) \ |
82 | case CODE: { \ |
83 | os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " << code[pc + 2].v_int \ |
84 | << "\n" \ |
85 | << "[" << pc + 1 << "]" << std::endl \ |
86 | << "[" << pc + 2 << "]" << std::endl; \ |
87 | return pc + 3; \ |
88 | } |
89 | |
90 | #define STACK_VM_PRINT_HEAP_ACCESS(CODE) \ |
91 | case CODE: { \ |
92 | os << "[" << pc << "]\t" << #CODE << " " << code[pc + 1].v_int << " " \ |
93 | << heap_id_name[code[pc + 1].v_int] << "\n" \ |
94 | << "[" << pc + 1 << "]" << std::endl; \ |
95 | return pc + 2; \ |
96 | } |
97 | |
98 | #define STACK_VM_PRINT_JUMP(CODE) \ |
99 | case CODE: { \ |
100 | os << "[" << pc << "]\t" << #CODE << " rel=" << code[pc + 1].v_int << " to " \ |
101 | << pc + code[pc + 1].v_int << '\n' \ |
102 | << "[" << pc + 1 << "]" << std::endl; \ |
103 | return pc + 2; \ |
104 | } |
105 | |
106 | int64_t StackVM::PrintCode(std::ostream& os, int64_t pc) const { |
107 | switch (code[pc].op_code) { |
108 | // int |
109 | STACK_VM_PRINT_CODE0(ADD_I64); |
110 | STACK_VM_PRINT_CODE0(SUB_I64); |
111 | STACK_VM_PRINT_CODE0(MUL_I64); |
112 | STACK_VM_PRINT_CODE0(MOD_I64); |
113 | STACK_VM_PRINT_CODE0(DIV_I64); |
114 | STACK_VM_PRINT_CODE0(EQ_I64); |
115 | STACK_VM_PRINT_CODE0(LT_I64); |
116 | STACK_VM_PRINT_CODE0(LE_I64); |
117 | // floats |
118 | STACK_VM_PRINT_CODE0(ADD_F64); |
119 | STACK_VM_PRINT_CODE0(SUB_F64); |
120 | STACK_VM_PRINT_CODE0(MUL_F64); |
121 | STACK_VM_PRINT_CODE0(DIV_F64); |
122 | STACK_VM_PRINT_CODE0(EQ_F64); |
123 | STACK_VM_PRINT_CODE0(LT_F64); |
124 | STACK_VM_PRINT_CODE0(LE_F64); |
125 | // handle. |
126 | STACK_VM_PRINT_CODE0(EQ_HANDLE); |
127 | // addressing load |
128 | STACK_VM_PRINT_CODE1(ARRAY_LOAD_UINT32); |
129 | STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT32); |
130 | STACK_VM_PRINT_CODE1(ARRAY_LOAD_INT64); |
131 | STACK_VM_PRINT_CODE1(ARRAY_LOAD_FP64); |
132 | STACK_VM_PRINT_CODE1(ARRAY_LOAD_HANDLE); |
133 | STACK_VM_PRINT_CODE1(ARRAY_LOAD_TVMVALUE); |
134 | STACK_VM_PRINT_CODE1(ARRAY_STORE_UINT32); |
135 | STACK_VM_PRINT_CODE1(ARRAY_STORE_INT32); |
136 | STACK_VM_PRINT_CODE1(ARRAY_STORE_INT64); |
137 | STACK_VM_PRINT_CODE1(ARRAY_STORE_FP64); |
138 | STACK_VM_PRINT_CODE1(ARRAY_STORE_HANDLE); |
139 | STACK_VM_PRINT_CODE1(ARRAY_STORE_TVMVALUE); |
140 | STACK_VM_PRINT_CODE0(NOT); |
141 | STACK_VM_PRINT_CODE0(ADDR_ADD); |
142 | // stack ops |
143 | STACK_VM_PRINT_CODE1(PUSH_I64); |
144 | STACK_VM_PRINT_CODE1(PUSH_VALUE); |
145 | STACK_VM_PRINT_CODE0(POP); |
146 | STACK_VM_PRINT_CODE0(SELECT); |
147 | STACK_VM_PRINT_HEAP_ACCESS(STORE_HEAP); |
148 | STACK_VM_PRINT_HEAP_ACCESS(LOAD_HEAP); |
149 | STACK_VM_PRINT_CODE1(ASSERT); |
150 | STACK_VM_PRINT_JUMP(RJUMP_IF_TRUE); |
151 | STACK_VM_PRINT_JUMP(RJUMP_IF_FALSE); |
152 | STACK_VM_PRINT_JUMP(RJUMP); |
153 | STACK_VM_PRINT_CODE1(ASSERT_SP); |
154 | // Intrinsics |
155 | STACK_VM_PRINT_CODE2(TVM_STRUCT_GET); |
156 | STACK_VM_PRINT_CODE2(TVM_STRUCT_SET); |
157 | // Allocate data by 8 bytes. |
158 | STACK_VM_PRINT_CODE1(TVM_STACK_ALLOCA_BY_8BYTE); |
159 | STACK_VM_PRINT_CODE0(TVM_DEVICE_ALLOCA); |
160 | STACK_VM_PRINT_CODE0(TVM_DEVICE_FREE); |
161 | STACK_VM_PRINT_CODE0(TVM_THROW_LAST_ERROR); |
162 | // packed function. |
163 | case CALL_PACKED_LOWERED: { |
164 | int call_fid = code[pc + 1].v_int; |
165 | int begin = code[pc + 2].v_int; |
166 | int end = code[pc + 3].v_int; |
167 | os << "[" << pc << "]\tCALL_PACKED_FUNC " |
168 | << " fid=" << call_fid << " begin=" << begin << " end=" << end; |
169 | os << '\n'; |
170 | for (int i = 0; i < 3; ++i) { |
171 | os << "[" << pc + 1 + i << "]" << std::endl; |
172 | } |
173 | return pc + 4; |
174 | } |
175 | } |
176 | LOG(FATAL) << "unknown op code " << code[pc].op_code; |
177 | } |
178 | |
179 | std::ostream& operator<<(std::ostream& os, const StackVM& vm) { // NOLINT(*) |
180 | int64_t pc = 0; |
181 | const int64_t code_size = static_cast<int64_t>(vm.code.size()); |
182 | os << "Program dump: code-size=" << code_size << '\n' << "----------begin-----------------\n" ; |
183 | while (pc < code_size) { |
184 | pc = vm.PrintCode(os, pc); |
185 | } |
186 | os << "----------end--------------------\n" ; |
187 | return os; |
188 | } |
189 | |
190 | void StackVM::Run(const runtime::TVMArgs& args, runtime::ModuleNode* mod_ctx) const { |
191 | StackVM::State* s = StackVM::ThreadLocalState(); |
192 | if (s->heap.size() < heap_size) { |
193 | s->heap.resize(heap_size); |
194 | } |
195 | s->sp = 0; |
196 | s->pc = 0; |
197 | s->mod_ctx = mod_ctx; |
198 | s->heap[0].v_handle = (void*)args.values; // NOLINT(*) |
199 | s->heap[1].v_handle = (void*)args.type_codes; // NOLINT(*) |
200 | s->heap[2].v_int64 = args.num_args; |
201 | this->Run(s); |
202 | } |
203 | |
204 | void StackVM::InitCache() { |
205 | extern_func_cache_.clear(); |
206 | extern_func_cache_.resize(extern_func_name.size(), PackedFunc(nullptr)); |
207 | } |
208 | |
209 | void StackVM::Save(dmlc::Stream* strm) const { |
210 | // to be endian invariant. |
211 | std::vector<int32_t> code_copy(code.size()); |
212 | std::transform(code.begin(), code.end(), code_copy.begin(), [](Code c) { return c.v_int; }); |
213 | strm->Write(code_copy); |
214 | strm->Write(str_data); |
215 | strm->Write(extern_func_name); |
216 | strm->Write(heap_id_name); |
217 | strm->Write(heap_size); |
218 | strm->Write(stack_size); |
219 | } |
220 | |
221 | bool StackVM::Load(dmlc::Stream* strm) { |
222 | // to be endian invariant. |
223 | std::vector<int32_t> code_copy; |
224 | if (!strm->Read(&code_copy)) return false; |
225 | code.resize(code_copy.size()); |
226 | std::transform(code_copy.begin(), code_copy.end(), code.begin(), [](int v) { |
227 | Code code; |
228 | code.v_int = v; |
229 | return code; |
230 | }); |
231 | if (!strm->Read(&str_data)) return false; |
232 | if (!strm->Read(&extern_func_name)) return false; |
233 | if (!strm->Read(&heap_id_name)) return false; |
234 | if (!strm->Read(&heap_size)) return false; |
235 | if (!strm->Read(&stack_size)) return false; |
236 | this->InitCache(); |
237 | return true; |
238 | } |
239 | |
240 | void StackVM::Run(State* s) const { |
241 | int64_t sp = s->sp; |
242 | int64_t pc = s->pc; |
243 | int64_t alloca_sp = s->sp; |
244 | std::vector<TVMValue>& stack = s->stack; |
245 | std::vector<TVMValue>& heap = s->heap; |
246 | if (stack.size() < stack_size) { |
247 | stack.resize(stack_size); |
248 | } |
249 | int64_t stack_cap = static_cast<int64_t>(stack_size - 4); |
250 | if (heap.size() < heap_size) { |
251 | heap.resize(heap_size); |
252 | } |
253 | const int64_t code_size = static_cast<int64_t>(code.size()); |
254 | while (pc < code_size) { |
255 | switch (code[pc].op_code) { |
256 | case ADD_I64: |
257 | STACK_VM_BINOP(+, v_int64); |
258 | break; |
259 | case SUB_I64: |
260 | STACK_VM_BINOP(-, v_int64); |
261 | break; |
262 | case MUL_I64: |
263 | STACK_VM_BINOP(*, v_int64); |
264 | break; |
265 | case DIV_I64: |
266 | STACK_VM_BINOP(/, v_int64); |
267 | break; |
268 | case MOD_I64: |
269 | STACK_VM_BINOP(%, v_int64); |
270 | break; |
271 | case EQ_I64: |
272 | STACK_VM_CMPOP(==, v_int64); |
273 | break; |
274 | case LT_I64: |
275 | STACK_VM_CMPOP(<, v_int64); |
276 | break; |
277 | case LE_I64: |
278 | STACK_VM_CMPOP(<=, v_int64); |
279 | break; |
280 | case ADD_F64: |
281 | STACK_VM_BINOP(+, v_float64); |
282 | break; |
283 | case SUB_F64: |
284 | STACK_VM_BINOP(-, v_float64); |
285 | break; |
286 | case MUL_F64: |
287 | STACK_VM_BINOP(*, v_float64); |
288 | break; |
289 | case DIV_F64: |
290 | STACK_VM_BINOP(/, v_float64); |
291 | break; |
292 | case EQ_F64: |
293 | STACK_VM_CMPOP(==, v_float64); |
294 | break; |
295 | case LT_F64: |
296 | STACK_VM_CMPOP(<, v_float64); |
297 | break; |
298 | case LE_F64: |
299 | STACK_VM_CMPOP(<=, v_float64); |
300 | break; |
301 | case EQ_HANDLE: |
302 | STACK_VM_CMPOP(==, v_handle); |
303 | break; |
304 | // addressing |
305 | case ARRAY_LOAD_UINT32: |
306 | STACK_VM_LOAD(.v_int64, int64_t, uint32_t); |
307 | break; |
308 | case ARRAY_LOAD_INT32: |
309 | STACK_VM_LOAD(.v_int64, int64_t, int32_t); |
310 | break; |
311 | case ARRAY_LOAD_INT64: |
312 | STACK_VM_LOAD(.v_int64, int64_t, int64_t); |
313 | break; |
314 | case ARRAY_LOAD_FP64: |
315 | STACK_VM_LOAD(.v_float64, double, double); |
316 | break; |
317 | case ARRAY_LOAD_HANDLE: |
318 | STACK_VM_LOAD(.v_handle, void*, void*); |
319 | break; |
320 | case ARRAY_LOAD_TVMVALUE: |
321 | STACK_VM_LOAD(, TVMValue, TVMValue); |
322 | break; |
323 | // store |
324 | case ARRAY_STORE_UINT32: |
325 | STACK_VM_STORE(.v_int64, uint32_t); |
326 | break; |
327 | case ARRAY_STORE_INT32: |
328 | STACK_VM_STORE(.v_int64, int32_t); |
329 | break; |
330 | case ARRAY_STORE_INT64: |
331 | STACK_VM_STORE(.v_int64, int64_t); |
332 | break; |
333 | case ARRAY_STORE_FP64: |
334 | STACK_VM_STORE(.v_float64, double); |
335 | break; |
336 | case ARRAY_STORE_HANDLE: |
337 | STACK_VM_STORE(.v_handle, void*); |
338 | break; |
339 | case ARRAY_STORE_TVMVALUE: |
340 | STACK_VM_STORE(, TVMValue); |
341 | break; |
342 | // add |
343 | case ADDR_ADD: { |
344 | stack[sp - 1].v_handle = (char*)(stack[sp - 1].v_handle) + stack[sp].v_int64; // NOLINT(*) |
345 | sp = sp - 1; |
346 | pc = pc + 1; |
347 | break; |
348 | } |
349 | case NOT: { |
350 | stack[sp].v_int64 = !stack[sp].v_int64; |
351 | pc += 1; |
352 | break; |
353 | } |
354 | case PUSH_I64: { |
355 | stack[sp + 1].v_int64 = code[pc + 1].v_int; |
356 | sp += 1; |
357 | pc += 2; |
358 | break; |
359 | } |
360 | case PUSH_VALUE: { |
361 | int relpos = code[pc + 1].v_int; |
362 | ICHECK_LE(relpos, 0); |
363 | stack[sp + 1] = stack[sp + relpos]; |
364 | sp += 1; |
365 | pc += 2; |
366 | break; |
367 | } |
368 | case POP: { |
369 | sp -= 1; |
370 | pc += 1; |
371 | break; |
372 | } |
373 | case SELECT: { |
374 | stack[sp - 2] = (stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]); |
375 | sp -= 2; |
376 | pc += 1; |
377 | break; |
378 | } |
379 | case LOAD_HEAP: { |
380 | stack[sp + 1] = heap[code[pc + 1].v_int]; |
381 | sp += 1; |
382 | pc += 2; |
383 | break; |
384 | } |
385 | case STORE_HEAP: { |
386 | heap[code[pc + 1].v_int] = stack[sp]; |
387 | sp -= 1; |
388 | pc += 2; |
389 | break; |
390 | } |
391 | case ASSERT: { |
392 | ICHECK(stack[sp].v_int64) << str_data[code[pc + 1].v_int]; |
393 | sp -= 1; |
394 | pc += 2; |
395 | break; |
396 | } |
397 | case RJUMP_IF_TRUE: { |
398 | if (stack[sp].v_int64) { |
399 | pc += code[pc + 1].v_int; |
400 | } else { |
401 | pc += 2; |
402 | } |
403 | break; |
404 | } |
405 | case RJUMP_IF_FALSE: { |
406 | if (!stack[sp].v_int64) { |
407 | pc += code[pc + 1].v_int; |
408 | } else { |
409 | pc += 2; |
410 | } |
411 | break; |
412 | } |
413 | case RJUMP: { |
414 | pc += code[pc + 1].v_int; |
415 | break; |
416 | } |
417 | case ASSERT_SP: { |
418 | int64_t expected = code[pc + 1].v_int; |
419 | ICHECK_EQ(sp, expected) << "sp assertion failed, expected=" << expected << " now=" << sp |
420 | << ", pc=" << pc; |
421 | pc += 2; |
422 | break; |
423 | } |
424 | case CALL_PACKED_LOWERED: { |
425 | // call packed function. |
426 | TVMValue* value_stack = static_cast<TVMValue*>(stack[sp - 1].v_handle); |
427 | int* type_stack = static_cast<int*>(stack[sp].v_handle); |
428 | int call_fid = code[pc + 1].v_int; |
429 | int begin = code[pc + 2].v_int; |
430 | int end = code[pc + 3].v_int; |
431 | int num_args = end - begin; |
432 | static_assert(sizeof(Code) == sizeof(int) && alignof(Code) == alignof(int), "asusmption" ); |
433 | runtime::TVMRetValue rv; |
434 | GetExtern(s, call_fid) |
435 | .CallPacked(runtime::TVMArgs(value_stack + begin, type_stack + begin, num_args), &rv); |
436 | sp = sp - 1; |
437 | stack[sp] = rv.value(); |
438 | pc += 4; |
439 | break; |
440 | } |
441 | // intrinsics |
442 | case TVM_STRUCT_GET: { |
443 | int index = code[pc + 1].v_int; |
444 | int kind = code[pc + 2].v_int; |
445 | DLTensor* arr = static_cast<DLTensor*>(stack[sp].v_handle); |
446 | switch (kind) { |
447 | case StackVM::kArrData: { |
448 | stack[sp].v_handle = arr[index].data; |
449 | break; |
450 | } |
451 | case StackVM::kArrShape: { |
452 | stack[sp].v_handle = arr[index].shape; |
453 | break; |
454 | } |
455 | case StackVM::kArrStrides: { |
456 | stack[sp].v_handle = arr[index].strides; |
457 | break; |
458 | } |
459 | case StackVM::kArrNDim: { |
460 | stack[sp].v_int64 = arr[index].ndim; |
461 | break; |
462 | } |
463 | case StackVM::kArrTypeCode: { |
464 | stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.code); |
465 | break; |
466 | } |
467 | case StackVM::kArrTypeBits: { |
468 | stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.bits); |
469 | break; |
470 | } |
471 | case StackVM::kArrTypeLanes: { |
472 | stack[sp].v_int64 = static_cast<int64_t>(arr[index].dtype.lanes); |
473 | break; |
474 | } |
475 | case StackVM::kArrByteOffset: { |
476 | stack[sp].v_int64 = static_cast<int64_t>(arr[index].byte_offset); |
477 | break; |
478 | } |
479 | case StackVM::kArrDeviceId: { |
480 | stack[sp].v_int64 = arr[index].device.device_id; |
481 | break; |
482 | } |
483 | case StackVM::kArrDeviceType: { |
484 | stack[sp].v_int64 = static_cast<int64_t>(arr[index].device.device_type); |
485 | break; |
486 | } |
487 | case StackVM::kArrAddr: { |
488 | stack[sp].v_handle = arr + index; |
489 | break; |
490 | } |
491 | case StackVM::kTVMValueContent: { |
492 | stack[sp] = static_cast<TVMValue*>(stack[sp].v_handle)[index]; |
493 | break; |
494 | } |
495 | default: |
496 | LOG(FATAL) << "unhandled get " << kind; |
497 | } |
498 | pc = pc + 3; |
499 | break; |
500 | } |
501 | case TVM_STRUCT_SET: { |
502 | int index = code[pc + 1].v_int; |
503 | int kind = code[pc + 2].v_int; |
504 | DLTensor* arr = static_cast<DLTensor*>(stack[sp - 1].v_handle); |
505 | switch (kind) { |
506 | case StackVM::kArrData: { |
507 | arr[index].data = stack[sp].v_handle; |
508 | break; |
509 | } |
510 | case StackVM::kArrShape: { |
511 | arr[index].shape = static_cast<int64_t*>(stack[sp].v_handle); |
512 | break; |
513 | } |
514 | case StackVM::kArrStrides: { |
515 | arr[index].strides = static_cast<int64_t*>(stack[sp].v_handle); |
516 | break; |
517 | } |
518 | case StackVM::kArrNDim: { |
519 | arr[index].ndim = static_cast<int>(stack[sp].v_int64); |
520 | break; |
521 | } |
522 | case StackVM::kArrTypeCode: { |
523 | arr[index].dtype.code = static_cast<uint8_t>(stack[sp].v_int64); |
524 | break; |
525 | } |
526 | case StackVM::kArrTypeBits: { |
527 | arr[index].dtype.bits = static_cast<uint8_t>(stack[sp].v_int64); |
528 | break; |
529 | } |
530 | case StackVM::kArrTypeLanes: { |
531 | arr[index].dtype.lanes = static_cast<uint16_t>(stack[sp].v_int64); |
532 | break; |
533 | } |
534 | case StackVM::kArrByteOffset: { |
535 | arr[index].byte_offset = static_cast<uint64_t>(stack[sp].v_int64); |
536 | break; |
537 | } |
538 | case StackVM::kArrDeviceId: { |
539 | arr[index].device.device_id = static_cast<int>(stack[sp].v_int64); |
540 | break; |
541 | } |
542 | case StackVM::kArrDeviceType: { |
543 | arr[index].device.device_type = static_cast<DLDeviceType>(stack[sp].v_int64); |
544 | break; |
545 | } |
546 | case StackVM::kTVMValueContent: { |
547 | static_cast<TVMValue*>(stack[sp - 1].v_handle)[index] = stack[sp]; |
548 | break; |
549 | } |
550 | default: |
551 | LOG(FATAL) << "unhandled tvm_struct_set " << kind; |
552 | } |
553 | sp -= 2; |
554 | pc += 3; |
555 | break; |
556 | } |
557 | // alloca |
558 | case TVM_STACK_ALLOCA_BY_8BYTE: { |
559 | static_assert(sizeof(TVMValue) == 8, "invariance" ); |
560 | int num = code[pc + 1].v_int; |
561 | void* addr = &stack[sp] + 1; |
562 | sp = sp + num + 1; |
563 | alloca_sp = sp - 1; |
564 | stack[sp].v_handle = addr; |
565 | pc = pc + 2; |
566 | break; |
567 | } |
568 | case TVM_DEVICE_ALLOCA: { |
569 | int device_type = static_cast<int>(stack[sp - 4].v_int64); |
570 | int device_id = static_cast<int>(stack[sp - 3].v_int64); |
571 | size_t nbytes = static_cast<size_t>(stack[sp - 2].v_int64); |
572 | int dtype_code_hint = static_cast<int>(stack[sp - 1].v_int64); |
573 | int dtype_bits_hint = static_cast<int>(stack[sp].v_int64); |
574 | void* ptr = TVMBackendAllocWorkspace(device_type, device_id, nbytes, dtype_code_hint, |
575 | dtype_bits_hint); |
576 | stack[sp - 4].v_handle = ptr; |
577 | sp = sp - 4; |
578 | pc = pc + 1; |
579 | break; |
580 | } |
581 | case TVM_DEVICE_FREE: { |
582 | int device_type = static_cast<int>(stack[sp - 2].v_int64); |
583 | int device_id = static_cast<int>(stack[sp - 1].v_int64); |
584 | void* ptr = stack[sp].v_handle; |
585 | int ret = TVMBackendFreeWorkspace(device_type, device_id, ptr); |
586 | stack[sp - 2].v_int64 = ret; |
587 | sp = sp - 2; |
588 | pc = pc + 1; |
589 | break; |
590 | } |
591 | case TVM_THROW_LAST_ERROR: { |
592 | LOG(FATAL) << TVMGetLastError(); |
593 | break; |
594 | } |
595 | } |
596 | ICHECK_GE(sp, alloca_sp) << "touch allocated space" ; |
597 | ICHECK_LT(sp, stack_cap) << "Stack overflow" ; |
598 | } |
599 | } |
600 | |
601 | const PackedFunc& StackVM::GetExtern(State* s, int fid) const { |
602 | ICHECK_LT(static_cast<size_t>(fid), extern_func_cache_.size()); |
603 | // allow race write in this, since write is idempotent |
604 | PackedFunc& f = extern_func_cache_[fid]; |
605 | if (f == nullptr) { |
606 | ICHECK(s->mod_ctx != nullptr) << "No local context is set in stackvm" ; |
607 | const PackedFunc* pf = s->mod_ctx->GetFuncFromEnv(extern_func_name[fid]); |
608 | ICHECK(pf != nullptr); |
609 | f = *pf; |
610 | } |
611 | return f; |
612 | } |
613 | |
614 | } // namespace runtime |
615 | } // namespace tvm |
616 | |