1 | #include "taichi/program/texture.h" |
2 | #include "taichi/program/ndarray.h" |
3 | #include "taichi/program/program.h" |
4 | #include "taichi/rhi/device.h" |
5 | #include "taichi/ir/snode.h" |
6 | |
7 | namespace taichi::lang { |
8 | |
9 | // FIXME: (penguinliong) We might have to differentiate buffer formats and |
10 | // texture formats at some point because formats like `rgb10a2` are not easily |
11 | // represented by primitive types. |
12 | std::pair<DataType, uint32_t> buffer_format2type_channels(BufferFormat format) { |
13 | switch (format) { |
14 | case BufferFormat::r8: |
15 | return std::make_pair(PrimitiveType::u8, 1); |
16 | case BufferFormat::rg8: |
17 | return std::make_pair(PrimitiveType::u8, 2); |
18 | case BufferFormat::rgba8: |
19 | return std::make_pair(PrimitiveType::u8, 4); |
20 | case BufferFormat::rgba8srgb: |
21 | return std::make_pair(PrimitiveType::u8, 4); |
22 | case BufferFormat::bgra8: |
23 | return std::make_pair(PrimitiveType::u8, 4); |
24 | case BufferFormat::bgra8srgb: |
25 | return std::make_pair(PrimitiveType::u8, 4); |
26 | case BufferFormat::r8u: |
27 | return std::make_pair(PrimitiveType::u8, 1); |
28 | case BufferFormat::rg8u: |
29 | return std::make_pair(PrimitiveType::u8, 2); |
30 | case BufferFormat::rgba8u: |
31 | return std::make_pair(PrimitiveType::u8, 4); |
32 | case BufferFormat::r8i: |
33 | return std::make_pair(PrimitiveType::i8, 1); |
34 | case BufferFormat::rg8i: |
35 | return std::make_pair(PrimitiveType::i8, 2); |
36 | case BufferFormat::rgba8i: |
37 | return std::make_pair(PrimitiveType::i8, 4); |
38 | case BufferFormat::r16: |
39 | return std::make_pair(PrimitiveType::u16, 1); |
40 | case BufferFormat::rg16: |
41 | return std::make_pair(PrimitiveType::u16, 2); |
42 | case BufferFormat::rgb16: |
43 | return std::make_pair(PrimitiveType::u16, 3); |
44 | case BufferFormat::rgba16: |
45 | return std::make_pair(PrimitiveType::u16, 4); |
46 | case BufferFormat::r16u: |
47 | return std::make_pair(PrimitiveType::u16, 1); |
48 | case BufferFormat::rg16u: |
49 | return std::make_pair(PrimitiveType::u16, 2); |
50 | case BufferFormat::rgb16u: |
51 | return std::make_pair(PrimitiveType::u16, 3); |
52 | case BufferFormat::rgba16u: |
53 | return std::make_pair(PrimitiveType::u16, 4); |
54 | case BufferFormat::r16i: |
55 | return std::make_pair(PrimitiveType::i16, 1); |
56 | case BufferFormat::rg16i: |
57 | return std::make_pair(PrimitiveType::i16, 2); |
58 | case BufferFormat::rgb16i: |
59 | return std::make_pair(PrimitiveType::i16, 3); |
60 | case BufferFormat::rgba16i: |
61 | return std::make_pair(PrimitiveType::i16, 4); |
62 | case BufferFormat::r16f: |
63 | return std::make_pair(PrimitiveType::f16, 1); |
64 | case BufferFormat::rg16f: |
65 | return std::make_pair(PrimitiveType::f16, 2); |
66 | case BufferFormat::rgb16f: |
67 | return std::make_pair(PrimitiveType::f16, 3); |
68 | case BufferFormat::rgba16f: |
69 | return std::make_pair(PrimitiveType::f16, 4); |
70 | case BufferFormat::r32u: |
71 | return std::make_pair(PrimitiveType::u32, 1); |
72 | case BufferFormat::rg32u: |
73 | return std::make_pair(PrimitiveType::u32, 2); |
74 | case BufferFormat::rgb32u: |
75 | return std::make_pair(PrimitiveType::u32, 3); |
76 | case BufferFormat::rgba32u: |
77 | return std::make_pair(PrimitiveType::u32, 4); |
78 | case BufferFormat::r32i: |
79 | return std::make_pair(PrimitiveType::i32, 1); |
80 | case BufferFormat::rg32i: |
81 | return std::make_pair(PrimitiveType::i32, 2); |
82 | case BufferFormat::rgb32i: |
83 | return std::make_pair(PrimitiveType::i32, 3); |
84 | case BufferFormat::rgba32i: |
85 | return std::make_pair(PrimitiveType::i32, 4); |
86 | case BufferFormat::r32f: |
87 | return std::make_pair(PrimitiveType::f32, 1); |
88 | case BufferFormat::rg32f: |
89 | return std::make_pair(PrimitiveType::f32, 2); |
90 | case BufferFormat::rgb32f: |
91 | return std::make_pair(PrimitiveType::f32, 3); |
92 | case BufferFormat::rgba32f: |
93 | return std::make_pair(PrimitiveType::f32, 4); |
94 | default: |
95 | TI_ERROR("Invalid buffer format" ); |
96 | return {}; |
97 | } |
98 | } |
99 | |
100 | BufferFormat type_channels2buffer_format(const DataType &type, |
101 | uint32_t num_channels) { |
102 | BufferFormat format; |
103 | if (type == PrimitiveType::f16) { |
104 | if (num_channels == 1) { |
105 | format = BufferFormat::r16f; |
106 | } else if (num_channels == 2) { |
107 | format = BufferFormat::rg16f; |
108 | } else if (num_channels == 4) { |
109 | format = BufferFormat::rgba16f; |
110 | } else { |
111 | TI_ERROR("Invalid texture channels" ); |
112 | } |
113 | } else if (type == PrimitiveType::u16) { |
114 | if (num_channels == 1) { |
115 | format = BufferFormat::r16; |
116 | } else if (num_channels == 2) { |
117 | format = BufferFormat::rg16; |
118 | } else if (num_channels == 4) { |
119 | format = BufferFormat::rgba16; |
120 | } else { |
121 | TI_ERROR("Invalid texture channels" ); |
122 | } |
123 | } else if (type == PrimitiveType::u8) { |
124 | if (num_channels == 1) { |
125 | format = BufferFormat::r8; |
126 | } else if (num_channels == 2) { |
127 | format = BufferFormat::rg8; |
128 | } else if (num_channels == 4) { |
129 | format = BufferFormat::rgba8; |
130 | } else { |
131 | TI_ERROR("Invalid texture channels" ); |
132 | } |
133 | } else if (type == PrimitiveType::f32) { |
134 | if (num_channels == 1) { |
135 | format = BufferFormat::r32f; |
136 | } else if (num_channels == 2) { |
137 | format = BufferFormat::rg32f; |
138 | } else if (num_channels == 3) { |
139 | format = BufferFormat::rgb32f; |
140 | } else if (num_channels == 4) { |
141 | format = BufferFormat::rgba32f; |
142 | } else { |
143 | TI_ERROR("Invalid texture channels" ); |
144 | } |
145 | } else { |
146 | TI_ERROR("Invalid texture dtype" ); |
147 | } |
148 | return format; |
149 | } |
150 | |
151 | Texture::Texture(Program *prog, |
152 | BufferFormat format, |
153 | int width, |
154 | int height, |
155 | int depth) |
156 | : format_(format), |
157 | width_(width), |
158 | height_(height), |
159 | depth_(depth), |
160 | prog_(prog) { |
161 | GraphicsDevice *device = |
162 | static_cast<GraphicsDevice *>(prog_->get_graphics_device()); |
163 | |
164 | auto [type, num_channels] = buffer_format2type_channels(format); |
165 | TI_TRACE("Create image, gfx device {}, format={}, w={}, h={}, d={}" , |
166 | (void *)device, type.to_string(), num_channels, width, height, |
167 | depth); |
168 | |
169 | TI_ASSERT(num_channels > 0 && num_channels <= 4); |
170 | |
171 | ImageParams img_params{}; |
172 | img_params.dimension = depth > 1 ? ImageDimension::d3D : ImageDimension::d2D; |
173 | img_params.format = format; |
174 | img_params.x = width; |
175 | img_params.y = height; |
176 | img_params.z = depth; |
177 | img_params.initial_layout = ImageLayout::undefined; |
178 | texture_alloc_ = prog_->allocate_texture(img_params); |
179 | |
180 | format_ = img_params.format; |
181 | |
182 | TI_TRACE("image created, gfx device {}" , (void *)device); |
183 | } |
184 | |
185 | Texture::Texture(DeviceAllocation &devalloc, |
186 | BufferFormat format, |
187 | int width, |
188 | int height, |
189 | int depth) |
190 | : texture_alloc_(devalloc), |
191 | format_(format), |
192 | width_(width), |
193 | height_(height), |
194 | depth_(depth) { |
195 | format_ = format; |
196 | } |
197 | |
198 | intptr_t Texture::get_device_allocation_ptr_as_int() const { |
199 | return reinterpret_cast<intptr_t>(&texture_alloc_); |
200 | } |
201 | |
202 | void Texture::from_ndarray(Ndarray *ndarray) { |
203 | auto semaphore = prog_->flush(); |
204 | |
205 | GraphicsDevice *device = |
206 | static_cast<GraphicsDevice *>(prog_->get_graphics_device()); |
207 | |
208 | device->image_transition(texture_alloc_, ImageLayout::undefined, |
209 | ImageLayout::transfer_dst); |
210 | |
211 | Stream *stream = device->get_compute_stream(); |
212 | auto [cmdlist, res] = stream->new_command_list_unique(); |
213 | TI_ASSERT(res == RhiResult::success); |
214 | |
215 | BufferImageCopyParams params; |
216 | params.buffer_row_length = ndarray->shape[0]; |
217 | params.buffer_image_height = ndarray->shape[1]; |
218 | params.image_mip_level = 0; |
219 | params.image_extent.x = width_; |
220 | params.image_extent.y = height_; |
221 | params.image_extent.z = depth_; |
222 | |
223 | cmdlist->buffer_barrier(ndarray->ndarray_alloc_); |
224 | cmdlist->buffer_to_image(texture_alloc_, ndarray->ndarray_alloc_.get_ptr(0), |
225 | ImageLayout::transfer_dst, params); |
226 | |
227 | stream->submit_synced(cmdlist.get(), {semaphore}); |
228 | } |
229 | |
230 | DevicePtr get_device_ptr(taichi::lang::Program *program, SNode *snode) { |
231 | SNode *dense_parent = snode->parent; |
232 | SNode *root = dense_parent->parent; |
233 | |
234 | int tree_id = root->get_snode_tree_id(); |
235 | DevicePtr root_ptr = program->get_snode_tree_device_ptr(tree_id); |
236 | |
237 | return root_ptr.get_ptr(program->get_field_in_tree_offset(tree_id, snode)); |
238 | } |
239 | |
240 | void Texture::from_snode(SNode *snode) { |
241 | auto semaphore = prog_->flush(); |
242 | |
243 | TI_ASSERT(snode->is_path_all_dense); |
244 | |
245 | GraphicsDevice *device = |
246 | static_cast<GraphicsDevice *>(prog_->get_graphics_device()); |
247 | |
248 | device->image_transition(texture_alloc_, ImageLayout::undefined, |
249 | ImageLayout::transfer_dst); |
250 | |
251 | DevicePtr devptr = get_device_ptr(prog_, snode); |
252 | |
253 | Stream *stream = device->get_compute_stream(); |
254 | auto [cmdlist, res] = stream->new_command_list_unique(); |
255 | TI_ASSERT(res == RhiResult::success); |
256 | |
257 | BufferImageCopyParams params; |
258 | params.buffer_row_length = snode->shape_along_axis(0); |
259 | params.buffer_image_height = snode->shape_along_axis(1); |
260 | params.image_mip_level = 0; |
261 | params.image_extent.x = width_; |
262 | params.image_extent.y = height_; |
263 | params.image_extent.z = depth_; |
264 | |
265 | cmdlist->buffer_barrier(devptr); |
266 | cmdlist->buffer_to_image(texture_alloc_, devptr, ImageLayout::transfer_dst, |
267 | params); |
268 | |
269 | stream->submit_synced(cmdlist.get(), {semaphore}); |
270 | } |
271 | |
272 | Texture::~Texture() { |
273 | if (prog_) { |
274 | GraphicsDevice *device = |
275 | static_cast<GraphicsDevice *>(prog_->get_graphics_device()); |
276 | device->destroy_image(texture_alloc_); |
277 | } |
278 | } |
279 | |
280 | } // namespace taichi::lang |
281 | |