1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file stackvm.h |
22 | * \brief A simple stack-based virtual machine. |
23 | * |
24 | * This can be used to interepret host side code |
25 | * to setup calls into device functions |
26 | * when only Runtime compilation for device is available(via NVRTC or OpenCL). |
27 | */ |
28 | #ifndef TVM_RUNTIME_STACKVM_STACKVM_H_ |
29 | #define TVM_RUNTIME_STACKVM_STACKVM_H_ |
30 | |
31 | #include <tvm/runtime/c_runtime_api.h> |
32 | #include <tvm/runtime/module.h> |
33 | #include <tvm/runtime/packed_func.h> |
34 | |
35 | #include <string> |
36 | #include <vector> |
37 | |
38 | namespace tvm { |
39 | namespace runtime { |
40 | |
41 | using runtime::operator<<; |
42 | |
43 | /*! |
44 | * \brief A simple stack-based virtual machine program. |
45 | */ |
46 | class StackVM { |
47 | public: |
48 | /*! |
49 | * \brief Invoke the StackVM program. |
50 | * \param args The arguments to the StackVM. |
51 | * \param mod_ctx The module context used in running. |
52 | */ |
53 | void Run(const TVMArgs& args, runtime::ModuleNode* mod_ctx) const; |
54 | /*! |
55 | * \brief The opcode of stack vm |
56 | * \note Notation |
57 | * - sp Stack pointer |
58 | * - pc Program pointer |
59 | */ |
60 | enum OpCode { |
61 | // integer ops |
62 | ADD_I64, |
63 | SUB_I64, |
64 | MUL_I64, |
65 | DIV_I64, |
66 | MOD_I64, |
67 | EQ_I64, |
68 | LT_I64, |
69 | LE_I64, |
70 | // floating ops |
71 | ADD_F64, |
72 | SUB_F64, |
73 | MUL_F64, |
74 | DIV_F64, |
75 | EQ_F64, |
76 | LT_F64, |
77 | LE_F64, |
78 | // Pointer comparison |
79 | EQ_HANDLE, |
80 | /*! |
81 | * \brief Routine to load data from address with const offset. |
82 | * \code |
83 | * stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int]; |
84 | * pc = pc + 2; |
85 | * \endcode |
86 | */ |
87 | ARRAY_LOAD_UINT32, |
88 | ARRAY_LOAD_INT32, |
89 | ARRAY_LOAD_INT64, |
90 | ARRAY_LOAD_FP64, |
91 | ARRAY_LOAD_HANDLE, |
92 | ARRAY_LOAD_TVMVALUE, |
93 | /*! |
94 | * \brief Routine to store data from constant offset. |
95 | * \code |
96 | * ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp]; |
97 | * pc = pc + 2; |
98 | * sp = sp - 2; |
99 | * \endcode |
100 | */ |
101 | ARRAY_STORE_UINT32, |
102 | ARRAY_STORE_INT32, |
103 | ARRAY_STORE_INT64, |
104 | ARRAY_STORE_FP64, |
105 | ARRAY_STORE_HANDLE, |
106 | ARRAY_STORE_TVMVALUE, |
107 | // logical ops |
108 | NOT, |
109 | /*! |
110 | * \brief Add address by an offset. |
111 | * \code |
112 | * stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64); |
113 | * sp = sp - 1; |
114 | * \endcode |
115 | */ |
116 | ADDR_ADD, |
117 | /*! |
118 | * \brief push integer fetched from next pc position into stack |
119 | * \code |
120 | * stack[sp + 1].v_int64 = code[pc + 1].v_int; |
121 | * pc = pc + 2; |
122 | * sp = sp + 1; |
123 | * \endcode |
124 | */ |
125 | PUSH_I64, |
126 | /*! |
127 | * \brief push a value given relative index on the stack |
128 | * \code |
129 | * stack[sp + 1] = stack[sp + code[pc + 1].v_int]; |
130 | * pc = pc + 2; |
131 | * sp = sp + 1; |
132 | * \endcode |
133 | */ |
134 | PUSH_VALUE, |
135 | /*! |
136 | * \brief Load data from heap to top of stack |
137 | * \code |
138 | * stack[sp + 1] = heap[code[pc + 1].v_int]; |
139 | * pc = pc + 2; |
140 | * sp = sp + 1; |
141 | * \endcode |
142 | */ |
143 | LOAD_HEAP, |
144 | /*! |
145 | * \brief Store data to heap |
146 | * \code |
147 | * heap[code[pc + 1].v_int] = stack[sp]; |
148 | * sp = sp - 1; |
149 | * \endcode |
150 | */ |
151 | STORE_HEAP, |
152 | /*! \brief pop value from top of the stack */ |
153 | POP, |
154 | /*! |
155 | * \brief select based on operands. |
156 | * \code |
157 | * stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1] |
158 | * sp = sp - 2; |
159 | * \endcode |
160 | */ |
161 | SELECT, |
162 | /*! |
163 | * \brief Assert condition is true. |
164 | * \code |
165 | * ICHECK(stack[sp]) << str_data[code[pc + 1].v_int]; |
166 | * sp = sp - 1; |
167 | * \endcode |
168 | */ |
169 | ASSERT, |
170 | /*! |
171 | * \brief Relative Jump if the condition is true, |
172 | * Does not change the stack status. |
173 | * \code |
174 | * if (stack[sp]) { |
175 | * pc += code[pc + 1].v_int |
176 | * } else { |
177 | * pc = pc + 2; |
178 | * } |
179 | * \endcode |
180 | */ |
181 | RJUMP_IF_TRUE, |
182 | /*! |
183 | * \brief Relative Jump if the condition is true, |
184 | * Does not change the stack status. |
185 | * \code |
186 | * if (stack[sp]) { |
187 | * pc += code[pc + 1].v_int |
188 | * } else { |
189 | * pc = pc + 2; |
190 | * } |
191 | * \endcode |
192 | */ |
193 | RJUMP_IF_FALSE, |
194 | /*! |
195 | * \brief Relative jump to a location. |
196 | * \code |
197 | * pc += code[pc + 1].v_int; |
198 | * \endcode |
199 | */ |
200 | RJUMP, |
201 | /*! |
202 | * \brief debug instruction. |
203 | * \code |
204 | * ICHECK_EQ(sp, code[pc + 1]).v_int; |
205 | * pc += 2; |
206 | * \code |
207 | */ |
208 | ASSERT_SP, |
209 | /*! |
210 | * \brief call an extern packed function |
211 | * \code |
212 | * value_stack = stack[sp - 1].v_handle; |
213 | * type_stack = stack[sp - 0].v_handle; |
214 | * call_fid = code[pc + 1].v_int; |
215 | * begin = code[pc + 2].v_int; |
216 | * end = code[pc + 3].v_int; |
217 | * num_args = end - begin - 1; |
218 | * f = extern_func[call_fid]; |
219 | * stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args); |
220 | * sp = sp - 1; |
221 | * // The type codes are hidden in the code space. |
222 | * pc = pc + 4 |
223 | * \endcode |
224 | */ |
225 | CALL_PACKED_LOWERED, |
226 | // Allocate things on stack |
227 | /*! |
228 | * \brief allocate data from stack. |
229 | * \code |
230 | * num = code[pc + 1].v_int; |
231 | * void* addr = &stack[sp]; |
232 | * sp = sp + num; |
233 | * stack[sp].v_handle = addr; |
234 | * pc = pc + 1; |
235 | * \endcode |
236 | */ |
237 | TVM_STACK_ALLOCA_BY_8BYTE, |
238 | /*! |
239 | * \brief allocate data from device. |
240 | * \code |
241 | * device_type = stack[sp - 2].v_int64; |
242 | * device_id = stack[sp - 1].v_int64; |
243 | * nbytes = stack[sp].v_int64; |
244 | * stack[sp - 2].v_handle = device_alloca(device_type, device_id, nbytes); |
245 | * sp = sp - 2; |
246 | * pc = pc + 1; |
247 | * \endcode |
248 | */ |
249 | TVM_DEVICE_ALLOCA, |
250 | /*! |
251 | * \brief free data into device. |
252 | * \code |
253 | * device_type = stack[sp - 2].v_int64; |
254 | * device_id = stack[sp - 1].v_int64; |
255 | * ptr = stack[sp].v_handle; |
256 | * stack[sp - 2].v_int64 = device_free(device_type, device_id, ptr); |
257 | * sp = sp - 2; |
258 | * pc = pc + 1; |
259 | * \endcode |
260 | */ |
261 | TVM_DEVICE_FREE, |
262 | /*! |
263 | * \brief throw last error |
264 | */ |
265 | TVM_THROW_LAST_ERROR, |
266 | /*! |
267 | * \brief get data from structure. |
268 | * \code |
269 | * index = code[pc + 1].v_int; |
270 | * field = code[pc + 2].v_int; |
271 | * stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field; |
272 | * pc = pc + 3 |
273 | * \endcode |
274 | */ |
275 | TVM_STRUCT_GET, |
276 | /*! |
277 | * \brief set data into structure. |
278 | * \code |
279 | * index = code[pc + 1].v_int; |
280 | * field = code[pc + 2].v_int; |
281 | * ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp]; |
282 | * pc = pc + 3 |
283 | * sp = sp - 1 |
284 | * \endcode |
285 | */ |
286 | TVM_STRUCT_SET |
287 | }; |
288 | /*! \brief The kind of structure field info */ |
289 | enum StructFieldKind : int { |
290 | // array head address |
291 | kArrAddr, |
292 | kArrData, |
293 | kArrShape, |
294 | kArrStrides, |
295 | kArrNDim, |
296 | kArrTypeCode, |
297 | kArrTypeBits, |
298 | kArrTypeLanes, |
299 | kArrByteOffset, |
300 | kArrDeviceId, |
301 | kArrDeviceType, |
302 | kArrKindBound_, |
303 | // TVMValue field |
304 | kTVMValueContent, |
305 | kTVMValueKindBound_ |
306 | }; |
307 | /*! \brief The code structure */ |
308 | union Code { |
309 | OpCode op_code; |
310 | int v_int; |
311 | }; |
312 | /*! \brief The state object of StackVM */ |
313 | struct State { |
314 | /*! \brief The execution stack */ |
315 | std::vector<TVMValue> stack; |
316 | /*! \brief The global heap space */ |
317 | std::vector<TVMValue> heap; |
318 | /*! \brief stack pointer */ |
319 | int64_t sp{0}; |
320 | /*! \brief program counter */ |
321 | int64_t pc{0}; |
322 | /*! \brief The current module context of stackvm */ |
323 | runtime::ModuleNode* mod_ctx{nullptr}; |
324 | }; |
325 | /*! \brief Initialize local cache*/ |
326 | void InitCache(); |
327 | /*! |
328 | * \brief Save stackvm program to an output stream |
329 | * \param strm The output stream |
330 | */ |
331 | void Save(dmlc::Stream* strm) const; |
332 | /*! |
333 | * \brief Load stackvm program from output stream |
334 | * \param strm The output stream |
335 | */ |
336 | bool Load(dmlc::Stream* strm); |
337 | /*! |
338 | * \brief Print instruction at location pc |
339 | * \param os The ostream |
340 | * \param pc The pc |
341 | * \return the pc to next instruction. |
342 | */ |
343 | int64_t PrintCode(std::ostream& os, int64_t pc) const; // NOLINT(*) |
344 | /*! \brief Get thread local state of the stack VM */ |
345 | static State* ThreadLocalState(); |
346 | // The code below are programs |
347 | /*! \brief The instructions */ |
348 | std::vector<Code> code; |
349 | /*! \brief constant error messages */ |
350 | std::vector<std::string> str_data; |
351 | /*! \brief Extern functions */ |
352 | std::vector<std::string> extern_func_name; |
353 | /*! \brief name of each heap id */ |
354 | std::vector<std::string> heap_id_name; |
355 | /*! \brief The memory size needed */ |
356 | size_t heap_size{0}; |
357 | /*! \brief The stack size required */ |
358 | size_t stack_size{1024}; |
359 | /*! |
360 | * \brief Convert I64 opcode to F64 Ones |
361 | * \param code The op code. |
362 | * \return the F64 op code. |
363 | */ |
364 | static OpCode CodeI64ToF64(OpCode code) { |
365 | switch (code) { |
366 | case ADD_I64: |
367 | return ADD_F64; |
368 | case SUB_I64: |
369 | return SUB_F64; |
370 | case MUL_I64: |
371 | return MUL_F64; |
372 | case DIV_I64: |
373 | return DIV_F64; |
374 | case EQ_I64: |
375 | return EQ_F64; |
376 | case LT_I64: |
377 | return LT_F64; |
378 | case LE_I64: |
379 | return LE_F64; |
380 | case MOD_I64: |
381 | LOG(FATAL) << "cannot handle mod for float" ; |
382 | default: |
383 | LOG(FATAL) << "cannot handle op " << code; |
384 | } |
385 | } |
386 | /*! |
387 | * \brief Get load opcode for type t |
388 | * \param t the type code. |
389 | * \return The load opcode |
390 | */ |
391 | static OpCode GetLoad(DLDataType t) { |
392 | ICHECK_EQ(t.lanes, 1U); |
393 | if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE; |
394 | if (t.code == kDLInt) { |
395 | switch (t.bits) { |
396 | case 32: |
397 | return ARRAY_LOAD_INT32; |
398 | case 64: |
399 | return ARRAY_LOAD_INT64; |
400 | } |
401 | } else if (t.code == kDLUInt) { |
402 | switch (t.bits) { |
403 | case 32: |
404 | return ARRAY_LOAD_UINT32; |
405 | } |
406 | } else if (t.code == kDLFloat) { |
407 | switch (t.bits) { |
408 | case 64: |
409 | return ARRAY_LOAD_FP64; |
410 | } |
411 | } |
412 | LOG(FATAL) << "Cannot load type " << t; |
413 | } |
414 | /*! |
415 | * \brief Get store opcode for type t |
416 | * \param t the type code. |
417 | * \return The load opcode |
418 | */ |
419 | static OpCode GetStore(DLDataType t) { |
420 | ICHECK_EQ(t.lanes, 1U); |
421 | if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE; |
422 | if (t.code == kDLInt) { |
423 | switch (t.bits) { |
424 | case 32: |
425 | return ARRAY_STORE_INT32; |
426 | case 64: |
427 | return ARRAY_STORE_INT64; |
428 | } |
429 | } else if (t.code == kDLUInt) { |
430 | switch (t.bits) { |
431 | case 32: |
432 | return ARRAY_STORE_UINT32; |
433 | } |
434 | } else if (t.code == kDLFloat) { |
435 | switch (t.bits) { |
436 | case 64: |
437 | return ARRAY_STORE_FP64; |
438 | } |
439 | } |
440 | LOG(FATAL) << "Cannot store type " << t; |
441 | } |
442 | friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*) |
443 | |
444 | private: |
445 | // execute the stack vm with given state |
446 | void Run(State* state) const; |
447 | // get extern function. |
448 | const PackedFunc& GetExtern(State* s, int fid) const; |
449 | // cached extern function |
450 | mutable std::vector<PackedFunc> extern_func_cache_; |
451 | }; |
452 | |
453 | } // namespace runtime |
454 | } // namespace tvm |
455 | |
456 | namespace dmlc { |
457 | DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::StackVM, true); |
458 | } |
459 | #endif // TVM_RUNTIME_STACKVM_STACKVM_H_ |
460 | |