1 | #include "triton/codegen/pass.h" |
2 | #include "triton/codegen/target.h" |
3 | #include "triton/codegen/extern_lib.h" |
4 | #include "triton/driver/error.h" |
5 | #include "triton/driver/llvm.h" |
6 | #include "triton/ir/builder.h" |
7 | #include "triton/ir/enums.h" |
8 | #include "triton/ir/function.h" |
9 | #include "triton/ir/module.h" |
10 | #include "triton/ir/print.h" |
11 | #include <optional> |
12 | #include <pybind11/buffer_info.h> |
13 | #include <pybind11/functional.h> |
14 | #include <pybind11/pybind11.h> |
15 | #include <pybind11/stl_bind.h> |
16 | #include <pybind11/stl.h> |
17 | #include "Python.h" |
18 | #include <regex> |
19 | #include <sstream> |
20 | #include <stdexcept> |
21 | #include <string> |
22 | #include "llvm/IR/Module.h" |
23 | #include "llvm/IR/Verifier.h" |
24 | |
25 | namespace py = pybind11; |
26 | namespace ir = triton::ir; |
27 | namespace drv = triton::driver; |
28 | |
29 | |
30 | /*****************************************************************************/ |
31 | /* Python bindings for triton::driver */ |
32 | /*****************************************************************************/ |
33 | // information query |
34 | template<CUdevice_attribute attr> |
35 | int cuGetInfo(CUdevice device) { |
36 | int res; |
37 | drv::dispatch::cuDeviceGetAttribute(&res, attr, device); |
38 | return res; |
39 | } |
40 | |
41 | template<hipDeviceAttribute_t attr> |
42 | int hipGetInfo(hipDevice_t device) { |
43 | int res; |
44 | drv::dispatch::hipDeviceGetAttribute(&res, attr, device); |
45 | return res; |
46 | } |
47 | |
48 | enum backend_t { |
49 | HOST, |
50 | CUDA, |
51 | ROCM, |
52 | }; |
53 | |
54 | void cu_enable_peer_access(uint64_t peer_ptr){ |
55 | CUcontext context; |
56 | drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_ptr); |
57 | try { |
58 | drv::dispatch::cuCtxEnablePeerAccess(context, 0); |
59 | } catch (drv::exception::cuda::peer_access_already_enabled) {} |
60 | } |
61 | |
62 | void host_enqueue(uint64_t stream, uint64_t kernel, |
63 | uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, |
64 | uint64_t block_0, uint64_t block_1, uint64_t block_2, |
65 | void* args_ptr, size_t args_size, int64_t shared_mem){ |
66 | throw std::runtime_error("unsupported" ); |
67 | // auto hst = kernel->module()->hst(); |
68 | // hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]); |
69 | // char* params = new char[args_size]; |
70 | // std::memcpy((void*)params, (void*)args, args_size); |
71 | // for(size_t i = 0; i < grid[0]; i++) |
72 | // for(size_t j = 0; j < grid[1]; j++) |
73 | // for(size_t k = 0; k < grid[2]; k++) |
74 | // hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k))); |
75 | } |
76 | |
77 | void cu_enqueue(uint64_t stream, uint64_t kernel, |
78 | uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, |
79 | uint64_t block_0, uint64_t block_1, uint64_t block_2, |
80 | void* args_ptr, size_t args_size, int64_t shared_mem){ |
81 | void *config[] = { |
82 | CU_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr, |
83 | CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size, |
84 | CU_LAUNCH_PARAM_END |
85 | }; |
86 | drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, |
87 | block_0, block_1, block_2, |
88 | shared_mem, (CUstream)stream, nullptr, config); |
89 | } |
90 | |
91 | void hip_enqueue(uint64_t stream, uint64_t kernel, |
92 | uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, |
93 | uint64_t block_0, uint64_t block_1, uint64_t block_2, |
94 | void* args_ptr, size_t args_size, int64_t shared_mem) { |
95 | void *config[] = { |
96 | HIP_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr, |
97 | HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size, |
98 | HIP_LAUNCH_PARAM_END |
99 | }; |
100 | drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2, |
101 | block_0, block_1, block_2, |
102 | shared_mem, (hipStream_t)stream, nullptr, config); |
103 | |
104 | } |
105 | |
106 | long pow2_divisor(long N){ |
107 | if(N % 16 == 0) return 16; |
108 | if(N % 8 == 0) return 8; |
109 | if(N % 4 == 0) return 4; |
110 | if(N % 2 == 0) return 2; |
111 | return 1; |
112 | } |
113 | |
114 | // Returns something like "int16", whether dtype is a torch.dtype or |
115 | // triton.language.dtype. |
116 | std::string dtype_cache_key_part(const py::object& dtype) { |
117 | if (py::hasattr(dtype, "cache_key_part" )) { |
118 | // Presumed to be a triton.language.dtype. |
119 | return std::string(py::str(py::getattr(dtype, "cache_key_part" ))); |
120 | } else { |
121 | // Remove 'torch.' prefix from repr of torch.dtype. |
122 | py::object repr = py::repr(dtype); |
123 | size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr()); |
124 | const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()); |
125 | if (repr_len <= 6 || strncmp(repr_ptr, "torch." , 6)) { |
126 | throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len)); |
127 | } |
128 | return std::string(repr_ptr + 6, repr_len - 6); |
129 | } |
130 | } |
131 | |
132 | size_t get_pointer_range_size(uint64_t addr){ |
133 | if(addr == 0) |
134 | return 0; |
135 | size_t size; |
136 | drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr); |
137 | return size; |
138 | } |
139 | |
140 | // Launch |
141 | void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, |
142 | std::string& cache_key, std::string& params, size_t& params_size, py::dict constants, |
143 | int num_warps, int num_stages, py::dict& extern_libs) { |
144 | size_t len = PyList_Size(args.ptr()); |
145 | params.reserve(8*len); // 8 max bytes by argument |
146 | char* params_ptr = ¶ms[0]; |
147 | cache_key = func_key; |
148 | cache_key += "-" + std::to_string(num_warps); |
149 | cache_key += "-" + std::to_string(num_stages); |
150 | cache_key += "-" ; |
151 | for(int i = 0; i < len; i++){ |
152 | cache_key += "_" ; |
153 | py::int_ py_i = py::int_(i); |
154 | bool specialize = !do_not_specialize.contains(py_i); |
155 | py::object arg = args[i]; |
156 | auto arg_ptr = arg.ptr(); |
157 | |
158 | // argument is `long` |
159 | if(PyLong_Check(arg_ptr)){ |
160 | int overflow; |
161 | long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow); |
162 | // values equal to 1 are specialized |
163 | if(specialize && (value == 1)){ |
164 | cache_key += "1" ; |
165 | continue; |
166 | } |
167 | // int32, uint32, int64, and uint64 have different kernels |
168 | if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) { |
169 | cache_key += "int32" ; |
170 | params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); |
171 | std::memcpy(params_ptr, &value, 4); |
172 | params_ptr += 4; |
173 | } else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) { |
174 | cache_key += "uint32" ; |
175 | params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); |
176 | std::memcpy(params_ptr, &value, 4); |
177 | params_ptr += 4; |
178 | } else if (!overflow) { |
179 | cache_key += "int64" ; |
180 | params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); |
181 | std::memcpy(params_ptr, &value, 8); |
182 | params_ptr += 8; |
183 | } else { |
184 | if (PyErr_Occurred()) { |
185 | throw std::logic_error("An error occurred?" ); |
186 | } |
187 | unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr); |
188 | if (PyErr_Occurred()) { |
189 | throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg))); |
190 | } |
191 | cache_key += "uint64" ; |
192 | params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); |
193 | std::memcpy(params_ptr, &unsigned_value, 8); |
194 | params_ptr += 8; |
195 | } |
196 | if(!specialize) |
197 | continue; |
198 | // values divisible by small powers of 2 are specialized |
199 | cache_key += "[multipleof(" ; |
200 | cache_key += std::to_string(pow2_divisor(value)); |
201 | cache_key += ")]" ; |
202 | continue; |
203 | } |
204 | // argument is `float` |
205 | if(PyFloat_Check(arg_ptr)){ |
206 | cache_key += "float32" ; |
207 | float value = PyFloat_AsDouble(arg_ptr); |
208 | params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4)); |
209 | std::memcpy(params_ptr, &value, 4); |
210 | params_ptr += 4; |
211 | continue; |
212 | } |
213 | // argument is `bool` |
214 | if(PyBool_Check(arg_ptr)){ |
215 | cache_key += "bool" ; |
216 | bool value = arg_ptr == Py_True ? true : false; |
217 | std::memcpy(params_ptr, &value, 1); |
218 | params_ptr += 1; |
219 | continue; |
220 | } |
221 | // argument is tensor |
222 | if(py::hasattr(arg, "data_ptr" )){ |
223 | py::object data_ptr = arg.attr("data_ptr" )(); |
224 | long value = data_ptr.cast<long>(); |
225 | params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8)); |
226 | // copy param |
227 | std::memcpy(params_ptr, &value, 8); |
228 | params_ptr += 8; |
229 | // update cache key |
230 | cache_key += dtype_cache_key_part(arg.attr("dtype" )); |
231 | cache_key += "*" ; |
232 | cache_key += "[multipleof(" ; |
233 | size_t range_size; |
234 | try { |
235 | range_size = get_pointer_range_size(value); |
236 | } catch (...) { |
237 | throw std::runtime_error("argument tensor #" + std::to_string(i) + " is not on cuda! " + std::string(py::str(arg))); |
238 | } |
239 | cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size))); |
240 | cache_key += ")]" ; |
241 | continue; |
242 | } |
243 | // argument is `constexpr` |
244 | if (py::hasattr(arg, "value" )) { |
245 | py::object value = arg.attr("value" ); |
246 | // check if value is a callable object using PyCallable_Check |
247 | if (PyCallable_Check(value.ptr())) { |
248 | throw std::runtime_error( |
249 | "constant argument cannot be a callable object: " + |
250 | std::string(py::str(arg))); |
251 | } |
252 | py::object name = arg_names[i]; |
253 | constants[name] = value; |
254 | py::object repr = py::repr(value); |
255 | const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr()); |
256 | size_t len = PyUnicode_GET_LENGTH(repr.ptr()); |
257 | cache_key += std::string(start, len); |
258 | continue; |
259 | } |
260 | std::string ty_str = arg.attr("__class__" ).attr("__name__" ).cast<std::string>(); |
261 | if(ty_str == "NoneType" ){ |
262 | cache_key += "None" ; |
263 | continue; |
264 | } |
265 | std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "." |
266 | + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported." ; |
267 | throw std::runtime_error(err_msg); |
268 | } |
269 | params_size = (std::ptrdiff_t)(params_ptr - ¶ms[0]); |
270 | |
271 | for (auto item : extern_libs) { |
272 | cache_key += "-" + item.first.cast<std::string>(); |
273 | cache_key += "_" + item.second.cast<std::string>(); |
274 | } |
275 | } |
276 | |
277 | // |
278 | |
279 | void init_triton_runtime(py::module &&m) { |
280 | |
281 | // m.def("current_stream", [](uint64_t device){ |
282 | // return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream()); |
283 | // }); |
284 | |
285 | // wrap backend_t |
286 | py::enum_<backend_t>(m, "backend" ) |
287 | .value("HOST" , HOST) |
288 | .value("CUDA" , CUDA) |
289 | .value("ROCM" , ROCM) |
290 | .export_values(); |
291 | |
292 | // enable peer-to-peer |
293 | m.def("enable_peer_access" , [](backend_t backend, uint64_t peer_ptr) { |
294 | if (backend != CUDA) |
295 | throw std::runtime_error("P2P only supported on CUDA devices!" ); |
296 | cu_enable_peer_access(peer_ptr); |
297 | } |
298 | ); |
299 | |
300 | // get range size for the given pointer |
301 | m.def("get_pointer_range_size" , &get_pointer_range_size); |
302 | |
303 | |
304 | // cache key |
305 | m.def("launch" , [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names, |
306 | py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages, |
307 | py::dict extern_libs, py::function add_to_cache, py::object grid){ |
308 | // parse arguments to compute cache key, compile-time constants and packed kernel arguments |
309 | long _num_warps = PyLong_AsLong(num_warps.ptr()); |
310 | long _num_stages = PyLong_AsLong(num_stages.ptr()); |
311 | std::string cache_key; |
312 | std::string params; |
313 | size_t params_size; |
314 | py::dict constants; |
315 | parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params, |
316 | params_size, constants, _num_warps, _num_stages, extern_libs); |
317 | |
318 | // get cached binary |
319 | py::str key(cache_key); |
320 | py::bool_ noop = false; |
321 | if(!bin_cache.contains(key)) { |
322 | noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs); |
323 | } |
324 | if (noop) |
325 | return (py::object)py::none(); |
326 | py::object bin = bin_cache[key]; |
327 | |
328 | // get grid |
329 | py::sequence seq; |
330 | if(!PySequence_Check(grid.ptr())) |
331 | seq = grid(constants); |
332 | else |
333 | seq = grid; |
334 | int size = seq.size(); |
335 | int grid_0 = py::cast<int>(seq[0]); |
336 | int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]); |
337 | int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]); |
338 | |
339 | // enqueue |
340 | uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel" )); |
341 | uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem" )); |
342 | |
343 | // actually launch |
344 | void *config[] = { |
345 | CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(), |
346 | CU_LAUNCH_PARAM_BUFFER_SIZE, ¶ms_size, |
347 | CU_LAUNCH_PARAM_END |
348 | }; |
349 | uint64_t _stream = PyLong_AsLong(stream.ptr()); |
350 | if(grid_0*grid_1*grid_2 > 0) { |
351 | // release the gil in case the enqueue blocks |
352 | // cuda will block if too many ops are enqueued |
353 | py::gil_scoped_release allow_threads; |
354 | drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2, |
355 | _num_warps*32, 1, 1, shared_mem, (CUstream)_stream, |
356 | nullptr, config); |
357 | } |
358 | return bin; |
359 | }); |
360 | |
361 | m.def("cc" , [](backend_t backend, uint64_t device) -> int { |
362 | if (backend == CUDA) { |
363 | CUdevice dev = (CUdevice)device; |
364 | int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); |
365 | int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); |
366 | return major*10 + minor; |
367 | } |
368 | return -1; |
369 | }); |
370 | |
371 | // query maximum shared memory |
372 | m.def("max_shared_memory" , [](backend_t backend, uint64_t device) { |
373 | if (backend == HOST) |
374 | return 0; |
375 | if(backend == CUDA) |
376 | return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device); |
377 | if(backend == ROCM) |
378 | return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device); |
379 | return -1; |
380 | }); |
381 | |
382 | // query DRAM & L2 cache |
383 | m.def("memory_clock_rate" , [](backend_t backend, uint64_t device) { |
384 | if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device); |
385 | return -1; |
386 | }); |
387 | m.def("global_memory_bus_width" , [](backend_t backend, uint64_t device) { |
388 | if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device); |
389 | return -1; |
390 | }); |
391 | m.def("l2_cache_size" , [](backend_t backend, uint64_t device) { |
392 | if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device); |
393 | return -1; |
394 | }); |
395 | |
396 | // query clock rate (in kilohertz) |
397 | m.def("clock_rate" , [](backend_t backend, uint64_t device) { |
398 | if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device); |
399 | return -1; |
400 | }); |
401 | |
402 | m.def("num_sm" , [](backend_t backend, uint64_t device) { |
403 | if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device); |
404 | return -1; |
405 | }); |
406 | |
407 | // enqueue |
408 | m.def("enqueue" , [](backend_t backend, uint64_t stream, uint64_t kernel, |
409 | uint64_t grid_0, uint64_t grid_1, uint64_t grid_2, |
410 | uint64_t block_0, uint64_t block_1, uint64_t block_2, |
411 | const std::string &args, int64_t shared_mem){ |
412 | void* args_ptr = (void*)args.data(); |
413 | size_t args_size = args.size(); |
414 | // release the gil in case the enqueue blocks |
415 | // cuda will block if too many ops are enqueued |
416 | py::gil_scoped_release allow_threads; |
417 | if(backend == HOST) |
418 | host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); |
419 | if(backend == CUDA) |
420 | cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); |
421 | if(backend == ROCM) |
422 | hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem); |
423 | }); |
424 | |
425 | |
426 | } |
427 | |
428 | /*****************************************************************************/ |
429 | /* Python bindings for triton::codegen */ |
430 | /*****************************************************************************/ |
431 | typedef std::map<std::string, py::object> asm_map_t; |
432 | |
433 | // --------------------------------------- |
434 | // Compile Triton-IR to assembly |
435 | // --------------------------------------- |
436 | |
437 | void init_triton_codegen(py::module &&m) { |
438 | m.def("compile_ttir" , |
439 | [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) { |
440 | std::ostringstream ttir; |
441 | int n_shared_bytes; |
442 | std::string tmp; |
443 | std::string ptx; |
444 | std::string cubin; |
445 | std::string name; |
446 | { // Scope where the GIL is released |
447 | py::gil_scoped_release allow_threads; |
448 | name = ir.get_function_list()[0]->get_name(); |
449 | ir.print(ttir); |
450 | llvm::LLVMContext ctx; |
451 | // construct extern lib map |
452 | triton::codegen::ExternLibMap extern_lib_map; |
453 | for (auto item : extern_libs) { |
454 | auto name = item.first.cast<std::string>(); |
455 | auto path = item.second.cast<std::string>(); |
456 | extern_lib_map.emplace( |
457 | name, triton::codegen::create_extern_lib(name, path)); |
458 | } |
459 | // device properties |
460 | if (cc == 0) { |
461 | CUdevice dev = (CUdevice)device; |
462 | size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev); |
463 | size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev); |
464 | cc = major*10 + minor; |
465 | } |
466 | int version; |
467 | std::string ptxas_path = drv::path_to_ptxas(version); |
468 | // Triton-IR -> NVPTX LLVM-IR |
469 | triton::codegen::nvidia_cu_target target(cc); |
470 | auto llvm = triton::codegen::add_passes_to_emit_bin( |
471 | ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map); |
472 | llvm::raw_string_ostream llir(tmp); |
473 | llir << *llvm; |
474 | llir.flush(); |
475 | // LLVM-IR -> PTX |
476 | ptx = drv::llir_to_ptx(llvm.get(), cc, version); |
477 | // PTX -> Binary |
478 | cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc); |
479 | } |
480 | asm_map_t asm_map; |
481 | asm_map["ttir" ] = py::cast(ttir.str()); |
482 | asm_map["llir" ] = py::cast(tmp); |
483 | asm_map["ptx" ] = py::cast(ptx); |
484 | |
485 | if(!cubin.empty()){ |
486 | py::bytes bytes(cubin); |
487 | asm_map["cubin" ] = bytes; |
488 | } |
489 | return std::make_tuple(name, asm_map, n_shared_bytes); |
490 | }, |
491 | py::return_value_policy::take_ownership); |
492 | |
493 | |
494 | // --------------------------------------- |
495 | // Load provided assembly code into driver |
496 | // --------------------------------------- |
497 | m.def("load_binary" , [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){ |
498 | py::gil_scoped_release allow_threads; |
499 | // create driver handles |
500 | CUfunction fun; |
501 | CUmodule mod; |
502 | drv::dispatch::cuModuleLoadData(&mod, data.c_str()); |
503 | drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str()); |
504 | // get allocated registers and spilled registers from the function |
505 | int n_regs = 0; |
506 | int n_spills = 0; |
507 | drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun); |
508 | drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun); |
509 | n_spills /= 4; |
510 | // set dynamic shared memory if necessary |
511 | int shared_optin; |
512 | drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device); |
513 | if(n_shared_bytes > 49152 && shared_optin > 49152){ |
514 | drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED); |
515 | int shared_total, shared_static; |
516 | drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device); |
517 | drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun); |
518 | drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static); |
519 | } |
520 | return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills); |
521 | }, |
522 | py::return_value_policy::take_ownership |
523 | ); |
524 | |
525 | |
526 | struct InstanceDescriptor |
527 | { |
528 | std::unordered_set<int> divisibleBy16; |
529 | std::unordered_set<int> equalTo1; |
530 | }; |
531 | |
532 | py::class_<InstanceDescriptor>(m, "instance_descriptor" ) |
533 | .def(py::init<>()) |
534 | .def(py::init<std::unordered_set<int>, std::unordered_set<int>>()) |
535 | .def_readonly("divisible_by_16" , &InstanceDescriptor::divisibleBy16) |
536 | .def_readonly("equal_to_1" , &InstanceDescriptor::equalTo1); |
537 | } |
538 | |
539 | |
540 | /*****************************************************************************/ |
541 | /* Python bindings for triton::ir */ |
542 | /*****************************************************************************/ |
543 | |
544 | void init_triton_ir(py::module &&m) { |
545 | using ret = py::return_value_policy; |
546 | using namespace pybind11::literals; |
547 | |
548 | py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER" ) |
549 | .value("NONE" , ir::load_inst::NONE) |
550 | .value("CA" , ir::load_inst::CA) |
551 | .value("CG" , ir::load_inst::CG) |
552 | .export_values(); |
553 | |
554 | py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY" ) |
555 | .value("NORMAL" , ir::load_inst::NORMAL) |
556 | .value("EVICT_FIRST" , ir::load_inst::EVICT_FIRST) |
557 | .value("EVICT_LAST" , ir::load_inst::EVICT_LAST) |
558 | .export_values(); |
559 | |
560 | py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP" ) |
561 | .value("ADD" , ir::reduce_inst::ADD) |
562 | .value("FADD" , ir::reduce_inst::FADD) |
563 | .value("MIN" , ir::reduce_inst::MIN) |
564 | .value("MAX" , ir::reduce_inst::MAX) |
565 | .value("UMIN" , ir::reduce_inst::UMIN) |
566 | .value("UMAX" , ir::reduce_inst::UMAX) |
567 | .value("ARGMIN" , ir::reduce_inst::ARGMIN) |
568 | .value("ARGMAX" , ir::reduce_inst::ARGMAX) |
569 | .value("ARGUMIN" , ir::reduce_inst::ARGUMIN) |
570 | .value("ARGUMAX" , ir::reduce_inst::ARGUMAX) |
571 | .value("FMIN" , ir::reduce_inst::FMIN) |
572 | .value("FMAX" , ir::reduce_inst::FMAX) |
573 | .value("ARGFMIN" , ir::reduce_inst::ARGFMIN) |
574 | .value("ARGFMAX" , ir::reduce_inst::ARGFMAX) |
575 | .value("XOR" , ir::reduce_inst::XOR); |
576 | |
577 | py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP" ) |
578 | .value("ADD" , ir::atomic_rmw_op_t::Add) |
579 | .value("FADD" , ir::atomic_rmw_op_t::FAdd) |
580 | .value("AND" , ir::atomic_rmw_op_t::And) |
581 | .value("OR" , ir::atomic_rmw_op_t::Or) |
582 | .value("XOR" , ir::atomic_rmw_op_t::Xor) |
583 | .value("XCHG" , ir::atomic_rmw_op_t::Xchg) |
584 | .value("MAX" , ir::atomic_rmw_op_t::Max) |
585 | .value("MIN" , ir::atomic_rmw_op_t::Min) |
586 | .value("UMIN" , ir::atomic_rmw_op_t::UMin) |
587 | .value("UMAX" , ir::atomic_rmw_op_t::UMax); |
588 | |
589 | py::class_<ir::context>(m, "context" ) |
590 | .def(py::init<>()); |
591 | |
592 | py::class_<ir::value>(m, "value" ) |
593 | .def("multiple_of" , [](ir::value *self, std::vector<unsigned> val) { |
594 | if (auto *instr = dynamic_cast<ir::instruction*>(self)) { |
595 | instr->set_metadata(ir::metadata::multiple_of, val); |
596 | } else |
597 | throw std::runtime_error("multiple_of" ); |
598 | }) |
599 | .def("max_contiguous" , [](ir::value *self, std::vector<unsigned> val) { |
600 | if (auto *instr = dynamic_cast<ir::instruction*>(self)) { |
601 | instr->set_metadata(ir::metadata::max_contiguous, val); |
602 | } else |
603 | throw std::runtime_error("max_contiguous" ); |
604 | }) |
605 | .def("set_fdiv_ieee_rounding" , [](ir::value *self, bool val) { |
606 | if (auto *instr = dynamic_cast<ir::binary_operator*>(self)) |
607 | instr->set_fdiv_ieee_rounding(val); |
608 | else |
609 | throw std::runtime_error("set_fdiv_ieee_rounding" ); |
610 | }) |
611 | .def("is_phi" , [](ir::value *self) { |
612 | if (auto *pn = dynamic_cast<ir::phi_node*>(self)) |
613 | return true; |
614 | return false; |
615 | }) |
616 | .def("ops" , [](ir::value *self) { |
617 | if (auto *instr = dynamic_cast<ir::instruction*>(self)) { |
618 | return instr->ops(); |
619 | } |
620 | throw std::runtime_error("cannot use ops()" ); |
621 | }) |
622 | .def("replace_all_uses_with" , &ir::value::replace_all_uses_with) |
623 | .def("erase_from_parent" , [](ir::value *self) { |
624 | if (auto *instr = dynamic_cast<ir::instruction*>(self)) |
625 | return instr->erase_from_parent(); |
626 | throw std::runtime_error("cannot use erase_from_parent" ); |
627 | }) |
628 | .def_property("name" , &ir::value::get_name, &ir::value::set_name) |
629 | .def_property_readonly("type" , &ir::value::get_type); |
630 | |
631 | py::class_<ir::user, ir::value>(m, "user" ); |
632 | |
633 | py::class_<ir::constant, ir::user>(m, "constant" ) |
634 | .def("get_null_value" , &ir::constant::get_null_value, ret::reference) |
635 | .def("get_all_ones_value" , &ir::constant::get_all_ones_value, ret::reference); |
636 | |
637 | py::class_<ir::undef_value, ir::constant>(m, "undef" ) |
638 | .def("get" , &ir::undef_value::get, ret::reference); |
639 | |
640 | py::class_<ir::constant_int, ir::constant>(m, "constant_int" ) |
641 | .def_property_readonly("value" , &ir::constant_int::get_value) |
642 | .def("__int__" , [](ir::constant_int *self) { return self->get_value(); }) |
643 | .def("__bool__" , [](ir::constant_int *self) { return self->get_value(); }); |
644 | |
645 | py::class_<ir::constant_fp, ir::constant>(m, "constant_float" ) |
646 | .def_property_readonly("value" , &ir::constant_fp::get_value) |
647 | .def("get" , [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference); |
648 | |
649 | py::class_<ir::instruction, ir::user>(m, "instruction" ) |
650 | .def("get_parent" , [](ir::instruction *self) { |
651 | return self->get_parent(); |
652 | }, ret::reference); |
653 | py::class_<ir::phi_node, ir::instruction>(m, "phi_node" ) |
654 | .def("add_incoming" , &ir::phi_node::add_incoming); |
655 | |
656 | py::class_<ir::type>(m, "type" ) |
657 | .def("make_ptr" , &ir::pointer_type::get, ret::reference) |
658 | .def("make_function" , &ir::function_type::get, ret::reference) |
659 | .def("make_block" , &ir::block_type::get, ret::reference) |
660 | .def("get_void" , &ir::type::get_void_ty, ret::reference) |
661 | .def("get_fp8" , &ir::type::get_fp8_ty, ret::reference) |
662 | .def("get_fp16" , &ir::type::get_fp16_ty, ret::reference) |
663 | .def("get_bf16" , &ir::type::get_bf16_ty, ret::reference) |
664 | .def("get_fp32" , &ir::type::get_fp32_ty, ret::reference) |
665 | .def("get_fp64" , &ir::type::get_fp64_ty, ret::reference) |
666 | .def("get_int1" , &ir::type::get_int1_ty, ret::reference) |
667 | .def("get_int8" , &ir::type::get_int8_ty, ret::reference) |
668 | .def("get_int16" , &ir::type::get_int16_ty, ret::reference) |
669 | .def("get_int32" , &ir::type::get_int32_ty, ret::reference) |
670 | .def("get_int64" , &ir::type::get_int64_ty, ret::reference) |
671 | .def("get_fp_mantissa_width" , &ir::type::get_fp_mantissa_width, ret::reference) |
672 | |
673 | .def("get_block_shapes" , &ir::type::get_block_shapes) |
674 | |
675 | .def("is_ptr" , &ir::type::is_pointer_ty) |
676 | .def("is_int" , static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty)) |
677 | .def("is_floating" , &ir::type::is_floating_point_ty) |
678 | .def("is_block" , &ir::type::is_block_ty) |
679 | .def("is_struct" , &ir::type::is_struct_ty) |
680 | .def("is_void" , &ir::type::is_void_ty) |
681 | .def("is_bool" , &ir::type::is_bool_ty) |
682 | .def("is_fp8" , &ir::type::is_fp8_ty) |
683 | .def("is_fp16" , &ir::type::is_fp16_ty) |
684 | .def("is_bf16" , &ir::type::is_bf16_ty) |
685 | .def("is_fp32" , &ir::type::is_fp32_ty) |
686 | .def("is_fp64" , &ir::type::is_fp64_ty) |
687 | .def("is_int1" , [](ir::type *self) { return self->is_integer_ty(1); }) |
688 | .def("is_int8" , [](ir::type *self) { return self->is_integer_ty(8); }) |
689 | .def("is_int16" , [](ir::type *self) { return self->is_integer_ty(16); }) |
690 | .def("is_int32" , [](ir::type *self) { return self->is_integer_ty(32); }) |
691 | .def("is_int64" , [](ir::type *self) { return self->is_integer_ty(64); }) |
692 | .def("is_int_or_tileint" , &ir::type::is_int_or_tileint_ty) |
693 | |
694 | .def("repr" , &ir::type::repr) |
695 | .def_property_readonly("fp_mantissa_width" , &ir::type::get_fp_mantissa_width) |
696 | .def_property_readonly("scalar" , &ir::type::get_scalar_ty) |
697 | .def_property_readonly("context" , &ir::type::get_context, ret::reference) |
698 | .def_property_readonly("int_bitwidth" , &ir::type::get_integer_bitwidth) |
699 | .def_property_readonly("primitive_bitwidth" , &ir::type::get_primitive_size_in_bits); |
700 | |
701 | py::class_<ir::pointer_type, ir::type>(m, "pointer_type" ) |
702 | .def_property_readonly("element" , &ir::pointer_type::get_element_ty, ret::reference) |
703 | .def_property_readonly("address_space" , &ir::pointer_type::get_pointer_address_space, ret::reference); |
704 | |
705 | py::class_<ir::function_type, ir::type>(m, "function_type" ) |
706 | .def_property_readonly("ret_ty" , &ir::function_type::get_return_ty) |
707 | .def_property_readonly("arg_tys" , [](ir::function_type* self){ |
708 | return std::vector<ir::type*>(self->params_begin(), self->params_end()); |
709 | }); |
710 | |
711 | py::class_<ir::integer_type, ir::type>(m, "integer_type" ); |
712 | |
713 | py::class_<ir::block_type, ir::type>(m, "block_type" ) |
714 | .def_property_readonly("shape" , &ir::block_type::get_shapes) |
715 | .def_property_readonly("numel" , &ir::type::get_tile_num_elements); |
716 | |
717 | py::class_<ir::struct_type, ir::type>(m, "struct_type" ) |
718 | .def("get" , &ir::struct_type::get, ret::reference) |
719 | .def_property_readonly("num_types" , &ir::struct_type::get_num_types); |
720 | |
721 | py::class_<ir::module>(m, "module" , py::dynamic_attr()) |
722 | .def(py::init<std::string, ir::builder &>()) |
723 | .def("has_function" , &ir::module::has_function) |
724 | .def("get_function" , &ir::module::get_function, ret::reference) |
725 | .def("get_functions" , &ir::module::get_function_list, ret::reference) |
726 | .def("get_or_insert_function" , &ir::module::get_or_insert_function, ret::reference) |
727 | .def("print" , [](ir::module *self) { |
728 | self->print(std::cout); |
729 | }) |
730 | .def("reset_ret_ty" , &ir::module::reset_ret_ty) |
731 | .def("set_instr_metadata" , [](ir::module *self, const std::string &name, ir::value *value) { |
732 | const auto metadatas = self->get_metadatas(); |
733 | auto it = metadatas.find(name); |
734 | if (it != metadatas.end()) |
735 | if (auto *instr = dynamic_cast<ir::instruction*>(value)) { |
736 | instr->set_metadata(it->second.first, it->second.second); |
737 | } |
738 | }) |
739 | .def_property_readonly("builder" , &ir::module::get_builder, ret::reference); |
740 | |
741 | using eattr = ir::attribute_kind_t; |
742 | py::enum_<eattr>(m, "attribute_kind" ) |
743 | .value("readonly" , eattr::readonly) |
744 | .value("writeonly" , eattr::writeonly) |
745 | .value("noalias" , eattr::noalias) |
746 | .value("aligned" , eattr::aligned) |
747 | .value("multiple_of" , eattr::multiple_of) |
748 | .value("retune" , eattr::retune) |
749 | .value("not_implemented" , eattr::not_implemented); |
750 | |
751 | py::class_<ir::attribute>(m, "attribute" ) |
752 | .def(py::init<eattr, int>()) |
753 | .def_property_readonly("value" , &ir::attribute::get_value); |
754 | |
755 | py::class_<ir::function>(m, "function" ) |
756 | .def_property_readonly("args" , &ir::function::args) |
757 | .def_property_readonly("attrs" , &ir::function::attrs) |
758 | .def("set_is_kernel" , &ir::function::set_is_kernel) |
759 | .def("add_attr" , &ir::function::add_attr) |
760 | .def("has_attr" , &ir::function::has_attr) |
761 | .def("get_attrs" , &ir::function::get_attributes); |
762 | |
763 | py::class_<ir::argument, ir::value>(m, "argument" ) |
764 | .def_property_readonly("parent" , &ir::argument::get_parent, ret::reference) |
765 | .def_property_readonly("arg_no" , &ir::argument::get_arg_no); |
766 | |
767 | py::class_<ir::basic_block, ir::value>(m, "basic_block" ) |
768 | .def("create" , &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr) |
769 | .def("get_predecessors" , &ir::basic_block::get_predecessors, ret::reference) |
770 | .def("get_first_non_phi" , [](ir::basic_block *self) -> ir::instruction* { |
771 | ir::basic_block::iterator it = self->get_first_non_phi(); |
772 | if (it == self->end()) |
773 | return nullptr; |
774 | return *it; |
775 | }, ret::reference) |
776 | .def_property_readonly("parent" , &ir::basic_block::get_parent, ret::reference); |
777 | |
778 | py::class_<ir::builder::iterator>(m, "bb_iterator" ); |
779 | |
780 | py::class_<ir::builder>(m, "builder" , py::dynamic_attr()) |
781 | .def(py::init<ir::context &>()) |
782 | // getters |
783 | .def_property_readonly("context" , &ir::builder::get_context, ret::reference) |
784 | // control flow |
785 | .def("call" , &ir::builder::create_call, ret::reference) |
786 | .def("launch" , &ir::builder::create_launch, ret::reference) |
787 | .def("br" , &ir::builder::create_br, ret::reference) |
788 | .def("cond_br" , &ir::builder::create_cond_br, ret::reference) |
789 | .def("ret_void" , &ir::builder::create_ret_void, ret::reference) |
790 | .def("ret" , &ir::builder::create_ret, ret::reference) |
791 | // insertion block/point, insert points are represented as (*bb, *instr) |
792 | .def("get_insert_block" , &ir::builder::get_insert_block, ret::reference) |
793 | .def("set_insert_block" , (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point) |
794 | .def("get_insert_point" , [](ir::builder *self) { |
795 | ir::basic_block *bb = self->get_insert_block(); |
796 | ir::basic_block::iterator it = self->get_insert_point(); |
797 | ir::instruction *instr = it == bb->end() ? nullptr : *it; |
798 | return std::make_pair(bb, instr); |
799 | }, ret::reference) |
800 | .def("set_insert_point" , [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) { |
801 | ir::basic_block *bb = pt.first; |
802 | ir::instruction *instr = pt.second; |
803 | if (instr) { |
804 | if (bb != instr->get_parent()) |
805 | throw std::runtime_error("invalid insertion point, instr not in bb" ); |
806 | self->set_insert_point(instr); |
807 | } else { |
808 | assert(bb); |
809 | self->set_insert_point(bb); |
810 | } |
811 | }) |
812 | // Constants |
813 | .def("get_int1" , &ir::builder::get_int1, ret::reference) |
814 | .def("get_int32" , [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference) |
815 | .def("get_uint32" , &ir::builder::get_int32, ret::reference) |
816 | .def("get_int64" , [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference) |
817 | .def("get_uint64" , &ir::builder::get_int64, ret::reference) |
818 | .def("get_float16" , &ir::builder::get_float16, ret::reference) |
819 | .def("get_float32" , &ir::builder::get_float32, ret::reference) |
820 | .def("get_range" , &ir::builder::get_range, ret::reference) |
821 | // Types |
822 | .def("get_void_ty" , &ir::builder::get_void_ty, ret::reference) |
823 | .def("get_int1_ty" , &ir::builder::get_int1_ty, ret::reference) |
824 | .def("get_int8_ty" , &ir::builder::get_int8_ty, ret::reference) |
825 | .def("get_int16_ty" , &ir::builder::get_int16_ty, ret::reference) |
826 | .def("get_int32_ty" , &ir::builder::get_int32_ty, ret::reference) |
827 | .def("get_int64_ty" , &ir::builder::get_int64_ty, ret::reference) |
828 | .def("get_fp8_ty" , &ir::builder::get_fp8_ty, ret::reference) |
829 | .def("get_half_ty" , &ir::builder::get_half_ty, ret::reference) |
830 | .def("get_bf16_ty" , &ir::builder::get_bf16_ty, ret::reference) |
831 | .def("get_float_ty" , &ir::builder::get_float_ty, ret::reference) |
832 | .def("get_double_ty" , &ir::builder::get_double_ty, ret::reference) |
833 | // terminator instructions |
834 | .def("create_br" , &ir::builder::create_br, ret::reference) |
835 | .def("create_cond_br" , &ir::builder::create_cond_br, ret::reference) |
836 | .def("create_ret_void" , &ir::builder::create_ret_void, ret::reference) |
837 | // Dequantize instructions |
838 | .def("create_dequantize" , &ir::builder::create_dequantize, ret::reference) |
839 | // Cast instructions |
840 | .def("create_bitcast" , &ir::builder::create_bitcast, ret::reference) |
841 | .def("create_cast" , &ir::builder::create_cast, ret::reference) |
842 | .def("create_ptr_to_int" , &ir::builder::create_ptr_to_int, ret::reference) |
843 | .def("create_si_to_fp" , &ir::builder::create_si_to_fp, ret::reference) |
844 | .def("create_ui_to_fp" , &ir::builder::create_ui_to_fp, ret::reference) |
845 | .def("create_fp_to_si" , &ir::builder::create_fp_to_si, ret::reference) |
846 | .def("create_fp_to_ui" , &ir::builder::create_fp_to_ui, ret::reference) |
847 | .def("create_fp_ext" , &ir::builder::create_fp_ext, ret::reference) |
848 | .def("create_fp_trunc" , &ir::builder::create_fp_trunc, ret::reference) |
849 | .def("create_int_cast" , &ir::builder::create_int_cast, ret::reference) |
850 | .def("create_downcast" , &ir::builder::create_downcast, ret::reference) |
851 | .def("create_int_to_ptr" , &ir::builder::create_int_to_ptr, ret::reference) |
852 | .def("create_ptr_to_int" , &ir::builder::create_ptr_to_int, ret::reference) |
853 | // phi |
854 | .def("create_phi" , &ir::builder::create_phi, ret::reference) |
855 | // Binary instructions |
856 | .def("create_insert_nuwnswb_binop" , &ir::builder::create_insert_nuwnswb_binop, ret::reference) |
857 | .def("create_fmul" , &ir::builder::create_fmul, ret::reference) |
858 | .def("create_fdiv" , &ir::builder::create_fdiv, ret::reference) |
859 | .def("create_frem" , &ir::builder::create_frem, ret::reference) |
860 | .def("create_fadd" , &ir::builder::create_fadd, ret::reference) |
861 | .def("create_fsub" , &ir::builder::create_fsub, ret::reference) |
862 | .def("create_mul" , &ir::builder::create_mul, ret::reference, |
863 | py::arg("lhs" ), py::arg("rhs" ), |
864 | py::arg("has_nuw" )=false, py::arg("has_nsw" )=false) |
865 | .def("create_sdiv" , &ir::builder::create_sdiv, ret::reference) |
866 | .def("create_udiv" , &ir::builder::create_udiv, ret::reference) |
867 | .def("create_srem" , &ir::builder::create_srem, ret::reference) |
868 | .def("create_urem" , &ir::builder::create_urem, ret::reference) |
869 | .def("create_add" , &ir::builder::create_add, ret::reference, |
870 | py::arg("lhs" ), py::arg("rhs" ), |
871 | py::arg("has_nuw" )=false, py::arg("has_nsw" )=false) |
872 | .def("create_sub" , &ir::builder::create_sub, ret::reference, |
873 | py::arg("lhs" ), py::arg("rhs" ), |
874 | py::arg("has_nuw" )=false, py::arg("has_nsw" )=false) |
875 | .def("create_shl" , &ir::builder::create_shl, ret::reference, |
876 | py::arg("lhs" ), py::arg("rhs" ), |
877 | py::arg("has_nuw" )=false, py::arg("has_nsw" )=false) |
878 | .def("create_lshr" , &ir::builder::create_lshr, ret::reference, |
879 | py::arg("lhs" ), py::arg("rhs" ), |
880 | py::arg("has_nuw" )=false, py::arg("has_nsw" )=false) |
881 | .def("create_ashr" , &ir::builder::create_ashr, ret::reference, |
882 | py::arg("lhs" ), py::arg("rhs" ), |
883 | py::arg("has_nuw" )=false, py::arg("has_nsw" )=false) |
884 | // GEP |
885 | .def("create_gep" , &ir::builder::create_gep, ret::reference) |
886 | // Comparison (int) |
887 | .def("create_icmp" , &ir::builder::create_icmp, ret::reference) |
888 | .def("create_icmpSLE" , &ir::builder::create_icmpSLE, ret::reference) |
889 | .def("create_icmpSLT" , &ir::builder::create_icmpSLT, ret::reference) |
890 | .def("create_icmpSGE" , &ir::builder::create_icmpSGE, ret::reference) |
891 | .def("create_icmpSGT" , &ir::builder::create_icmpSGT, ret::reference) |
892 | .def("create_icmpULE" , &ir::builder::create_icmpULE, ret::reference) |
893 | .def("create_icmpULT" , &ir::builder::create_icmpULT, ret::reference) |
894 | .def("create_icmpUGE" , &ir::builder::create_icmpUGE, ret::reference) |
895 | .def("create_icmpUGT" , &ir::builder::create_icmpUGT, ret::reference) |
896 | .def("create_icmpEQ" , &ir::builder::create_icmpEQ, ret::reference) |
897 | .def("create_icmpNE" , &ir::builder::create_icmpNE, ret::reference) |
898 | // Comparison (float) |
899 | .def("create_fcmp" , &ir::builder::create_fcmp, ret::reference) |
900 | .def("create_fcmpOLT" , &ir::builder::create_fcmpOLT, ret::reference) |
901 | .def("create_fcmpOGT" , &ir::builder::create_fcmpOGT, ret::reference) |
902 | .def("create_fcmpOLE" , &ir::builder::create_fcmpOLE, ret::reference) |
903 | .def("create_fcmpOGE" , &ir::builder::create_fcmpOGE, ret::reference) |
904 | .def("create_fcmpOEQ" , &ir::builder::create_fcmpOEQ, ret::reference) |
905 | .def("create_fcmpONE" , &ir::builder::create_fcmpONE, ret::reference) |
906 | .def("create_fcmpULT" , &ir::builder::create_fcmpULT, ret::reference) |
907 | .def("create_fcmpUGT" , &ir::builder::create_fcmpUGT, ret::reference) |
908 | .def("create_fcmpULE" , &ir::builder::create_fcmpULE, ret::reference) |
909 | .def("create_fcmpUGE" , &ir::builder::create_fcmpUGE, ret::reference) |
910 | .def("create_fcmpUEQ" , &ir::builder::create_fcmpUEQ, ret::reference) |
911 | .def("create_fcmpUNE" , &ir::builder::create_fcmpUNE, ret::reference) |
912 | // Logical |
913 | .def("create_and" , &ir::builder::create_and, ret::reference) |
914 | .def("create_xor" , &ir::builder::create_xor, ret::reference) |
915 | .def("create_or" , &ir::builder::create_or, ret::reference) |
916 | // Input/Output |
917 | .def("create_load" , &ir::builder::create_load, ret::reference) |
918 | .def("create_store" , &ir::builder::create_store, ret::reference) |
919 | .def("create_masked_load" , &ir::builder::create_masked_load, ret::reference) |
920 | .def("create_masked_store" , &ir::builder::create_masked_store, ret::reference) |
921 | // Block instruction |
922 | .def("create_splat" , &ir::builder::create_splat, ret::reference) |
923 | .def("create_reshape" , &ir::builder::create_reshape, ret::reference) |
924 | .def("create_cat" , &ir::builder::create_cat, ret::reference) |
925 | .def("create_broadcast" , &ir::builder::create_broadcast, ret::reference) |
926 | // atomic |
927 | .def("create_atomic_cas" , &ir::builder::create_atomic_cas, ret::reference) |
928 | .def("create_atomic_rmw" , &ir::builder::create_atomic_rmw, ret::reference) |
929 | // Utilities |
930 | .def("create_clock" , &ir::builder::create_clock, ret::reference) |
931 | .def("create_globaltimer" , &ir::builder::create_globaltimer, ret::reference) |
932 | // Extern instruction |
933 | .def("create_extern_elementwise" , &ir::builder::create_extern_elementwise, ret::reference) |
934 | // Built-in instruction |
935 | .def("create_get_program_id" , &ir::builder::create_get_program_id, ret::reference) |
936 | .def("create_get_num_programs" , &ir::builder::create_get_num_programs, ret::reference) |
937 | .def("create_exp" , &ir::builder::create_exp, ret::reference) |
938 | .def("create_cos" , &ir::builder::create_cos, ret::reference) |
939 | .def("create_sin" , &ir::builder::create_sin, ret::reference) |
940 | .def("create_log" , &ir::builder::create_log, ret::reference) |
941 | .def("create_dot" , &ir::builder::create_dot, ret::reference) |
942 | .def("create_trans" , &ir::builder::create_trans, ret::reference) |
943 | .def("create_sqrt" , &ir::builder::create_sqrt, ret::reference) |
944 | .def("create_reduce" , &ir::builder::create_reduce, ret::reference) |
945 | .def("create_select" , &ir::builder::create_select, ret::reference) |
946 | // struct |
947 | .def("insert_value" , &ir::builder::create_insert_value, ret::reference) |
948 | .def("extract_value" , &ir::builder::create_extract_value, ret::reference) |
949 | // Intrinsics |
950 | // These have no place in the IR, and hopefully they can be removed at some point |
951 | .def("create_umulhi" , &ir::builder::create_umulhi, ret::reference) |
952 | .def("create_copy_to_shared" , &ir::builder::create_copy_to_shared, ret::reference) |
953 | .def("create_masked_load_async" , &ir::builder::create_masked_load_async, ret::reference) |
954 | .def("create_copy_from_shared" , &ir::builder::create_copy_from_shared, ret::reference) |
955 | .def("create_barrier" , &ir::builder::create_barrier, ret::reference) |
956 | .def("create_async_wait" , &ir::builder::create_async_wait, ret::reference) |
957 | .def("create_prefetch_s" , &ir::builder::create_prefetch_s, ret::reference); |
958 | } |
959 | |
960 | void init_triton(py::module &m) { |
961 | py::module subm = m.def_submodule("triton" ); |
962 | init_triton_codegen(std::move(subm.def_submodule("code_gen" ))); |
963 | init_triton_runtime(std::move(subm.def_submodule("runtime" ))); |
964 | init_triton_ir(std::move(subm.def_submodule("ir" ))); |
965 | } |
966 | |