1 | // Bindings for the python frontend |
2 | |
3 | #include <optional> |
4 | #include <string> |
5 | #include "taichi/ir/snode.h" |
6 | |
7 | #if TI_WITH_LLVM |
8 | #include "llvm/Config/llvm-config.h" |
9 | #endif |
10 | |
11 | #include "pybind11/functional.h" |
12 | #include "pybind11/pybind11.h" |
13 | #include "pybind11/eigen.h" |
14 | #include "pybind11/numpy.h" |
15 | |
16 | #include "taichi/ir/expression_ops.h" |
17 | #include "taichi/ir/frontend_ir.h" |
18 | #include "taichi/ir/statements.h" |
19 | #include "taichi/program/graph_builder.h" |
20 | #include "taichi/program/extension.h" |
21 | #include "taichi/program/ndarray.h" |
22 | #include "taichi/python/export.h" |
23 | #include "taichi/math/svd.h" |
24 | #include "taichi/util/action_recorder.h" |
25 | #include "taichi/system/timeline.h" |
26 | #include "taichi/python/snode_registry.h" |
27 | #include "taichi/program/sparse_matrix.h" |
28 | #include "taichi/program/sparse_solver.h" |
29 | #include "taichi/aot/graph_data.h" |
30 | #include "taichi/ir/mesh.h" |
31 | |
32 | #include "taichi/program/kernel_profiler.h" |
33 | |
34 | #if defined(TI_WITH_CUDA) |
35 | #include "taichi/rhi/cuda/cuda_context.h" |
36 | #endif |
37 | |
38 | namespace taichi { |
39 | bool test_threading(); |
40 | |
41 | } // namespace taichi |
42 | |
43 | namespace taichi::lang { |
44 | |
45 | std::string libdevice_path(); |
46 | |
47 | } // namespace taichi::lang |
48 | |
49 | namespace taichi { |
50 | void export_lang(py::module &m) { |
51 | using namespace taichi::lang; |
52 | using namespace std::placeholders; |
53 | |
54 | py::register_exception<TaichiTypeError>(m, "TaichiTypeError" , |
55 | PyExc_TypeError); |
56 | py::register_exception<TaichiSyntaxError>(m, "TaichiSyntaxError" , |
57 | PyExc_SyntaxError); |
58 | py::register_exception<TaichiIndexError>(m, "TaichiIndexError" , |
59 | PyExc_IndexError); |
60 | py::register_exception<TaichiRuntimeError>(m, "TaichiRuntimeError" , |
61 | PyExc_RuntimeError); |
62 | py::register_exception<TaichiAssertionError>(m, "TaichiAssertionError" , |
63 | PyExc_AssertionError); |
64 | py::enum_<Arch>(m, "Arch" , py::arithmetic()) |
65 | #define PER_ARCH(x) .value(#x, Arch::x) |
66 | #include "taichi/inc/archs.inc.h" |
67 | #undef PER_ARCH |
68 | .export_values(); |
69 | |
70 | m.def("arch_name" , arch_name); |
71 | m.def("arch_from_name" , arch_from_name); |
72 | |
73 | py::enum_<SNodeType>(m, "SNodeType" , py::arithmetic()) |
74 | #define PER_SNODE(x) .value(#x, SNodeType::x) |
75 | #include "taichi/inc/snodes.inc.h" |
76 | #undef PER_SNODE |
77 | .export_values(); |
78 | |
79 | py::enum_<Extension>(m, "Extension" , py::arithmetic()) |
80 | #define PER_EXTENSION(x) .value(#x, Extension::x) |
81 | #include "taichi/inc/extensions.inc.h" |
82 | #undef PER_EXTENSION |
83 | .export_values(); |
84 | |
85 | py::enum_<ExternalArrayLayout>(m, "Layout" , py::arithmetic()) |
86 | .value("AOS" , ExternalArrayLayout::kAOS) |
87 | .value("SOA" , ExternalArrayLayout::kSOA) |
88 | .value("NULL" , ExternalArrayLayout::kNull) |
89 | .export_values(); |
90 | |
91 | py::enum_<AutodiffMode>(m, "AutodiffMode" , py::arithmetic()) |
92 | .value("NONE" , AutodiffMode::kNone) |
93 | .value("VALIDATION" , AutodiffMode::kCheckAutodiffValid) |
94 | .value("FORWARD" , AutodiffMode::kForward) |
95 | .value("REVERSE" , AutodiffMode::kReverse) |
96 | .export_values(); |
97 | |
98 | py::enum_<SNodeGradType>(m, "SNodeGradType" , py::arithmetic()) |
99 | .value("PRIMAL" , SNodeGradType::kPrimal) |
100 | .value("ADJOINT" , SNodeGradType::kAdjoint) |
101 | .value("DUAL" , SNodeGradType::kDual) |
102 | .value("ADJOINT_CHECKBIT" , SNodeGradType::kAdjointCheckbit) |
103 | .export_values(); |
104 | |
105 | // TODO(type): This should be removed |
106 | py::class_<DataType>(m, "DataType" ) |
107 | .def(py::init<Type *>()) |
108 | .def(py::self == py::self) |
109 | .def("__hash__" , &DataType::hash) |
110 | .def("to_string" , &DataType::to_string) |
111 | .def("__str__" , &DataType::to_string) |
112 | .def("shape" , &DataType::get_shape) |
113 | .def("element_type" , &DataType::get_element_type) |
114 | .def( |
115 | "get_ptr" , [](DataType *dtype) -> Type * { return *dtype; }, |
116 | py::return_value_policy::reference) |
117 | .def(py::pickle( |
118 | [](const DataType &dt) { |
119 | // Note: this only works for primitive types, which is fine for now. |
120 | auto primitive = |
121 | dynamic_cast<const PrimitiveType *>((const Type *)dt); |
122 | TI_ASSERT(primitive); |
123 | return py::make_tuple((std::size_t)primitive->type); |
124 | }, |
125 | [](py::tuple t) { |
126 | if (t.size() != 1) |
127 | throw std::runtime_error("Invalid state!" ); |
128 | |
129 | DataType dt = |
130 | PrimitiveType::get((PrimitiveTypeID)(t[0].cast<std::size_t>())); |
131 | |
132 | return dt; |
133 | })); |
134 | |
135 | py::class_<CompileConfig>(m, "CompileConfig" ) |
136 | .def(py::init<>()) |
137 | .def_readwrite("arch" , &CompileConfig::arch) |
138 | .def_readwrite("opt_level" , &CompileConfig::opt_level) |
139 | .def_readwrite("print_ir" , &CompileConfig::print_ir) |
140 | .def_readwrite("print_preprocessed_ir" , |
141 | &CompileConfig::print_preprocessed_ir) |
142 | .def_readwrite("debug" , &CompileConfig::debug) |
143 | .def_readwrite("cfg_optimization" , &CompileConfig::cfg_optimization) |
144 | .def_readwrite("check_out_of_bound" , &CompileConfig::check_out_of_bound) |
145 | .def_readwrite("print_accessor_ir" , &CompileConfig::print_accessor_ir) |
146 | .def_readwrite("print_evaluator_ir" , &CompileConfig::print_evaluator_ir) |
147 | .def_readwrite("use_llvm" , &CompileConfig::use_llvm) |
148 | .def_readwrite("print_struct_llvm_ir" , |
149 | &CompileConfig::print_struct_llvm_ir) |
150 | .def_readwrite("print_kernel_llvm_ir" , |
151 | &CompileConfig::print_kernel_llvm_ir) |
152 | .def_readwrite("print_kernel_llvm_ir_optimized" , |
153 | &CompileConfig::print_kernel_llvm_ir_optimized) |
154 | .def_readwrite("print_kernel_nvptx" , &CompileConfig::print_kernel_nvptx) |
155 | .def_readwrite("simplify_before_lower_access" , |
156 | &CompileConfig::simplify_before_lower_access) |
157 | .def_readwrite("simplify_after_lower_access" , |
158 | &CompileConfig::simplify_after_lower_access) |
159 | .def_readwrite("lower_access" , &CompileConfig::lower_access) |
160 | .def_readwrite("move_loop_invariant_outside_if" , |
161 | &CompileConfig::move_loop_invariant_outside_if) |
162 | .def_readwrite("cache_loop_invariant_global_vars" , |
163 | &CompileConfig::cache_loop_invariant_global_vars) |
164 | .def_readwrite("default_cpu_block_dim" , |
165 | &CompileConfig::default_cpu_block_dim) |
166 | .def_readwrite("cpu_block_dim_adaptive" , |
167 | &CompileConfig::cpu_block_dim_adaptive) |
168 | .def_readwrite("default_gpu_block_dim" , |
169 | &CompileConfig::default_gpu_block_dim) |
170 | .def_readwrite("gpu_max_reg" , &CompileConfig::gpu_max_reg) |
171 | .def_readwrite("saturating_grid_dim" , &CompileConfig::saturating_grid_dim) |
172 | .def_readwrite("max_block_dim" , &CompileConfig::max_block_dim) |
173 | .def_readwrite("cpu_max_num_threads" , &CompileConfig::cpu_max_num_threads) |
174 | .def_readwrite("random_seed" , &CompileConfig::random_seed) |
175 | .def_readwrite("verbose_kernel_launches" , |
176 | &CompileConfig::verbose_kernel_launches) |
177 | .def_readwrite("verbose" , &CompileConfig::verbose) |
178 | .def_readwrite("demote_dense_struct_fors" , |
179 | &CompileConfig::demote_dense_struct_fors) |
180 | .def_readwrite("kernel_profiler" , &CompileConfig::kernel_profiler) |
181 | .def_readwrite("timeline" , &CompileConfig::timeline) |
182 | .def_readwrite("default_fp" , &CompileConfig::default_fp) |
183 | .def_readwrite("default_ip" , &CompileConfig::default_ip) |
184 | .def_readwrite("default_up" , &CompileConfig::default_up) |
185 | .def_readwrite("device_memory_GB" , &CompileConfig::device_memory_GB) |
186 | .def_readwrite("device_memory_fraction" , |
187 | &CompileConfig::device_memory_fraction) |
188 | .def_readwrite("fast_math" , &CompileConfig::fast_math) |
189 | .def_readwrite("advanced_optimization" , |
190 | &CompileConfig::advanced_optimization) |
191 | .def_readwrite("ad_stack_size" , &CompileConfig::ad_stack_size) |
192 | .def_readwrite("flatten_if" , &CompileConfig::flatten_if) |
193 | .def_readwrite("make_thread_local" , &CompileConfig::make_thread_local) |
194 | .def_readwrite("make_block_local" , &CompileConfig::make_block_local) |
195 | .def_readwrite("detect_read_only" , &CompileConfig::detect_read_only) |
196 | .def_readwrite("ndarray_use_cached_allocator" , |
197 | &CompileConfig::ndarray_use_cached_allocator) |
198 | .def_readwrite("real_matrix_scalarize" , |
199 | &CompileConfig::real_matrix_scalarize) |
200 | .def_readwrite("cc_compile_cmd" , &CompileConfig::cc_compile_cmd) |
201 | .def_readwrite("cc_link_cmd" , &CompileConfig::cc_link_cmd) |
202 | .def_readwrite("quant_opt_store_fusion" , |
203 | &CompileConfig::quant_opt_store_fusion) |
204 | .def_readwrite("quant_opt_atomic_demotion" , |
205 | &CompileConfig::quant_opt_atomic_demotion) |
206 | .def_readwrite("allow_nv_shader_extension" , |
207 | &CompileConfig::allow_nv_shader_extension) |
208 | .def_readwrite("make_mesh_block_local" , |
209 | &CompileConfig::make_mesh_block_local) |
210 | .def_readwrite("mesh_localize_to_end_mapping" , |
211 | &CompileConfig::mesh_localize_to_end_mapping) |
212 | .def_readwrite("mesh_localize_from_end_mapping" , |
213 | &CompileConfig::mesh_localize_from_end_mapping) |
214 | .def_readwrite("optimize_mesh_reordered_mapping" , |
215 | &CompileConfig::optimize_mesh_reordered_mapping) |
216 | .def_readwrite("mesh_localize_all_attr_mappings" , |
217 | &CompileConfig::mesh_localize_all_attr_mappings) |
218 | .def_readwrite("demote_no_access_mesh_fors" , |
219 | &CompileConfig::demote_no_access_mesh_fors) |
220 | .def_readwrite("experimental_auto_mesh_local" , |
221 | &CompileConfig::experimental_auto_mesh_local) |
222 | .def_readwrite("auto_mesh_local_default_occupacy" , |
223 | &CompileConfig::auto_mesh_local_default_occupacy) |
224 | .def_readwrite("offline_cache" , &CompileConfig::offline_cache) |
225 | .def_readwrite("offline_cache_file_path" , |
226 | &CompileConfig::offline_cache_file_path) |
227 | .def_readwrite("offline_cache_cleaning_policy" , |
228 | &CompileConfig::offline_cache_cleaning_policy) |
229 | .def_readwrite("offline_cache_max_size_of_files" , |
230 | &CompileConfig::offline_cache_max_size_of_files) |
231 | .def_readwrite("offline_cache_cleaning_factor" , |
232 | &CompileConfig::offline_cache_cleaning_factor) |
233 | .def_readwrite("num_compile_threads" , &CompileConfig::num_compile_threads) |
234 | .def_readwrite("vk_api_version" , &CompileConfig::vk_api_version) |
235 | .def_readwrite("cuda_stack_limit" , &CompileConfig::cuda_stack_limit); |
236 | |
237 | m.def("reset_default_compile_config" , |
238 | [&]() { default_compile_config = CompileConfig(); }); |
239 | |
240 | m.def( |
241 | "default_compile_config" , |
242 | [&]() -> CompileConfig & { return default_compile_config; }, |
243 | py::return_value_policy::reference); |
244 | |
245 | py::class_<Program::KernelProfilerQueryResult>(m, "KernelProfilerQueryResult" ) |
246 | .def_readwrite("counter" , &Program::KernelProfilerQueryResult::counter) |
247 | .def_readwrite("min" , &Program::KernelProfilerQueryResult::min) |
248 | .def_readwrite("max" , &Program::KernelProfilerQueryResult::max) |
249 | .def_readwrite("avg" , &Program::KernelProfilerQueryResult::avg); |
250 | |
251 | py::class_<KernelProfileTracedRecord>(m, "KernelProfileTracedRecord" ) |
252 | .def_readwrite("register_per_thread" , |
253 | &KernelProfileTracedRecord::register_per_thread) |
254 | .def_readwrite("shared_mem_per_block" , |
255 | &KernelProfileTracedRecord::shared_mem_per_block) |
256 | .def_readwrite("grid_size" , &KernelProfileTracedRecord::grid_size) |
257 | .def_readwrite("block_size" , &KernelProfileTracedRecord::block_size) |
258 | .def_readwrite( |
259 | "active_blocks_per_multiprocessor" , |
260 | &KernelProfileTracedRecord::active_blocks_per_multiprocessor) |
261 | .def_readwrite("kernel_time" , |
262 | &KernelProfileTracedRecord::kernel_elapsed_time_in_ms) |
263 | .def_readwrite("base_time" , &KernelProfileTracedRecord::time_since_base) |
264 | .def_readwrite("name" , &KernelProfileTracedRecord::name) |
265 | .def_readwrite("metric_values" , |
266 | &KernelProfileTracedRecord::metric_values); |
267 | |
268 | py::enum_<SNodeAccessFlag>(m, "SNodeAccessFlag" , py::arithmetic()) |
269 | .value("block_local" , SNodeAccessFlag::block_local) |
270 | .value("read_only" , SNodeAccessFlag::read_only) |
271 | .value("mesh_local" , SNodeAccessFlag::mesh_local) |
272 | .export_values(); |
273 | |
274 | // Export ASTBuilder |
275 | py::class_<ASTBuilder>(m, "ASTBuilder" ) |
276 | .def("make_id_expr" , &ASTBuilder::make_id_expr) |
277 | .def("create_kernel_exprgroup_return" , |
278 | &ASTBuilder::create_kernel_exprgroup_return) |
279 | .def("create_print" , &ASTBuilder::create_print) |
280 | .def("begin_func" , &ASTBuilder::begin_func) |
281 | .def("end_func" , &ASTBuilder::end_func) |
282 | .def("stop_grad" , &ASTBuilder::stop_gradient) |
283 | .def("begin_frontend_if" , &ASTBuilder::begin_frontend_if) |
284 | .def("begin_frontend_if_true" , &ASTBuilder::begin_frontend_if_true) |
285 | .def("pop_scope" , &ASTBuilder::pop_scope) |
286 | .def("begin_frontend_if_false" , &ASTBuilder::begin_frontend_if_false) |
287 | .def("insert_deactivate" , &ASTBuilder::insert_snode_deactivate) |
288 | .def("insert_activate" , &ASTBuilder::insert_snode_activate) |
289 | .def("expr_snode_get_addr" , &ASTBuilder::snode_get_addr) |
290 | .def("expr_snode_append" , &ASTBuilder::snode_append) |
291 | .def("expr_snode_is_active" , &ASTBuilder::snode_is_active) |
292 | .def("expr_snode_length" , &ASTBuilder::snode_length) |
293 | .def("insert_external_func_call" , &ASTBuilder::insert_external_func_call) |
294 | .def("make_matrix_expr" , &ASTBuilder::make_matrix_expr) |
295 | .def("expr_alloca" , &ASTBuilder::expr_alloca) |
296 | .def("expr_alloca_shared_array" , &ASTBuilder::expr_alloca_shared_array) |
297 | .def("create_assert_stmt" , &ASTBuilder::create_assert_stmt) |
298 | .def("expr_assign" , &ASTBuilder::expr_assign) |
299 | .def("begin_frontend_range_for" , &ASTBuilder::begin_frontend_range_for) |
300 | .def("end_frontend_range_for" , &ASTBuilder::pop_scope) |
301 | .def("begin_frontend_struct_for_on_snode" , |
302 | &ASTBuilder::begin_frontend_struct_for_on_snode) |
303 | .def("begin_frontend_struct_for_on_external_tensor" , |
304 | &ASTBuilder::begin_frontend_struct_for_on_external_tensor) |
305 | .def("end_frontend_struct_for" , &ASTBuilder::pop_scope) |
306 | .def("begin_frontend_mesh_for" , &ASTBuilder::begin_frontend_mesh_for) |
307 | .def("end_frontend_mesh_for" , &ASTBuilder::pop_scope) |
308 | .def("begin_frontend_while" , &ASTBuilder::begin_frontend_while) |
309 | .def("insert_break_stmt" , &ASTBuilder::insert_break_stmt) |
310 | .def("insert_continue_stmt" , &ASTBuilder::insert_continue_stmt) |
311 | .def("insert_expr_stmt" , &ASTBuilder::insert_expr_stmt) |
312 | .def("insert_thread_idx_expr" , &ASTBuilder::insert_thread_idx_expr) |
313 | .def("insert_patch_idx_expr" , &ASTBuilder::insert_patch_idx_expr) |
314 | .def("make_texture_op_expr" , &ASTBuilder::make_texture_op_expr) |
315 | .def("expand_exprs" , &ASTBuilder::expand_exprs) |
316 | .def("mesh_index_conversion" , &ASTBuilder::mesh_index_conversion) |
317 | .def("expr_subscript" , &ASTBuilder::expr_subscript) |
318 | .def("insert_func_call" , &ASTBuilder::insert_func_call) |
319 | .def("sifakis_svd_f32" , sifakis_svd_export<float32, int32>) |
320 | .def("sifakis_svd_f64" , sifakis_svd_export<float64, int64>) |
321 | .def("expr_var" , &ASTBuilder::make_var) |
322 | .def("bit_vectorize" , &ASTBuilder::bit_vectorize) |
323 | .def("parallelize" , &ASTBuilder::parallelize) |
324 | .def("strictly_serialize" , &ASTBuilder::strictly_serialize) |
325 | .def("block_dim" , &ASTBuilder::block_dim) |
326 | .def("insert_snode_access_flag" , &ASTBuilder::insert_snode_access_flag) |
327 | .def("reset_snode_access_flag" , &ASTBuilder::reset_snode_access_flag); |
328 | |
329 | py::class_<Program>(m, "Program" ) |
330 | .def(py::init<>()) |
331 | .def("config" , &Program::compile_config, |
332 | py::return_value_policy::reference) |
333 | .def("sync_kernel_profiler" , |
334 | [](Program *program) { program->profiler->sync(); }) |
335 | .def("update_kernel_profiler" , |
336 | [](Program *program) { program->profiler->update(); }) |
337 | .def("clear_kernel_profiler" , |
338 | [](Program *program) { program->profiler->clear(); }) |
339 | .def("query_kernel_profile_info" , |
340 | [](Program *program, const std::string &name) { |
341 | return program->query_kernel_profile_info(name); |
342 | }) |
343 | .def("get_kernel_profiler_records" , |
344 | [](Program *program) { |
345 | return program->profiler->get_traced_records(); |
346 | }) |
347 | .def( |
348 | "get_kernel_profiler_device_name" , |
349 | [](Program *program) { return program->profiler->get_device_name(); }) |
350 | .def("reinit_kernel_profiler_with_metrics" , |
351 | [](Program *program, const std::vector<std::string> metrics) { |
352 | return program->profiler->reinit_with_metrics(metrics); |
353 | }) |
354 | .def("kernel_profiler_total_time" , |
355 | [](Program *program) { return program->profiler->get_total_time(); }) |
356 | .def("set_kernel_profiler_toolkit" , |
357 | [](Program *program, const std::string toolkit_name) { |
358 | return program->profiler->set_profiler_toolkit(toolkit_name); |
359 | }) |
360 | .def("timeline_clear" , |
361 | [](Program *) { Timelines::get_instance().clear(); }) |
362 | .def("timeline_save" , |
363 | [](Program *, const std::string &fn) { |
364 | Timelines::get_instance().save(fn); |
365 | }) |
366 | .def("print_memory_profiler_info" , &Program::print_memory_profiler_info) |
367 | .def("finalize" , &Program::finalize) |
368 | .def("get_total_compilation_time" , &Program::get_total_compilation_time) |
369 | .def("get_snode_num_dynamically_allocated" , |
370 | &Program::get_snode_num_dynamically_allocated) |
371 | .def("synchronize" , &Program::synchronize) |
372 | .def("materialize_runtime" , &Program::materialize_runtime) |
373 | .def("make_aot_module_builder" , &Program::make_aot_module_builder) |
374 | .def("get_snode_tree_size" , &Program::get_snode_tree_size) |
375 | .def("get_snode_root" , &Program::get_snode_root, |
376 | py::return_value_policy::reference) |
377 | .def( |
378 | "create_kernel" , |
379 | [](Program *program, const std::function<void(Kernel *)> &body, |
380 | const std::string &name, AutodiffMode autodiff_mode) -> Kernel * { |
381 | py::gil_scoped_release release; |
382 | return &program->kernel(body, name, autodiff_mode); |
383 | }, |
384 | py::return_value_policy::reference) |
385 | .def("create_function" , &Program::create_function, |
386 | py::return_value_policy::reference) |
387 | .def("create_sparse_matrix_builder" , |
388 | [](Program *program, int n, int m, uint64 max_num_entries, |
389 | DataType dtype, const std::string &storage_format) { |
390 | TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) && |
391 | !arch_is_cuda(program->compile_config().arch), |
392 | "SparseMatrix only supports CPU and CUDA for now." ); |
393 | return SparseMatrixBuilder(n, m, max_num_entries, dtype, |
394 | storage_format, program); |
395 | }) |
396 | .def("create_sparse_matrix" , |
397 | [](Program *program, int n, int m, DataType dtype, |
398 | std::string storage_format) { |
399 | TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) && |
400 | !arch_is_cuda(program->compile_config().arch), |
401 | "SparseMatrix only supports CPU and CUDA for now." ); |
402 | if (arch_is_cpu(program->compile_config().arch)) |
403 | return make_sparse_matrix(n, m, dtype, storage_format); |
404 | else |
405 | return make_cu_sparse_matrix(n, m, dtype); |
406 | }) |
407 | .def("make_sparse_matrix_from_ndarray" , |
408 | [](Program *program, SparseMatrix &sm, const Ndarray &ndarray) { |
409 | TI_ERROR_IF(!arch_is_cpu(program->compile_config().arch) && |
410 | !arch_is_cuda(program->compile_config().arch), |
411 | "SparseMatrix only supports CPU and CUDA for now." ); |
412 | return make_sparse_matrix_from_ndarray(program, sm, ndarray); |
413 | }) |
414 | .def("make_id_expr" , |
415 | [](Program *program, const std::string &name) { |
416 | return Expr::make<IdExpression>(program->get_next_global_id(name)); |
417 | }) |
418 | .def( |
419 | "create_ndarray" , |
420 | [&](Program *program, const DataType &dt, |
421 | const std::vector<int> &shape, ExternalArrayLayout layout, |
422 | bool zero_fill) -> Ndarray * { |
423 | return program->create_ndarray(dt, shape, layout, zero_fill); |
424 | }, |
425 | py::arg("dt" ), py::arg("shape" ), |
426 | py::arg("layout" ) = ExternalArrayLayout::kNull, |
427 | py::arg("zero_fill" ) = false, py::return_value_policy::reference) |
428 | .def("delete_ndarray" , &Program::delete_ndarray) |
429 | .def( |
430 | "create_texture" , |
431 | [&](Program *program, const DataType &dt, int num_channels, |
432 | const std::vector<int> &shape) -> Texture * { |
433 | return program->create_texture(dt, num_channels, shape); |
434 | }, |
435 | py::arg("dt" ), py::arg("num_channels" ), |
436 | py::arg("shape" ) = py::tuple(), py::return_value_policy::reference) |
437 | .def("get_ndarray_data_ptr_as_int" , |
438 | [](Program *program, Ndarray *ndarray) { |
439 | return program->get_ndarray_data_ptr_as_int(ndarray); |
440 | }) |
441 | .def("fill_float" , |
442 | [](Program *program, Ndarray *ndarray, float val) { |
443 | program->fill_ndarray_fast_u32(ndarray, |
444 | reinterpret_cast<uint32_t &>(val)); |
445 | }) |
446 | .def("fill_int" , |
447 | [](Program *program, Ndarray *ndarray, int32_t val) { |
448 | program->fill_ndarray_fast_u32(ndarray, |
449 | reinterpret_cast<int32_t &>(val)); |
450 | }) |
451 | .def("fill_uint" , [](Program *program, Ndarray *ndarray, uint32_t val) { |
452 | program->fill_ndarray_fast_u32(ndarray, val); |
453 | }); |
454 | |
455 | py::class_<AotModuleBuilder>(m, "AotModuleBuilder" ) |
456 | .def("add_field" , &AotModuleBuilder::add_field) |
457 | .def("add" , &AotModuleBuilder::add) |
458 | .def("add_kernel_template" , &AotModuleBuilder::add_kernel_template) |
459 | .def("add_graph" , &AotModuleBuilder::add_graph) |
460 | .def("dump" , &AotModuleBuilder::dump); |
461 | |
462 | py::class_<Axis>(m, "Axis" ).def(py::init<int>()); |
463 | py::class_<SNode>(m, "SNode" ) |
464 | .def(py::init<>()) |
465 | .def_readwrite("parent" , &SNode::parent) |
466 | .def_readonly("type" , &SNode::type) |
467 | .def_readonly("id" , &SNode::id) |
468 | .def("dense" , |
469 | (SNode & (SNode::*)(const std::vector<Axis> &, |
470 | const std::vector<int> &, |
471 | const std::string &))(&SNode::dense), |
472 | py::return_value_policy::reference) |
473 | .def("pointer" , |
474 | (SNode & (SNode::*)(const std::vector<Axis> &, |
475 | const std::vector<int> &, |
476 | const std::string &))(&SNode::pointer), |
477 | py::return_value_policy::reference) |
478 | .def("hash" , |
479 | (SNode & (SNode::*)(const std::vector<Axis> &, |
480 | const std::vector<int> &, |
481 | const std::string &))(&SNode::hash), |
482 | py::return_value_policy::reference) |
483 | .def("dynamic" , &SNode::dynamic, py::return_value_policy::reference) |
484 | .def("bitmasked" , |
485 | (SNode & (SNode::*)(const std::vector<Axis> &, |
486 | const std::vector<int> &, |
487 | const std::string &))(&SNode::bitmasked), |
488 | py::return_value_policy::reference) |
489 | .def("bit_struct" , &SNode::bit_struct, py::return_value_policy::reference) |
490 | .def("quant_array" , &SNode::quant_array, |
491 | py::return_value_policy::reference) |
492 | .def("place" , &SNode::place) |
493 | .def("data_type" , [](SNode *snode) { return snode->dt; }) |
494 | .def("name" , [](SNode *snode) { return snode->name; }) |
495 | .def("get_num_ch" , |
496 | [](SNode *snode) -> int { return (int)snode->ch.size(); }) |
497 | .def( |
498 | "get_ch" , |
499 | [](SNode *snode, int i) -> SNode * { return snode->ch[i].get(); }, |
500 | py::return_value_policy::reference) |
501 | .def("lazy_grad" , &SNode::lazy_grad) |
502 | .def("lazy_dual" , &SNode::lazy_dual) |
503 | .def("allocate_adjoint_checkbit" , &SNode::allocate_adjoint_checkbit) |
504 | .def("read_int" , &SNode::read_int) |
505 | .def("read_uint" , &SNode::read_uint) |
506 | .def("read_float" , &SNode::read_float) |
507 | .def("has_adjoint" , &SNode::has_adjoint) |
508 | .def("has_adjoint_checkbit" , &SNode::has_adjoint_checkbit) |
509 | .def("get_snode_grad_type" , &SNode::get_snode_grad_type) |
510 | .def("has_dual" , &SNode::has_dual) |
511 | .def("is_primal" , &SNode::is_primal) |
512 | .def("is_place" , &SNode::is_place) |
513 | .def("get_expr" , &SNode::get_expr) |
514 | .def("write_int" , &SNode::write_int) |
515 | .def("write_uint" , &SNode::write_uint) |
516 | .def("write_float" , &SNode::write_float) |
517 | .def("get_shape_along_axis" , &SNode::shape_along_axis) |
518 | .def("get_physical_index_position" , |
519 | [](SNode *snode) { |
520 | return std::vector<int>( |
521 | snode->physical_index_position, |
522 | snode->physical_index_position + taichi_max_num_indices); |
523 | }) |
524 | .def("num_active_indices" , |
525 | [](SNode *snode) { return snode->num_active_indices; }) |
526 | .def_readonly("cell_size_bytes" , &SNode::cell_size_bytes) |
527 | .def_readonly("offset_bytes_in_parent_cell" , |
528 | &SNode::offset_bytes_in_parent_cell); |
529 | |
530 | py::class_<SNodeTree>(m, "SNodeTree" ) |
531 | .def("id" , &SNodeTree::id) |
532 | .def("destroy_snode_tree" , [](SNodeTree *snode_tree, Program *program) { |
533 | program->destroy_snode_tree(snode_tree); |
534 | }); |
535 | |
536 | py::class_<DeviceAllocation>(m, "DeviceAllocation" ) |
537 | .def(py::init([](uint64_t device, uint64_t alloc_id) -> DeviceAllocation { |
538 | DeviceAllocation alloc; |
539 | alloc.device = (Device *)device; |
540 | alloc.alloc_id = (DeviceAllocationId)alloc_id; |
541 | return alloc; |
542 | }), |
543 | py::arg("device" ), py::arg("alloc_id" )) |
544 | .def_readonly("device" , &DeviceAllocation::device) |
545 | .def_readonly("alloc_id" , &DeviceAllocation::alloc_id); |
546 | |
547 | py::class_<Ndarray>(m, "Ndarray" ) |
548 | .def("device_allocation_ptr" , &Ndarray::get_device_allocation_ptr_as_int) |
549 | .def("device_allocation" , &Ndarray::get_device_allocation) |
550 | .def("element_size" , &Ndarray::get_element_size) |
551 | .def("nelement" , &Ndarray::get_nelement) |
552 | .def("read_int" , &Ndarray::read_int) |
553 | .def("read_uint" , &Ndarray::read_uint) |
554 | .def("read_float" , &Ndarray::read_float) |
555 | .def("write_int" , &Ndarray::write_int) |
556 | .def("write_float" , &Ndarray::write_float) |
557 | .def("total_shape" , &Ndarray::total_shape) |
558 | .def("element_shape" , &Ndarray::get_element_shape) |
559 | .def("element_data_type" , &Ndarray::get_element_data_type) |
560 | .def_readonly("dtype" , &Ndarray::dtype) |
561 | .def_readonly("shape" , &Ndarray::shape); |
562 | |
563 | py::enum_<BufferFormat>(m, "Format" ) |
564 | #define PER_BUFFER_FORMAT(x) .value(#x, BufferFormat::x) |
565 | #include "taichi/inc/rhi_constants.inc.h" |
566 | #undef PER_EXTENSION |
567 | ; |
568 | |
569 | py::class_<Texture>(m, "Texture" ) |
570 | .def("device_allocation_ptr" , &Texture::get_device_allocation_ptr_as_int) |
571 | .def("from_ndarray" , &Texture::from_ndarray) |
572 | .def("from_snode" , &Texture::from_snode); |
573 | |
574 | py::enum_<aot::ArgKind>(m, "ArgKind" ) |
575 | .value("SCALAR" , aot::ArgKind::kScalar) |
576 | .value("NDARRAY" , aot::ArgKind::kNdarray) |
577 | // Using this MATRIX as Scalar alias, we can move to native matrix type |
578 | // when supported |
579 | .value("MATRIX" , aot::ArgKind::kMatrix) |
580 | .value("TEXTURE" , aot::ArgKind::kTexture) |
581 | .value("RWTEXTURE" , aot::ArgKind::kRWTexture) |
582 | .export_values(); |
583 | |
584 | py::class_<aot::Arg>(m, "Arg" ) |
585 | .def(py::init<aot::ArgKind, std::string, DataType &, size_t, |
586 | std::vector<int>>(), |
587 | py::arg("tag" ), py::arg("name" ), py::arg("dtype" ), |
588 | py::arg("field_dim" ), py::arg("element_shape" )) |
589 | .def(py::init<aot::ArgKind, std::string, DataType &, size_t, |
590 | std::vector<int>>(), |
591 | py::arg("tag" ), py::arg("name" ), py::arg("channel_format" ), |
592 | py::arg("num_channels" ), py::arg("shape" )) |
593 | .def_readonly("name" , &aot::Arg::name) |
594 | .def_readonly("element_shape" , &aot::Arg::element_shape) |
595 | .def_readonly("texture_shape" , &aot::Arg::element_shape) |
596 | .def_readonly("field_dim" , &aot::Arg::field_dim) |
597 | .def_readonly("num_channels" , &aot::Arg::num_channels) |
598 | .def("dtype" , &aot::Arg::dtype) |
599 | .def("channel_format" , &aot::Arg::dtype); |
600 | |
601 | py::class_<Node>(m, "Node" ); // NOLINT(bugprone-unused-raii) |
602 | |
603 | py::class_<Sequential, Node>(m, "Sequential" ) |
604 | .def(py::init<GraphBuilder *>()) |
605 | .def("append" , &Sequential::append) |
606 | .def("dispatch" , &Sequential::dispatch); |
607 | |
608 | py::class_<GraphBuilder>(m, "GraphBuilder" ) |
609 | .def(py::init<>()) |
610 | .def("dispatch" , &GraphBuilder::dispatch) |
611 | .def("compile" , &GraphBuilder::compile) |
612 | .def("create_sequential" , &GraphBuilder::new_sequential_node, |
613 | py::return_value_policy::reference) |
614 | .def("seq" , &GraphBuilder::seq, py::return_value_policy::reference); |
615 | |
616 | py::class_<aot::CompiledGraph>(m, "CompiledGraph" ) |
617 | .def("run" , [](aot::CompiledGraph *self, const py::dict &pyargs) { |
618 | std::unordered_map<std::string, aot::IValue> args; |
619 | for (auto it : pyargs) { |
620 | std::string arg_name = py::cast<std::string>(it.first); |
621 | auto tag = self->args[arg_name].tag; |
622 | if (tag == aot::ArgKind::kNdarray) { |
623 | auto &val = it.second.cast<Ndarray &>(); |
624 | args.insert( |
625 | {py::cast<std::string>(it.first), aot::IValue::create(val)}); |
626 | } else if (tag == aot::ArgKind::kTexture || |
627 | tag == aot::ArgKind::kRWTexture) { |
628 | auto &val = it.second.cast<Texture &>(); |
629 | args.insert( |
630 | {py::cast<std::string>(it.first), aot::IValue::create(val)}); |
631 | |
632 | } else if (tag == aot::ArgKind::kScalar || |
633 | tag == aot::ArgKind::kMatrix) { |
634 | std::string arg_name = py::cast<std::string>(it.first); |
635 | auto expected_dtype = self->args[arg_name].dtype(); |
636 | if (expected_dtype == PrimitiveType::i32) { |
637 | args.insert( |
638 | {arg_name, aot::IValue::create(py::cast<int>(it.second))}); |
639 | } else if (expected_dtype == PrimitiveType::i64) { |
640 | args.insert( |
641 | {arg_name, aot::IValue::create(py::cast<int64>(it.second))}); |
642 | } else if (expected_dtype == PrimitiveType::f32) { |
643 | args.insert( |
644 | {arg_name, aot::IValue::create(py::cast<float>(it.second))}); |
645 | } else if (expected_dtype == PrimitiveType::f64) { |
646 | args.insert( |
647 | {arg_name, aot::IValue::create(py::cast<double>(it.second))}); |
648 | } else if (expected_dtype == PrimitiveType::i16) { |
649 | args.insert( |
650 | {arg_name, aot::IValue::create(py::cast<int16>(it.second))}); |
651 | } else if (expected_dtype == PrimitiveType::u32) { |
652 | args.insert( |
653 | {arg_name, aot::IValue::create(py::cast<uint32>(it.second))}); |
654 | } else if (expected_dtype == PrimitiveType::u64) { |
655 | args.insert( |
656 | {arg_name, aot::IValue::create(py::cast<uint64>(it.second))}); |
657 | } else if (expected_dtype == PrimitiveType::u16) { |
658 | args.insert( |
659 | {arg_name, aot::IValue::create(py::cast<uint16>(it.second))}); |
660 | } else if (expected_dtype == PrimitiveType::u8) { |
661 | args.insert({arg_name, |
662 | aot::IValue::create(py::cast<uint8_t>(it.second))}); |
663 | } else if (expected_dtype == PrimitiveType::i8) { |
664 | args.insert( |
665 | {arg_name, aot::IValue::create(py::cast<int8_t>(it.second))}); |
666 | } else { |
667 | TI_NOT_IMPLEMENTED; |
668 | } |
669 | } else { |
670 | TI_NOT_IMPLEMENTED; |
671 | } |
672 | } |
673 | self->run(args); |
674 | }); |
675 | |
676 | py::class_<Kernel>(m, "Kernel" ) |
677 | .def("no_activate" , |
678 | [](Kernel *self, SNode *snode) { |
679 | // TODO(#2193): Also apply to @ti.func? |
680 | self->no_activate.push_back(snode); |
681 | }) |
682 | .def("insert_scalar_param" , &Kernel::insert_scalar_param) |
683 | .def("insert_arr_param" , &Kernel::insert_arr_param) |
684 | .def("insert_texture_param" , &Kernel::insert_texture_param) |
685 | .def("insert_ret" , &Kernel::insert_ret) |
686 | .def("finalize_rets" , &Kernel::finalize_rets) |
687 | .def("get_ret_int" , &Kernel::get_ret_int) |
688 | .def("get_ret_uint" , &Kernel::get_ret_uint) |
689 | .def("get_ret_float" , &Kernel::get_ret_float) |
690 | .def("get_ret_int_tensor" , &Kernel::get_ret_int_tensor) |
691 | .def("get_ret_uint_tensor" , &Kernel::get_ret_uint_tensor) |
692 | .def("get_ret_float_tensor" , &Kernel::get_ret_float_tensor) |
693 | .def("get_struct_ret_int" , &Kernel::get_struct_ret_int) |
694 | .def("get_struct_ret_uint" , &Kernel::get_struct_ret_uint) |
695 | .def("get_struct_ret_float" , &Kernel::get_struct_ret_float) |
696 | .def("make_launch_context" , &Kernel::make_launch_context) |
697 | .def( |
698 | "ast_builder" , |
699 | [](Kernel *self) -> ASTBuilder * { |
700 | return &self->context->builder(); |
701 | }, |
702 | py::return_value_policy::reference) |
703 | .def("__call__" , |
704 | [](Kernel *kernel, Kernel::LaunchContextBuilder &launch_ctx) { |
705 | py::gil_scoped_release release; |
706 | kernel->operator()(kernel->program->compile_config(), launch_ctx); |
707 | }); |
708 | |
709 | py::class_<Kernel::LaunchContextBuilder>(m, "KernelLaunchContext" ) |
710 | .def("set_arg_int" , &Kernel::LaunchContextBuilder::set_arg_int) |
711 | .def("set_arg_uint" , &Kernel::LaunchContextBuilder::set_arg_uint) |
712 | .def("set_arg_float" , &Kernel::LaunchContextBuilder::set_arg_float) |
713 | .def("set_arg_external_array_with_shape" , |
714 | &Kernel::LaunchContextBuilder::set_arg_external_array_with_shape) |
715 | .def("set_arg_ndarray" , &Kernel::LaunchContextBuilder::set_arg_ndarray) |
716 | .def("set_arg_ndarray_with_grad" , |
717 | &Kernel::LaunchContextBuilder::set_arg_ndarray_with_grad) |
718 | .def("set_arg_texture" , &Kernel::LaunchContextBuilder::set_arg_texture) |
719 | .def("set_arg_rw_texture" , |
720 | &Kernel::LaunchContextBuilder::set_arg_rw_texture) |
721 | .def("set_extra_arg_int" , |
722 | &Kernel::LaunchContextBuilder::set_extra_arg_int); |
723 | |
724 | py::class_<Function>(m, "Function" ) |
725 | .def("insert_scalar_param" , &Function::insert_scalar_param) |
726 | .def("insert_arr_param" , &Function::insert_arr_param) |
727 | .def("insert_texture_param" , &Function::insert_texture_param) |
728 | .def("insert_ret" , &Function::insert_ret) |
729 | .def("set_function_body" , |
730 | py::overload_cast<const std::function<void()> &>( |
731 | &Function::set_function_body)) |
732 | .def("finalize_rets" , &Function::finalize_rets) |
733 | .def( |
734 | "ast_builder" , |
735 | [](Function *self) -> ASTBuilder * { |
736 | return &self->context->builder(); |
737 | }, |
738 | py::return_value_policy::reference); |
739 | |
740 | py::class_<Expr> expr(m, "Expr" ); |
741 | expr.def("snode" , &Expr::snode, py::return_value_policy::reference) |
742 | .def("is_external_tensor_expr" , |
743 | [](Expr *expr) { return expr->is<ExternalTensorExpression>(); }) |
744 | .def("is_index_expr" , |
745 | [](Expr *expr) { return expr->is<IndexExpression>(); }) |
746 | .def("is_primal" , |
747 | [](Expr *expr) { |
748 | return expr->cast<FieldExpression>()->snode_grad_type == |
749 | SNodeGradType::kPrimal; |
750 | }) |
751 | .def("set_tb" , &Expr::set_tb) |
752 | .def("set_name" , |
753 | [&](Expr *expr, std::string na) { |
754 | expr->cast<FieldExpression>()->name = na; |
755 | }) |
756 | .def("set_grad_type" , |
757 | [&](Expr *expr, SNodeGradType t) { |
758 | expr->cast<FieldExpression>()->snode_grad_type = t; |
759 | }) |
760 | .def("set_adjoint" , &Expr::set_adjoint) |
761 | .def("set_adjoint_checkbit" , &Expr::set_adjoint_checkbit) |
762 | .def("set_dual" , &Expr::set_dual) |
763 | .def("set_dynamic_index_stride" , |
764 | [&](Expr *expr, int dynamic_index_stride) { |
765 | auto matrix_field = expr->cast<MatrixFieldExpression>(); |
766 | matrix_field->dynamic_indexable = true; |
767 | matrix_field->dynamic_index_stride = dynamic_index_stride; |
768 | }) |
769 | .def("get_dynamic_indexable" , |
770 | [&](Expr *expr) -> bool { |
771 | return expr->cast<MatrixFieldExpression>()->dynamic_indexable; |
772 | }) |
773 | .def("get_dynamic_index_stride" , |
774 | [&](Expr *expr) -> int { |
775 | return expr->cast<MatrixFieldExpression>()->dynamic_index_stride; |
776 | }) |
777 | .def( |
778 | "get_dt" , |
779 | [&](Expr *expr) -> const Type * { |
780 | return expr->cast<FieldExpression>()->dt; |
781 | }, |
782 | py::return_value_policy::reference) |
783 | .def("get_ret_type" , &Expr::get_ret_type) |
784 | .def("is_tensor" , |
785 | [](Expr *expr) { return expr->expr->ret_type->is<TensorType>(); }) |
786 | .def("get_shape" , |
787 | [](Expr *expr) -> std::optional<std::vector<int>> { |
788 | if (expr->expr->ret_type->is<TensorType>()) { |
789 | return std::optional<std::vector<int>>( |
790 | expr->expr->ret_type->cast<TensorType>()->get_shape()); |
791 | } |
792 | return std::nullopt; |
793 | }) |
794 | .def("type_check" , &Expr::type_check) |
795 | .def("get_expr_name" , |
796 | [](Expr *expr) { return expr->cast<FieldExpression>()->name; }) |
797 | .def("get_raw_address" , [](Expr *expr) { return (uint64)expr; }) |
798 | .def("get_underlying_ptr_address" , [](Expr *e) { |
799 | // The reason that there are both get_raw_address() and |
800 | // get_underlying_ptr_address() is that Expr itself is mostly wrapper |
801 | // around its underlying |expr| (of type Expression). Expr |e| can be |
802 | // temporary, while the underlying |expr| is mostly persistent. |
803 | // |
804 | // Same get_raw_address() implies that get_underlying_ptr_address() are |
805 | // also the same. The reverse is not true. |
806 | return (uint64)e->expr.get(); |
807 | }); |
808 | |
809 | py::class_<ExprGroup>(m, "ExprGroup" ) |
810 | .def(py::init<>()) |
811 | .def("size" , [](ExprGroup *eg) { return eg->exprs.size(); }) |
812 | .def("push_back" , &ExprGroup::push_back); |
813 | |
814 | py::class_<Stmt>(m, "Stmt" ); // NOLINT(bugprone-unused-raii) |
815 | |
816 | m.def("insert_internal_func_call" , |
817 | [&](const std::string &func_name, const ExprGroup &args, |
818 | bool with_runtime_context) { |
819 | return Expr::make<InternalFuncCallExpression>(func_name, args.exprs, |
820 | with_runtime_context); |
821 | }); |
822 | |
823 | m.def("make_get_element_expr" , |
824 | Expr::make<GetElementExpression, const Expr &, std::vector<int>>); |
825 | |
826 | m.def("value_cast" , static_cast<Expr (*)(const Expr &expr, DataType)>(cast)); |
827 | m.def("bits_cast" , |
828 | static_cast<Expr (*)(const Expr &expr, DataType)>(bit_cast)); |
829 | |
830 | m.def("expr_atomic_add" , [&](const Expr &a, const Expr &b) { |
831 | return Expr::make<AtomicOpExpression>(AtomicOpType::add, a, b); |
832 | }); |
833 | |
834 | m.def("expr_atomic_sub" , [&](const Expr &a, const Expr &b) { |
835 | return Expr::make<AtomicOpExpression>(AtomicOpType::sub, a, b); |
836 | }); |
837 | |
838 | m.def("expr_atomic_min" , [&](const Expr &a, const Expr &b) { |
839 | return Expr::make<AtomicOpExpression>(AtomicOpType::min, a, b); |
840 | }); |
841 | |
842 | m.def("expr_atomic_max" , [&](const Expr &a, const Expr &b) { |
843 | return Expr::make<AtomicOpExpression>(AtomicOpType::max, a, b); |
844 | }); |
845 | |
846 | m.def("expr_atomic_bit_and" , [&](const Expr &a, const Expr &b) { |
847 | return Expr::make<AtomicOpExpression>(AtomicOpType::bit_and, a, b); |
848 | }); |
849 | |
850 | m.def("expr_atomic_bit_or" , [&](const Expr &a, const Expr &b) { |
851 | return Expr::make<AtomicOpExpression>(AtomicOpType::bit_or, a, b); |
852 | }); |
853 | |
854 | m.def("expr_atomic_bit_xor" , [&](const Expr &a, const Expr &b) { |
855 | return Expr::make<AtomicOpExpression>(AtomicOpType::bit_xor, a, b); |
856 | }); |
857 | |
858 | m.def("expr_assume_in_range" , assume_range); |
859 | |
860 | m.def("expr_loop_unique" , loop_unique); |
861 | |
862 | m.def("expr_field" , expr_field); |
863 | |
864 | m.def("expr_matrix_field" , expr_matrix_field); |
865 | |
866 | #define DEFINE_EXPRESSION_OP(x) m.def("expr_" #x, expr_##x); |
867 | |
868 | DEFINE_EXPRESSION_OP(neg) |
869 | DEFINE_EXPRESSION_OP(sqrt) |
870 | DEFINE_EXPRESSION_OP(round) |
871 | DEFINE_EXPRESSION_OP(floor) |
872 | DEFINE_EXPRESSION_OP(ceil) |
873 | DEFINE_EXPRESSION_OP(abs) |
874 | DEFINE_EXPRESSION_OP(sin) |
875 | DEFINE_EXPRESSION_OP(asin) |
876 | DEFINE_EXPRESSION_OP(cos) |
877 | DEFINE_EXPRESSION_OP(acos) |
878 | DEFINE_EXPRESSION_OP(tan) |
879 | DEFINE_EXPRESSION_OP(tanh) |
880 | DEFINE_EXPRESSION_OP(inv) |
881 | DEFINE_EXPRESSION_OP(rcp) |
882 | DEFINE_EXPRESSION_OP(rsqrt) |
883 | DEFINE_EXPRESSION_OP(exp) |
884 | DEFINE_EXPRESSION_OP(log) |
885 | |
886 | DEFINE_EXPRESSION_OP(select) |
887 | DEFINE_EXPRESSION_OP(ifte) |
888 | |
889 | DEFINE_EXPRESSION_OP(cmp_le) |
890 | DEFINE_EXPRESSION_OP(cmp_lt) |
891 | DEFINE_EXPRESSION_OP(cmp_ge) |
892 | DEFINE_EXPRESSION_OP(cmp_gt) |
893 | DEFINE_EXPRESSION_OP(cmp_ne) |
894 | DEFINE_EXPRESSION_OP(cmp_eq) |
895 | |
896 | DEFINE_EXPRESSION_OP(bit_and) |
897 | DEFINE_EXPRESSION_OP(bit_or) |
898 | DEFINE_EXPRESSION_OP(bit_xor) |
899 | DEFINE_EXPRESSION_OP(bit_shl) |
900 | DEFINE_EXPRESSION_OP(bit_shr) |
901 | DEFINE_EXPRESSION_OP(bit_sar) |
902 | DEFINE_EXPRESSION_OP(bit_not) |
903 | |
904 | DEFINE_EXPRESSION_OP(logic_not) |
905 | DEFINE_EXPRESSION_OP(logical_and) |
906 | DEFINE_EXPRESSION_OP(logical_or) |
907 | |
908 | DEFINE_EXPRESSION_OP(add) |
909 | DEFINE_EXPRESSION_OP(sub) |
910 | DEFINE_EXPRESSION_OP(mul) |
911 | DEFINE_EXPRESSION_OP(div) |
912 | DEFINE_EXPRESSION_OP(truediv) |
913 | DEFINE_EXPRESSION_OP(floordiv) |
914 | DEFINE_EXPRESSION_OP(mod) |
915 | DEFINE_EXPRESSION_OP(max) |
916 | DEFINE_EXPRESSION_OP(min) |
917 | DEFINE_EXPRESSION_OP(atan2) |
918 | DEFINE_EXPRESSION_OP(pow) |
919 | |
920 | #undef DEFINE_EXPRESSION_OP |
921 | |
922 | m.def("make_global_load_stmt" , Stmt::make<GlobalLoadStmt, Stmt *>); |
923 | m.def("make_global_store_stmt" , Stmt::make<GlobalStoreStmt, Stmt *, Stmt *>); |
924 | m.def("make_frontend_assign_stmt" , |
925 | Stmt::make<FrontendAssignStmt, const Expr &, const Expr &>); |
926 | |
927 | m.def("make_arg_load_expr" , |
928 | Expr::make<ArgLoadExpression, int, const DataType &, bool>); |
929 | |
930 | m.def("make_reference" , Expr::make<ReferenceExpression, const Expr &>); |
931 | |
932 | m.def("make_external_tensor_expr" , |
933 | Expr::make<ExternalTensorExpression, const DataType &, int, int, int, |
934 | const std::vector<int> &>); |
935 | |
936 | m.def("make_external_grad_tensor_expr" , |
937 | Expr::make<ExternalTensorExpression, Expr *>); |
938 | |
939 | m.def("make_rand_expr" , Expr::make<RandExpression, const DataType &>); |
940 | |
941 | m.def("make_const_expr_int" , |
942 | Expr::make<ConstExpression, const DataType &, int64>); |
943 | |
944 | m.def("make_const_expr_fp" , |
945 | Expr::make<ConstExpression, const DataType &, float64>); |
946 | |
947 | m.def("make_texture_ptr_expr" , Expr::make<TexturePtrExpression, int, int>); |
948 | m.def("make_rw_texture_ptr_expr" , |
949 | Expr::make<TexturePtrExpression, int, int, int, const DataType &, int>); |
950 | |
951 | auto &&texture = |
952 | py::enum_<TextureOpType>(m, "TextureOpType" , py::arithmetic()); |
953 | for (int t = 0; t <= (int)TextureOpType::kStore; t++) |
954 | texture.value(texture_op_type_name(TextureOpType(t)).c_str(), |
955 | TextureOpType(t)); |
956 | texture.export_values(); |
957 | |
958 | auto &&bin = py::enum_<BinaryOpType>(m, "BinaryOpType" , py::arithmetic()); |
959 | for (int t = 0; t <= (int)BinaryOpType::undefined; t++) |
960 | bin.value(binary_op_type_name(BinaryOpType(t)).c_str(), BinaryOpType(t)); |
961 | bin.export_values(); |
962 | m.def("make_binary_op_expr" , |
963 | Expr::make<BinaryOpExpression, const BinaryOpType &, const Expr &, |
964 | const Expr &>); |
965 | |
966 | auto &&unary = py::enum_<UnaryOpType>(m, "UnaryOpType" , py::arithmetic()); |
967 | for (int t = 0; t <= (int)UnaryOpType::undefined; t++) |
968 | unary.value(unary_op_type_name(UnaryOpType(t)).c_str(), UnaryOpType(t)); |
969 | unary.export_values(); |
970 | m.def("make_unary_op_expr" , |
971 | Expr::make<UnaryOpExpression, const UnaryOpType &, const Expr &>); |
972 | #define PER_TYPE(x) \ |
973 | m.attr(("DataType_" + data_type_name(PrimitiveType::x)).c_str()) = \ |
974 | PrimitiveType::x; |
975 | #include "taichi/inc/data_type.inc.h" |
976 | #undef PER_TYPE |
977 | |
978 | m.def("data_type_size" , data_type_size); |
979 | m.def("is_quant" , is_quant); |
980 | m.def("is_integral" , is_integral); |
981 | m.def("is_signed" , is_signed); |
982 | m.def("is_real" , is_real); |
983 | m.def("is_unsigned" , is_unsigned); |
984 | m.def("is_tensor" , is_tensor); |
985 | |
986 | m.def("data_type_name" , data_type_name); |
987 | |
988 | m.def( |
989 | "subscript_with_multiple_indices" , |
990 | Expr::make<IndexExpression, const Expr &, const std::vector<ExprGroup> &, |
991 | const std::vector<int> &, std::string>); |
992 | |
993 | m.def("get_external_tensor_element_dim" , [](const Expr &expr) { |
994 | TI_ASSERT(expr.is<ExternalTensorExpression>()); |
995 | return expr.cast<ExternalTensorExpression>()->element_dim; |
996 | }); |
997 | |
998 | m.def("get_external_tensor_element_shape" , [](const Expr &expr) { |
999 | TI_ASSERT(expr.is<ExternalTensorExpression>()); |
1000 | auto external_tensor_expr = expr.cast<ExternalTensorExpression>(); |
1001 | return external_tensor_expr->dt.get_shape(); |
1002 | }); |
1003 | |
1004 | m.def("get_external_tensor_dim" , [](const Expr &expr) { |
1005 | if (expr.is<ExternalTensorExpression>()) { |
1006 | return expr.cast<ExternalTensorExpression>()->dim; |
1007 | } else if (expr.is<TexturePtrExpression>()) { |
1008 | return expr.cast<TexturePtrExpression>()->num_dims; |
1009 | } else { |
1010 | TI_ASSERT(false); |
1011 | return 0; |
1012 | } |
1013 | }); |
1014 | |
1015 | m.def("get_external_tensor_shape_along_axis" , |
1016 | Expr::make<ExternalTensorShapeAlongAxisExpression, const Expr &, int>); |
1017 | |
1018 | // Mesh related. |
1019 | m.def("get_relation_size" , [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx, |
1020 | mesh::MeshElementType to_type) { |
1021 | return Expr::make<MeshRelationAccessExpression>(mesh_ptr.ptr.get(), |
1022 | mesh_idx, to_type); |
1023 | }); |
1024 | |
1025 | m.def("get_relation_access" , |
1026 | [](mesh::MeshPtr mesh_ptr, const Expr &mesh_idx, |
1027 | mesh::MeshElementType to_type, const Expr &neighbor_idx) { |
1028 | return Expr::make<MeshRelationAccessExpression>( |
1029 | mesh_ptr.ptr.get(), mesh_idx, to_type, neighbor_idx); |
1030 | }); |
1031 | |
1032 | py::class_<FunctionKey>(m, "FunctionKey" ) |
1033 | .def(py::init<const std::string &, int, int>()) |
1034 | .def_readonly("instance_id" , &FunctionKey::instance_id); |
1035 | |
1036 | m.def("test_throw" , [] { |
1037 | try { |
1038 | throw IRModified(); |
1039 | } catch (IRModified) { |
1040 | TI_INFO("caught" ); |
1041 | } |
1042 | }); |
1043 | |
1044 | m.def("test_throw" , [] { throw IRModified(); }); |
1045 | |
1046 | #if TI_WITH_LLVM |
1047 | m.def("libdevice_path" , libdevice_path); |
1048 | #endif |
1049 | |
1050 | m.def("host_arch" , host_arch); |
1051 | m.def("arch_uses_llvm" , arch_uses_llvm); |
1052 | |
1053 | m.def("set_lib_dir" , [&](const std::string &dir) { compiled_lib_dir = dir; }); |
1054 | m.def("set_tmp_dir" , [&](const std::string &dir) { runtime_tmp_dir = dir; }); |
1055 | |
1056 | m.def("get_commit_hash" , get_commit_hash); |
1057 | m.def("get_version_string" , get_version_string); |
1058 | m.def("get_version_major" , get_version_major); |
1059 | m.def("get_version_minor" , get_version_minor); |
1060 | m.def("get_version_patch" , get_version_patch); |
1061 | m.def("get_llvm_target_support" , [] { |
1062 | #if defined(TI_WITH_LLVM) |
1063 | return LLVM_VERSION_STRING; |
1064 | #else |
1065 | return "targets unsupported" ; |
1066 | #endif |
1067 | }); |
1068 | m.def("test_printf" , [] { printf("test_printf\n" ); }); |
1069 | m.def("test_logging" , [] { TI_INFO("test_logging" ); }); |
1070 | m.def("trigger_crash" , [] { *(int *)(1) = 0; }); |
1071 | m.def("get_max_num_indices" , [] { return taichi_max_num_indices; }); |
1072 | m.def("get_max_num_args" , [] { return taichi_max_num_args; }); |
1073 | m.def("test_threading" , test_threading); |
1074 | m.def("is_extension_supported" , is_extension_supported); |
1075 | |
1076 | m.def("record_action_entry" , |
1077 | [](std::string name, |
1078 | std::vector<std::pair<std::string, |
1079 | std::variant<std::string, int, float>>> args) { |
1080 | std::vector<ActionArg> acts; |
1081 | for (auto const &[k, v] : args) { |
1082 | if (std::holds_alternative<int>(v)) { |
1083 | acts.push_back(ActionArg(k, std::get<int>(v))); |
1084 | } else if (std::holds_alternative<float>(v)) { |
1085 | acts.push_back(ActionArg(k, std::get<float>(v))); |
1086 | } else { |
1087 | acts.push_back(ActionArg(k, std::get<std::string>(v))); |
1088 | } |
1089 | } |
1090 | ActionRecorder::get_instance().record(name, acts); |
1091 | }); |
1092 | |
1093 | m.def("start_recording" , [](const std::string &fn) { |
1094 | ActionRecorder::get_instance().start_recording(fn); |
1095 | }); |
1096 | |
1097 | m.def("stop_recording" , |
1098 | []() { ActionRecorder::get_instance().stop_recording(); }); |
1099 | |
1100 | m.def("query_int64" , [](const std::string &key) { |
1101 | if (key == "cuda_compute_capability" ) { |
1102 | #if defined(TI_WITH_CUDA) |
1103 | return CUDAContext::get_instance().get_compute_capability(); |
1104 | #else |
1105 | TI_NOT_IMPLEMENTED |
1106 | #endif |
1107 | } else { |
1108 | TI_ERROR("Key {} not supported in query_int64" , key); |
1109 | } |
1110 | }); |
1111 | |
1112 | // Type system |
1113 | |
1114 | py::class_<Type>(m, "Type" ).def("to_string" , &Type::to_string); |
1115 | |
1116 | m.def("promoted_type" , promoted_type); |
1117 | |
1118 | // Note that it is important to specify py::return_value_policy::reference for |
1119 | // the factory methods, otherwise pybind11 will delete the Types owned by |
1120 | // TypeFactory on Python-scope pointer destruction. |
1121 | py::class_<TypeFactory>(m, "TypeFactory" ) |
1122 | .def("get_quant_int_type" , &TypeFactory::get_quant_int_type, |
1123 | py::arg("num_bits" ), py::arg("is_signed" ), py::arg("compute_type" ), |
1124 | py::return_value_policy::reference) |
1125 | .def("get_quant_fixed_type" , &TypeFactory::get_quant_fixed_type, |
1126 | py::arg("digits_type" ), py::arg("compute_type" ), py::arg("scale" ), |
1127 | py::return_value_policy::reference) |
1128 | .def("get_quant_float_type" , &TypeFactory::get_quant_float_type, |
1129 | py::arg("digits_type" ), py::arg("exponent_type" ), |
1130 | py::arg("compute_type" ), py::return_value_policy::reference) |
1131 | .def( |
1132 | "get_tensor_type" , |
1133 | [&](TypeFactory *factory, std::vector<int> shape, |
1134 | const DataType &element_type) { |
1135 | return factory->create_tensor_type(shape, element_type); |
1136 | }, |
1137 | py::return_value_policy::reference) |
1138 | .def( |
1139 | "get_struct_type" , |
1140 | [&](TypeFactory *factory, |
1141 | std::vector<std::pair<DataType, std::string>> elements) { |
1142 | std::vector<StructMember> members; |
1143 | for (auto &[type, name] : elements) { |
1144 | members.push_back({type, name}); |
1145 | } |
1146 | return DataType(factory->get_struct_type(members)); |
1147 | }, |
1148 | py::return_value_policy::reference); |
1149 | |
1150 | m.def("get_type_factory_instance" , TypeFactory::get_instance, |
1151 | py::return_value_policy::reference); |
1152 | |
1153 | // NOLINTNEXTLINE(bugprone-unused-raii) |
1154 | py::class_<BitStructType>(m, "BitStructType" ); |
1155 | py::class_<BitStructTypeBuilder>(m, "BitStructTypeBuilder" ) |
1156 | .def(py::init<int>()) |
1157 | .def("begin_placing_shared_exponent" , |
1158 | &BitStructTypeBuilder::begin_placing_shared_exponent) |
1159 | .def("end_placing_shared_exponent" , |
1160 | &BitStructTypeBuilder::end_placing_shared_exponent) |
1161 | .def("add_member" , &BitStructTypeBuilder::add_member) |
1162 | .def("build" , &BitStructTypeBuilder::build, |
1163 | py::return_value_policy::reference); |
1164 | |
1165 | py::class_<SNodeRegistry>(m, "SNodeRegistry" ) |
1166 | .def(py::init<>()) |
1167 | .def("create_root" , &SNodeRegistry::create_root, |
1168 | py::return_value_policy::reference); |
1169 | |
1170 | m.def( |
1171 | "finalize_snode_tree" , |
1172 | [](SNodeRegistry *registry, const SNode *root, Program *program, |
1173 | bool compile_only) -> SNodeTree * { |
1174 | return program->add_snode_tree(registry->finalize(root), compile_only); |
1175 | }, |
1176 | py::return_value_policy::reference); |
1177 | |
1178 | // Sparse Matrix |
1179 | py::class_<SparseMatrixBuilder>(m, "SparseMatrixBuilder" ) |
1180 | .def("print_triplets_eigen" , &SparseMatrixBuilder::print_triplets_eigen) |
1181 | .def("print_triplets_cuda" , &SparseMatrixBuilder::print_triplets_cuda) |
1182 | .def("get_ndarray_data_ptr" , &SparseMatrixBuilder::get_ndarray_data_ptr) |
1183 | .def("build" , &SparseMatrixBuilder::build) |
1184 | .def("build_cuda" , &SparseMatrixBuilder::build_cuda) |
1185 | .def("get_addr" , [](SparseMatrixBuilder *mat) { return uint64(mat); }); |
1186 | |
1187 | py::class_<SparseMatrix>(m, "SparseMatrix" ) |
1188 | .def(py::init<>()) |
1189 | .def(py::init<int, int, DataType>(), py::arg("rows" ), py::arg("cols" ), |
1190 | py::arg("dt" ) = PrimitiveType::f32) |
1191 | .def(py::init<SparseMatrix &>()) |
1192 | .def("to_string" , &SparseMatrix::to_string) |
1193 | .def("get_element" , &SparseMatrix::get_element<float32>) |
1194 | .def("set_element" , &SparseMatrix::set_element<float32>) |
1195 | .def("num_rows" , &SparseMatrix::num_rows) |
1196 | .def("num_cols" , &SparseMatrix::num_cols); |
1197 | |
1198 | #define MAKE_SPARSE_MATRIX(TYPE, STORAGE, VTYPE) \ |
1199 | using STORAGE##TYPE##EigenMatrix = \ |
1200 | Eigen::SparseMatrix<float##TYPE, Eigen::STORAGE>; \ |
1201 | py::class_<EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>, SparseMatrix>( \ |
1202 | m, #VTYPE #STORAGE "_EigenSparseMatrix") \ |
1203 | .def(py::init<int, int, DataType>()) \ |
1204 | .def(py::init<EigenSparseMatrix<STORAGE##TYPE##EigenMatrix> &>()) \ |
1205 | .def(py::init<const STORAGE##TYPE##EigenMatrix &>()) \ |
1206 | .def(py::self += py::self) \ |
1207 | .def(py::self + py::self) \ |
1208 | .def(py::self -= py::self) \ |
1209 | .def(py::self - py::self) \ |
1210 | .def(py::self *= float##TYPE()) \ |
1211 | .def(py::self *float##TYPE()) \ |
1212 | .def(float##TYPE() * py::self) \ |
1213 | .def(py::self *py::self) \ |
1214 | .def("matmul", &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::matmul) \ |
1215 | .def("spmv", &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::spmv) \ |
1216 | .def("transpose", \ |
1217 | &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::transpose) \ |
1218 | .def("get_element", \ |
1219 | &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::get_element< \ |
1220 | float##TYPE>) \ |
1221 | .def("set_element", \ |
1222 | &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::set_element< \ |
1223 | float##TYPE>) \ |
1224 | .def("mat_vec_mul", \ |
1225 | &EigenSparseMatrix<STORAGE##TYPE##EigenMatrix>::mat_vec_mul< \ |
1226 | Eigen::VectorX##VTYPE>); |
1227 | |
1228 | MAKE_SPARSE_MATRIX(32, ColMajor, f); |
1229 | MAKE_SPARSE_MATRIX(32, RowMajor, f); |
1230 | MAKE_SPARSE_MATRIX(64, ColMajor, d); |
1231 | MAKE_SPARSE_MATRIX(64, RowMajor, d); |
1232 | |
1233 | py::class_<CuSparseMatrix, SparseMatrix>(m, "CuSparseMatrix" ) |
1234 | .def(py::init<int, int, DataType>()) |
1235 | .def(py::init<const CuSparseMatrix &>()) |
1236 | .def("spmv" , &CuSparseMatrix::spmv) |
1237 | .def(py::self + py::self) |
1238 | .def(py::self - py::self) |
1239 | .def(py::self * float32()) |
1240 | .def(float32() * py::self) |
1241 | .def("matmul" , &CuSparseMatrix::matmul) |
1242 | .def("transpose" , &CuSparseMatrix::transpose) |
1243 | .def("get_element" , &CuSparseMatrix::get_element) |
1244 | .def("to_string" , &CuSparseMatrix::to_string); |
1245 | |
1246 | py::class_<SparseSolver>(m, "SparseSolver" ) |
1247 | .def("compute" , &SparseSolver::compute) |
1248 | .def("analyze_pattern" , &SparseSolver::analyze_pattern) |
1249 | .def("factorize" , &SparseSolver::factorize) |
1250 | .def("info" , &SparseSolver::info); |
1251 | |
1252 | #define REGISTER_EIGEN_SOLVER(dt, type, order, fd) \ |
1253 | py::class_<EigenSparseSolver##dt##type##order, SparseSolver>( \ |
1254 | m, "EigenSparseSolver" #dt #type #order) \ |
1255 | .def("compute", &EigenSparseSolver##dt##type##order::compute) \ |
1256 | .def("analyze_pattern", \ |
1257 | &EigenSparseSolver##dt##type##order::analyze_pattern) \ |
1258 | .def("factorize", &EigenSparseSolver##dt##type##order::factorize) \ |
1259 | .def("solve", \ |
1260 | &EigenSparseSolver##dt##type##order::solve<Eigen::VectorX##fd>) \ |
1261 | .def("solve_rf", \ |
1262 | &EigenSparseSolver##dt##type##order::solve_rf<Eigen::VectorX##fd, \ |
1263 | dt>) \ |
1264 | .def("info", &EigenSparseSolver##dt##type##order::info); |
1265 | |
1266 | REGISTER_EIGEN_SOLVER(float32, LLT, AMD, f) |
1267 | REGISTER_EIGEN_SOLVER(float32, LLT, COLAMD, f) |
1268 | REGISTER_EIGEN_SOLVER(float32, LDLT, AMD, f) |
1269 | REGISTER_EIGEN_SOLVER(float32, LDLT, COLAMD, f) |
1270 | REGISTER_EIGEN_SOLVER(float32, LU, AMD, f) |
1271 | REGISTER_EIGEN_SOLVER(float32, LU, COLAMD, f) |
1272 | REGISTER_EIGEN_SOLVER(float64, LLT, AMD, d) |
1273 | REGISTER_EIGEN_SOLVER(float64, LLT, COLAMD, d) |
1274 | REGISTER_EIGEN_SOLVER(float64, LDLT, AMD, d) |
1275 | REGISTER_EIGEN_SOLVER(float64, LDLT, COLAMD, d) |
1276 | REGISTER_EIGEN_SOLVER(float64, LU, AMD, d) |
1277 | REGISTER_EIGEN_SOLVER(float64, LU, COLAMD, d) |
1278 | |
1279 | py::class_<CuSparseSolver, SparseSolver>(m, "CuSparseSolver" ) |
1280 | .def("compute" , &CuSparseSolver::compute) |
1281 | .def("analyze_pattern" , &CuSparseSolver::analyze_pattern) |
1282 | .def("factorize" , &CuSparseSolver::factorize) |
1283 | .def("solve_rf" , &CuSparseSolver::solve_rf) |
1284 | .def("info" , &CuSparseSolver::info); |
1285 | |
1286 | m.def("make_sparse_solver" , &make_sparse_solver); |
1287 | m.def("make_cusparse_solver" , &make_cusparse_solver); |
1288 | |
1289 | // Mesh Class |
1290 | // Mesh related. |
1291 | py::enum_<mesh::MeshTopology>(m, "MeshTopology" , py::arithmetic()) |
1292 | .value("Triangle" , mesh::MeshTopology::Triangle) |
1293 | .value("Tetrahedron" , mesh::MeshTopology::Tetrahedron) |
1294 | .export_values(); |
1295 | |
1296 | py::enum_<mesh::MeshElementType>(m, "MeshElementType" , py::arithmetic()) |
1297 | .value("Vertex" , mesh::MeshElementType::Vertex) |
1298 | .value("Edge" , mesh::MeshElementType::Edge) |
1299 | .value("Face" , mesh::MeshElementType::Face) |
1300 | .value("Cell" , mesh::MeshElementType::Cell) |
1301 | .export_values(); |
1302 | |
1303 | py::enum_<mesh::MeshRelationType>(m, "MeshRelationType" , py::arithmetic()) |
1304 | .value("VV" , mesh::MeshRelationType::VV) |
1305 | .value("VE" , mesh::MeshRelationType::VE) |
1306 | .value("VF" , mesh::MeshRelationType::VF) |
1307 | .value("VC" , mesh::MeshRelationType::VC) |
1308 | .value("EV" , mesh::MeshRelationType::EV) |
1309 | .value("EE" , mesh::MeshRelationType::EE) |
1310 | .value("EF" , mesh::MeshRelationType::EF) |
1311 | .value("EC" , mesh::MeshRelationType::EC) |
1312 | .value("FV" , mesh::MeshRelationType::FV) |
1313 | .value("FE" , mesh::MeshRelationType::FE) |
1314 | .value("FF" , mesh::MeshRelationType::FF) |
1315 | .value("FC" , mesh::MeshRelationType::FC) |
1316 | .value("CV" , mesh::MeshRelationType::CV) |
1317 | .value("CE" , mesh::MeshRelationType::CE) |
1318 | .value("CF" , mesh::MeshRelationType::CF) |
1319 | .value("CC" , mesh::MeshRelationType::CC) |
1320 | .export_values(); |
1321 | |
1322 | py::enum_<mesh::ConvType>(m, "ConvType" , py::arithmetic()) |
1323 | .value("l2g" , mesh::ConvType::l2g) |
1324 | .value("l2r" , mesh::ConvType::l2r) |
1325 | .value("g2r" , mesh::ConvType::g2r) |
1326 | .export_values(); |
1327 | |
1328 | py::class_<mesh::Mesh>(m, "Mesh" ); // NOLINT(bugprone-unused-raii) |
1329 | py::class_<mesh::MeshPtr>(m, "MeshPtr" ); // NOLINT(bugprone-unused-raii) |
1330 | |
1331 | m.def("element_order" , mesh::element_order); |
1332 | m.def("from_end_element_order" , mesh::from_end_element_order); |
1333 | m.def("to_end_element_order" , mesh::to_end_element_order); |
1334 | m.def("relation_by_orders" , mesh::relation_by_orders); |
1335 | m.def("inverse_relation" , mesh::inverse_relation); |
1336 | m.def("element_type_name" , mesh::element_type_name); |
1337 | |
1338 | m.def( |
1339 | "create_mesh" , |
1340 | []() { |
1341 | auto mesh_shared = std::make_shared<mesh::Mesh>(); |
1342 | mesh::MeshPtr mesh_ptr = mesh::MeshPtr{mesh_shared}; |
1343 | return mesh_ptr; |
1344 | }, |
1345 | py::return_value_policy::reference); |
1346 | |
1347 | // ad-hoc setters |
1348 | m.def("set_owned_offset" , |
1349 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) { |
1350 | mesh_ptr.ptr->owned_offset.insert(std::pair(type, snode)); |
1351 | }); |
1352 | m.def("set_total_offset" , |
1353 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, SNode *snode) { |
1354 | mesh_ptr.ptr->total_offset.insert(std::pair(type, snode)); |
1355 | }); |
1356 | m.def("set_num_patches" , [](mesh::MeshPtr &mesh_ptr, int num_patches) { |
1357 | mesh_ptr.ptr->num_patches = num_patches; |
1358 | }); |
1359 | |
1360 | m.def("set_num_elements" , [](mesh::MeshPtr &mesh_ptr, |
1361 | mesh::MeshElementType type, int num_elements) { |
1362 | mesh_ptr.ptr->num_elements.insert(std::pair(type, num_elements)); |
1363 | }); |
1364 | |
1365 | m.def("get_num_elements" , |
1366 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type) { |
1367 | return mesh_ptr.ptr->num_elements.find(type)->second; |
1368 | }); |
1369 | |
1370 | m.def("set_patch_max_element_num" , |
1371 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType type, |
1372 | int max_element_num) { |
1373 | mesh_ptr.ptr->patch_max_element_num.insert( |
1374 | std::pair(type, max_element_num)); |
1375 | }); |
1376 | |
1377 | m.def("set_index_mapping" , |
1378 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshElementType element_type, |
1379 | mesh::ConvType conv_type, SNode *snode) { |
1380 | mesh_ptr.ptr->index_mapping.insert( |
1381 | std::make_pair(std::make_pair(element_type, conv_type), snode)); |
1382 | }); |
1383 | |
1384 | m.def("set_relation_fixed" , |
1385 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value) { |
1386 | mesh_ptr.ptr->relations.insert( |
1387 | std::pair(type, mesh::MeshLocalRelation(value))); |
1388 | }); |
1389 | |
1390 | m.def("set_relation_dynamic" , |
1391 | [](mesh::MeshPtr &mesh_ptr, mesh::MeshRelationType type, SNode *value, |
1392 | SNode *patch_offset, SNode *offset) { |
1393 | mesh_ptr.ptr->relations.insert(std::pair( |
1394 | type, mesh::MeshLocalRelation(value, patch_offset, offset))); |
1395 | }); |
1396 | |
1397 | m.def("wait_for_debugger" , []() { |
1398 | #ifdef WIN32 |
1399 | while (!::IsDebuggerPresent()) |
1400 | ::Sleep(100); |
1401 | #endif |
1402 | }); |
1403 | } |
1404 | |
1405 | } // namespace taichi |
1406 | |