1// Bindings for the python frontend
2
3#include <optional>
4#include <string>
5#include "taichi/ir/snode.h"
6
7#if TI_WITH_LLVM
8#include "llvm/Config/llvm-config.h"
9#endif
10
11#include "pybind11/functional.h"
12#include "pybind11/pybind11.h"
13#include "pybind11/eigen.h"
14#include "pybind11/numpy.h"
15
16#include "taichi/ir/expression_ops.h"
17#include "taichi/ir/frontend_ir.h"
18#include "taichi/ir/statements.h"
19#include "taichi/program/graph_builder.h"
20#include "taichi/program/extension.h"
21#include "taichi/program/ndarray.h"
22#include "taichi/python/export.h"
23#include "taichi/math/svd.h"
24#include "taichi/util/action_recorder.h"
25#include "taichi/system/timeline.h"
26#include "taichi/python/snode_registry.h"
27#include "taichi/program/sparse_matrix.h"
28#include "taichi/program/sparse_solver.h"
29#include "taichi/aot/graph_data.h"
30#include "taichi/ir/mesh.h"
31
32#include "taichi/program/kernel_profiler.h"
33
34#if defined(TI_WITH_CUDA)
35#include "taichi/rhi/cuda/cuda_context.h"
36#endif
37
38namespace taichi {
39bool test_threading();
40
41} // namespace taichi
42
43namespace taichi::lang {
44
45std::string libdevice_path();
46
47} // namespace taichi::lang
48
49namespace taichi {
50void export_lang(py::module &m) {
51 using namespace taichi::lang;
52 using namespace std::placeholders;
53
54 py::register_exception<TaichiTypeError>(m, "TaichiTypeError",
55 PyExc_TypeError);
56 py::register_exception<TaichiSyntaxError>(m, "TaichiSyntaxError",
57 PyExc_SyntaxError);
58 py::register_exception<TaichiIndexError>(m, "TaichiIndexError",
59 PyExc_IndexError);
60 py::register_exception<TaichiRuntimeError>(m, "TaichiRuntimeError",
61 PyExc_RuntimeError);
62 py::register_exception<TaichiAssertionError>(m, "TaichiAssertionError",
63 PyExc_AssertionError);
64 py::enum_<Arch>(m, "Arch", py::arithmetic())
65#define PER_ARCH(x) .value(#x, Arch::x)
66#include "taichi/inc/archs.inc.h"
67#undef PER_ARCH
68 .export_values();
69
70 m.def("arch_name", arch_name);
71 m.def("arch_from_name", arch_from_name);
72
73 py::enum_<SNodeType>(m, "SNodeType", py::arithmetic())
74#define PER_SNODE(x) .value(#x, SNodeType::x)
75#include "taichi/inc/snodes.inc.h"
76#undef PER_SNODE
77 .export_values();
78
79 py::enum_<Extension>(m, "Extension", py::arithmetic())
80#define PER_EXTENSION(x) .value(#x, Extension::x)
81#include "taichi/inc/extensions.inc.h"
82#undef PER_EXTENSION
83 .export_values();
84
85 py::enum_<ExternalArrayLayout>(m, "Layout", py::arithmetic())
86 .value("AOS", ExternalArrayLayout::kAOS)
87 .value("SOA", ExternalArrayLayout::kSOA)
88 .value("NULL", ExternalArrayLayout::kNull)
89 .export_values();
90
91 py::enum_<AutodiffMode>(m, "AutodiffMode", py::arithmetic())
92 .value("NONE", AutodiffMode::kNone)
93 .value("VALIDATION", AutodiffMode::kCheckAutodiffValid)
94 .value("FORWARD", AutodiffMode::kForward)
95 .value("REVERSE", AutodiffMode::kReverse)
96 .export_values();
97
98 py::enum_<SNodeGradType>(m, "SNodeGradType", py::arithmetic())
99 .value("PRIMAL", SNodeGradType::kPrimal)
100 .value("ADJOINT", SNodeGradType::kAdjoint)
101 .value("DUAL", SNodeGradType::kDual)
102 .value("ADJOINT_CHECKBIT", SNodeGradType::kAdjointCheckbit)
103 .export_values();
104
105 // TODO(type): This should be removed
106 py::class_<DataType>(m, "DataType")
107 .def(py::init<Type *>())
108 .def(py::self == py::self)
109 .def("__hash__", &DataType::hash)
110 .def("to_string", &DataType::to_string)
111 .def("__str__", &DataType::to_string)
112 .def("shape", &DataType::get_shape)
113 .def("element_type", &DataType::get_element_type)
114 .def(
115 "get_ptr", [](DataType *dtype) -> Type * { return *dtype; },
116 py::return_value_policy::reference)
117 .def(py::pickle(
118 [](const DataType &dt) {
119 // Note: this only works for primitive types, which is fine for now.
120 auto primitive =
121 dynamic_cast<const PrimitiveType *>((const Type *)dt);
122 TI_ASSERT(primitive);
123 return py::make_tuple((std::size_t)primitive->type);
124 },
125 [](py::tuple t) {
126 if (t.size() != 1)
127 throw std::runtime_error("Invalid state!");
128
129 DataType dt =
130 PrimitiveType::get((PrimitiveTypeID)(t[0].cast<std::size_t>()));
131
132 return dt;
133 }));
134
135 py::class_<CompileConfig>(m, "CompileConfig")
136 .def(py::init<>())
137 .def_readwrite("arch", &CompileConfig::arch)
138 .def_readwrite("opt_level", &CompileConfig::opt_level)
139 .def_readwrite("print_ir", &CompileConfig::print_ir)
140 .def_readwrite("print_preprocessed_ir",
141 &CompileConfig::print_preprocessed_ir)
142 .def_readwrite("debug", &CompileConfig::debug)
143 .def_readwrite("cfg_optimization", &CompileConfig::cfg_optimization)
144 .def_readwrite("check_out_of_bound", &CompileConfig::check_out_of_bound)
145 .def_readwrite("print_accessor_ir", &CompileConfig::print_accessor_ir)
146 .def_readwrite("print_evaluator_ir", &CompileConfig::print_evaluator_ir)
147 .def_readwrite("use_llvm", &CompileConfig::use_llvm)
148 .def_readwrite("print_struct_llvm_ir",
149 &CompileConfig::print_struct_llvm_ir)
150 .def_readwrite("print_kernel_llvm_ir",
151 &CompileConfig::print_kernel_llvm_ir)
152 .def_readwrite("print_kernel_llvm_ir_optimized",
153 &CompileConfig::print_kernel_llvm_ir_optimized)
154 .def_readwrite("print_kernel_nvptx", &CompileConfig::print_kernel_nvptx)
155 .def_readwrite("simplify_before_lower_access",
156 &CompileConfig::simplify_before_lower_access)
157 .def_readwrite("simplify_after_lower_access",
158 &CompileConfig::simplify_after_lower_access)
159 .def_readwrite("lower_access", &CompileConfig::lower_access)
160 .def_readwrite("move_loop_invariant_outside_if",
161 &CompileConfig::move_loop_invariant_outside_if)
162 .def_readwrite("cache_loop_invariant_global_vars",
163 &CompileConfig::cache_loop_invariant_global_vars)
164 .def_readwrite("default_cpu_block_dim",
165 &CompileConfig::default_cpu_block_dim)
166 .def_readwrite("cpu_block_dim_adaptive",
167 &CompileConfig::cpu_block_dim_adaptive)
168 .def_readwrite("default_gpu_block_dim",
169 &CompileConfig::default_gpu_block_dim)
170 .def_readwrite("gpu_max_reg", &CompileConfig::gpu_max_reg)
171 .def_readwrite("saturating_grid_dim", &CompileConfig::saturating_grid_dim)
172 .def_readwrite("max_block_dim", &CompileConfig::max_block_dim)
173 .def_readwrite("cpu_max_num_threads", &CompileConfig::cpu_max_num_threads)
174 .def_readwrite("random_seed", &CompileConfig::random_seed)
175 .def_readwrite("verbose_kernel_launches",
176 &CompileConfig::verbose_kernel_launches)
177 .def_readwrite("verbose", &CompileConfig::verbose)
178 .def_readwrite("demote_dense_struct_fors",
179 &CompileConfig::demote_dense_struct_fors)
180 .def_readwrite("kernel_profiler", &CompileConfig::kernel_profiler)
181 .def_readwrite("timeline", &CompileConfig::timeline)
182 .def_readwrite("default_fp", &CompileConfig::default_fp)
183 .def_readwrite("default_ip", &CompileConfig::default_ip)
184 .def_readwrite("default_up", &CompileConfig::default_up)
185 .def_readwrite("device_memory_GB", &CompileConfig::device_memory_GB)
186 .def_readwrite("device_memory_fraction",
187 &CompileConfig::device_memory_fraction)
188 .def_readwrite("fast_math", &CompileConfig::fast_math)
189 .def_readwrite("advanced_optimization",
190 &CompileConfig::advanced_optimization)
191 .def_readwrite("ad_stack_size", &CompileConfig::ad_stack_size)
192 .def_readwrite("flatten_if", &CompileConfig::flatten_if)
193 .def_readwrite("make_thread_local", &CompileConfig::make_thread_local)
194 .def_readwrite("make_block_local", &CompileConfig::make_block_local)
195 .def_readwrite("detect_read_only", &CompileConfig::detect_read_only)
196 .def_readwrite("ndarray_use_cached_allocator",
197 &CompileConfig::ndarray_use_cached_allocator)
198 .def_readwrite("real_matrix_scalarize",
199 &CompileConfig::real_matrix_scalarize)
200 .def_readwrite("cc_compile_cmd", &CompileConfig::cc_compile_cmd)
201 .def_readwrite("cc_link_cmd", &CompileConfig::cc_link_cmd)
202 .def_readwrite("quant_opt_store_fusion",
203 &CompileConfig::quant_opt_store_fusion)
204 .def_readwrite("quant_opt_atomic_demotion",
205 &CompileConfig::quant_opt_atomic_demotion)
206 .def_readwrite("allow_nv_shader_extension",
207 &CompileConfig::allow_nv_shader_extension)
208 .def_readwrite("make_mesh_block_local",
209 &CompileConfig::make_mesh_block_local)
210 .def_readwrite("mesh_localize_to_end_mapping",
211 &CompileConfig::mesh_localize_to_end_mapping)
212 .def_readwrite("mesh_localize_from_end_mapping",
213 &CompileConfig::mesh_localize_from_end_mapping)
214 .def_readwrite("optimize_mesh_reordered_mapping",
215 &CompileConfig::optimize_mesh_reordered_mapping)
216 .def_readwrite("mesh_localize_all_attr_mappings",
217 &CompileConfig::mesh_localize_all_attr_mappings)
218 .def_readwrite("demote_no_access_mesh_fors",
219 &CompileConfig::demote_no_access_mesh_fors)
220 .def_readwrite("experimental_auto_mesh_local",
221 &CompileConfig::experimental_auto_mesh_local)
222 .def_readwrite("auto_mesh_local_default_occupacy",
223 &CompileConfig::auto_mesh_local_default_occupacy)
224 .def_readwrite("offline_cache", &CompileConfig::offline_cache)
225 .def_readwrite("offline_cache_file_path",
226 &CompileConfig::offline_cache_file_path)
227 .def_readwrite("offline_cache_cleaning_policy",
228 &CompileConfig::offline_cache_cleaning_policy)
229 .def_readwrite("offline_cache_max_size_of_files",
230 &CompileConfig::offline_cache_max_size_of_files)
231 .def_readwrite("offline_cache_cleaning_factor",
232 &CompileConfig::offline_cache_cleaning_factor)
233 .def_readwrite("num_compile_threads", &CompileConfig::num_compile_threads)
234 .def_readwrite("vk_api_version", &CompileConfig::vk_api_version)
235 .def_readwrite("cuda_stack_limit", &CompileConfig::cuda_stack_limit);
236
237 m.def("reset_default_compile_config",
238 [&]() { default_compile_config = CompileConfig(); });
239
240 m.def(
241 "default_compile_config",
242 [&]() -> CompileConfig & { return default_compile_config; },
243 py::return_value_policy::reference);
244
245 py::class_<Program::KernelProfilerQueryResult>(m, "KernelProfilerQueryResult")
246 .def_readwrite("counter", &Program::KernelProfilerQueryResult::counter)
247 .def_readwrite("min", &Program::KernelProfilerQueryResult::min)
248 .def_readwrite("max", &Program::KernelProfilerQueryResult::max)
249 .def_readwrite("avg", &Program::KernelProfilerQueryResult::avg);
250
251 py::class_<KernelProfileTracedRecord>(m, "KernelProfileTracedRecord")
252 .def_readwrite("register_per_thread",
253 &KernelProfileTracedRecord::register_per_thread)
254 .def_readwrite("shared_mem_per_block",
255 &KernelProfileTracedRecord::shared_mem_per_block)
256 .def_readwrite("grid_size", &KernelProfileTracedRecord::grid_size)
257 .def_readwrite("block_size", &KernelProfileTracedRecord::block_size)
258 .def_readwrite(
259 "active_blocks_per_multiprocessor",
260 &KernelProfileTracedRecord::active_blocks_per_multiprocessor)
261 .def_readwrite("kernel_time",
262 &KernelProfileTracedRecord::kernel_elapsed_time_in_ms)
263 .def_readwrite("base_time", &KernelProfileTracedRecord::time_since_base)
264 .def_readwrite("name", &KernelProfileTracedRecord::name)
265 .def_readwrite("metric_values",
266 &KernelProfileTracedRecord::metric_values);
267
268 py::enum_<SNodeAccessFlag>(m, "SNodeAccessFlag", py::arithmetic())
269 .value("block_local", SNodeAccessFlag::block_local)
270 .value("read_only", SNodeAccessFlag::read_only)
271 .value("mesh_local", SNodeAccessFlag::mesh_local)
272 .export_values();
273
274 // Export ASTBuilder
275 py::class_<ASTBuilder>(m, "ASTBuilder")
276 .def("make_id_expr", &ASTBuilder::make_id_expr)
277 .def("create_kernel_exprgroup_return",
278 &ASTBuilder::create_kernel_exprgroup_return)
279 .def("create_print", &ASTBuilder::create_print)
280 .def("begin_func", &ASTBuilder::begin_func)
281 .def("end_func", &ASTBuilder::end_func)
282 .def("stop_grad", &ASTBuilder::stop_gradient)
283 .def("begin_frontend_if", &ASTBuilder::begin_frontend_if)
284 .def("begin_frontend_if_true", &ASTBuilder::begin_frontend_if_true)
285 .def("pop_scope", &ASTBuilder::pop_scope)
286 .def("begin_frontend_if_false", &ASTBuilder::begin_frontend_if_false)
287 .def("insert_deactivate", &ASTBuilder::insert_snode_deactivate)
288 .def("insert_activate", &ASTBuilder::insert_snode_activate)
289 .def("expr_snode_get_addr", &ASTBuilder::snode_get_addr)
290 .def("expr_snode_append", &ASTBuilder::snode_append)
291 .def("expr_snode_is_active", &ASTBuilder::snode_is_active)
292 .def("expr_snode_length", &ASTBuilder::snode_length)
293 .def("insert_external_func_call", &ASTBuilder::insert_external_func_call)
294 .def("make_matrix_expr", &ASTBuilder::make_matrix_expr)
295 .def("expr_alloca", &ASTBuilder::expr_alloca)
296 .def("expr_alloca_shared_array", &ASTBuilder::expr_alloca_shared_array)
297 .def("create_assert_stmt", &ASTBuilder::create_assert_stmt)
298 .def("expr_assign", &ASTBuilder::expr_assign)
299 .def("begin_frontend_range_for", &ASTBuilder::begin_frontend_range_for)
300 .def("end_frontend_range_for", &ASTBuilder::pop_scope)
301 .def("begin_frontend_struct_for_on_snode",
302 &ASTBuilder::begin_frontend_struct_for_on_snode)
303 .def("begin_frontend_struct_for_on_external_tensor",
304 &ASTBuilder::begin_frontend_struct_for_on_external_tensor)
305 .def("end_frontend_struct_for", &ASTBuilder::pop_scope)
306 .def("begin_frontend_mesh_for", &ASTBuilder::begin_frontend_mesh_for)
307 .def("end_frontend_mesh_for", &ASTBuilder::pop_scope)
308 .def("begin_frontend_while", &ASTBuilder::begin_frontend_while)
309 .def("insert_break_stmt", &ASTBuilder::insert_break_stmt)
310 .def("insert_continue_stmt", &ASTBuilder::insert_continue_stmt)
311 .def("insert_expr_stmt", &ASTBuilder::insert_expr_stmt)
312 .def("insert_thread_idx_expr", &ASTBuilder::insert_thread_idx_expr)
313 .def("insert_patch_idx_expr", &ASTBuilder::insert_patch_idx_expr)
314 .def("make_texture_op_expr", &ASTBuilder::make_texture_op_expr)
315 .def("expand_exprs", &ASTBuilder::expand_exprs)
316 .def("mesh_index_conversion", &ASTBuilder::mesh_index_conversion)
317 .def("expr_subscript", &ASTBuilder::expr_subscript)
318 .def("insert_func_call", &ASTBuilder::insert_func_call)
319 .def("sifakis_svd_f32", sifakis_svd_export<float32, int32>)
320 .def("sifakis_svd_f64", sifakis_svd_export<float64, int64>)
321 .def("expr_var", &ASTBuilder::make_var)
322 .def("bit_vectorize", &ASTBuilder::bit_vectorize)
323 .def("parallelize", &ASTBuilder::parallelize)
324 .def("strictly_serialize", &ASTBuilder::strictly_serialize)
325 .def("block_dim", &ASTBuilder::block_dim)
326 .def("insert_snode_access_flag", &ASTBuilder::insert_snode_access_flag)
327 .def("reset_snode_access_flag", &ASTBuilder::reset_snode_access_flag);
328
329 py::class_<Program>(m, "Program")
330 .def(py::init<>())
331 .def("config", &Program::compile_config,
332 py::return_value_policy::reference)
333 .def("sync_kernel_profiler",
334 [](Program *program) { program->profiler->sync(); })
335 .def("update_kernel_profiler",
336 [](Program *program) { program->profiler->update(); })
337 .def("clear_kernel_profiler",
338 [](Program *program) { program->profiler->clear(); })
339 .def("query_kernel_profile_info",
340 [](Program *program, const std::string &name) {
341 return program->query_kernel_profile_info(name);
342 })
343 .def("get_kernel_profiler_records",
344 [](Program *program) {
345 return program->profiler->get_traced_records();
346 })
347 .def(
348 "get_kernel_profiler_device_name",
349 [](Program *program) { return program->profiler->get_device_name(); })
350 .def("reinit_kernel_profiler_with_metrics",
351 [](Program *program, const std::vector<std::string> metrics) {
352 return program->profiler->reinit_with_metrics(metrics);
353 })
354 .def("kernel_profiler_total_time",
355 [](Program *program) { return program->profiler->get_total_time(); })
356 .def("set_kernel_profiler_toolkit",
357 [](Program *program, const std::string toolkit_name) {
358 return program->profiler->set_profiler_toolkit(toolkit_name);
359 })
360 .def("timeline_clear",
361 [](Program *) { Timelines::get_instance().clear(); })
362 .def("timeline_save",
363 [](Program *, const std::string &fn) {
364 Timelines::get_instance().save(fn);
365 })
366 .def("print_memory_profiler_info", &Program::print_memory_profiler_info)
367 .def("finalize", &Program::finalize)
368 .def("get_total_compilation_time", &Program::get_total_compilation_time)
369 .def("get_snode_num_dynamically_allocated",
370 &Program::get_snode_num_dynamically_allocated)
371 .def("synchronize", &Program::synchronize)
372 .def("materialize_runtime", &Program::materialize_runtime)
373 .def("make_aot_module_builder", &Program::make_aot_module_builder)
374 .def("get_snode_tree_size", &Program::get_snode_tree_size)
375 .def("get_snode_root", &Program::get_snode_root,
376 py::return_value_policy::reference)
377 .def(
378 "create_kernel",
379 [](Program *program, const std::function<void(Kernel *)> &body,
380 const std::string &name, AutodiffMode autodiff_mode) -> Kernel * {
381 py::gil_scoped_release release;
382 return &program->kernel(body, name, autodiff_mode);
383 },
384 py::return_value_policy::reference)
385 .def("create_function", &Program::create_function,
386 py::return_value_policy::reference)
387 .def("create_sparse_matrix_builder",
388 [](Program *program, int n, int m, uint64 max_num_entries,
389 DataType dtype, const std::string &storage_format) {
390 TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) &&
391 !arch_is_cuda(program->compile_config().arch),
392 "SparseMatrix only supports CPU and CUDA for now.");
393 return SparseMatrixBuilder(n, m, max_num_entries, dtype,
394 storage_format, program);
395 })
396 .def("create_sparse_matrix",
397 [](Program *program, int n, int m, DataType dtype,
398 std::string storage_format) {
399 TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) &&
400 !arch_is_cuda(program->compile_config().arch),
401 "SparseMatrix only supports CPU and CUDA for now.");
402 if (arch_is_cpu(program->compile_config().arch))
403 return make_sparse_matrix(n, m, dtype, storage_format);
404 else
405 return make_cu_sparse_matrix(n, m, dtype);
406 })
407 .def("make_sparse_matrix_from_ndarray",
408 [](Program *program, SparseMatrix &sm, const Ndarray &ndarray) {
409 TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) &&
410 !arch_is_cuda(program->compile_config().arch),
411 "SparseMatrix only supports CPU and CUDA for now.");
412 return make_sparse_matrix_from_ndarray(program, sm, ndarray);
413 })
414 .def("make_id_expr",
415 [](Program *program, const std::string &name) {
416 return Expr::make<IdExpression>(program->get_next_global_id(name));
417 })
418 .def(
419 "create_ndarray",
420 [&](Program *program, const DataType &dt,
421 const std::vector<int> &shape, ExternalArrayLayout layout,
422 bool zero_fill) -> Ndarray * {
423 return program->create_ndarray(dt, shape, layout, zero_fill);
424 },
425 py::arg("dt"), py::arg("shape"),
426 py::arg("layout") = ExternalArrayLayout::kNull,
427 py::arg("zero_fill") = false, py::return_value_policy::reference)
428 .def("delete_ndarray", &Program::delete_ndarray)
429 .def(
430 "create_texture",
431 [&](Program *program, const DataType &dt, int num_channels,
432 const std::vector<int> &shape) -> Texture * {
433 return program->create_texture(dt, num_channels, shape);
434 },
435 py::arg("dt"), py::arg("num_channels"),
436 py::arg("shape") = py::tuple(), py::return_value_policy::reference)
437 .def("get_ndarray_data_ptr_as_int",
438 [](Program *program, Ndarray *ndarray) {
439 return program->get_ndarray_data_ptr_as_int(ndarray);
440 })
441 .def("fill_float",
442 [](Program *program, Ndarray *ndarray, float val) {
443 program->fill_ndarray_fast_u32(ndarray,
444 reinterpret_cast<uint32_t &>(val));
445 })
446 .def("fill_int",
447 [](Program *program, Ndarray *ndarray, int32_t val) {
448 program->fill_ndarray_fast_u32(ndarray,
449 reinterpret_cast<int32_t &>(val));
450 })
451 .def("fill_uint", [](Program *program, Ndarray *ndarray, uint32_t val) {
452 program->fill_ndarray_fast_u32(ndarray, val);
453 });
454
455 py::class_<AotModuleBuilder>(m, "AotModuleBuilder")
456 .def("add_field", &AotModuleBuilder::add_field)
457 .def("add", &AotModuleBuilder::add)
458 .def("add_kernel_template", &AotModuleBuilder::add_kernel_template)
459 .def("add_graph", &AotModuleBuilder::add_graph)
460 .def("dump", &AotModuleBuilder::dump);
461
462 py::class_<Axis>(m, "Axis").def(py::init<int>());
463 py::class_<SNode>(m, "SNode")
464 .def(py::init<>())
465 .def_readwrite("parent", &SNode::parent)
466 .def_readonly("type", &SNode::type)
467 .def_readonly("id", &SNode::id)
468 .def("dense",
469 (SNode & (SNode::*)(const std::vector<Axis> &,
470 const std::vector<int> &,
471 const std::string &))(&SNode::dense),
472 py::return_value_policy::reference)
473 .def("pointer",
474 (SNode & (SNode::*)(const std::vector<Axis> &,
475 const std::vector<int> &,
476 const std::string &))(&SNode::pointer),
477 py::return_value_policy::reference)
478 .def("hash",
479 (SNode & (SNode::*)(const std::vector<Axis> &,
480 const std::vector<int> &,
481 const std::string &))(&SNode::hash),
482 py::return_value_policy::reference)
483 .def("dynamic", &SNode::dynamic, py::return_value_policy::reference)
484 .def("bitmasked",
485 (SNode & (SNode::*)(const std::vector<Axis> &,
486 const std::vector<int> &,
487 const std::string &))(&SNode::bitmasked),
488 py::return_value_policy::reference)
489 .def("bit_struct", &SNode::bit_struct, py::return_value_policy::reference)
490 .def("quant_array", &SNode::quant_array,
491 py::return_value_policy::reference)
492 .def("place", &SNode::place)
493 .def("data_type", [](SNode *snode) { return snode->dt; })
494 .def("name", [](SNode *snode) { return snode->name; })
495 .def("get_num_ch",
496 [](SNode *snode) -> int { return (int)snode->ch.size(); })
497 .def(
498 "get_ch",
499 [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); },
500 py::return_value_policy::reference)
501 .def("lazy_grad", &SNode::lazy_grad)
502 .def("lazy_dual", &SNode::lazy_dual)
503 .def("allocate_adjoint_checkbit", &SNode::allocate_adjoint_checkbit)
504 .def("read_int", &SNode::read_int)
505 .def("read_uint", &SNode::read_uint)
506 .def("read_float", &SNode::read_float)
507 .def("has_adjoint", &SNode::has_adjoint)
508 .def("has_adjoint_checkbit", &SNode::has_adjoint_checkbit)
509 .def("get_snode_grad_type", &SNode::get_snode_grad_type)
510 .def("has_dual", &SNode::has_dual)
511 .def("is_primal", &SNode::is_primal)
512 .def("is_place", &SNode::is_place)
513 .def("get_expr", &SNode::get_expr)
514 .def("write_int", &SNode::write_int)
515 .def("write_uint", &SNode::write_uint)
516 .def("write_float", &SNode::write_float)
517 .def("get_shape_along_axis", &SNode::shape_along_axis)
518 .def("get_physical_index_position",
519 [](SNode *snode) {
520 return std::vector<int>(
521 snode->physical_index_position,
522 snode->physical_index_position + taichi_max_num_indices);
523 })
524 .def("num_active_indices",
525 [](SNode *snode) { return snode->num_active_indices; })
526 .def_readonly("cell_size_bytes", &SNode::cell_size_bytes)
527 .def_readonly("offset_bytes_in_parent_cell",
528 &SNode::offset_bytes_in_parent_cell);
529
530 py::class_<SNodeTree>(m, "SNodeTree")
531 .def("id", &SNodeTree::id)
532 .def("destroy_snode_tree", [](SNodeTree *snode_tree, Program *program) {
533 program->destroy_snode_tree(snode_tree);
534 });
535
536 py::class_<DeviceAllocation>(m, "DeviceAllocation")
537 .def(py::init([](uint64_t device, uint64_t alloc_id) -> DeviceAllocation {
538 DeviceAllocation alloc;
539 alloc.device = (Device *)device;
540 alloc.alloc_id = (DeviceAllocationId)alloc_id;
541 return alloc;
542 }),
543 py::arg("device"), py::arg("alloc_id"))
544 .def_readonly("device", &DeviceAllocation::device)
545 .def_readonly("alloc_id", &DeviceAllocation::alloc_id);
546
547 py::class_<Ndarray>(m, "Ndarray")
548 .def("device_allocation_ptr", &Ndarray::get_device_allocation_ptr_as_int)
549 .def("device_allocation", &Ndarray::get_device_allocation)
550 .def("element_size", &Ndarray::get_element_size)
551 .def("nelement", &Ndarray::get_nelement)
552 .def("read_int", &Ndarray::read_int)
553 .def("read_uint", &Ndarray::read_uint)
554 .def("read_float", &Ndarray::read_float)
555 .def("write_int", &Ndarray::write_int)
556 .def("write_float", &Ndarray::write_float)
557 .def("total_shape", &Ndarray::total_shape)
558 .def("element_shape", &Ndarray::get_element_shape)
559 .def("element_data_type", &Ndarray::get_element_data_type)
560 .def_readonly("dtype", &Ndarray::dtype)
561 .def_readonly("shape", &Ndarray::shape);
562
563 py::enum_<BufferFormat>(m, "Format")
564#define PER_BUFFER_FORMAT(x) .value(#x, BufferFormat::x)
565#include "taichi/inc/rhi_constants.inc.h"
566#undef PER_EXTENSION
567 ;
568
569 py::class_<Texture>(m, "Texture")
570 .def("device_allocation_ptr", &Texture::get_device_allocation_ptr_as_int)
571 .def("from_ndarray", &Texture::from_ndarray)
572 .def("from_snode", &Texture::from_snode);
573
574 py::enum_<aot::ArgKind>(m, "ArgKind")
575 .value("SCALAR", aot::ArgKind::kScalar)
576 .value("NDARRAY", aot::ArgKind::kNdarray)
577 // Using this MATRIX as Scalar alias, we can move to native matrix type
578 // when supported
579 .value("MATRIX", aot::ArgKind::kMatrix)
580 .value("TEXTURE", aot::ArgKind::kTexture)
581 .value("RWTEXTURE", aot::ArgKind::kRWTexture)
582 .export_values();
583
584 py::class_<aot::Arg>(m, "Arg")
585 .def(py::init<aot::ArgKind, std::string, DataType &, size_t,
586 std::vector<int>>(),
587 py::arg("tag"), py::arg("name"), py::arg("dtype"),
588 py::arg("field_dim"), py::arg("element_shape"))
589 .def(py::init<aot::ArgKind, std::string, DataType &, size_t,
590 std::vector<int>>(),
591 py::arg("tag"), py::arg("name"), py::arg("channel_format"),
592 py::arg("num_channels"), py::arg("shape"))
593 .def_readonly("name", &aot::Arg::name)
594 .def_readonly("element_shape", &aot::Arg::element_shape)
595 .def_readonly("texture_shape", &aot::Arg::element_shape)
596 .def_readonly("field_dim", &aot::Arg::field_dim)
597 .def_readonly("num_channels", &aot::Arg::num_channels)
598 .def("dtype", &aot::Arg::dtype)
599 .def("channel_format", &aot::Arg::dtype);
600
601 py::class_<Node>(m, "Node"); // NOLINT(bugprone-unused-raii)
602
603 py::class_<Sequential, Node>(m, "Sequential")
604 .def(py::init<GraphBuilder *>())
605 .def("append", &Sequential::append)
606 .def("dispatch", &Sequential::dispatch);
607
608 py::class_<GraphBuilder>(m, "GraphBuilder")
609 .def(py::init<>())
610 .def("dispatch", &GraphBuilder::dispatch)
611 .def("compile", &GraphBuilder::compile)
612 .def("create_sequential", &GraphBuilder::new_sequential_node,
613 py::return_value_policy::reference)
614 .def("seq", &GraphBuilder::seq, py::return_value_policy::reference);
615
616 py::class_<aot::CompiledGraph>(m, "CompiledGraph")
617 .def("run", [](aot::CompiledGraph *self, const py::dict &pyargs) {
618 std::unordered_map<std::string, aot::IValue> args;
619 for (auto it : pyargs) {
620 std::string arg_name = py::cast<std::string>(it.first);
621 auto tag = self->args[arg_name].tag;
622 if (tag == aot::ArgKind::kNdarray) {
623 auto &val = it.second.cast<Ndarray &>();
624 args.insert(
625 {py::cast<std::string>(it.first), aot::IValue::create(val)});
626 } else if (tag == aot::ArgKind::kTexture ||
627 tag == aot::ArgKind::kRWTexture) {
628 auto &val = it.second.cast<Texture &>();
629 args.insert(
630 {py::cast<std::string>(it.first), aot::IValue::create(val)});
631
632 } else if (tag == aot::ArgKind::kScalar ||
633 tag == aot::ArgKind::kMatrix) {
634 std::string arg_name = py::cast<std::string>(it.first);
635 auto expected_dtype = self->args[arg_name].dtype();
636 if (expected_dtype == PrimitiveType::i32) {
637 args.insert(
638 {arg_name, aot::IValue::create(py::cast<int>(it.second))});
639 } else if (expected_dtype == PrimitiveType::i64) {
640 args.insert(
641 {arg_name, aot::IValue::create(py::cast<int64>(it.second))});
642 } else if (expected_dtype == PrimitiveType::f32) {
643 args.insert(
644 {arg_name, aot::IValue::create(py::cast<float>(it.second))});
645 } else if (expected_dtype == PrimitiveType::f64) {
646 args.insert(
647 {arg_name, aot::IValue::create(py::cast<double>(it.second))});
648 } else if (expected_dtype == PrimitiveType::i16) {
649 args.insert(
650 {arg_name, aot::IValue::create(py::cast<int16>(it.second))});
651 } else if (expected_dtype == PrimitiveType::u32) {
652 args.insert(
653 {arg_name, aot::IValue::create(py::cast<uint32>(it.second))});
654 } else if (expected_dtype == PrimitiveType::u64) {
655 args.insert(
656 {arg_name, aot::IValue::create(py::cast<uint64>(it.second))});
657 } else if (expected_dtype == PrimitiveType::u16) {
658 args.insert(
659 {arg_name, aot::IValue::create(py::cast<uint16>(it.second))});
660 } else if (expected_dtype == PrimitiveType::u8) {
661 args.insert({arg_name,
662 aot::IValue::create(py::cast<uint8_t>(it.second))});
663 } else if (expected_dtype == PrimitiveType::i8) {
664 args.insert(
665 {arg_name, aot::IValue::create(py::cast<int8_t>(it.second))});
666 } else {
667 TI_NOT_IMPLEMENTED;
668 }
669 } else {
670 TI_NOT_IMPLEMENTED;
671 }
672 }
673 self->run(args);
674 });
675
676 py::class_<Kernel>(m, "Kernel")
677 .def("no_activate",
678 [](Kernel *self, SNode *snode) {
679 // TODO(#2193): Also apply to @ti.func?
680 self->no_activate.push_back(snode);
681 })
682 .def("insert_scalar_param", &Kernel::insert_scalar_param)
683 .def("insert_arr_param", &Kernel::insert_arr_param)
684 .def("insert_texture_param", &Kernel::insert_texture_param)
685 .def("insert_ret", &Kernel::insert_ret)
686 .def("finalize_rets", &Kernel::finalize_rets)
687 .def("get_ret_int", &Kernel::get_ret_int)
688 .def("get_ret_uint", &Kernel::get_ret_uint)
689 .def("get_ret_float", &Kernel::get_ret_float)
690 .def("get_ret_int_tensor", &Kernel::get_ret_int_tensor)
691 .def("get_ret_uint_tensor", &Kernel::get_ret_uint_tensor)
692 .def("get_ret_float_tensor", &Kernel::get_ret_float_tensor)
693 .def("get_struct_ret_int", &Kernel::get_struct_ret_int)
694 .def("get_struct_ret_uint", &Kernel::get_struct_ret_uint)
695 .def("get_struct_ret_float", &Kernel::get_struct_ret_float)
696 .def("make_launch_context", &Kernel::make_launch_context)
697 .def(
698 "ast_builder",
699 [](Kernel *self) -> ASTBuilder * {
700 return &self->context->builder();
701 },
702 py::return_value_policy::reference)
703 .def("__call__",
704 [](Kernel *kernel, Kernel::LaunchContextBuilder &launch_ctx) {
705 py::gil_scoped_release release;
706 kernel->operator()(kernel->program->compile_config(), launch_ctx);
707 });
708
709 py::class_<Kernel::LaunchContextBuilder>(m, "KernelLaunchContext")
710 .def("set_arg_int", &Kernel::LaunchContextBuilder::set_arg_int)
711 .def("set_arg_uint", &Kernel::LaunchContextBuilder::set_arg_uint)
712 .def("set_arg_float", &Kernel::LaunchContextBuilder::set_arg_float)
713 .def("set_arg_external_array_with_shape",
714 &Kernel::LaunchContextBuilder::set_arg_external_array_with_shape)
715 .def("set_arg_ndarray", &Kernel::LaunchContextBuilder::set_arg_ndarray)
716 .def("set_arg_ndarray_with_grad",
717 &Kernel::LaunchContextBuilder::set_arg_ndarray_with_grad)
718 .def("set_arg_texture", &Kernel::LaunchContextBuilder::set_arg_texture)
719 .def("set_arg_rw_texture",
720 &Kernel::LaunchContextBuilder::set_arg_rw_texture)
721 .def("set_extra_arg_int",
722 &Kernel::LaunchContextBuilder::set_extra_arg_int);
723
724 py::class_<Function>(m, "Function")
725 .def("insert_scalar_param", &Function::insert_scalar_param)
726 .def("insert_arr_param", &Function::insert_arr_param)
727 .def("insert_texture_param", &Function::insert_texture_param)
728 .def("insert_ret", &Function::insert_ret)
729 .def("set_function_body",
730 py::overload_cast<const std::function<void()> &>(
731 &Function::set_function_body))
732 .def("finalize_rets", &Function::finalize_rets)
733 .def(
734 "ast_builder",
735 [](Function *self) -> ASTBuilder * {
736 return &self->context->builder();
737 },
738 py::return_value_policy::reference);
739
740 py::class_<Expr> expr(m, "Expr");
741 expr.def("snode", &Expr::snode, py::return_value_policy::reference)
742 .def("is_external_tensor_expr",
743 [](Expr *expr) { return expr->is<ExternalTensorExpression>(); })
744 .def("is_index_expr",
745 [](Expr *expr) { return expr->is<IndexExpression>(); })
746 .def("is_primal",
747 [](Expr *expr) {
748 return expr->cast<FieldExpression>()->snode_grad_type ==
749 SNodeGradType::kPrimal;
750 })
751 .def("set_tb", &Expr::set_tb)
752 .def("set_name",
753 [&](Expr *expr, std::string na) {
754 expr->cast<FieldExpression>()->name = na;
755 })
756 .def("set_grad_type",
757 [&](Expr *expr, SNodeGradType t) {
758 expr->cast<FieldExpression>()->snode_grad_type = t;
759 })
760 .def("set_adjoint", &Expr::set_adjoint)
761 .def("set_adjoint_checkbit", &Expr::set_adjoint_checkbit)
762 .def("set_dual", &Expr::set_dual)
763 .def("set_dynamic_index_stride",
764 [&](Expr *expr, int dynamic_index_stride) {
765 auto matrix_field = expr->cast<MatrixFieldExpression>();
766 matrix_field->dynamic_indexable = true;
767 matrix_field->dynamic_index_stride = dynamic_index_stride;
768 })
769 .def("get_dynamic_indexable",
770 [&](Expr *expr) -> bool {
771 return expr->cast<MatrixFieldExpression>()->dynamic_indexable;
772 })
773 .def("get_dynamic_index_stride",
774 [&](Expr *expr) -> int {
775 return expr->cast<MatrixFieldExpression>()->dynamic_index_stride;
776 })
777 .def(
778 "get_dt",
779 [&](Expr *expr) -> const Type * {
780 return expr->cast<FieldExpression>()->dt;
781 },
782 py::return_value_policy::reference)
783 .def("get_ret_type", &Expr::get_ret_type)
784 .def("is_tensor",
785 [](Expr *expr) { return expr->expr->ret_type->is<TensorType>(); })
786 .def("get_shape",
787 [](Expr *expr) -> std::optional<std::vector<int>> {
788 if (expr->expr->ret_type->is<TensorType>()) {
789 return std::optional<std::vector<int>>(
790 expr->expr->ret_type->cast<TensorType>()->get_shape());
791 }
792 return std::nullopt;
793 })
794 .def("type_check", &Expr::type_check)
795 .def("get_expr_name",
796 [](Expr *expr) { return expr->cast<FieldExpression>()->name; })
797 .def("get_raw_address", [](Expr *expr) { return (uint64)expr; })
798 .def("get_underlying_ptr_address", [](Expr *e) {
799 // The reason that there are both get_raw_address() and
800 // get_underlying_ptr_address() is that Expr itself is mostly wrapper
801 // around its underlying |expr| (of type Expression). Expr |e| can be
802 // temporary, while the underlying |expr| is mostly persistent.
803 //
804 // Same get_raw_address() implies that get_underlying_ptr_address() are
805 // also the same. The reverse is not true.
806 return (uint64)e->expr.get();
807 });
808
809 py::class_<ExprGroup>(m, "ExprGroup")
810 .def(py::init<>())
811 .def("size", [](ExprGroup *eg) { return eg->exprs.size(); })
812 .def("push_back", &ExprGroup::push_back);
813
814 py::class_<Stmt>(m, "Stmt"); // NOLINT(bugprone-unused-raii)
815
816 m.def("insert_internal_func_call",
817 [&](const std::string &func_name, const ExprGroup &args,
818 bool with_runtime_context) {
819 return Expr::make<InternalFuncCallExpression>(func_name, args.exprs,
820 with_runtime_context);
821 });
822
823 m.def("make_get_element_expr",
824 Expr::make<GetElementExpression, const Expr &, std::vector<int>>);
825
826 m.def("value_cast", static_cast<Expr (*)(const Expr &expr, DataType)>(cast));
827 m.def("bits_cast",
828 static_cast<Expr (*)(const Expr &expr, DataType)>(bit_cast));
829
830 m.def("expr_atomic_add", [&](const Expr &a, const Expr &b) {
831 return Expr::make<AtomicOpExpression>(AtomicOpType::add, a, b);
832 });
833
834 m.def("expr_atomic_sub", [&](const Expr &a, const Expr &b) {
835 return Expr::make<AtomicOpExpression>(AtomicOpType::sub, a, b);
836 });
837
838 m.def("expr_atomic_min", [&](const Expr &a, const Expr &b) {
839 return Expr::make<AtomicOpExpression>(AtomicOpType::min, a, b);
840 });
841
842 m.def("expr_atomic_max", [&](const Expr &a, const Expr &b) {
843 return Expr::make<AtomicOpExpression>(AtomicOpType::max, a, b);
844 });
845
846 m.def("expr_atomic_bit_and", [&](const Expr &a, const Expr &b) {
847 return Expr::make<AtomicOpExpression>(AtomicOpType::bit_and, a, b);
848 });
849
850 m.def("expr_atomic_bit_or", [&](const Expr &a, const Expr &b) {
851 return Expr::make<AtomicOpExpression>(AtomicOpType::bit_or, a, b);
852 });
853
854 m.def("expr_atomic_bit_xor", [&](const Expr &a, const Expr &b) {
855 return Expr::make<AtomicOpExpression>(AtomicOpType::bit_xor, a, b);
856 });
857
858 m.def("expr_assume_in_range", assume_range);
859
860 m.def("expr_loop_unique", loop_unique);
861
862 m.def("expr_field", expr_field);
863
864 m.def("expr_matrix_field", expr_matrix_field);
865
866#define DEFINE_EXPRESSION_OP(x) m.def("expr_" #x, expr_##x);
867
868 DEFINE_EXPRESSION_OP(neg)
869 DEFINE_EXPRESSION_OP(sqrt)
870 DEFINE_EXPRESSION_OP(round)
871 DEFINE_EXPRESSION_OP(floor)
872 DEFINE_EXPRESSION_OP(ceil)
873 DEFINE_EXPRESSION_OP(abs)
874 DEFINE_EXPRESSION_OP(sin)
875 DEFINE_EXPRESSION_OP(asin)
876 DEFINE_EXPRESSION_OP(cos)
877 DEFINE_EXPRESSION_OP(acos)
878 DEFINE_EXPRESSION_OP(tan)
879 DEFINE_EXPRESSION_OP(tanh)
880 DEFINE_EXPRESSION_OP(inv)
881 DEFINE_EXPRESSION_OP(rcp)
882 DEFINE_EXPRESSION_OP(rsqrt)
883 DEFINE_EXPRESSION_OP(exp)
884 DEFINE_EXPRESSION_OP(log)
885
886 DEFINE_EXPRESSION_OP(select)
887 DEFINE_EXPRESSION_OP(ifte)
888
889 DEFINE_EXPRESSION_OP(cmp_le)
890 DEFINE_EXPRESSION_OP(cmp_lt)
891 DEFINE_EXPRESSION_OP(cmp_ge)
892 DEFINE_EXPRESSION_OP(cmp_gt)
893 DEFINE_EXPRESSION_OP(cmp_ne)
894 DEFINE_EXPRESSION_OP(cmp_eq)
895
896 DEFINE_EXPRESSION_OP(bit_and)
897 DEFINE_EXPRESSION_OP(bit_or)
898 DEFINE_EXPRESSION_OP(bit_xor)
899 DEFINE_EXPRESSION_OP(bit_shl)
900 DEFINE_EXPRESSION_OP(bit_shr)
901 DEFINE_EXPRESSION_OP(bit_sar)
902 DEFINE_EXPRESSION_OP(bit_not)
903
904 DEFINE_EXPRESSION_OP(logic_not)
905 DEFINE_EXPRESSION_OP(logical_and)
906 DEFINE_EXPRESSION_OP(logical_or)
907
908 DEFINE_EXPRESSION_OP(add)
909 DEFINE_EXPRESSION_OP(sub)
910 DEFINE_EXPRESSION_OP(mul)
911 DEFINE_EXPRESSION_OP(div)
912 DEFINE_EXPRESSION_OP(truediv)
913 DEFINE_EXPRESSION_OP(floordiv)
914 DEFINE_EXPRESSION_OP(mod)
915 DEFINE_EXPRESSION_OP(max)
916 DEFINE_EXPRESSION_OP(min)
917 DEFINE_EXPRESSION_OP(atan2)
918 DEFINE_EXPRESSION_OP(pow)
919
920#undef DEFINE_EXPRESSION_OP
921
922 m.def("make_global_load_stmt", Stmt::make<GlobalLoadStmt, Stmt *>);
923 m.def("make_global_store_stmt", Stmt::make<GlobalStoreStmt, Stmt *, Stmt *>);
924 m.def("make_frontend_assign_stmt",
925 Stmt::make<FrontendAssignStmt, const Expr &, const Expr &>);
926
927 m.def("make_arg_load_expr",
928 Expr::make<ArgLoadExpression, int, const DataType &, bool>);
929
930 m.def("make_reference", Expr::make<ReferenceExpression, const Expr &>);
931
932 m.def("make_external_tensor_expr",
933 Expr::make<ExternalTensorExpression, const DataType &, int, int, int,
934 const std::vector<int> &>);
935
936 m.def("make_external_grad_tensor_expr",
937 Expr::make<ExternalTensorExpression, Expr *>);
938
939 m.def("make_rand_expr", Expr::make<RandExpression, const DataType &>);
940
941 m.def("make_const_expr_int",
942 Expr::make<ConstExpression, const DataType &, int64>);
943
944 m.def("make_const_expr_fp",
945 Expr::make<ConstExpression, const DataType &, float64>);
946
947 m.def("make_texture_ptr_expr", Expr::make<TexturePtrExpression, int, int>);
948 m.def("make_rw_texture_ptr_expr",
949 Expr::make<TexturePtrExpression, int, int, int, const DataType &, int>);
950
951 auto &&texture =
952 py::enum_<TextureOpType>(m, "TextureOpType", py::arithmetic());
953 for (int t = 0; t <= (int)TextureOpType::kStore; t++)
954 texture.value(texture_op_type_name(TextureOpType(t)).c_str(),
955 TextureOpType(t));
956 texture.export_values();
957
958 auto &&bin = py::enum_<BinaryOpType>(m, "BinaryOpType", py::arithmetic());
959 for (int t = 0; t <= (int)BinaryOpType::undefined; t++)
960 bin.value(binary_op_type_name(BinaryOpType(t)).c_str(), BinaryOpType(t));
961 bin.export_values();
962 m.def("make_binary_op_expr",
963 Expr::make<BinaryOpExpression, const BinaryOpType &, const Expr &,
964 const Expr &>);
965
966 auto &&unary = py::enum_<UnaryOpType>(m, "UnaryOpType", py::arithmetic());
967 for (int t = 0; t <= (int)UnaryOpType::undefined; t++)
968 unary.value(unary_op_type_name(UnaryOpType(t)).c_str(), UnaryOpType(t));
969 unary.export_values();
970 m.def("make_unary_op_expr",
971 Expr::make<UnaryOpExpression, const UnaryOpType &, const Expr &>);
972#define PER_TYPE(x) \
973 m.attr(("DataType_" + data_type_name(PrimitiveType::x)).c_str()) = \
974 PrimitiveType::x;
975#include "taichi/inc/data_type.inc.h"
976#undef PER_TYPE
977
978 m.def("data_type_size", data_type_size);
979 m.def("is_quant", is_quant);
980 m.def("is_integral", is_integral);
981 m.def("is_signed", is_signed);
982 m.def("is_real", is_real);
983 m.def("is_unsigned", is_unsigned);
984 m.def("is_tensor", is_tensor);
985
986 m.def("data_type_name", data_type_name);
987
988 m.def(
989 "subscript_with_multiple_indices",
990 Expr::make<IndexExpression, const Expr &, const std::vector<ExprGroup> &,
991 const std::vector<int> &, std::string>);
992
993 m.def("get_external_tensor_element_dim", [](const Expr &expr) {
994 TI_ASSERT(expr.is<ExternalTensorExpression>());
995 return expr.cast<ExternalTensorExpression>()->element_dim;
996 });
997
998 m.def("get_external_tensor_element_shape", [](const Expr &expr) {
999 TI_ASSERT(expr.is<ExternalTensorExpression>());
1000 auto external_tensor_expr = expr.cast<ExternalTensorExpression>();
1001 return external_tensor_expr->dt.get_shape();
1002 });
1003
1004 m.def("get_external_tensor_dim", [](const Expr &expr) {
1005 if (expr.is<ExternalTensorExpression>()) {
1006 return expr.cast<ExternalTensorExpression>()->dim;
1007 } else if (expr.is<TexturePtrExpression>()) {
1008 return expr.cast<TexturePtrExpression>()->num_dims;
1009 } else {
1010 TI_ASSERT(false);
1011 return 0;
1012 }
1013 });
1014
1015 m.def("get_external_tensor_shape_along_axis",
1016 Expr::make<ExternalTensorShapeAlongAxisExpression, const Expr &, int>);
1017
1018 // Mesh related.
1019 m.def("get_relation_size", [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx,
1020 mesh::MeshElementType to_type) {
1021 return Expr::make<MeshRelationAccessExpression>(mesh_ptr.ptr.get(),
1022 mesh_idx, to_type);
1023 });
1024
1025 m.def("get_relation_access",
1026 [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx,
1027 mesh::MeshElementType to_type, const Expr &neighbor_idx) {
1028 return Expr::make<MeshRelationAccessExpression>(
1029 mesh_ptr.ptr.get(), mesh_idx, to_type, neighbor_idx);
1030 });
1031
1032 py::class_<FunctionKey>(m, "FunctionKey")
1033 .def(py::init<const std::string &, int, int>())
1034 .def_readonly("instance_id", &FunctionKey::instance_id);
1035
1036 m.def("test_throw", [] {
1037 try {
1038 throw IRModified();
1039 } catch (IRModified) {
1040 TI_INFO("caught");
1041 }
1042 });
1043
1044 m.def("test_throw", [] { throw IRModified(); });
1045
1046#if TI_WITH_LLVM
1047 m.def("libdevice_path", libdevice_path);
1048#endif
1049
1050 m.def("host_arch", host_arch);
1051 m.def("arch_uses_llvm", arch_uses_llvm);
1052
1053 m.def("set_lib_dir", [&](const std::string &dir) { compiled_lib_dir = dir; });
1054 m.def("set_tmp_dir", [&](const std::string &dir) { runtime_tmp_dir = dir; });
1055
1056 m.def("get_commit_hash", get_commit_hash);
1057 m.def("get_version_string", get_version_string);
1058 m.def("get_version_major", get_version_major);
1059 m.def("get_version_minor", get_version_minor);
1060 m.def("get_version_patch", get_version_patch);
1061 m.def("get_llvm_target_support", [] {
1062#if defined(TI_WITH_LLVM)
1063 return LLVM_VERSION_STRING;
1064#else
1065 return "targets unsupported";
1066#endif
1067 });
1068 m.def("test_printf", [] { printf("test_printf\n"); });
1069 m.def("test_logging", [] { TI_INFO("test_logging"); });
1070 m.def("trigger_crash", [] { *(int *)(1) = 0; });
1071 m.def("get_max_num_indices", [] { return taichi_max_num_indices; });
1072 m.def("get_max_num_args", [] { return taichi_max_num_args; });
1073 m.def("test_threading", test_threading);
1074 m.def("is_extension_supported", is_extension_supported);
1075
1076 m.def("record_action_entry",
1077 [](std::string name,
1078 std::vector<std::pair<std::string,
1079 std::variant<std::string, int, float>>> args) {
1080 std::vector<ActionArg> acts;
1081 for (auto const &[k, v] : args) {
1082 if (std::holds_alternative<int>(v)) {
1083 acts.push_back(ActionArg(k, std::get<int>(v)));
1084 } else if (std::holds_alternative<float>(v)) {
1085 acts.push_back(ActionArg(k, std::get<float>(v)));
1086 } else {
1087 acts.push_back(ActionArg(k, std::get<std::string>(v)));
1088 }
1089 }
1090 ActionRecorder::get_instance().record(name, acts);
1091 });
1092
1093 m.def("start_recording", [](const std::string &fn) {
1094 ActionRecorder::get_instance().start_recording(fn);
1095 });
1096
1097 m.def("stop_recording",
1098 []() { ActionRecorder::get_instance().stop_recording(); });
1099
1100 m.def("query_int64", [](const std::string &key) {
1101 if (key == "cuda_compute_capability") {
1102#if defined(TI_WITH_CUDA)
1103 return CUDAContext::get_instance().get_compute_capability();
1104#else
1105 TI_NOT_IMPLEMENTED
1106#endif
1107 } else {
1108 TI_ERROR("Key {} not supported in query_int64", key);
1109 }
1110 });
1111
1112 // Type system
1113
1114 py::class_<Type>(m, "Type").def("to_string", &Type::to_string);
1115
1116 m.def("promoted_type", promoted_type);
1117
1118 // Note that it is important to specify py::return_value_policy::reference for
1119 // the factory methods, otherwise pybind11 will delete the Types owned by
1120 // TypeFactory on Python-scope pointer destruction.
1121 py::class_<TypeFactory>(m, "TypeFactory")
1122 .def("get_quant_int_type", &TypeFactory::get_quant_int_type,
1123 py::arg("num_bits"), py::arg("is_signed"), py::arg("compute_type"),
1124 py::return_value_policy::reference)
1125 .def("get_quant_fixed_type", &TypeFactory::get_quant_fixed_type,
1126 py::arg("digits_type"), py::arg("compute_type"), py::arg("scale"),
1127 py::return_value_policy::reference)
1128 .def("get_quant_float_type", &TypeFactory::get_quant_float_type,
1129 py::arg("digits_type"), py::arg("exponent_type"),
1130 py::arg("compute_type"), py::return_value_policy::reference)
1131 .def(
1132 "get_tensor_type",
1133 [&](TypeFactory *factory, std::vector<int> shape,
1134 const DataType &element_type) {
1135 return factory->create_tensor_type(shape, element_type);
1136 },
1137 py::return_value_policy::reference)
1138 .def(
1139 "get_struct_type",
1140 [&](TypeFactory *factory,
1141 std::vector<std::pair<DataType, std::string>> elements) {
1142 std::vector<StructMember> members;
1143 for (auto &[type, name] : elements) {
1144 members.push_back({type, name});
1145 }
1146 return DataType(factory->get_struct_type(members));
1147 },
1148 py::return_value_policy::reference);
1149
1150 m.def("get_type_factory_instance", TypeFactory::get_instance,
1151 py::return_value_policy::reference);
1152
1153 // NOLINTNEXTLINE(bugprone-unused-raii)
1154 py::class_<BitStructType>(m, "BitStructType");
1155 py::class_<BitStructTypeBuilder>(m, "BitStructTypeBuilder")
1156 .def(py::init<int>())
1157 .def("begin_placing_shared_exponent",
1158 &BitStructTypeBuilder::begin_placing_shared_exponent)
1159 .def("end_placing_shared_exponent",
1160 &BitStructTypeBuilder::end_placing_shared_exponent)
1161 .def("add_member", &BitStructTypeBuilder::add_member)
1162 .def("build", &BitStructTypeBuilder::build,
1163 py::return_value_policy::reference);
1164
1165 py::class_<SNodeRegistry>(m, "SNodeRegistry")
1166 .def(py::init<>())
1167 .def("create_root", &SNodeRegistry::create_root,
1168 py::return_value_policy::reference);
1169
1170 m.def(
1171 "finalize_snode_tree",
1172 [](SNodeRegistry *registry, const SNode *root, Program *program,
1173 bool compile_only) -> SNodeTree * {
1174 return program->add_snode_tree(registry->finalize(root), compile_only);
1175 },
1176 py::return_value_policy::reference);
1177
1178 // Sparse Matrix
1179 py::class_<SparseMatrixBuilder>(m, "SparseMatrixBuilder")
1180 .def("print_triplets_eigen", &SparseMatrixBuilder::print_triplets_eigen)
1181 .def("print_triplets_cuda", &SparseMatrixBuilder::print_triplets_cuda)
1182 .def("get_ndarray_data_ptr", &SparseMatrixBuilder::get_ndarray_data_ptr)
1183 .def("build", &SparseMatrixBuilder::build)
1184 .def("build_cuda", &SparseMatrixBuilder::build_cuda)
1185 .def("get_addr", [](SparseMatrixBuilder *mat) { return uint64(mat); });
1186
1187 py::class_<SparseMatrix>(m, "SparseMatrix")
1188 .def(py::init<>())
1189 .def(py::init<int, int, DataType>(), py::arg("rows"), py::arg("cols"),
1190 py::arg("dt") = PrimitiveType::f32)
1191 .def(py::init<SparseMatrix &>())
1192 .def("to_string", &SparseMatrix::to_string)
1193 .def("get_element", &SparseMatrix::get_element<float32>)
1194 .def("set_element", &SparseMatrix::set_element<float32>)
1195 .def("num_rows", &SparseMatrix::num_rows)
1196 .def("num_cols", &SparseMatrix::num_cols);
1197
1198#define MAKE_SPARSE_MATRIX(TYPE, STORAGE, VTYPE) \
1199 using STORAGE##TYPE##EigenMatrix = \
1200 Eigen::SparseMatrix<float##TYPE, Eigen::STORAGE>; \
1201 py::class_<EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>, SparseMatrix>( \
1202 m, #VTYPE #STORAGE "_EigenSparseMatrix") \
1203 .def(py::init<int, int, DataType>()) \
1204 .def(py::init<EigenSparseMatrix<STORAGE##TYPE##EigenMatrix> &>()) \
1205 .def(py::init<const STORAGE##TYPE##EigenMatrix &>()) \
1206 .def(py::self += py::self) \
1207 .def(py::self + py::self) \
1208 .def(py::self -= py::self) \
1209 .def(py::self - py::self) \
1210 .def(py::self *= float##TYPE()) \
1211 .def(py::self *float##TYPE()) \
1212 .def(float##TYPE() * py::self) \
1213 .def(py::self *py::self) \
1214 .def("matmul", &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::matmul) \
1215 .def("spmv", &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::spmv) \
1216 .def("transpose", \
1217 &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::transpose) \
1218 .def("get_element", \
1219 &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::get_element< \
1220 float##TYPE>) \
1221 .def("set_element", \
1222 &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::set_element< \
1223 float##TYPE>) \
1224 .def("mat_vec_mul", \
1225 &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::mat_vec_mul< \
1226 Eigen::VectorX##VTYPE>);
1227
1228 MAKE_SPARSE_MATRIX(32, ColMajor, f);
1229 MAKE_SPARSE_MATRIX(32, RowMajor, f);
1230 MAKE_SPARSE_MATRIX(64, ColMajor, d);
1231 MAKE_SPARSE_MATRIX(64, RowMajor, d);
1232
1233 py::class_<CuSparseMatrix, SparseMatrix>(m, "CuSparseMatrix")
1234 .def(py::init<int, int, DataType>())
1235 .def(py::init<const CuSparseMatrix &>())
1236 .def("spmv", &CuSparseMatrix::spmv)
1237 .def(py::self + py::self)
1238 .def(py::self - py::self)
1239 .def(py::self * float32())
1240 .def(float32() * py::self)
1241 .def("matmul", &CuSparseMatrix::matmul)
1242 .def("transpose", &CuSparseMatrix::transpose)
1243 .def("get_element", &CuSparseMatrix::get_element)
1244 .def("to_string", &CuSparseMatrix::to_string);
1245
1246 py::class_<SparseSolver>(m, "SparseSolver")
1247 .def("compute", &SparseSolver::compute)
1248 .def("analyze_pattern", &SparseSolver::analyze_pattern)
1249 .def("factorize", &SparseSolver::factorize)
1250 .def("info", &SparseSolver::info);
1251
1252#define REGISTER_EIGEN_SOLVER(dt, type, order, fd) \
1253 py::class_<EigenSparseSolver##dt##type##order, SparseSolver>( \
1254 m, "EigenSparseSolver" #dt #type #order) \
1255 .def("compute", &EigenSparseSolver##dt##type##order::compute) \
1256 .def("analyze_pattern", \
1257 &EigenSparseSolver##dt##type##order::analyze_pattern) \
1258 .def("factorize", &EigenSparseSolver##dt##type##order::factorize) \
1259 .def("solve", \
1260 &EigenSparseSolver##dt##type##order::solve<Eigen::VectorX##fd>) \
1261 .def("solve_rf", \
1262 &EigenSparseSolver##dt##type##order::solve_rf<Eigen::VectorX##fd, \
1263 dt>) \
1264 .def("info", &EigenSparseSolver##dt##type##order::info);
1265
1266 REGISTER_EIGEN_SOLVER(float32, LLT, AMD, f)
1267 REGISTER_EIGEN_SOLVER(float32, LLT, COLAMD, f)
1268 REGISTER_EIGEN_SOLVER(float32, LDLT, AMD, f)
1269 REGISTER_EIGEN_SOLVER(float32, LDLT, COLAMD, f)
1270 REGISTER_EIGEN_SOLVER(float32, LU, AMD, f)
1271 REGISTER_EIGEN_SOLVER(float32, LU, COLAMD, f)
1272 REGISTER_EIGEN_SOLVER(float64, LLT, AMD, d)
1273 REGISTER_EIGEN_SOLVER(float64, LLT, COLAMD, d)
1274 REGISTER_EIGEN_SOLVER(float64, LDLT, AMD, d)
1275 REGISTER_EIGEN_SOLVER(float64, LDLT, COLAMD, d)
1276 REGISTER_EIGEN_SOLVER(float64, LU, AMD, d)
1277 REGISTER_EIGEN_SOLVER(float64, LU, COLAMD, d)
1278
1279 py::class_<CuSparseSolver, SparseSolver>(m, "CuSparseSolver")
1280 .def("compute", &CuSparseSolver::compute)
1281 .def("analyze_pattern", &CuSparseSolver::analyze_pattern)
1282 .def("factorize", &CuSparseSolver::factorize)
1283 .def("solve_rf", &CuSparseSolver::solve_rf)
1284 .def("info", &CuSparseSolver::info);
1285
1286 m.def("make_sparse_solver", &make_sparse_solver);
1287 m.def("make_cusparse_solver", &make_cusparse_solver);
1288
1289 // Mesh Class
1290 // Mesh related.
1291 py::enum_<mesh::MeshTopology>(m, "MeshTopology", py::arithmetic())
1292 .value("Triangle", mesh::MeshTopology::Triangle)
1293 .value("Tetrahedron", mesh::MeshTopology::Tetrahedron)
1294 .export_values();
1295
1296 py::enum_<mesh::MeshElementType>(m, "MeshElementType", py::arithmetic())
1297 .value("Vertex", mesh::MeshElementType::Vertex)
1298 .value("Edge", mesh::MeshElementType::Edge)
1299 .value("Face", mesh::MeshElementType::Face)
1300 .value("Cell", mesh::MeshElementType::Cell)
1301 .export_values();
1302
1303 py::enum_<mesh::MeshRelationType>(m, "MeshRelationType", py::arithmetic())
1304 .value("VV", mesh::MeshRelationType::VV)
1305 .value("VE", mesh::MeshRelationType::VE)
1306 .value("VF", mesh::MeshRelationType::VF)
1307 .value("VC", mesh::MeshRelationType::VC)
1308 .value("EV", mesh::MeshRelationType::EV)
1309 .value("EE", mesh::MeshRelationType::EE)
1310 .value("EF", mesh::MeshRelationType::EF)
1311 .value("EC", mesh::MeshRelationType::EC)
1312 .value("FV", mesh::MeshRelationType::FV)
1313 .value("FE", mesh::MeshRelationType::FE)
1314 .value("FF", mesh::MeshRelationType::FF)
1315 .value("FC", mesh::MeshRelationType::FC)
1316 .value("CV", mesh::MeshRelationType::CV)
1317 .value("CE", mesh::MeshRelationType::CE)
1318 .value("CF", mesh::MeshRelationType::CF)
1319 .value("CC", mesh::MeshRelationType::CC)
1320 .export_values();
1321
1322 py::enum_<mesh::ConvType>(m, "ConvType", py::arithmetic())
1323 .value("l2g", mesh::ConvType::l2g)
1324 .value("l2r", mesh::ConvType::l2r)
1325 .value("g2r", mesh::ConvType::g2r)
1326 .export_values();
1327
1328 py::class_<mesh::Mesh>(m, "Mesh"); // NOLINT(bugprone-unused-raii)
1329 py::class_<mesh::MeshPtr>(m, "MeshPtr"); // NOLINT(bugprone-unused-raii)
1330
1331 m.def("element_order", mesh::element_order);
1332 m.def("from_end_element_order", mesh::from_end_element_order);
1333 m.def("to_end_element_order", mesh::to_end_element_order);
1334 m.def("relation_by_orders", mesh::relation_by_orders);
1335 m.def("inverse_relation", mesh::inverse_relation);
1336 m.def("element_type_name", mesh::element_type_name);
1337
1338 m.def(
1339 "create_mesh",
1340 []() {
1341 auto mesh_shared = std::make_shared<mesh::Mesh>();
1342 mesh::MeshPtr mesh_ptr = mesh::MeshPtr{mesh_shared};
1343 return mesh_ptr;
1344 },
1345 py::return_value_policy::reference);
1346
1347 // ad-hoc setters
1348 m.def("set_owned_offset",
1349 [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) {
1350 mesh_ptr.ptr->owned_offset.insert(std::pair(type, snode));
1351 });
1352 m.def("set_total_offset",
1353 [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) {
1354 mesh_ptr.ptr->total_offset.insert(std::pair(type, snode));
1355 });
1356 m.def("set_num_patches", [](mesh::MeshPtr &mesh_ptr, int num_patches) {
1357 mesh_ptr.ptr->num_patches = num_patches;
1358 });
1359
1360 m.def("set_num_elements", [](mesh::MeshPtr &mesh_ptr,
1361 mesh::MeshElementType type, int num_elements) {
1362 mesh_ptr.ptr->num_elements.insert(std::pair(type, num_elements));
1363 });
1364
1365 m.def("get_num_elements",
1366 [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type) {
1367 return mesh_ptr.ptr->num_elements.find(type)->second;
1368 });
1369
1370 m.def("set_patch_max_element_num",
1371 [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type,
1372 int max_element_num) {
1373 mesh_ptr.ptr->patch_max_element_num.insert(
1374 std::pair(type, max_element_num));
1375 });
1376
1377 m.def("set_index_mapping",
1378 [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType element_type,
1379 mesh::ConvType conv_type, SNode *snode) {
1380 mesh_ptr.ptr->index_mapping.insert(
1381 std::make_pair(std::make_pair(element_type, conv_type), snode));
1382 });
1383
1384 m.def("set_relation_fixed",
1385 [](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value) {
1386 mesh_ptr.ptr->relations.insert(
1387 std::pair(type, mesh::MeshLocalRelation(value)));
1388 });
1389
1390 m.def("set_relation_dynamic",
1391 [](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value,
1392 SNode *patch_offset, SNode *offset) {
1393 mesh_ptr.ptr->relations.insert(std::pair(
1394 type, mesh::MeshLocalRelation(value, patch_offset, offset)));
1395 });
1396
1397 m.def("wait_for_debugger", []() {
1398#ifdef WIN32
1399 while (!::IsDebuggerPresent())
1400 ::Sleep(100);
1401#endif
1402 });
1403}
1404
1405} // namespace taichi
1406