1#include "taichi/program/texture.h"
2#include "taichi/program/ndarray.h"
3#include "taichi/program/program.h"
4#include "taichi/rhi/device.h"
5#include "taichi/ir/snode.h"
6
7namespace taichi::lang {
8
9// FIXME: (penguinliong) We might have to differentiate buffer formats and
10// texture formats at some point because formats like `rgb10a2` are not easily
11// represented by primitive types.
12std::pair<DataType, uint32_t> buffer_format2type_channels(BufferFormat format) {
13 switch (format) {
14 case BufferFormat::r8:
15 return std::make_pair(PrimitiveType::u8, 1);
16 case BufferFormat::rg8:
17 return std::make_pair(PrimitiveType::u8, 2);
18 case BufferFormat::rgba8:
19 return std::make_pair(PrimitiveType::u8, 4);
20 case BufferFormat::rgba8srgb:
21 return std::make_pair(PrimitiveType::u8, 4);
22 case BufferFormat::bgra8:
23 return std::make_pair(PrimitiveType::u8, 4);
24 case BufferFormat::bgra8srgb:
25 return std::make_pair(PrimitiveType::u8, 4);
26 case BufferFormat::r8u:
27 return std::make_pair(PrimitiveType::u8, 1);
28 case BufferFormat::rg8u:
29 return std::make_pair(PrimitiveType::u8, 2);
30 case BufferFormat::rgba8u:
31 return std::make_pair(PrimitiveType::u8, 4);
32 case BufferFormat::r8i:
33 return std::make_pair(PrimitiveType::i8, 1);
34 case BufferFormat::rg8i:
35 return std::make_pair(PrimitiveType::i8, 2);
36 case BufferFormat::rgba8i:
37 return std::make_pair(PrimitiveType::i8, 4);
38 case BufferFormat::r16:
39 return std::make_pair(PrimitiveType::u16, 1);
40 case BufferFormat::rg16:
41 return std::make_pair(PrimitiveType::u16, 2);
42 case BufferFormat::rgb16:
43 return std::make_pair(PrimitiveType::u16, 3);
44 case BufferFormat::rgba16:
45 return std::make_pair(PrimitiveType::u16, 4);
46 case BufferFormat::r16u:
47 return std::make_pair(PrimitiveType::u16, 1);
48 case BufferFormat::rg16u:
49 return std::make_pair(PrimitiveType::u16, 2);
50 case BufferFormat::rgb16u:
51 return std::make_pair(PrimitiveType::u16, 3);
52 case BufferFormat::rgba16u:
53 return std::make_pair(PrimitiveType::u16, 4);
54 case BufferFormat::r16i:
55 return std::make_pair(PrimitiveType::i16, 1);
56 case BufferFormat::rg16i:
57 return std::make_pair(PrimitiveType::i16, 2);
58 case BufferFormat::rgb16i:
59 return std::make_pair(PrimitiveType::i16, 3);
60 case BufferFormat::rgba16i:
61 return std::make_pair(PrimitiveType::i16, 4);
62 case BufferFormat::r16f:
63 return std::make_pair(PrimitiveType::f16, 1);
64 case BufferFormat::rg16f:
65 return std::make_pair(PrimitiveType::f16, 2);
66 case BufferFormat::rgb16f:
67 return std::make_pair(PrimitiveType::f16, 3);
68 case BufferFormat::rgba16f:
69 return std::make_pair(PrimitiveType::f16, 4);
70 case BufferFormat::r32u:
71 return std::make_pair(PrimitiveType::u32, 1);
72 case BufferFormat::rg32u:
73 return std::make_pair(PrimitiveType::u32, 2);
74 case BufferFormat::rgb32u:
75 return std::make_pair(PrimitiveType::u32, 3);
76 case BufferFormat::rgba32u:
77 return std::make_pair(PrimitiveType::u32, 4);
78 case BufferFormat::r32i:
79 return std::make_pair(PrimitiveType::i32, 1);
80 case BufferFormat::rg32i:
81 return std::make_pair(PrimitiveType::i32, 2);
82 case BufferFormat::rgb32i:
83 return std::make_pair(PrimitiveType::i32, 3);
84 case BufferFormat::rgba32i:
85 return std::make_pair(PrimitiveType::i32, 4);
86 case BufferFormat::r32f:
87 return std::make_pair(PrimitiveType::f32, 1);
88 case BufferFormat::rg32f:
89 return std::make_pair(PrimitiveType::f32, 2);
90 case BufferFormat::rgb32f:
91 return std::make_pair(PrimitiveType::f32, 3);
92 case BufferFormat::rgba32f:
93 return std::make_pair(PrimitiveType::f32, 4);
94 default:
95 TI_ERROR("Invalid buffer format");
96 return {};
97 }
98}
99
100BufferFormat type_channels2buffer_format(const DataType &type,
101 uint32_t num_channels) {
102 BufferFormat format;
103 if (type == PrimitiveType::f16) {
104 if (num_channels == 1) {
105 format = BufferFormat::r16f;
106 } else if (num_channels == 2) {
107 format = BufferFormat::rg16f;
108 } else if (num_channels == 4) {
109 format = BufferFormat::rgba16f;
110 } else {
111 TI_ERROR("Invalid texture channels");
112 }
113 } else if (type == PrimitiveType::u16) {
114 if (num_channels == 1) {
115 format = BufferFormat::r16;
116 } else if (num_channels == 2) {
117 format = BufferFormat::rg16;
118 } else if (num_channels == 4) {
119 format = BufferFormat::rgba16;
120 } else {
121 TI_ERROR("Invalid texture channels");
122 }
123 } else if (type == PrimitiveType::u8) {
124 if (num_channels == 1) {
125 format = BufferFormat::r8;
126 } else if (num_channels == 2) {
127 format = BufferFormat::rg8;
128 } else if (num_channels == 4) {
129 format = BufferFormat::rgba8;
130 } else {
131 TI_ERROR("Invalid texture channels");
132 }
133 } else if (type == PrimitiveType::f32) {
134 if (num_channels == 1) {
135 format = BufferFormat::r32f;
136 } else if (num_channels == 2) {
137 format = BufferFormat::rg32f;
138 } else if (num_channels == 3) {
139 format = BufferFormat::rgb32f;
140 } else if (num_channels == 4) {
141 format = BufferFormat::rgba32f;
142 } else {
143 TI_ERROR("Invalid texture channels");
144 }
145 } else {
146 TI_ERROR("Invalid texture dtype");
147 }
148 return format;
149}
150
151Texture::Texture(Program *prog,
152 BufferFormat format,
153 int width,
154 int height,
155 int depth)
156 : format_(format),
157 width_(width),
158 height_(height),
159 depth_(depth),
160 prog_(prog) {
161 GraphicsDevice *device =
162 static_cast<GraphicsDevice *>(prog_->get_graphics_device());
163
164 auto [type, num_channels] = buffer_format2type_channels(format);
165 TI_TRACE("Create image, gfx device {}, format={}, w={}, h={}, d={}",
166 (void *)device, type.to_string(), num_channels, width, height,
167 depth);
168
169 TI_ASSERT(num_channels > 0 && num_channels <= 4);
170
171 ImageParams img_params{};
172 img_params.dimension = depth > 1 ? ImageDimension::d3D : ImageDimension::d2D;
173 img_params.format = format;
174 img_params.x = width;
175 img_params.y = height;
176 img_params.z = depth;
177 img_params.initial_layout = ImageLayout::undefined;
178 texture_alloc_ = prog_->allocate_texture(img_params);
179
180 format_ = img_params.format;
181
182 TI_TRACE("image created, gfx device {}", (void *)device);
183}
184
185Texture::Texture(DeviceAllocation &devalloc,
186 BufferFormat format,
187 int width,
188 int height,
189 int depth)
190 : texture_alloc_(devalloc),
191 format_(format),
192 width_(width),
193 height_(height),
194 depth_(depth) {
195 format_ = format;
196}
197
198intptr_t Texture::get_device_allocation_ptr_as_int() const {
199 return reinterpret_cast<intptr_t>(&texture_alloc_);
200}
201
202void Texture::from_ndarray(Ndarray *ndarray) {
203 auto semaphore = prog_->flush();
204
205 GraphicsDevice *device =
206 static_cast<GraphicsDevice *>(prog_->get_graphics_device());
207
208 device->image_transition(texture_alloc_, ImageLayout::undefined,
209 ImageLayout::transfer_dst);
210
211 Stream *stream = device->get_compute_stream();
212 auto [cmdlist, res] = stream->new_command_list_unique();
213 TI_ASSERT(res == RhiResult::success);
214
215 BufferImageCopyParams params;
216 params.buffer_row_length = ndarray->shape[0];
217 params.buffer_image_height = ndarray->shape[1];
218 params.image_mip_level = 0;
219 params.image_extent.x = width_;
220 params.image_extent.y = height_;
221 params.image_extent.z = depth_;
222
223 cmdlist->buffer_barrier(ndarray->ndarray_alloc_);
224 cmdlist->buffer_to_image(texture_alloc_, ndarray->ndarray_alloc_.get_ptr(0),
225 ImageLayout::transfer_dst, params);
226
227 stream->submit_synced(cmdlist.get(), {semaphore});
228}
229
230DevicePtr get_device_ptr(taichi::lang::Program *program, SNode *snode) {
231 SNode *dense_parent = snode->parent;
232 SNode *root = dense_parent->parent;
233
234 int tree_id = root->get_snode_tree_id();
235 DevicePtr root_ptr = program->get_snode_tree_device_ptr(tree_id);
236
237 return root_ptr.get_ptr(program->get_field_in_tree_offset(tree_id, snode));
238}
239
240void Texture::from_snode(SNode *snode) {
241 auto semaphore = prog_->flush();
242
243 TI_ASSERT(snode->is_path_all_dense);
244
245 GraphicsDevice *device =
246 static_cast<GraphicsDevice *>(prog_->get_graphics_device());
247
248 device->image_transition(texture_alloc_, ImageLayout::undefined,
249 ImageLayout::transfer_dst);
250
251 DevicePtr devptr = get_device_ptr(prog_, snode);
252
253 Stream *stream = device->get_compute_stream();
254 auto [cmdlist, res] = stream->new_command_list_unique();
255 TI_ASSERT(res == RhiResult::success);
256
257 BufferImageCopyParams params;
258 params.buffer_row_length = snode->shape_along_axis(0);
259 params.buffer_image_height = snode->shape_along_axis(1);
260 params.image_mip_level = 0;
261 params.image_extent.x = width_;
262 params.image_extent.y = height_;
263 params.image_extent.z = depth_;
264
265 cmdlist->buffer_barrier(devptr);
266 cmdlist->buffer_to_image(texture_alloc_, devptr, ImageLayout::transfer_dst,
267 params);
268
269 stream->submit_synced(cmdlist.get(), {semaphore});
270}
271
272Texture::~Texture() {
273 if (prog_) {
274 GraphicsDevice *device =
275 static_cast<GraphicsDevice *>(prog_->get_graphics_device());
276 device->destroy_image(texture_alloc_);
277 }
278}
279
280} // namespace taichi::lang
281