1 | #pragma once |
2 | |
3 | #include <memory> |
4 | #include <functional> |
5 | #include <tuple> |
6 | |
7 | #include "taichi/inc/constants.h" |
8 | #include "taichi/util/lang_util.h" |
9 | #include "taichi/program/kernel_profiler.h" |
10 | |
11 | namespace taichi::lang { |
12 | |
13 | // A architecture-specific JIT module that initializes with an **LLVM** module |
14 | // and allows the user to call its functions |
15 | // TODO: should we generalize this to include the Metal and OpenGL backends as |
16 | // well? |
17 | |
18 | class JITModule { |
19 | public: |
20 | JITModule() { |
21 | } |
22 | |
23 | // Lookup a serial function. |
24 | // For example, a CPU function, or a serial GPU function |
25 | // This function returns a function pointer |
26 | virtual void *lookup_function(const std::string &name) = 0; |
27 | |
28 | // Unfortunately, this can't be virtual since it's a template function |
29 | template <typename... Args> |
30 | std::function<void(Args...)> get_function(const std::string &name) { |
31 | using FuncT = typename std::function<void(Args...)>; |
32 | auto ret = FuncT((function_pointer_type<FuncT>)lookup_function(name)); |
33 | TI_ASSERT(ret != nullptr); |
34 | return ret; |
35 | } |
36 | |
37 | inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers() { |
38 | return std::make_tuple(std::vector<void *>(), std::vector<int>()); |
39 | } |
40 | |
41 | template <typename... Args, typename T> |
42 | inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers( |
43 | T &t, |
44 | Args &...args) { |
45 | auto [arg_pointers, arg_sizes] = get_arg_pointers(args...); |
46 | arg_pointers.insert(arg_pointers.begin(), &t); |
47 | arg_sizes.insert(arg_sizes.begin(), sizeof(t)); |
48 | return std::make_tuple(arg_pointers, arg_sizes); |
49 | } |
50 | |
51 | // Note: **call** is for serial functions |
52 | // Note: args must pass by value |
53 | // Note: AMDGPU need to pass args by extra_arg currently |
54 | template <typename... Args> |
55 | void call(const std::string &name, Args... args) { |
56 | if (direct_dispatch()) { |
57 | get_function<Args...>(name)(args...); |
58 | } else { |
59 | auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...); |
60 | call(name, arg_pointers, arg_sizes); |
61 | } |
62 | } |
63 | |
64 | virtual void call(const std::string &name, |
65 | const std::vector<void *> &arg_pointers, |
66 | const std::vector<int> &arg_sizes) { |
67 | TI_NOT_IMPLEMENTED |
68 | } |
69 | |
70 | // Note: **launch** is for parallel (GPU)_kernels |
71 | // Note: args must pass by value |
72 | template <typename... Args> |
73 | void launch(const std::string &name, |
74 | std::size_t grid_dim, |
75 | std::size_t block_dim, |
76 | std::size_t shared_mem_bytes, |
77 | Args... args) { |
78 | auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...); |
79 | launch(name, grid_dim, block_dim, shared_mem_bytes, arg_pointers, |
80 | arg_sizes); |
81 | } |
82 | |
83 | virtual void launch(const std::string &name, |
84 | std::size_t grid_dim, |
85 | std::size_t block_dim, |
86 | std::size_t shared_mem_bytes, |
87 | const std::vector<void *> &arg_pointers, |
88 | const std::vector<int> &arg_sizes) { |
89 | TI_NOT_IMPLEMENTED |
90 | } |
91 | |
92 | virtual bool direct_dispatch() const = 0; |
93 | |
94 | virtual ~JITModule() { |
95 | } |
96 | }; |
97 | |
98 | } // namespace taichi::lang |
99 | |