1
2#include <vector>
3#include "pybind11/pybind11.h"
4#include <pybind11/numpy.h>
5#include "pybind11/stl.h"
6
7#include "taichi/common/interface.h"
8#include "taichi/common/core.h"
9
10namespace py = pybind11;
11
12#ifdef TI_WITH_GGUI
13
14#include "taichi/ui/utils/utils.h"
15#include "taichi/ui/common/window_base.h"
16#include "taichi/ui/backends/vulkan/window.h"
17#include "taichi/ui/common/canvas_base.h"
18#include "taichi/ui/common/camera.h"
19#include "taichi/ui/backends/vulkan/canvas.h"
20#include "taichi/ui/backends/vulkan/scene.h"
21#include "taichi/rhi/vulkan/vulkan_loader.h"
22#include "taichi/rhi/arch.h"
23#include "taichi/program/field_info.h"
24#include "taichi/ui/common/gui_base.h"
25#include "taichi/program/ndarray.h"
26#include <memory>
27
28namespace taichi::ui {
29
30using namespace taichi::lang;
31
32glm::vec3 tuple_to_vec3(pybind11::tuple t) {
33 return glm::vec3(t[0].cast<float>(), t[1].cast<float>(), t[2].cast<float>());
34}
35
36pybind11::tuple vec3_to_tuple(glm::vec3 v) {
37 return pybind11::make_tuple(v.x, v.y, v.z);
38}
39
40// Here we convert the 2d-array to numpy array using pybind. Refs:
41// https://pybind11.readthedocs.io/en/stable/advanced/pycpp/numpy.html?highlight=array_t#vectorizing-functions
42// https://stackoverflow.com/questions/44659924/returning-numpy-arrays-via-pybind11
43py::array_t<float> mat4_to_nparray(glm::mat4 mat) {
44 // Here we must explicitly pass args using py::detail::any_container<ssize_t>.
45 // Refs:
46 // https://stackoverflow.com/questions/54055530/error-no-matching-function-for-call-to-pybind11buffer-infobuffer-info
47 return py::array_t<float>(
48 py::detail::any_container<ssize_t>({4, 4}), // shape (rows, cols)
49 py::detail::any_container<ssize_t>(
50 {sizeof(float) * 4, sizeof(float)}), // strides in bytes
51 glm::value_ptr(mat), // buffer pointer
52 nullptr);
53}
54
55struct PyGui {
56 GuiBase *gui; // not owned
57 void begin(std::string name, float x, float y, float width, float height) {
58 gui->begin(name, x, y, width, height);
59 }
60 void end() {
61 gui->end();
62 }
63 void text(std::string text) {
64 gui->text(text);
65 }
66 void text_colored(std::string text, py::tuple color) {
67 gui->text(text, tuple_to_vec3(color));
68 }
69 bool checkbox(std::string name, bool old_value) {
70 return gui->checkbox(name, old_value);
71 }
72 int slider_int(std::string name, int old_value, int minimum, int maximum) {
73 return gui->slider_int(name, old_value, minimum, maximum);
74 }
75 float slider_float(std::string name,
76 float old_value,
77 float minimum,
78 float maximum) {
79 return gui->slider_float(name, old_value, minimum, maximum);
80 }
81 py::tuple color_edit_3(std::string name, py::tuple old_value) {
82 glm::vec3 old_color = tuple_to_vec3(old_value);
83 glm::vec3 new_color = gui->color_edit_3(name, old_color);
84 return vec3_to_tuple(new_color);
85 }
86 bool button(std::string name) {
87 return gui->button(name);
88 }
89};
90
91struct PyCamera {
92 Camera camera;
93 void position(float x, float y, float z) {
94 camera.position = glm::vec3(x, y, z);
95 }
96 void lookat(float x, float y, float z) {
97 camera.lookat = glm::vec3(x, y, z);
98 }
99 void up(float x, float y, float z) {
100 camera.up = glm::vec3(x, y, z);
101 }
102 void projection_mode(ProjectionMode mode) {
103 camera.projection_mode = mode;
104 }
105 void fov(float fov_) {
106 camera.fov = fov_;
107 }
108 void left(float left_) {
109 camera.left = left_;
110 }
111 void right(float right_) {
112 camera.right = right_;
113 }
114 void top(float top_) {
115 camera.top = top_;
116 }
117 void bottom(float bottom_) {
118 camera.bottom = bottom_;
119 }
120 void z_near(float z_near_) {
121 camera.z_near = z_near_;
122 }
123 void z_far(float z_far_) {
124 camera.z_far = z_far_;
125 }
126 py::array_t<float> get_view_matrix() {
127 return mat4_to_nparray(camera.get_view_matrix());
128 }
129 py::array_t<float> get_projection_matrix(float aspect_ratio) {
130 return mat4_to_nparray(camera.get_projection_matrix(aspect_ratio));
131 }
132};
133
134struct PyScene {
135 SceneBase *scene; // owned
136
137 PyScene() {
138 // todo: support other ggui backends
139 scene = new vulkan::Scene();
140 }
141
142 void set_camera(PyCamera camera) {
143 scene->set_camera(camera.camera);
144 }
145
146 void lines(FieldInfo vbo,
147 FieldInfo indices,
148 bool has_per_vertex_color,
149 py::tuple color_,
150 float width,
151 float draw_index_count,
152 float draw_first_index,
153 float draw_vertex_count,
154 float draw_first_vertex) {
155 RenderableInfo renderable_info;
156 renderable_info.vbo = vbo;
157 renderable_info.indices = indices;
158 renderable_info.has_per_vertex_color = has_per_vertex_color;
159 renderable_info.has_user_customized_draw = true;
160 renderable_info.draw_index_count = (int)draw_index_count;
161 renderable_info.draw_first_index = (int)draw_first_index;
162 renderable_info.draw_vertex_count = (int)draw_vertex_count;
163 renderable_info.draw_first_vertex = (int)draw_first_vertex;
164
165 SceneLinesInfo info;
166 info.renderable_info = renderable_info;
167 info.color = tuple_to_vec3(color_);
168 info.width = width;
169
170 return scene->lines(info);
171 }
172
173 void mesh(FieldInfo vbo,
174 bool has_per_vertex_color,
175 FieldInfo indices,
176 py::tuple color,
177 bool two_sided,
178 float draw_index_count,
179 float draw_first_index,
180 float draw_vertex_count,
181 float draw_first_vertex,
182 bool show_wireframe) {
183 RenderableInfo renderable_info;
184 renderable_info.vbo = vbo;
185 renderable_info.has_per_vertex_color = has_per_vertex_color;
186 renderable_info.indices = indices;
187 renderable_info.has_user_customized_draw = true;
188 renderable_info.draw_index_count = (int)draw_index_count;
189 renderable_info.draw_first_index = (int)draw_first_index;
190 renderable_info.draw_vertex_count = (int)draw_vertex_count;
191 renderable_info.draw_first_vertex = (int)draw_first_vertex;
192 renderable_info.display_mode = show_wireframe
193 ? taichi::lang::PolygonMode::Line
194 : taichi::lang::PolygonMode::Fill;
195
196 MeshInfo info;
197 info.renderable_info = renderable_info;
198 info.color = tuple_to_vec3(color);
199 info.two_sided = two_sided;
200
201 scene->mesh(info);
202 }
203
204 void particles(FieldInfo vbo,
205 bool has_per_vertex_color,
206 py::tuple color_,
207 float radius,
208 float draw_vertex_count,
209 float draw_first_vertex) {
210 RenderableInfo renderable_info;
211 renderable_info.vbo = vbo;
212 renderable_info.has_user_customized_draw = true;
213 renderable_info.has_per_vertex_color = has_per_vertex_color;
214 renderable_info.draw_vertex_count = (int)draw_vertex_count;
215 renderable_info.draw_first_vertex = (int)draw_first_vertex;
216
217 ParticlesInfo info;
218 info.renderable_info = renderable_info;
219 info.color = tuple_to_vec3(color_);
220 info.radius = radius;
221
222 scene->particles(info);
223 }
224
225 void mesh_instance(FieldInfo vbo,
226 bool has_per_vertex_color,
227 FieldInfo indices,
228 py::tuple color,
229 bool two_sided,
230 FieldInfo transforms,
231 float draw_instance_count,
232 float draw_first_instance,
233 float draw_index_count,
234 float draw_first_index,
235 float draw_vertex_count,
236 float draw_first_vertex,
237 bool show_wireframe) {
238 RenderableInfo renderable_info;
239 renderable_info.vbo = vbo;
240 renderable_info.has_per_vertex_color = has_per_vertex_color;
241 renderable_info.indices = indices;
242 renderable_info.has_user_customized_draw = true;
243 renderable_info.draw_index_count = (int)draw_index_count;
244 renderable_info.draw_first_index = (int)draw_first_index;
245 renderable_info.draw_vertex_count = (int)draw_vertex_count;
246 renderable_info.draw_first_vertex = (int)draw_first_vertex;
247 renderable_info.display_mode = show_wireframe
248 ? taichi::lang::PolygonMode::Line
249 : taichi::lang::PolygonMode::Fill;
250
251 MeshInfo info;
252 info.renderable_info = renderable_info;
253 info.color = tuple_to_vec3(color);
254 info.two_sided = two_sided;
255 if (transforms.valid) {
256 info.start_instance = (int)draw_first_instance;
257 info.num_instances =
258 (draw_instance_count + info.start_instance) > transforms.shape[0]
259 ? (transforms.shape[0] - info.start_instance)
260 : (int)draw_instance_count;
261 }
262 info.mesh_attribute_info.mesh_attribute = transforms;
263 info.mesh_attribute_info.has_attribute = transforms.valid;
264
265 scene->mesh(info);
266 }
267
268 void point_light(py::tuple pos_, py::tuple color_) {
269 glm::vec3 pos = tuple_to_vec3(pos_);
270 glm::vec3 color = tuple_to_vec3(color_);
271 scene->point_light(pos, color);
272 }
273
274 void ambient_light(py::tuple color_) {
275 glm::vec3 color = tuple_to_vec3(color_);
276 scene->ambient_light(color);
277 }
278
279 ~PyScene() {
280 delete scene;
281 }
282};
283
284struct PyCanvas {
285 CanvasBase *canvas; // not owned
286
287 void set_background_color(py::tuple color_) {
288 glm::vec3 color = tuple_to_vec3(color_);
289 return canvas->set_background_color(color);
290 }
291
292 void set_image(FieldInfo img) {
293 canvas->set_image({img});
294 }
295
296 void set_image_texture(Texture *texture) {
297 canvas->set_image(texture);
298 }
299
300 void scene(PyScene &scene) {
301 canvas->scene(scene.scene);
302 }
303
304 void triangles(FieldInfo vbo,
305 FieldInfo indices,
306 bool has_per_vertex_color,
307 py::tuple color_) {
308 RenderableInfo renderable_info;
309 renderable_info.vbo = vbo;
310 renderable_info.indices = indices;
311 renderable_info.has_per_vertex_color = has_per_vertex_color;
312
313 TrianglesInfo info;
314 info.renderable_info = renderable_info;
315 info.color = tuple_to_vec3(color_);
316
317 return canvas->triangles(info);
318 }
319
320 void lines(FieldInfo vbo,
321 FieldInfo indices,
322 bool has_per_vertex_color,
323 py::tuple color_,
324 float width) {
325 RenderableInfo renderable_info;
326 renderable_info.vbo = vbo;
327 renderable_info.indices = indices;
328 renderable_info.has_per_vertex_color = has_per_vertex_color;
329
330 LinesInfo info;
331 info.renderable_info = renderable_info;
332 info.color = tuple_to_vec3(color_);
333 info.width = width;
334
335 return canvas->lines(info);
336 }
337
338 void circles(FieldInfo vbo,
339 bool has_per_vertex_color,
340 py::tuple color_,
341 float radius) {
342 RenderableInfo renderable_info;
343 renderable_info.vbo = vbo;
344 renderable_info.has_per_vertex_color = has_per_vertex_color;
345
346 CirclesInfo info;
347 info.renderable_info = renderable_info;
348 info.color = tuple_to_vec3(color_);
349 info.radius = radius;
350
351 return canvas->circles(info);
352 }
353};
354
355struct PyWindow {
356 std::unique_ptr<WindowBase> window{nullptr};
357
358 PyWindow(Program *prog,
359 std::string name,
360 py::tuple res,
361 py::tuple pos,
362 bool vsync,
363 bool show_window,
364 std::string package_path,
365 Arch ti_arch) {
366 AppConfig config = {name,
367 res[0].cast<int>(),
368 res[1].cast<int>(),
369 pos[0].cast<int>(),
370 pos[1].cast<int>(),
371 vsync,
372 show_window,
373 package_path,
374 ti_arch};
375 if (!lang::vulkan::is_vulkan_api_available()) {
376 throw std::runtime_error("Vulkan must be available for GGUI");
377 }
378 window = std::make_unique<vulkan::Window>(prog, config);
379 }
380
381 py::tuple get_window_shape() {
382 auto [w, h] = window->get_window_shape();
383 return pybind11::make_tuple(w, h);
384 }
385
386 void write_image(const std::string &filename) {
387 window->write_image(filename);
388 }
389
390 void copy_depth_buffer_to_ndarray(Ndarray *depth_arr) {
391 window->copy_depth_buffer_to_ndarray(*depth_arr);
392 }
393
394 py::array_t<float> get_image_buffer() {
395 uint32_t w, h;
396 auto &img_buffer = window->get_image_buffer(w, h);
397
398 float *image = new float[w * h * 4];
399 // Here we must match the numpy 3d array memory layout. Refs:
400 // https://numpy.org/doc/stable/reference/arrays.ndarray.html
401 for (int i = 0; i < w; i++) {
402 for (int j = 0; j < h; j++) {
403 auto pixel = img_buffer[j * w + i];
404 for (int k = 0; k < 4; k++) {
405 // must flip up-down to match the numpy array memory layout
406 image[i * h * 4 + (h - j - 1) * 4 + k] = (pixel & 0xFF) / 255.0;
407 pixel >>= 8;
408 }
409 }
410 }
411 // Here we must pass a deconstructor to free the memory in python scope.
412 // Refs:
413 // https://stackoverflow.com/questions/44659924/returning-numpy-arrays-via-pybind11
414 py::capsule free_imgae(image, [](void *tmp) {
415 float *image = reinterpret_cast<float *>(tmp);
416 delete[] image;
417 });
418
419 return py::array_t<float>(
420 py::detail::any_container<ssize_t>({w, h, 4}),
421 py::detail::any_container<ssize_t>(
422 {sizeof(float) * h * 4, sizeof(float) * 4, sizeof(float)}),
423 image, free_imgae);
424 }
425
426 void show() {
427 window->show();
428 }
429
430 bool is_pressed(std::string button) {
431 return window->is_pressed(button);
432 }
433
434 bool is_running() {
435 return window->is_running();
436 }
437
438 void set_is_running(bool value) {
439 return window->set_is_running(value);
440 }
441
442 py::list get_events(EventType tag) {
443 return py::cast(window->get_events(tag));
444 }
445
446 bool get_event(EventType e) {
447 return window->get_event(e);
448 }
449
450 Event get_current_event() {
451 return window->get_current_event();
452 }
453 void set_current_event(const Event &event) {
454 window->set_current_event(event);
455 }
456
457 PyCanvas get_canvas() {
458 PyCanvas canvas = {window->get_canvas()};
459 return canvas;
460 }
461
462 PyGui gui() {
463 PyGui gui = {window->gui()};
464 return gui;
465 }
466
467 // this is so that the GUI class does not need to use any pybind related stuff
468 py::tuple py_get_cursor_pos() {
469 auto pos = window->get_cursor_pos();
470 float x = std::get<0>(pos);
471 float y = std::get<1>(pos);
472 return py::make_tuple(x, y);
473 }
474
475 void destroy() {
476 if (window) {
477 window.reset();
478 }
479 }
480};
481
482void export_ggui(py::module &m) {
483 m.attr("GGUI_AVAILABLE") = py::bool_(true);
484
485 py::class_<PyWindow>(m, "PyWindow")
486 .def(py::init<Program *, std::string, py::tuple, py::tuple, bool, bool,
487 std::string, Arch>())
488 .def("get_canvas", &PyWindow::get_canvas)
489 .def("show", &PyWindow::show)
490 .def("get_window_shape", &PyWindow::get_window_shape)
491 .def("write_image", &PyWindow::write_image)
492 .def("copy_depth_buffer_to_ndarray",
493 &PyWindow::copy_depth_buffer_to_ndarray)
494 .def("get_image_buffer_as_numpy", &PyWindow::get_image_buffer)
495 .def("is_pressed", &PyWindow::is_pressed)
496 .def("get_cursor_pos", &PyWindow::py_get_cursor_pos)
497 .def("is_running", &PyWindow::is_running)
498 .def("set_is_running", &PyWindow::set_is_running)
499 .def("get_event", &PyWindow::get_event)
500 .def("get_events", &PyWindow::get_events)
501 .def("get_current_event", &PyWindow::get_current_event)
502 .def("set_current_event", &PyWindow::set_current_event)
503 .def("destroy", &PyWindow::destroy)
504 .def("GUI", &PyWindow::gui);
505
506 py::class_<PyCanvas>(m, "PyCanvas")
507 .def("set_background_color", &PyCanvas::set_background_color)
508 .def("set_image", &PyCanvas::set_image)
509 .def("set_image_texture", &PyCanvas::set_image_texture)
510 .def("triangles", &PyCanvas::triangles)
511 .def("lines", &PyCanvas::lines)
512 .def("circles", &PyCanvas::circles)
513 .def("scene", &PyCanvas::scene);
514
515 py::class_<PyGui>(m, "PyGui")
516 .def("begin", &PyGui::begin)
517 .def("end", &PyGui::end)
518 .def("text", &PyGui::text)
519 .def("text_colored", &PyGui::text_colored)
520 .def("checkbox", &PyGui::checkbox)
521 .def("slider_int", &PyGui::slider_int)
522 .def("slider_float", &PyGui::slider_float)
523 .def("color_edit_3", &PyGui::color_edit_3)
524 .def("button", &PyGui::button);
525
526 py::class_<PyScene>(m, "PyScene")
527 .def(py::init<>())
528 .def("set_camera", &PyScene::set_camera)
529 .def("lines", &PyScene::lines)
530 .def("mesh", &PyScene::mesh)
531 .def("particles", &PyScene::particles)
532 .def("mesh_instance", &PyScene::mesh_instance)
533 .def("point_light", &PyScene::point_light)
534 .def("ambient_light", &PyScene::ambient_light);
535
536 py::class_<PyCamera>(m, "PyCamera")
537 .def(py::init<>())
538 .def("lookat", &PyCamera::lookat)
539 .def("position", &PyCamera::position)
540 .def("up", &PyCamera::up)
541 .def("projection_mode", &PyCamera::projection_mode)
542 .def("fov", &PyCamera::fov)
543 .def("left", &PyCamera::left)
544 .def("right", &PyCamera::right)
545 .def("top", &PyCamera::top)
546 .def("bottom", &PyCamera::bottom)
547 .def("z_near", &PyCamera::z_near)
548 .def("z_far", &PyCamera::z_far)
549 .def("get_view_matrix", &PyCamera::get_view_matrix)
550 .def("get_projection_matrix", &PyCamera::get_projection_matrix);
551
552 py::class_<Event>(m, "Event")
553 .def_property("key", &Event::get_key, &Event::set_key);
554
555 py::class_<FieldInfo>(m, "FieldInfo")
556 .def(py::init<>())
557 .def_property("valid", &FieldInfo::get_valid, &FieldInfo::set_valid)
558 .def_property("num_elements", &FieldInfo::get_num_elements,
559 &FieldInfo::set_num_elements)
560 .def_property("shape", &FieldInfo::get_shape, &FieldInfo::set_shape)
561 .def_property("field_source", &FieldInfo::get_field_source,
562 &FieldInfo::set_field_source)
563 .def_property("dtype", &FieldInfo::get_dtype, &FieldInfo::set_dtype)
564 .def_property("dev_alloc", &FieldInfo::get_dev_alloc,
565 &FieldInfo::set_dev_alloc);
566
567 py::enum_<EventType>(m, "EventType")
568 .value("Any", EventType::Any)
569 .value("Press", EventType::Press)
570 .value("Release", EventType::Release)
571 .export_values();
572
573 py::enum_<FieldSource>(m, "FieldSource")
574 .value("TaichiNDarray", FieldSource::TaichiNDarray)
575 .value("HostMappedPtr", FieldSource::HostMappedPtr)
576 .export_values();
577
578 py::enum_<ProjectionMode>(m, "ProjectionMode")
579 .value("Perspective", ProjectionMode::Perspective)
580 .value("Orthogonal", ProjectionMode::Orthogonal)
581 .export_values();
582
583 py::enum_<taichi::lang::PolygonMode>(m, "DisplayMode")
584 .value("Fill", taichi::lang::PolygonMode::Fill)
585 .value("Line", taichi::lang::PolygonMode::Line)
586 .value("Point", taichi::lang::PolygonMode::Point)
587 .export_values();
588}
589
590} // namespace taichi::ui
591
592namespace taichi {
593
594void export_ggui(py::module &m) {
595 ui::export_ggui(m);
596}
597
598} // namespace taichi
599
600#else
601
602namespace taichi {
603
604void export_ggui(py::module &m) {
605 m.attr("GGUI_AVAILABLE") = py::bool_(false);
606}
607
608} // namespace taichi
609
610#endif
611