1#pragma once
2
3#include <memory>
4#include <functional>
5#include <tuple>
6
7#include "taichi/inc/constants.h"
8#include "taichi/util/lang_util.h"
9#include "taichi/program/kernel_profiler.h"
10
11namespace taichi::lang {
12
13// A architecture-specific JIT module that initializes with an **LLVM** module
14// and allows the user to call its functions
15// TODO: should we generalize this to include the Metal and OpenGL backends as
16// well?
17
18class JITModule {
19 public:
20 JITModule() {
21 }
22
23 // Lookup a serial function.
24 // For example, a CPU function, or a serial GPU function
25 // This function returns a function pointer
26 virtual void *lookup_function(const std::string &name) = 0;
27
28 // Unfortunately, this can't be virtual since it's a template function
29 template <typename... Args>
30 std::function<void(Args...)> get_function(const std::string &name) {
31 using FuncT = typename std::function<void(Args...)>;
32 auto ret = FuncT((function_pointer_type<FuncT>)lookup_function(name));
33 TI_ASSERT(ret != nullptr);
34 return ret;
35 }
36
37 inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers() {
38 return std::make_tuple(std::vector<void *>(), std::vector<int>());
39 }
40
41 template <typename... Args, typename T>
42 inline std::tuple<std::vector<void *>, std::vector<int> > get_arg_pointers(
43 T &t,
44 Args &...args) {
45 auto [arg_pointers, arg_sizes] = get_arg_pointers(args...);
46 arg_pointers.insert(arg_pointers.begin(), &t);
47 arg_sizes.insert(arg_sizes.begin(), sizeof(t));
48 return std::make_tuple(arg_pointers, arg_sizes);
49 }
50
51 // Note: **call** is for serial functions
52 // Note: args must pass by value
53 // Note: AMDGPU need to pass args by extra_arg currently
54 template <typename... Args>
55 void call(const std::string &name, Args... args) {
56 if (direct_dispatch()) {
57 get_function<Args...>(name)(args...);
58 } else {
59 auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...);
60 call(name, arg_pointers, arg_sizes);
61 }
62 }
63
64 virtual void call(const std::string &name,
65 const std::vector<void *> &arg_pointers,
66 const std::vector<int> &arg_sizes) {
67 TI_NOT_IMPLEMENTED
68 }
69
70 // Note: **launch** is for parallel (GPU)_kernels
71 // Note: args must pass by value
72 template <typename... Args>
73 void launch(const std::string &name,
74 std::size_t grid_dim,
75 std::size_t block_dim,
76 std::size_t shared_mem_bytes,
77 Args... args) {
78 auto [arg_pointers, arg_sizes] = JITModule::get_arg_pointers(args...);
79 launch(name, grid_dim, block_dim, shared_mem_bytes, arg_pointers,
80 arg_sizes);
81 }
82
83 virtual void launch(const std::string &name,
84 std::size_t grid_dim,
85 std::size_t block_dim,
86 std::size_t shared_mem_bytes,
87 const std::vector<void *> &arg_pointers,
88 const std::vector<int> &arg_sizes) {
89 TI_NOT_IMPLEMENTED
90 }
91
92 virtual bool direct_dispatch() const = 0;
93
94 virtual ~JITModule() {
95 }
96};
97
98} // namespace taichi::lang
99