1 | #pragma once |
2 | |
3 | #include "taichi/util/lang_util.h" |
4 | #include "taichi/ir/snode.h" |
5 | #include "taichi/ir/ir.h" |
6 | #include "taichi/rhi/arch.h" |
7 | #include "taichi/program/callable.h" |
8 | #include "taichi/program/ndarray.h" |
9 | #include "taichi/program/texture.h" |
10 | #include "taichi/aot/graph_data.h" |
11 | |
12 | namespace taichi::lang { |
13 | |
14 | class Program; |
15 | |
16 | class TI_DLL_EXPORT Kernel : public Callable { |
17 | public: |
18 | std::string name; |
19 | std::vector<SNode *> no_activate; |
20 | |
21 | bool is_accessor{false}; |
22 | bool is_evaluator{false}; |
23 | AutodiffMode autodiff_mode{AutodiffMode::kNone}; |
24 | |
25 | class LaunchContextBuilder { |
26 | public: |
27 | LaunchContextBuilder(Kernel *kernel, RuntimeContext *ctx); |
28 | explicit LaunchContextBuilder(Kernel *kernel); |
29 | |
30 | LaunchContextBuilder(LaunchContextBuilder &&) = default; |
31 | LaunchContextBuilder &operator=(LaunchContextBuilder &&) = default; |
32 | LaunchContextBuilder(const LaunchContextBuilder &) = delete; |
33 | LaunchContextBuilder &operator=(const LaunchContextBuilder &) = delete; |
34 | |
35 | void set_arg_float(int arg_id, float64 d); |
36 | |
37 | // Created signed and unsigned version for argument range check of pybind |
38 | void set_arg_int(int arg_id, int64 d); |
39 | void set_arg_uint(int arg_id, uint64 d); |
40 | |
41 | void (int i, int j, int32 d); |
42 | |
43 | void set_arg_external_array_with_shape(int arg_id, |
44 | uintptr_t ptr, |
45 | uint64 size, |
46 | const std::vector<int64> &shape); |
47 | |
48 | void set_arg_ndarray(int arg_id, const Ndarray &arr); |
49 | void set_arg_ndarray_with_grad(int arg_id, |
50 | const Ndarray &arr, |
51 | const Ndarray &arr_grad); |
52 | |
53 | void set_arg_texture(int arg_id, const Texture &tex); |
54 | void set_arg_rw_texture(int arg_id, const Texture &tex); |
55 | |
56 | // Sets the |arg_id|-th arg in the context to the bits stored in |d|. |
57 | // This ignores the underlying kernel's |arg_id|-th arg type. |
58 | void set_arg_raw(int arg_id, uint64 d); |
59 | |
60 | RuntimeContext &get_context(); |
61 | |
62 | private: |
63 | Kernel *kernel_; |
64 | std::unique_ptr<RuntimeContext> owned_ctx_; |
65 | // |ctx_| *almost* always points to |owned_ctx_|. However, it is possible |
66 | // that the caller passes a RuntimeContext pointer externally. In that case, |
67 | // |owned_ctx_| will be nullptr. |
68 | // Invariant: |ctx_| will never be nullptr. |
69 | RuntimeContext *ctx_; |
70 | }; |
71 | |
72 | Kernel(Program &program, |
73 | const std::function<void()> &func, |
74 | const std::string &name = "" , |
75 | AutodiffMode autodiff_mode = AutodiffMode::kNone); |
76 | |
77 | Kernel(Program &program, |
78 | const std::function<void(Kernel *)> &func, |
79 | const std::string &name = "" , |
80 | AutodiffMode autodiff_mode = AutodiffMode::kNone); |
81 | |
82 | Kernel(Program &program, |
83 | std::unique_ptr<IRNode> &&ir, |
84 | const std::string &name = "" , |
85 | AutodiffMode autodiff_mode = AutodiffMode::kNone); |
86 | |
87 | bool ir_is_ast() const { |
88 | return ir_is_ast_; |
89 | } |
90 | |
91 | bool lowered() const { |
92 | return lowered_; |
93 | } |
94 | |
95 | void set_lowered(bool lowered) { |
96 | lowered_ = lowered; |
97 | } |
98 | |
99 | void compile(const CompileConfig &compile_config); |
100 | |
101 | void operator()(const CompileConfig &compile_config, |
102 | LaunchContextBuilder &ctx_builder); |
103 | |
104 | LaunchContextBuilder make_launch_context(); |
105 | |
106 | template <typename T> |
107 | T fetch_ret(DataType dt, int i); |
108 | |
109 | float64 get_ret_float(int i); |
110 | int64 get_ret_int(int i); |
111 | uint64 get_ret_uint(int i); |
112 | std::vector<int64> get_ret_int_tensor(int i); |
113 | std::vector<uint64> get_ret_uint_tensor(int i); |
114 | std::vector<float64> get_ret_float_tensor(int i); |
115 | |
116 | TypedConstant fetch_ret(const std::vector<int> &index); |
117 | |
118 | float64 get_struct_ret_float(const std::vector<int> &index); |
119 | int64 get_struct_ret_int(const std::vector<int> &index); |
120 | uint64 get_struct_ret_uint(const std::vector<int> &index); |
121 | |
122 | uint64 get_next_task_id() { |
123 | return task_counter_++; |
124 | } |
125 | |
126 | [[nodiscard]] std::string get_name() const override; |
127 | |
128 | void set_kernel_key_for_cache(const std::string &kernel_key) { |
129 | kernel_key_ = kernel_key; |
130 | } |
131 | |
132 | const std::string &get_cached_kernel_key() { |
133 | return kernel_key_; |
134 | } |
135 | |
136 | private: |
137 | void init(Program &program, |
138 | const std::function<void()> &func, |
139 | const std::string &name = "" , |
140 | AutodiffMode autodiff_mode = AutodiffMode::kNone); |
141 | |
142 | // True if |ir| is a frontend AST. False if it's already offloaded to CHI IR. |
143 | bool ir_is_ast_{false}; |
144 | // The closure that, if invoked, launches the backend kernel (shader) |
145 | FunctionType compiled_{nullptr}; |
146 | // A flag to record whether |ir| has been fully lowered. |
147 | // lower initial AST all the way down to a bunch of |
148 | // OffloadedStmt for async execution TODO(Lin): Check this comment |
149 | bool lowered_{false}; |
150 | std::atomic<uint64> task_counter_{0}; |
151 | std::string kernel_key_; |
152 | }; |
153 | |
154 | } // namespace taichi::lang |
155 | |