1#include "triton/codegen/pass.h"
2#include "triton/codegen/target.h"
3#include "triton/codegen/extern_lib.h"
4#include "triton/driver/error.h"
5#include "triton/driver/llvm.h"
6#include "triton/ir/builder.h"
7#include "triton/ir/enums.h"
8#include "triton/ir/function.h"
9#include "triton/ir/module.h"
10#include "triton/ir/print.h"
11#include <optional>
12#include <pybind11/buffer_info.h>
13#include <pybind11/functional.h>
14#include <pybind11/pybind11.h>
15#include <pybind11/stl_bind.h>
16#include <pybind11/stl.h>
17#include "Python.h"
18#include <regex>
19#include <sstream>
20#include <stdexcept>
21#include <string>
22#include "llvm/IR/Module.h"
23#include "llvm/IR/Verifier.h"
24
25namespace py = pybind11;
26namespace ir = triton::ir;
27namespace drv = triton::driver;
28
29
30/*****************************************************************************/
31/* Python bindings for triton::driver */
32/*****************************************************************************/
33// information query
34template<CUdevice_attribute attr>
35int cuGetInfo(CUdevice device) {
36 int res;
37 drv::dispatch::cuDeviceGetAttribute(&res, attr, device);
38 return res;
39}
40
41template<hipDeviceAttribute_t attr>
42int hipGetInfo(hipDevice_t device) {
43 int res;
44 drv::dispatch::hipDeviceGetAttribute(&res, attr, device);
45 return res;
46}
47
48enum backend_t {
49 HOST,
50 CUDA,
51 ROCM,
52};
53
54void cu_enable_peer_access(uint64_t peer_ptr){
55 CUcontext context;
56 drv::dispatch::cuPointerGetAttribute(&context, CU_POINTER_ATTRIBUTE_CONTEXT, peer_ptr);
57 try {
58 drv::dispatch::cuCtxEnablePeerAccess(context, 0);
59 } catch (drv::exception::cuda::peer_access_already_enabled) {}
60}
61
62void host_enqueue(uint64_t stream, uint64_t kernel,
63 uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
64 uint64_t block_0, uint64_t block_1, uint64_t block_2,
65 void* args_ptr, size_t args_size, int64_t shared_mem){
66 throw std::runtime_error("unsupported");
67// auto hst = kernel->module()->hst();
68// hst_->futures->reserve(hst_->futures->size() + grid[0]*grid[1]*grid[2]);
69// char* params = new char[args_size];
70// std::memcpy((void*)params, (void*)args, args_size);
71// for(size_t i = 0; i < grid[0]; i++)
72// for(size_t j = 0; j < grid[1]; j++)
73// for(size_t k = 0; k < grid[2]; k++)
74// hst_->futures->emplace_back(hst_->pool->enqueue(hst->fn, (char**)params, int32_t(i), int32_t(j), int32_t(k)));
75}
76
77void cu_enqueue(uint64_t stream, uint64_t kernel,
78 uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
79 uint64_t block_0, uint64_t block_1, uint64_t block_2,
80 void* args_ptr, size_t args_size, int64_t shared_mem){
81 void *config[] = {
82 CU_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr,
83 CU_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
84 CU_LAUNCH_PARAM_END
85 };
86 drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
87 block_0, block_1, block_2,
88 shared_mem, (CUstream)stream, nullptr, config);
89}
90
91void hip_enqueue(uint64_t stream, uint64_t kernel,
92 uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
93 uint64_t block_0, uint64_t block_1, uint64_t block_2,
94 void* args_ptr, size_t args_size, int64_t shared_mem) {
95 void *config[] = {
96 HIP_LAUNCH_PARAM_BUFFER_POINTER, (void*)args_ptr,
97 HIP_LAUNCH_PARAM_BUFFER_SIZE, &args_size,
98 HIP_LAUNCH_PARAM_END
99 };
100 drv::dispatch::hipModuleLaunchKernel((hipFunction_t)kernel, grid_0, grid_1, grid_2,
101 block_0, block_1, block_2,
102 shared_mem, (hipStream_t)stream, nullptr, config);
103
104}
105
106long pow2_divisor(long N){
107 if(N % 16 == 0) return 16;
108 if(N % 8 == 0) return 8;
109 if(N % 4 == 0) return 4;
110 if(N % 2 == 0) return 2;
111 return 1;
112}
113
114// Returns something like "int16", whether dtype is a torch.dtype or
115// triton.language.dtype.
116std::string dtype_cache_key_part(const py::object& dtype) {
117 if (py::hasattr(dtype, "cache_key_part")) {
118 // Presumed to be a triton.language.dtype.
119 return std::string(py::str(py::getattr(dtype, "cache_key_part")));
120 } else {
121 // Remove 'torch.' prefix from repr of torch.dtype.
122 py::object repr = py::repr(dtype);
123 size_t repr_len = PyUnicode_GET_LENGTH(repr.ptr());
124 const char* repr_ptr = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
125 if (repr_len <= 6 || strncmp(repr_ptr, "torch.", 6)) {
126 throw std::logic_error("invalid dtype: " + std::string(repr_ptr, repr_len));
127 }
128 return std::string(repr_ptr + 6, repr_len - 6);
129 }
130}
131
132size_t get_pointer_range_size(uint64_t addr){
133 if(addr == 0)
134 return 0;
135 size_t size;
136 drv::dispatch::cuPointerGetAttribute(&size, CU_POINTER_ATTRIBUTE_RANGE_SIZE, (CUdeviceptr)addr);
137 return size;
138}
139
140// Launch
141void parse_args(py::list& args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
142 std::string& cache_key, std::string& params, size_t& params_size, py::dict constants,
143 int num_warps, int num_stages, py::dict& extern_libs) {
144 size_t len = PyList_Size(args.ptr());
145 params.reserve(8*len); // 8 max bytes by argument
146 char* params_ptr = &params[0];
147 cache_key = func_key;
148 cache_key += "-" + std::to_string(num_warps);
149 cache_key += "-" + std::to_string(num_stages);
150 cache_key += "-";
151 for(int i = 0; i < len; i++){
152 cache_key += "_";
153 py::int_ py_i = py::int_(i);
154 bool specialize = !do_not_specialize.contains(py_i);
155 py::object arg = args[i];
156 auto arg_ptr = arg.ptr();
157
158 // argument is `long`
159 if(PyLong_Check(arg_ptr)){
160 int overflow;
161 long long value = PyLong_AsLongLongAndOverflow(arg_ptr, &overflow);
162 // values equal to 1 are specialized
163 if(specialize && (value == 1)){
164 cache_key += "1";
165 continue;
166 }
167 // int32, uint32, int64, and uint64 have different kernels
168 if (!overflow && -0x8000'0000LL <= value && value <= 0x7FFF'FFFFLL) {
169 cache_key += "int32";
170 params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
171 std::memcpy(params_ptr, &value, 4);
172 params_ptr += 4;
173 } else if (!overflow && 0x8000'0000LL <= value && value <= 0xFFFF'FFFFLL) {
174 cache_key += "uint32";
175 params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
176 std::memcpy(params_ptr, &value, 4);
177 params_ptr += 4;
178 } else if (!overflow) {
179 cache_key += "int64";
180 params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
181 std::memcpy(params_ptr, &value, 8);
182 params_ptr += 8;
183 } else {
184 if (PyErr_Occurred()) {
185 throw std::logic_error("An error occurred?");
186 }
187 unsigned long long unsigned_value = PyLong_AsUnsignedLongLong(arg_ptr);
188 if (PyErr_Occurred()) {
189 throw std::runtime_error("integer overflow in argument: " + std::string(py::str(arg)));
190 }
191 cache_key += "uint64";
192 params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
193 std::memcpy(params_ptr, &unsigned_value, 8);
194 params_ptr += 8;
195 }
196 if(!specialize)
197 continue;
198 // values divisible by small powers of 2 are specialized
199 cache_key += "[multipleof(";
200 cache_key += std::to_string(pow2_divisor(value));
201 cache_key += ")]";
202 continue;
203 }
204 // argument is `float`
205 if(PyFloat_Check(arg_ptr)){
206 cache_key += "float32";
207 float value = PyFloat_AsDouble(arg_ptr);
208 params_ptr = (char*)(((uintptr_t)params_ptr + 3) & (-4));
209 std::memcpy(params_ptr, &value, 4);
210 params_ptr += 4;
211 continue;
212 }
213 // argument is `bool`
214 if(PyBool_Check(arg_ptr)){
215 cache_key += "bool";
216 bool value = arg_ptr == Py_True ? true : false;
217 std::memcpy(params_ptr, &value, 1);
218 params_ptr += 1;
219 continue;
220 }
221 // argument is tensor
222 if(py::hasattr(arg, "data_ptr")){
223 py::object data_ptr = arg.attr("data_ptr")();
224 long value = data_ptr.cast<long>();
225 params_ptr = (char*)(((uintptr_t)params_ptr + 7) & (-8));
226 // copy param
227 std::memcpy(params_ptr, &value, 8);
228 params_ptr += 8;
229 // update cache key
230 cache_key += dtype_cache_key_part(arg.attr("dtype"));
231 cache_key += "*";
232 cache_key += "[multipleof(";
233 size_t range_size;
234 try {
235 range_size = get_pointer_range_size(value);
236 } catch (...) {
237 throw std::runtime_error("argument tensor #" + std::to_string(i) + " is not on cuda! " + std::string(py::str(arg)));
238 }
239 cache_key += std::to_string(std::min(pow2_divisor(value), pow2_divisor(range_size)));
240 cache_key += ")]";
241 continue;
242 }
243 // argument is `constexpr`
244 if (py::hasattr(arg, "value")) {
245 py::object value = arg.attr("value");
246 // check if value is a callable object using PyCallable_Check
247 if (PyCallable_Check(value.ptr())) {
248 throw std::runtime_error(
249 "constant argument cannot be a callable object: " +
250 std::string(py::str(arg)));
251 }
252 py::object name = arg_names[i];
253 constants[name] = value;
254 py::object repr = py::repr(value);
255 const char* start = (const char*)PyUnicode_1BYTE_DATA(repr.ptr());
256 size_t len = PyUnicode_GET_LENGTH(repr.ptr());
257 cache_key += std::string(start, len);
258 continue;
259 }
260 std::string ty_str = arg.attr("__class__").attr("__name__").cast<std::string>();
261 if(ty_str == "NoneType"){
262 cache_key += "None";
263 continue;
264 }
265 std::string err_msg = "Received type '" + ty_str + "' for argument " + std::to_string(i) + "."
266 + " Only int, float, bool, torch.Tensor, and triton.language.constexpr are supported.";
267 throw std::runtime_error(err_msg);
268 }
269 params_size = (std::ptrdiff_t)(params_ptr - &params[0]);
270
271 for (auto item : extern_libs) {
272 cache_key += "-" + item.first.cast<std::string>();
273 cache_key += "_" + item.second.cast<std::string>();
274 }
275}
276
277//
278
279void init_triton_runtime(py::module &&m) {
280
281 // m.def("current_stream", [](uint64_t device){
282 // return (uint64_t)(c10::cuda::getCurrentCUDAStream(device).stream());
283 // });
284
285 // wrap backend_t
286 py::enum_<backend_t>(m, "backend")
287 .value("HOST", HOST)
288 .value("CUDA", CUDA)
289 .value("ROCM", ROCM)
290 .export_values();
291
292 // enable peer-to-peer
293 m.def("enable_peer_access", [](backend_t backend, uint64_t peer_ptr) {
294 if (backend != CUDA)
295 throw std::runtime_error("P2P only supported on CUDA devices!");
296 cu_enable_peer_access(peer_ptr);
297 }
298 );
299
300 // get range size for the given pointer
301 m.def("get_pointer_range_size", &get_pointer_range_size);
302
303
304 // cache key
305 m.def("launch", [](py::list args, py::list do_not_specialize, const std::string& func_key, py::list& arg_names,
306 py::object device, py::int_ stream, py::dict bin_cache, py::int_ num_warps, py::int_ num_stages,
307 py::dict extern_libs, py::function add_to_cache, py::object grid){
308 // parse arguments to compute cache key, compile-time constants and packed kernel arguments
309 long _num_warps = PyLong_AsLong(num_warps.ptr());
310 long _num_stages = PyLong_AsLong(num_stages.ptr());
311 std::string cache_key;
312 std::string params;
313 size_t params_size;
314 py::dict constants;
315 parse_args(args, do_not_specialize, func_key, arg_names, cache_key, params,
316 params_size, constants, _num_warps, _num_stages, extern_libs);
317
318 // get cached binary
319 py::str key(cache_key);
320 py::bool_ noop = false;
321 if(!bin_cache.contains(key)) {
322 noop = add_to_cache(key, args, device, num_warps, num_stages, extern_libs);
323 }
324 if (noop)
325 return (py::object)py::none();
326 py::object bin = bin_cache[key];
327
328 // get grid
329 py::sequence seq;
330 if(!PySequence_Check(grid.ptr()))
331 seq = grid(constants);
332 else
333 seq = grid;
334 int size = seq.size();
335 int grid_0 = py::cast<int>(seq[0]);
336 int grid_1 = size < 2 ? 1 : py::cast<int>(seq[1]);
337 int grid_2 = size < 3 ? 1 : py::cast<int>(seq[2]);
338
339 // enqueue
340 uint64_t kernel = py::cast<uint64_t>(bin.attr("kernel"));
341 uint64_t shared_mem = py::cast<uint64_t>(bin.attr("shared_mem"));
342
343 // actually launch
344 void *config[] = {
345 CU_LAUNCH_PARAM_BUFFER_POINTER, params.data(),
346 CU_LAUNCH_PARAM_BUFFER_SIZE, &params_size,
347 CU_LAUNCH_PARAM_END
348 };
349 uint64_t _stream = PyLong_AsLong(stream.ptr());
350 if(grid_0*grid_1*grid_2 > 0) {
351 // release the gil in case the enqueue blocks
352 // cuda will block if too many ops are enqueued
353 py::gil_scoped_release allow_threads;
354 drv::dispatch::cuLaunchKernel((CUfunction)kernel, grid_0, grid_1, grid_2,
355 _num_warps*32, 1, 1, shared_mem, (CUstream)_stream,
356 nullptr, config);
357 }
358 return bin;
359 });
360
361 m.def("cc", [](backend_t backend, uint64_t device) -> int {
362 if (backend == CUDA) {
363 CUdevice dev = (CUdevice)device;
364 int major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
365 int minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
366 return major*10 + minor;
367 }
368 return -1;
369 });
370
371 // query maximum shared memory
372 m.def("max_shared_memory", [](backend_t backend, uint64_t device) {
373 if (backend == HOST)
374 return 0;
375 if(backend == CUDA)
376 return cuGetInfo<CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN>(device);
377 if(backend == ROCM)
378 return hipGetInfo<hipDeviceAttributeMaxSharedMemoryPerBlock>(device);
379 return -1;
380 });
381
382 // query DRAM & L2 cache
383 m.def("memory_clock_rate", [](backend_t backend, uint64_t device) {
384 if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MEMORY_CLOCK_RATE>(device);
385 return -1;
386 });
387 m.def("global_memory_bus_width", [](backend_t backend, uint64_t device) {
388 if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_GLOBAL_MEMORY_BUS_WIDTH>(device);
389 return -1;
390 });
391 m.def("l2_cache_size", [](backend_t backend, uint64_t device) {
392 if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_L2_CACHE_SIZE>(device);
393 return -1;
394 });
395
396 // query clock rate (in kilohertz)
397 m.def("clock_rate", [](backend_t backend, uint64_t device) {
398 if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_CLOCK_RATE>(device);
399 return -1;
400 });
401
402 m.def("num_sm", [](backend_t backend, uint64_t device) {
403 if (backend == CUDA) return cuGetInfo<CU_DEVICE_ATTRIBUTE_MULTIPROCESSOR_COUNT>(device);
404 return -1;
405 });
406
407 // enqueue
408 m.def("enqueue", [](backend_t backend, uint64_t stream, uint64_t kernel,
409 uint64_t grid_0, uint64_t grid_1, uint64_t grid_2,
410 uint64_t block_0, uint64_t block_1, uint64_t block_2,
411 const std::string &args, int64_t shared_mem){
412 void* args_ptr = (void*)args.data();
413 size_t args_size = args.size();
414 // release the gil in case the enqueue blocks
415 // cuda will block if too many ops are enqueued
416 py::gil_scoped_release allow_threads;
417 if(backend == HOST)
418 host_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
419 if(backend == CUDA)
420 cu_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
421 if(backend == ROCM)
422 hip_enqueue(stream, kernel, grid_0, grid_1, grid_2, block_0, block_1, block_2, args_ptr, args_size, shared_mem);
423 });
424
425
426}
427
428/*****************************************************************************/
429/* Python bindings for triton::codegen */
430/*****************************************************************************/
431typedef std::map<std::string, py::object> asm_map_t;
432
433// ---------------------------------------
434// Compile Triton-IR to assembly
435// ---------------------------------------
436
437void init_triton_codegen(py::module &&m) {
438 m.def("compile_ttir",
439 [](backend_t backend, ir::module &ir, uint64_t device, int num_warps, int num_stages, py::dict& extern_libs, size_t cc) {
440 std::ostringstream ttir;
441 int n_shared_bytes;
442 std::string tmp;
443 std::string ptx;
444 std::string cubin;
445 std::string name;
446 { // Scope where the GIL is released
447 py::gil_scoped_release allow_threads;
448 name = ir.get_function_list()[0]->get_name();
449 ir.print(ttir);
450 llvm::LLVMContext ctx;
451 // construct extern lib map
452 triton::codegen::ExternLibMap extern_lib_map;
453 for (auto item : extern_libs) {
454 auto name = item.first.cast<std::string>();
455 auto path = item.second.cast<std::string>();
456 extern_lib_map.emplace(
457 name, triton::codegen::create_extern_lib(name, path));
458 }
459 // device properties
460 if (cc == 0) {
461 CUdevice dev = (CUdevice)device;
462 size_t major = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MAJOR>(dev);
463 size_t minor = cuGetInfo<CU_DEVICE_ATTRIBUTE_COMPUTE_CAPABILITY_MINOR>(dev);
464 cc = major*10 + minor;
465 }
466 int version;
467 std::string ptxas_path = drv::path_to_ptxas(version);
468 // Triton-IR -> NVPTX LLVM-IR
469 triton::codegen::nvidia_cu_target target(cc);
470 auto llvm = triton::codegen::add_passes_to_emit_bin(
471 ir, ctx, &target, num_warps, num_stages, n_shared_bytes, extern_lib_map);
472 llvm::raw_string_ostream llir(tmp);
473 llir << *llvm;
474 llir.flush();
475 // LLVM-IR -> PTX
476 ptx = drv::llir_to_ptx(llvm.get(), cc, version);
477 // PTX -> Binary
478 cubin = drv::ptx_to_cubin(ptx, ptxas_path, cc);
479 }
480 asm_map_t asm_map;
481 asm_map["ttir"] = py::cast(ttir.str());
482 asm_map["llir"] = py::cast(tmp);
483 asm_map["ptx"] = py::cast(ptx);
484
485 if(!cubin.empty()){
486 py::bytes bytes(cubin);
487 asm_map["cubin"] = bytes;
488 }
489 return std::make_tuple(name, asm_map, n_shared_bytes);
490 },
491 py::return_value_policy::take_ownership);
492
493
494 // ---------------------------------------
495 // Load provided assembly code into driver
496 // ---------------------------------------
497 m.def("load_binary", [](const std::string& name, const std::string& data, size_t n_shared_bytes, uint64_t device){
498 py::gil_scoped_release allow_threads;
499 // create driver handles
500 CUfunction fun;
501 CUmodule mod;
502 drv::dispatch::cuModuleLoadData(&mod, data.c_str());
503 drv::dispatch::cuModuleGetFunction(&fun, mod, name.c_str());
504 // get allocated registers and spilled registers from the function
505 int n_regs = 0;
506 int n_spills = 0;
507 drv::dispatch::cuFuncGetAttribute(&n_regs, CU_FUNC_ATTRIBUTE_NUM_REGS, fun);
508 drv::dispatch::cuFuncGetAttribute(&n_spills, CU_FUNC_ATTRIBUTE_LOCAL_SIZE_BYTES, fun);
509 n_spills /= 4;
510 // set dynamic shared memory if necessary
511 int shared_optin;
512 drv::dispatch::cuDeviceGetAttribute(&shared_optin, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_BLOCK_OPTIN, device);
513 if(n_shared_bytes > 49152 && shared_optin > 49152){
514 drv::dispatch::cuFuncSetCacheConfig(fun, CU_FUNC_CACHE_PREFER_SHARED);
515 int shared_total, shared_static;
516 drv::dispatch::cuDeviceGetAttribute(&shared_total, CU_DEVICE_ATTRIBUTE_MAX_SHARED_MEMORY_PER_MULTIPROCESSOR, device);
517 drv::dispatch::cuFuncGetAttribute(&shared_static, CU_FUNC_ATTRIBUTE_SHARED_SIZE_BYTES, fun);
518 drv::dispatch::cuFuncSetAttribute(fun, CU_FUNC_ATTRIBUTE_MAX_DYNAMIC_SHARED_SIZE_BYTES, shared_optin - shared_static);
519 }
520 return std::make_tuple((uint64_t)mod, (uint64_t)fun, (uint64_t)n_regs, (uint64_t)n_spills);
521 },
522 py::return_value_policy::take_ownership
523 );
524
525
526 struct InstanceDescriptor
527 {
528 std::unordered_set<int> divisibleBy16;
529 std::unordered_set<int> equalTo1;
530 };
531
532 py::class_<InstanceDescriptor>(m, "instance_descriptor")
533 .def(py::init<>())
534 .def(py::init<std::unordered_set<int>, std::unordered_set<int>>())
535 .def_readonly("divisible_by_16", &InstanceDescriptor::divisibleBy16)
536 .def_readonly("equal_to_1", &InstanceDescriptor::equalTo1);
537}
538
539
540/*****************************************************************************/
541/* Python bindings for triton::ir */
542/*****************************************************************************/
543
544void init_triton_ir(py::module &&m) {
545 using ret = py::return_value_policy;
546 using namespace pybind11::literals;
547
548 py::enum_<ir::load_inst::CACHE_MODIFIER>(m, "CACHE_MODIFIER")
549 .value("NONE", ir::load_inst::NONE)
550 .value("CA", ir::load_inst::CA)
551 .value("CG", ir::load_inst::CG)
552 .export_values();
553
554 py::enum_<ir::load_inst::EVICTION_POLICY>(m, "EVICTION_POLICY")
555 .value("NORMAL", ir::load_inst::NORMAL)
556 .value("EVICT_FIRST", ir::load_inst::EVICT_FIRST)
557 .value("EVICT_LAST", ir::load_inst::EVICT_LAST)
558 .export_values();
559
560 py::enum_<ir::reduce_inst::op_t>(m, "REDUCE_OP")
561 .value("ADD", ir::reduce_inst::ADD)
562 .value("FADD", ir::reduce_inst::FADD)
563 .value("MIN", ir::reduce_inst::MIN)
564 .value("MAX", ir::reduce_inst::MAX)
565 .value("UMIN", ir::reduce_inst::UMIN)
566 .value("UMAX", ir::reduce_inst::UMAX)
567 .value("ARGMIN", ir::reduce_inst::ARGMIN)
568 .value("ARGMAX", ir::reduce_inst::ARGMAX)
569 .value("ARGUMIN", ir::reduce_inst::ARGUMIN)
570 .value("ARGUMAX", ir::reduce_inst::ARGUMAX)
571 .value("FMIN", ir::reduce_inst::FMIN)
572 .value("FMAX", ir::reduce_inst::FMAX)
573 .value("ARGFMIN", ir::reduce_inst::ARGFMIN)
574 .value("ARGFMAX", ir::reduce_inst::ARGFMAX)
575 .value("XOR", ir::reduce_inst::XOR);
576
577 py::enum_<ir::atomic_rmw_op_t>(m, "ATOMIC_OP")
578 .value("ADD", ir::atomic_rmw_op_t::Add)
579 .value("FADD", ir::atomic_rmw_op_t::FAdd)
580 .value("AND", ir::atomic_rmw_op_t::And)
581 .value("OR", ir::atomic_rmw_op_t::Or)
582 .value("XOR", ir::atomic_rmw_op_t::Xor)
583 .value("XCHG", ir::atomic_rmw_op_t::Xchg)
584 .value("MAX", ir::atomic_rmw_op_t::Max)
585 .value("MIN", ir::atomic_rmw_op_t::Min)
586 .value("UMIN", ir::atomic_rmw_op_t::UMin)
587 .value("UMAX", ir::atomic_rmw_op_t::UMax);
588
589 py::class_<ir::context>(m, "context")
590 .def(py::init<>());
591
592 py::class_<ir::value>(m, "value")
593 .def("multiple_of", [](ir::value *self, std::vector<unsigned> val) {
594 if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
595 instr->set_metadata(ir::metadata::multiple_of, val);
596 } else
597 throw std::runtime_error("multiple_of");
598 })
599 .def("max_contiguous", [](ir::value *self, std::vector<unsigned> val) {
600 if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
601 instr->set_metadata(ir::metadata::max_contiguous, val);
602 } else
603 throw std::runtime_error("max_contiguous");
604 })
605 .def("set_fdiv_ieee_rounding", [](ir::value *self, bool val) {
606 if (auto *instr = dynamic_cast<ir::binary_operator*>(self))
607 instr->set_fdiv_ieee_rounding(val);
608 else
609 throw std::runtime_error("set_fdiv_ieee_rounding");
610 })
611 .def("is_phi", [](ir::value *self) {
612 if (auto *pn = dynamic_cast<ir::phi_node*>(self))
613 return true;
614 return false;
615 })
616 .def("ops", [](ir::value *self) {
617 if (auto *instr = dynamic_cast<ir::instruction*>(self)) {
618 return instr->ops();
619 }
620 throw std::runtime_error("cannot use ops()");
621 })
622 .def("replace_all_uses_with", &ir::value::replace_all_uses_with)
623 .def("erase_from_parent", [](ir::value *self) {
624 if (auto *instr = dynamic_cast<ir::instruction*>(self))
625 return instr->erase_from_parent();
626 throw std::runtime_error("cannot use erase_from_parent");
627 })
628 .def_property("name", &ir::value::get_name, &ir::value::set_name)
629 .def_property_readonly("type", &ir::value::get_type);
630
631 py::class_<ir::user, ir::value>(m, "user");
632
633 py::class_<ir::constant, ir::user>(m, "constant")
634 .def("get_null_value", &ir::constant::get_null_value, ret::reference)
635 .def("get_all_ones_value", &ir::constant::get_all_ones_value, ret::reference);
636
637 py::class_<ir::undef_value, ir::constant>(m, "undef")
638 .def("get", &ir::undef_value::get, ret::reference);
639
640 py::class_<ir::constant_int, ir::constant>(m, "constant_int")
641 .def_property_readonly("value", &ir::constant_int::get_value)
642 .def("__int__", [](ir::constant_int *self) { return self->get_value(); })
643 .def("__bool__", [](ir::constant_int *self) { return self->get_value(); });
644
645 py::class_<ir::constant_fp, ir::constant>(m, "constant_float")
646 .def_property_readonly("value", &ir::constant_fp::get_value)
647 .def("get", [](ir::type* ty, double val) { return ir::constant_fp::get(ty, val); }, ret::reference);
648
649 py::class_<ir::instruction, ir::user>(m, "instruction")
650 .def("get_parent", [](ir::instruction *self) {
651 return self->get_parent();
652 }, ret::reference);
653 py::class_<ir::phi_node, ir::instruction>(m, "phi_node")
654 .def("add_incoming", &ir::phi_node::add_incoming);
655
656 py::class_<ir::type>(m, "type")
657 .def("make_ptr", &ir::pointer_type::get, ret::reference)
658 .def("make_function", &ir::function_type::get, ret::reference)
659 .def("make_block", &ir::block_type::get, ret::reference)
660 .def("get_void", &ir::type::get_void_ty, ret::reference)
661 .def("get_fp8", &ir::type::get_fp8_ty, ret::reference)
662 .def("get_fp16", &ir::type::get_fp16_ty, ret::reference)
663 .def("get_bf16", &ir::type::get_bf16_ty, ret::reference)
664 .def("get_fp32", &ir::type::get_fp32_ty, ret::reference)
665 .def("get_fp64", &ir::type::get_fp64_ty, ret::reference)
666 .def("get_int1", &ir::type::get_int1_ty, ret::reference)
667 .def("get_int8", &ir::type::get_int8_ty, ret::reference)
668 .def("get_int16", &ir::type::get_int16_ty, ret::reference)
669 .def("get_int32", &ir::type::get_int32_ty, ret::reference)
670 .def("get_int64", &ir::type::get_int64_ty, ret::reference)
671 .def("get_fp_mantissa_width", &ir::type::get_fp_mantissa_width, ret::reference)
672
673 .def("get_block_shapes", &ir::type::get_block_shapes)
674
675 .def("is_ptr", &ir::type::is_pointer_ty)
676 .def("is_int", static_cast<bool (ir::type::*)() const>(&ir::type::is_integer_ty))
677 .def("is_floating", &ir::type::is_floating_point_ty)
678 .def("is_block", &ir::type::is_block_ty)
679 .def("is_struct", &ir::type::is_struct_ty)
680 .def("is_void", &ir::type::is_void_ty)
681 .def("is_bool", &ir::type::is_bool_ty)
682 .def("is_fp8", &ir::type::is_fp8_ty)
683 .def("is_fp16", &ir::type::is_fp16_ty)
684 .def("is_bf16", &ir::type::is_bf16_ty)
685 .def("is_fp32", &ir::type::is_fp32_ty)
686 .def("is_fp64", &ir::type::is_fp64_ty)
687 .def("is_int1", [](ir::type *self) { return self->is_integer_ty(1); })
688 .def("is_int8", [](ir::type *self) { return self->is_integer_ty(8); })
689 .def("is_int16", [](ir::type *self) { return self->is_integer_ty(16); })
690 .def("is_int32", [](ir::type *self) { return self->is_integer_ty(32); })
691 .def("is_int64", [](ir::type *self) { return self->is_integer_ty(64); })
692 .def("is_int_or_tileint", &ir::type::is_int_or_tileint_ty)
693
694 .def("repr", &ir::type::repr)
695 .def_property_readonly("fp_mantissa_width", &ir::type::get_fp_mantissa_width)
696 .def_property_readonly("scalar", &ir::type::get_scalar_ty)
697 .def_property_readonly("context", &ir::type::get_context, ret::reference)
698 .def_property_readonly("int_bitwidth", &ir::type::get_integer_bitwidth)
699 .def_property_readonly("primitive_bitwidth", &ir::type::get_primitive_size_in_bits);
700
701 py::class_<ir::pointer_type, ir::type>(m, "pointer_type")
702 .def_property_readonly("element", &ir::pointer_type::get_element_ty, ret::reference)
703 .def_property_readonly("address_space", &ir::pointer_type::get_pointer_address_space, ret::reference);
704
705 py::class_<ir::function_type, ir::type>(m, "function_type")
706 .def_property_readonly("ret_ty", &ir::function_type::get_return_ty)
707 .def_property_readonly("arg_tys", [](ir::function_type* self){
708 return std::vector<ir::type*>(self->params_begin(), self->params_end());
709 });
710
711 py::class_<ir::integer_type, ir::type>(m, "integer_type");
712
713 py::class_<ir::block_type, ir::type>(m, "block_type")
714 .def_property_readonly("shape", &ir::block_type::get_shapes)
715 .def_property_readonly("numel", &ir::type::get_tile_num_elements);
716
717 py::class_<ir::struct_type, ir::type>(m, "struct_type")
718 .def("get", &ir::struct_type::get, ret::reference)
719 .def_property_readonly("num_types", &ir::struct_type::get_num_types);
720
721 py::class_<ir::module>(m, "module", py::dynamic_attr())
722 .def(py::init<std::string, ir::builder &>())
723 .def("has_function", &ir::module::has_function)
724 .def("get_function", &ir::module::get_function, ret::reference)
725 .def("get_functions", &ir::module::get_function_list, ret::reference)
726 .def("get_or_insert_function", &ir::module::get_or_insert_function, ret::reference)
727 .def("print", [](ir::module *self) {
728 self->print(std::cout);
729 })
730 .def("reset_ret_ty", &ir::module::reset_ret_ty)
731 .def("set_instr_metadata", [](ir::module *self, const std::string &name, ir::value *value) {
732 const auto metadatas = self->get_metadatas();
733 auto it = metadatas.find(name);
734 if (it != metadatas.end())
735 if (auto *instr = dynamic_cast<ir::instruction*>(value)) {
736 instr->set_metadata(it->second.first, it->second.second);
737 }
738 })
739 .def_property_readonly("builder", &ir::module::get_builder, ret::reference);
740
741 using eattr = ir::attribute_kind_t;
742 py::enum_<eattr>(m, "attribute_kind")
743 .value("readonly", eattr::readonly)
744 .value("writeonly", eattr::writeonly)
745 .value("noalias", eattr::noalias)
746 .value("aligned", eattr::aligned)
747 .value("multiple_of", eattr::multiple_of)
748 .value("retune", eattr::retune)
749 .value("not_implemented", eattr::not_implemented);
750
751 py::class_<ir::attribute>(m, "attribute")
752 .def(py::init<eattr, int>())
753 .def_property_readonly("value", &ir::attribute::get_value);
754
755 py::class_<ir::function>(m, "function")
756 .def_property_readonly("args", &ir::function::args)
757 .def_property_readonly("attrs", &ir::function::attrs)
758 .def("set_is_kernel", &ir::function::set_is_kernel)
759 .def("add_attr", &ir::function::add_attr)
760 .def("has_attr", &ir::function::has_attr)
761 .def("get_attrs", &ir::function::get_attributes);
762
763 py::class_<ir::argument, ir::value>(m, "argument")
764 .def_property_readonly("parent", &ir::argument::get_parent, ret::reference)
765 .def_property_readonly("arg_no", &ir::argument::get_arg_no);
766
767 py::class_<ir::basic_block, ir::value>(m, "basic_block")
768 .def("create", &ir::basic_block::create, ret::reference, py::arg(), py::arg(), py::arg() = nullptr)
769 .def("get_predecessors", &ir::basic_block::get_predecessors, ret::reference)
770 .def("get_first_non_phi", [](ir::basic_block *self) -> ir::instruction* {
771 ir::basic_block::iterator it = self->get_first_non_phi();
772 if (it == self->end())
773 return nullptr;
774 return *it;
775 }, ret::reference)
776 .def_property_readonly("parent", &ir::basic_block::get_parent, ret::reference);
777
778 py::class_<ir::builder::iterator>(m, "bb_iterator");
779
780 py::class_<ir::builder>(m, "builder", py::dynamic_attr())
781 .def(py::init<ir::context &>())
782 // getters
783 .def_property_readonly("context", &ir::builder::get_context, ret::reference)
784 // control flow
785 .def("call", &ir::builder::create_call, ret::reference)
786 .def("launch", &ir::builder::create_launch, ret::reference)
787 .def("br", &ir::builder::create_br, ret::reference)
788 .def("cond_br", &ir::builder::create_cond_br, ret::reference)
789 .def("ret_void", &ir::builder::create_ret_void, ret::reference)
790 .def("ret", &ir::builder::create_ret, ret::reference)
791 // insertion block/point, insert points are represented as (*bb, *instr)
792 .def("get_insert_block", &ir::builder::get_insert_block, ret::reference)
793 .def("set_insert_block", (void (ir::builder::*)(ir::basic_block *)) & ir::builder::set_insert_point)
794 .def("get_insert_point", [](ir::builder *self) {
795 ir::basic_block *bb = self->get_insert_block();
796 ir::basic_block::iterator it = self->get_insert_point();
797 ir::instruction *instr = it == bb->end() ? nullptr : *it;
798 return std::make_pair(bb, instr);
799 }, ret::reference)
800 .def("set_insert_point", [](ir::builder *self, std::pair<ir::basic_block*, ir::instruction*> pt) {
801 ir::basic_block *bb = pt.first;
802 ir::instruction *instr = pt.second;
803 if (instr) {
804 if (bb != instr->get_parent())
805 throw std::runtime_error("invalid insertion point, instr not in bb");
806 self->set_insert_point(instr);
807 } else {
808 assert(bb);
809 self->set_insert_point(bb);
810 }
811 })
812 // Constants
813 .def("get_int1", &ir::builder::get_int1, ret::reference)
814 .def("get_int32", [](ir::builder *self, int32_t v) { return self->get_int32((uint32_t)v); }, ret::reference)
815 .def("get_uint32", &ir::builder::get_int32, ret::reference)
816 .def("get_int64", [](ir::builder *self, int64_t v) { return self->get_int64((uint64_t)v); }, ret::reference)
817 .def("get_uint64", &ir::builder::get_int64, ret::reference)
818 .def("get_float16", &ir::builder::get_float16, ret::reference)
819 .def("get_float32", &ir::builder::get_float32, ret::reference)
820 .def("get_range", &ir::builder::get_range, ret::reference)
821 // Types
822 .def("get_void_ty", &ir::builder::get_void_ty, ret::reference)
823 .def("get_int1_ty", &ir::builder::get_int1_ty, ret::reference)
824 .def("get_int8_ty", &ir::builder::get_int8_ty, ret::reference)
825 .def("get_int16_ty", &ir::builder::get_int16_ty, ret::reference)
826 .def("get_int32_ty", &ir::builder::get_int32_ty, ret::reference)
827 .def("get_int64_ty", &ir::builder::get_int64_ty, ret::reference)
828 .def("get_fp8_ty", &ir::builder::get_fp8_ty, ret::reference)
829 .def("get_half_ty", &ir::builder::get_half_ty, ret::reference)
830 .def("get_bf16_ty", &ir::builder::get_bf16_ty, ret::reference)
831 .def("get_float_ty", &ir::builder::get_float_ty, ret::reference)
832 .def("get_double_ty", &ir::builder::get_double_ty, ret::reference)
833 // terminator instructions
834 .def("create_br", &ir::builder::create_br, ret::reference)
835 .def("create_cond_br", &ir::builder::create_cond_br, ret::reference)
836 .def("create_ret_void", &ir::builder::create_ret_void, ret::reference)
837 // Dequantize instructions
838 .def("create_dequantize", &ir::builder::create_dequantize, ret::reference)
839 // Cast instructions
840 .def("create_bitcast", &ir::builder::create_bitcast, ret::reference)
841 .def("create_cast", &ir::builder::create_cast, ret::reference)
842 .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
843 .def("create_si_to_fp", &ir::builder::create_si_to_fp, ret::reference)
844 .def("create_ui_to_fp", &ir::builder::create_ui_to_fp, ret::reference)
845 .def("create_fp_to_si", &ir::builder::create_fp_to_si, ret::reference)
846 .def("create_fp_to_ui", &ir::builder::create_fp_to_ui, ret::reference)
847 .def("create_fp_ext", &ir::builder::create_fp_ext, ret::reference)
848 .def("create_fp_trunc", &ir::builder::create_fp_trunc, ret::reference)
849 .def("create_int_cast", &ir::builder::create_int_cast, ret::reference)
850 .def("create_downcast", &ir::builder::create_downcast, ret::reference)
851 .def("create_int_to_ptr", &ir::builder::create_int_to_ptr, ret::reference)
852 .def("create_ptr_to_int", &ir::builder::create_ptr_to_int, ret::reference)
853 // phi
854 .def("create_phi", &ir::builder::create_phi, ret::reference)
855 // Binary instructions
856 .def("create_insert_nuwnswb_binop", &ir::builder::create_insert_nuwnswb_binop, ret::reference)
857 .def("create_fmul", &ir::builder::create_fmul, ret::reference)
858 .def("create_fdiv", &ir::builder::create_fdiv, ret::reference)
859 .def("create_frem", &ir::builder::create_frem, ret::reference)
860 .def("create_fadd", &ir::builder::create_fadd, ret::reference)
861 .def("create_fsub", &ir::builder::create_fsub, ret::reference)
862 .def("create_mul", &ir::builder::create_mul, ret::reference,
863 py::arg("lhs"), py::arg("rhs"),
864 py::arg("has_nuw")=false, py::arg("has_nsw")=false)
865 .def("create_sdiv", &ir::builder::create_sdiv, ret::reference)
866 .def("create_udiv", &ir::builder::create_udiv, ret::reference)
867 .def("create_srem", &ir::builder::create_srem, ret::reference)
868 .def("create_urem", &ir::builder::create_urem, ret::reference)
869 .def("create_add", &ir::builder::create_add, ret::reference,
870 py::arg("lhs"), py::arg("rhs"),
871 py::arg("has_nuw")=false, py::arg("has_nsw")=false)
872 .def("create_sub", &ir::builder::create_sub, ret::reference,
873 py::arg("lhs"), py::arg("rhs"),
874 py::arg("has_nuw")=false, py::arg("has_nsw")=false)
875 .def("create_shl", &ir::builder::create_shl, ret::reference,
876 py::arg("lhs"), py::arg("rhs"),
877 py::arg("has_nuw")=false, py::arg("has_nsw")=false)
878 .def("create_lshr", &ir::builder::create_lshr, ret::reference,
879 py::arg("lhs"), py::arg("rhs"),
880 py::arg("has_nuw")=false, py::arg("has_nsw")=false)
881 .def("create_ashr", &ir::builder::create_ashr, ret::reference,
882 py::arg("lhs"), py::arg("rhs"),
883 py::arg("has_nuw")=false, py::arg("has_nsw")=false)
884 // GEP
885 .def("create_gep", &ir::builder::create_gep, ret::reference)
886 // Comparison (int)
887 .def("create_icmp", &ir::builder::create_icmp, ret::reference)
888 .def("create_icmpSLE", &ir::builder::create_icmpSLE, ret::reference)
889 .def("create_icmpSLT", &ir::builder::create_icmpSLT, ret::reference)
890 .def("create_icmpSGE", &ir::builder::create_icmpSGE, ret::reference)
891 .def("create_icmpSGT", &ir::builder::create_icmpSGT, ret::reference)
892 .def("create_icmpULE", &ir::builder::create_icmpULE, ret::reference)
893 .def("create_icmpULT", &ir::builder::create_icmpULT, ret::reference)
894 .def("create_icmpUGE", &ir::builder::create_icmpUGE, ret::reference)
895 .def("create_icmpUGT", &ir::builder::create_icmpUGT, ret::reference)
896 .def("create_icmpEQ", &ir::builder::create_icmpEQ, ret::reference)
897 .def("create_icmpNE", &ir::builder::create_icmpNE, ret::reference)
898 // Comparison (float)
899 .def("create_fcmp", &ir::builder::create_fcmp, ret::reference)
900 .def("create_fcmpOLT", &ir::builder::create_fcmpOLT, ret::reference)
901 .def("create_fcmpOGT", &ir::builder::create_fcmpOGT, ret::reference)
902 .def("create_fcmpOLE", &ir::builder::create_fcmpOLE, ret::reference)
903 .def("create_fcmpOGE", &ir::builder::create_fcmpOGE, ret::reference)
904 .def("create_fcmpOEQ", &ir::builder::create_fcmpOEQ, ret::reference)
905 .def("create_fcmpONE", &ir::builder::create_fcmpONE, ret::reference)
906 .def("create_fcmpULT", &ir::builder::create_fcmpULT, ret::reference)
907 .def("create_fcmpUGT", &ir::builder::create_fcmpUGT, ret::reference)
908 .def("create_fcmpULE", &ir::builder::create_fcmpULE, ret::reference)
909 .def("create_fcmpUGE", &ir::builder::create_fcmpUGE, ret::reference)
910 .def("create_fcmpUEQ", &ir::builder::create_fcmpUEQ, ret::reference)
911 .def("create_fcmpUNE", &ir::builder::create_fcmpUNE, ret::reference)
912 // Logical
913 .def("create_and", &ir::builder::create_and, ret::reference)
914 .def("create_xor", &ir::builder::create_xor, ret::reference)
915 .def("create_or", &ir::builder::create_or, ret::reference)
916 // Input/Output
917 .def("create_load", &ir::builder::create_load, ret::reference)
918 .def("create_store", &ir::builder::create_store, ret::reference)
919 .def("create_masked_load", &ir::builder::create_masked_load, ret::reference)
920 .def("create_masked_store", &ir::builder::create_masked_store, ret::reference)
921 // Block instruction
922 .def("create_splat", &ir::builder::create_splat, ret::reference)
923 .def("create_reshape", &ir::builder::create_reshape, ret::reference)
924 .def("create_cat", &ir::builder::create_cat, ret::reference)
925 .def("create_broadcast", &ir::builder::create_broadcast, ret::reference)
926 // atomic
927 .def("create_atomic_cas", &ir::builder::create_atomic_cas, ret::reference)
928 .def("create_atomic_rmw", &ir::builder::create_atomic_rmw, ret::reference)
929 // Utilities
930 .def("create_clock", &ir::builder::create_clock, ret::reference)
931 .def("create_globaltimer", &ir::builder::create_globaltimer, ret::reference)
932 // Extern instruction
933 .def("create_extern_elementwise", &ir::builder::create_extern_elementwise, ret::reference)
934 // Built-in instruction
935 .def("create_get_program_id", &ir::builder::create_get_program_id, ret::reference)
936 .def("create_get_num_programs", &ir::builder::create_get_num_programs, ret::reference)
937 .def("create_exp", &ir::builder::create_exp, ret::reference)
938 .def("create_cos", &ir::builder::create_cos, ret::reference)
939 .def("create_sin", &ir::builder::create_sin, ret::reference)
940 .def("create_log", &ir::builder::create_log, ret::reference)
941 .def("create_dot", &ir::builder::create_dot, ret::reference)
942 .def("create_trans", &ir::builder::create_trans, ret::reference)
943 .def("create_sqrt", &ir::builder::create_sqrt, ret::reference)
944 .def("create_reduce", &ir::builder::create_reduce, ret::reference)
945 .def("create_select", &ir::builder::create_select, ret::reference)
946 // struct
947 .def("insert_value", &ir::builder::create_insert_value, ret::reference)
948 .def("extract_value", &ir::builder::create_extract_value, ret::reference)
949 // Intrinsics
950 // These have no place in the IR, and hopefully they can be removed at some point
951 .def("create_umulhi", &ir::builder::create_umulhi, ret::reference)
952 .def("create_copy_to_shared", &ir::builder::create_copy_to_shared, ret::reference)
953 .def("create_masked_load_async", &ir::builder::create_masked_load_async, ret::reference)
954 .def("create_copy_from_shared", &ir::builder::create_copy_from_shared, ret::reference)
955 .def("create_barrier", &ir::builder::create_barrier, ret::reference)
956 .def("create_async_wait", &ir::builder::create_async_wait, ret::reference)
957 .def("create_prefetch_s", &ir::builder::create_prefetch_s, ret::reference);
958}
959
960void init_triton(py::module &m) {
961 py::module subm = m.def_submodule("triton");
962 init_triton_codegen(std::move(subm.def_submodule("code_gen")));
963 init_triton_runtime(std::move(subm.def_submodule("runtime")));
964 init_triton_ir(std::move(subm.def_submodule("ir")));
965}
966