1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file stackvm.h
22 * \brief A simple stack-based virtual machine.
23 *
24 * This can be used to interepret host side code
25 * to setup calls into device functions
26 * when only Runtime compilation for device is available(via NVRTC or OpenCL).
27 */
28#ifndef TVM_RUNTIME_STACKVM_STACKVM_H_
29#define TVM_RUNTIME_STACKVM_STACKVM_H_
30
31#include <tvm/runtime/c_runtime_api.h>
32#include <tvm/runtime/module.h>
33#include <tvm/runtime/packed_func.h>
34
35#include <string>
36#include <vector>
37
38namespace tvm {
39namespace runtime {
40
41using runtime::operator<<;
42
43/*!
44 * \brief A simple stack-based virtual machine program.
45 */
46class StackVM {
47 public:
48 /*!
49 * \brief Invoke the StackVM program.
50 * \param args The arguments to the StackVM.
51 * \param mod_ctx The module context used in running.
52 */
53 void Run(const TVMArgs& args, runtime::ModuleNode* mod_ctx) const;
54 /*!
55 * \brief The opcode of stack vm
56 * \note Notation
57 * - sp Stack pointer
58 * - pc Program pointer
59 */
60 enum OpCode {
61 // integer ops
62 ADD_I64,
63 SUB_I64,
64 MUL_I64,
65 DIV_I64,
66 MOD_I64,
67 EQ_I64,
68 LT_I64,
69 LE_I64,
70 // floating ops
71 ADD_F64,
72 SUB_F64,
73 MUL_F64,
74 DIV_F64,
75 EQ_F64,
76 LT_F64,
77 LE_F64,
78 // Pointer comparison
79 EQ_HANDLE,
80 /*!
81 * \brief Routine to load data from address with const offset.
82 * \code
83 * stack[sp].v_int64 = ((DType*)stack[sp].v_handle)[code[pc + 1].v_int];
84 * pc = pc + 2;
85 * \endcode
86 */
87 ARRAY_LOAD_UINT32,
88 ARRAY_LOAD_INT32,
89 ARRAY_LOAD_INT64,
90 ARRAY_LOAD_FP64,
91 ARRAY_LOAD_HANDLE,
92 ARRAY_LOAD_TVMVALUE,
93 /*!
94 * \brief Routine to store data from constant offset.
95 * \code
96 * ((DType*)stack[sp - 1].v_handle)[code[pc + 1].v_int] = stack[sp];
97 * pc = pc + 2;
98 * sp = sp - 2;
99 * \endcode
100 */
101 ARRAY_STORE_UINT32,
102 ARRAY_STORE_INT32,
103 ARRAY_STORE_INT64,
104 ARRAY_STORE_FP64,
105 ARRAY_STORE_HANDLE,
106 ARRAY_STORE_TVMVALUE,
107 // logical ops
108 NOT,
109 /*!
110 * \brief Add address by an offset.
111 * \code
112 * stack[sp - 1].v_handle = ((char*)stack[sp - 1].v_handle + stack[sp].v_int64);
113 * sp = sp - 1;
114 * \endcode
115 */
116 ADDR_ADD,
117 /*!
118 * \brief push integer fetched from next pc position into stack
119 * \code
120 * stack[sp + 1].v_int64 = code[pc + 1].v_int;
121 * pc = pc + 2;
122 * sp = sp + 1;
123 * \endcode
124 */
125 PUSH_I64,
126 /*!
127 * \brief push a value given relative index on the stack
128 * \code
129 * stack[sp + 1] = stack[sp + code[pc + 1].v_int];
130 * pc = pc + 2;
131 * sp = sp + 1;
132 * \endcode
133 */
134 PUSH_VALUE,
135 /*!
136 * \brief Load data from heap to top of stack
137 * \code
138 * stack[sp + 1] = heap[code[pc + 1].v_int];
139 * pc = pc + 2;
140 * sp = sp + 1;
141 * \endcode
142 */
143 LOAD_HEAP,
144 /*!
145 * \brief Store data to heap
146 * \code
147 * heap[code[pc + 1].v_int] = stack[sp];
148 * sp = sp - 1;
149 * \endcode
150 */
151 STORE_HEAP,
152 /*! \brief pop value from top of the stack */
153 POP,
154 /*!
155 * \brief select based on operands.
156 * \code
157 * stack[sp - 2] = stack[sp].v_int64 ? stack[sp - 2] : stack[sp - 1]
158 * sp = sp - 2;
159 * \endcode
160 */
161 SELECT,
162 /*!
163 * \brief Assert condition is true.
164 * \code
165 * ICHECK(stack[sp]) << str_data[code[pc + 1].v_int];
166 * sp = sp - 1;
167 * \endcode
168 */
169 ASSERT,
170 /*!
171 * \brief Relative Jump if the condition is true,
172 * Does not change the stack status.
173 * \code
174 * if (stack[sp]) {
175 * pc += code[pc + 1].v_int
176 * } else {
177 * pc = pc + 2;
178 * }
179 * \endcode
180 */
181 RJUMP_IF_TRUE,
182 /*!
183 * \brief Relative Jump if the condition is true,
184 * Does not change the stack status.
185 * \code
186 * if (stack[sp]) {
187 * pc += code[pc + 1].v_int
188 * } else {
189 * pc = pc + 2;
190 * }
191 * \endcode
192 */
193 RJUMP_IF_FALSE,
194 /*!
195 * \brief Relative jump to a location.
196 * \code
197 * pc += code[pc + 1].v_int;
198 * \endcode
199 */
200 RJUMP,
201 /*!
202 * \brief debug instruction.
203 * \code
204 * ICHECK_EQ(sp, code[pc + 1]).v_int;
205 * pc += 2;
206 * \code
207 */
208 ASSERT_SP,
209 /*!
210 * \brief call an extern packed function
211 * \code
212 * value_stack = stack[sp - 1].v_handle;
213 * type_stack = stack[sp - 0].v_handle;
214 * call_fid = code[pc + 1].v_int;
215 * begin = code[pc + 2].v_int;
216 * end = code[pc + 3].v_int;
217 * num_args = end - begin - 1;
218 * f = extern_func[call_fid];
219 * stack[sp - 1] = f(&value_stack[begin:end-1], type_stack[begin:end-1], num_args);
220 * sp = sp - 1;
221 * // The type codes are hidden in the code space.
222 * pc = pc + 4
223 * \endcode
224 */
225 CALL_PACKED_LOWERED,
226 // Allocate things on stack
227 /*!
228 * \brief allocate data from stack.
229 * \code
230 * num = code[pc + 1].v_int;
231 * void* addr = &stack[sp];
232 * sp = sp + num;
233 * stack[sp].v_handle = addr;
234 * pc = pc + 1;
235 * \endcode
236 */
237 TVM_STACK_ALLOCA_BY_8BYTE,
238 /*!
239 * \brief allocate data from device.
240 * \code
241 * device_type = stack[sp - 2].v_int64;
242 * device_id = stack[sp - 1].v_int64;
243 * nbytes = stack[sp].v_int64;
244 * stack[sp - 2].v_handle = device_alloca(device_type, device_id, nbytes);
245 * sp = sp - 2;
246 * pc = pc + 1;
247 * \endcode
248 */
249 TVM_DEVICE_ALLOCA,
250 /*!
251 * \brief free data into device.
252 * \code
253 * device_type = stack[sp - 2].v_int64;
254 * device_id = stack[sp - 1].v_int64;
255 * ptr = stack[sp].v_handle;
256 * stack[sp - 2].v_int64 = device_free(device_type, device_id, ptr);
257 * sp = sp - 2;
258 * pc = pc + 1;
259 * \endcode
260 */
261 TVM_DEVICE_FREE,
262 /*!
263 * \brief throw last error
264 */
265 TVM_THROW_LAST_ERROR,
266 /*!
267 * \brief get data from structure.
268 * \code
269 * index = code[pc + 1].v_int;
270 * field = code[pc + 2].v_int;
271 * stack[sp] = ((StructType*)stack[sp].v_handle)[index]->field;
272 * pc = pc + 3
273 * \endcode
274 */
275 TVM_STRUCT_GET,
276 /*!
277 * \brief set data into structure.
278 * \code
279 * index = code[pc + 1].v_int;
280 * field = code[pc + 2].v_int;
281 * ((StructType*)stack[sp - 1].v_handle)[index]->field = stack[sp];
282 * pc = pc + 3
283 * sp = sp - 1
284 * \endcode
285 */
286 TVM_STRUCT_SET
287 };
288 /*! \brief The kind of structure field info */
289 enum StructFieldKind : int {
290 // array head address
291 kArrAddr,
292 kArrData,
293 kArrShape,
294 kArrStrides,
295 kArrNDim,
296 kArrTypeCode,
297 kArrTypeBits,
298 kArrTypeLanes,
299 kArrByteOffset,
300 kArrDeviceId,
301 kArrDeviceType,
302 kArrKindBound_,
303 // TVMValue field
304 kTVMValueContent,
305 kTVMValueKindBound_
306 };
307 /*! \brief The code structure */
308 union Code {
309 OpCode op_code;
310 int v_int;
311 };
312 /*! \brief The state object of StackVM */
313 struct State {
314 /*! \brief The execution stack */
315 std::vector<TVMValue> stack;
316 /*! \brief The global heap space */
317 std::vector<TVMValue> heap;
318 /*! \brief stack pointer */
319 int64_t sp{0};
320 /*! \brief program counter */
321 int64_t pc{0};
322 /*! \brief The current module context of stackvm */
323 runtime::ModuleNode* mod_ctx{nullptr};
324 };
325 /*! \brief Initialize local cache*/
326 void InitCache();
327 /*!
328 * \brief Save stackvm program to an output stream
329 * \param strm The output stream
330 */
331 void Save(dmlc::Stream* strm) const;
332 /*!
333 * \brief Load stackvm program from output stream
334 * \param strm The output stream
335 */
336 bool Load(dmlc::Stream* strm);
337 /*!
338 * \brief Print instruction at location pc
339 * \param os The ostream
340 * \param pc The pc
341 * \return the pc to next instruction.
342 */
343 int64_t PrintCode(std::ostream& os, int64_t pc) const; // NOLINT(*)
344 /*! \brief Get thread local state of the stack VM */
345 static State* ThreadLocalState();
346 // The code below are programs
347 /*! \brief The instructions */
348 std::vector<Code> code;
349 /*! \brief constant error messages */
350 std::vector<std::string> str_data;
351 /*! \brief Extern functions */
352 std::vector<std::string> extern_func_name;
353 /*! \brief name of each heap id */
354 std::vector<std::string> heap_id_name;
355 /*! \brief The memory size needed */
356 size_t heap_size{0};
357 /*! \brief The stack size required */
358 size_t stack_size{1024};
359 /*!
360 * \brief Convert I64 opcode to F64 Ones
361 * \param code The op code.
362 * \return the F64 op code.
363 */
364 static OpCode CodeI64ToF64(OpCode code) {
365 switch (code) {
366 case ADD_I64:
367 return ADD_F64;
368 case SUB_I64:
369 return SUB_F64;
370 case MUL_I64:
371 return MUL_F64;
372 case DIV_I64:
373 return DIV_F64;
374 case EQ_I64:
375 return EQ_F64;
376 case LT_I64:
377 return LT_F64;
378 case LE_I64:
379 return LE_F64;
380 case MOD_I64:
381 LOG(FATAL) << "cannot handle mod for float";
382 default:
383 LOG(FATAL) << "cannot handle op " << code;
384 }
385 }
386 /*!
387 * \brief Get load opcode for type t
388 * \param t the type code.
389 * \return The load opcode
390 */
391 static OpCode GetLoad(DLDataType t) {
392 ICHECK_EQ(t.lanes, 1U);
393 if (t.code == kTVMOpaqueHandle) return ARRAY_LOAD_HANDLE;
394 if (t.code == kDLInt) {
395 switch (t.bits) {
396 case 32:
397 return ARRAY_LOAD_INT32;
398 case 64:
399 return ARRAY_LOAD_INT64;
400 }
401 } else if (t.code == kDLUInt) {
402 switch (t.bits) {
403 case 32:
404 return ARRAY_LOAD_UINT32;
405 }
406 } else if (t.code == kDLFloat) {
407 switch (t.bits) {
408 case 64:
409 return ARRAY_LOAD_FP64;
410 }
411 }
412 LOG(FATAL) << "Cannot load type " << t;
413 }
414 /*!
415 * \brief Get store opcode for type t
416 * \param t the type code.
417 * \return The load opcode
418 */
419 static OpCode GetStore(DLDataType t) {
420 ICHECK_EQ(t.lanes, 1U);
421 if (t.code == kTVMOpaqueHandle) return ARRAY_STORE_HANDLE;
422 if (t.code == kDLInt) {
423 switch (t.bits) {
424 case 32:
425 return ARRAY_STORE_INT32;
426 case 64:
427 return ARRAY_STORE_INT64;
428 }
429 } else if (t.code == kDLUInt) {
430 switch (t.bits) {
431 case 32:
432 return ARRAY_STORE_UINT32;
433 }
434 } else if (t.code == kDLFloat) {
435 switch (t.bits) {
436 case 64:
437 return ARRAY_STORE_FP64;
438 }
439 }
440 LOG(FATAL) << "Cannot store type " << t;
441 }
442 friend std::ostream& operator<<(std::ostream& os, const StackVM& vm); // NOLINT(*)
443
444 private:
445 // execute the stack vm with given state
446 void Run(State* state) const;
447 // get extern function.
448 const PackedFunc& GetExtern(State* s, int fid) const;
449 // cached extern function
450 mutable std::vector<PackedFunc> extern_func_cache_;
451};
452
453} // namespace runtime
454} // namespace tvm
455
456namespace dmlc {
457DMLC_DECLARE_TRAITS(has_saveload, ::tvm::runtime::StackVM, true);
458}
459#endif // TVM_RUNTIME_STACKVM_STACKVM_H_
460