1#pragma once
2
3#include <string>
4#include <vector>
5
6#include "taichi/rhi/device.h"
7#include "taichi/common/core.h"
8#include "taichi/common/serialization.h"
9
10namespace taichi::lang {
11namespace aot {
12
13struct CompiledFieldData {
14 std::string field_name;
15 uint32_t dtype{0};
16 std::string dtype_name;
17 size_t mem_offset_in_parent{0};
18 std::vector<int> shape;
19 bool is_scalar{false};
20 std::vector<int> element_shape;
21
22 TI_IO_DEF(field_name,
23 dtype,
24 dtype_name,
25 mem_offset_in_parent,
26 shape,
27 is_scalar,
28 element_shape);
29};
30
31enum class BufferType { Root, GlobalTmps, Args, Rets };
32
33struct BufferInfo {
34 BufferType type;
35 int id{-1}; // only used if type==Root
36
37 TI_IO_DEF(type, id);
38};
39
40struct BufferBind {
41 BufferInfo buffer;
42 int binding{0};
43
44 TI_IO_DEF(buffer, binding);
45};
46
47struct TextureBind {
48 int arg_id;
49 int binding;
50 bool is_storage;
51
52 TI_IO_DEF(arg_id, binding, is_storage);
53};
54
55struct CompiledOffloadedTask {
56 std::string type;
57 std::string range_hint;
58 std::string name;
59 // Do we need to inline the source code?
60 std::string source_path;
61 int gpu_block_size{0};
62
63 std::vector<BufferBind> buffer_binds;
64 std::vector<TextureBind> texture_binds;
65
66 TI_IO_DEF(type,
67 range_hint,
68 name,
69 source_path,
70 gpu_block_size,
71 buffer_binds,
72 texture_binds);
73};
74
75struct ScalarArg {
76 std::string dtype_name;
77 // Unit: byte
78 size_t offset_in_args_buf{0};
79
80 TI_IO_DEF(dtype_name, offset_in_args_buf);
81};
82
83struct ArrayArg {
84 std::string dtype_name;
85 std::size_t field_dim{0};
86 // If |element_shape| is empty, it means this is a scalar
87 std::vector<int> element_shape;
88 // Unit: byte
89 std::size_t shape_offset_in_args_buf{0};
90 // For Vulkan/OpenGL/Metal, this is the binding index
91 int bind_index{0};
92
93 TI_IO_DEF(dtype_name,
94 field_dim,
95 element_shape,
96 shape_offset_in_args_buf,
97 bind_index);
98};
99
100struct CompiledTaichiKernel {
101 std::vector<CompiledOffloadedTask> tasks;
102 int args_count{0};
103 int rets_count{0};
104 size_t args_buffer_size{0};
105 size_t rets_buffer_size{0};
106
107 std::unordered_map<int, ScalarArg> scalar_args;
108 std::unordered_map<int, ArrayArg> arr_args;
109
110 TI_IO_DEF(tasks,
111 args_count,
112 rets_count,
113 args_buffer_size,
114 rets_buffer_size,
115 scalar_args,
116 arr_args);
117};
118
119struct ModuleData {
120 std::unordered_map<std::string, CompiledTaichiKernel> kernels;
121 std::unordered_map<std::string, CompiledTaichiKernel> kernel_tmpls;
122 std::vector<aot::CompiledFieldData> fields;
123 std::map<std::string, uint32_t> required_caps;
124
125 size_t root_buffer_size;
126
127 void dump_json(std::string path) {
128 TextSerializer ts;
129 ts.serialize_to_json("aot_data", *this);
130 ts.write_to_file(path);
131 }
132
133 TI_IO_DEF(kernels, kernel_tmpls, fields, required_caps, root_buffer_size);
134};
135
136} // namespace aot
137} // namespace taichi::lang
138