1#pragma once
2
3#include "taichi/util/lang_util.h"
4#include "taichi/ir/snode.h"
5#include "taichi/ir/ir.h"
6#include "taichi/rhi/arch.h"
7#include "taichi/program/callable.h"
8#include "taichi/program/ndarray.h"
9#include "taichi/program/texture.h"
10#include "taichi/aot/graph_data.h"
11
12namespace taichi::lang {
13
14class Program;
15
16class TI_DLL_EXPORT Kernel : public Callable {
17 public:
18 std::string name;
19 std::vector<SNode *> no_activate;
20
21 bool is_accessor{false};
22 bool is_evaluator{false};
23 AutodiffMode autodiff_mode{AutodiffMode::kNone};
24
25 class LaunchContextBuilder {
26 public:
27 LaunchContextBuilder(Kernel *kernel, RuntimeContext *ctx);
28 explicit LaunchContextBuilder(Kernel *kernel);
29
30 LaunchContextBuilder(LaunchContextBuilder &&) = default;
31 LaunchContextBuilder &operator=(LaunchContextBuilder &&) = default;
32 LaunchContextBuilder(const LaunchContextBuilder &) = delete;
33 LaunchContextBuilder &operator=(const LaunchContextBuilder &) = delete;
34
35 void set_arg_float(int arg_id, float64 d);
36
37 // Created signed and unsigned version for argument range check of pybind
38 void set_arg_int(int arg_id, int64 d);
39 void set_arg_uint(int arg_id, uint64 d);
40
41 void set_extra_arg_int(int i, int j, int32 d);
42
43 void set_arg_external_array_with_shape(int arg_id,
44 uintptr_t ptr,
45 uint64 size,
46 const std::vector<int64> &shape);
47
48 void set_arg_ndarray(int arg_id, const Ndarray &arr);
49 void set_arg_ndarray_with_grad(int arg_id,
50 const Ndarray &arr,
51 const Ndarray &arr_grad);
52
53 void set_arg_texture(int arg_id, const Texture &tex);
54 void set_arg_rw_texture(int arg_id, const Texture &tex);
55
56 // Sets the |arg_id|-th arg in the context to the bits stored in |d|.
57 // This ignores the underlying kernel's |arg_id|-th arg type.
58 void set_arg_raw(int arg_id, uint64 d);
59
60 RuntimeContext &get_context();
61
62 private:
63 Kernel *kernel_;
64 std::unique_ptr<RuntimeContext> owned_ctx_;
65 // |ctx_| *almost* always points to |owned_ctx_|. However, it is possible
66 // that the caller passes a RuntimeContext pointer externally. In that case,
67 // |owned_ctx_| will be nullptr.
68 // Invariant: |ctx_| will never be nullptr.
69 RuntimeContext *ctx_;
70 };
71
72 Kernel(Program &program,
73 const std::function<void()> &func,
74 const std::string &name = "",
75 AutodiffMode autodiff_mode = AutodiffMode::kNone);
76
77 Kernel(Program &program,
78 const std::function<void(Kernel *)> &func,
79 const std::string &name = "",
80 AutodiffMode autodiff_mode = AutodiffMode::kNone);
81
82 Kernel(Program &program,
83 std::unique_ptr<IRNode> &&ir,
84 const std::string &name = "",
85 AutodiffMode autodiff_mode = AutodiffMode::kNone);
86
87 bool ir_is_ast() const {
88 return ir_is_ast_;
89 }
90
91 bool lowered() const {
92 return lowered_;
93 }
94
95 void set_lowered(bool lowered) {
96 lowered_ = lowered;
97 }
98
99 void compile(const CompileConfig &compile_config);
100
101 void operator()(const CompileConfig &compile_config,
102 LaunchContextBuilder &ctx_builder);
103
104 LaunchContextBuilder make_launch_context();
105
106 template <typename T>
107 T fetch_ret(DataType dt, int i);
108
109 float64 get_ret_float(int i);
110 int64 get_ret_int(int i);
111 uint64 get_ret_uint(int i);
112 std::vector<int64> get_ret_int_tensor(int i);
113 std::vector<uint64> get_ret_uint_tensor(int i);
114 std::vector<float64> get_ret_float_tensor(int i);
115
116 TypedConstant fetch_ret(const std::vector<int> &index);
117
118 float64 get_struct_ret_float(const std::vector<int> &index);
119 int64 get_struct_ret_int(const std::vector<int> &index);
120 uint64 get_struct_ret_uint(const std::vector<int> &index);
121
122 uint64 get_next_task_id() {
123 return task_counter_++;
124 }
125
126 [[nodiscard]] std::string get_name() const override;
127
128 void set_kernel_key_for_cache(const std::string &kernel_key) {
129 kernel_key_ = kernel_key;
130 }
131
132 const std::string &get_cached_kernel_key() {
133 return kernel_key_;
134 }
135
136 private:
137 void init(Program &program,
138 const std::function<void()> &func,
139 const std::string &name = "",
140 AutodiffMode autodiff_mode = AutodiffMode::kNone);
141
142 // True if |ir| is a frontend AST. False if it's already offloaded to CHI IR.
143 bool ir_is_ast_{false};
144 // The closure that, if invoked, launches the backend kernel (shader)
145 FunctionType compiled_{nullptr};
146 // A flag to record whether |ir| has been fully lowered.
147 // lower initial AST all the way down to a bunch of
148 // OffloadedStmt for async execution TODO(Lin): Check this comment
149 bool lowered_{false};
150 std::atomic<uint64> task_counter_{0};
151 std::string kernel_key_;
152};
153
154} // namespace taichi::lang
155