1#pragma once
2
3#ifndef _TRITON_IR_INSTRUCTIONS_H_
4#define _TRITON_IR_INSTRUCTIONS_H_
5
6#include <vector>
7#include <map>
8#include "triton/ir/enums.h"
9#include "triton/ir/constant.h"
10#include "triton/ir/value.h"
11#include "triton/ir/type.h"
12#include "triton/ir/metadata.h"
13#include "triton/ir/visitor.h"
14
15#define _TRITON_DEFINE_CLONE(name) \
16 ir::instruction* clone_impl() const { return new name(*this); }
17
18#define _TRITON_DEFINE_ACCEPT(name) \
19 void accept(visitor* v) { v->visit_ ## name (this); }
20
21namespace triton{
22namespace ir{
23
24class constant_int;
25class constant;
26class make_range;
27class basic_block;
28class context;
29class visitor;
30
31//===----------------------------------------------------------------------===//
32// instruction classes
33//===----------------------------------------------------------------------===//
34
35class result_reference;
36
37
38class instruction: public user{
39public:
40 virtual std::string repr_impl() const = 0;
41
42private:
43 virtual ir::instruction* clone_impl() const = 0;
44
45protected:
46 // constructors
47 instruction(type *ty, value_id_t ity, unsigned num_ops,
48 const std::string &name = "", instruction *next = nullptr);
49
50public:
51 // parent
52 void set_parent(basic_block *block) { parent_ = block; }
53 const basic_block *get_parent() const { return parent_; }
54 basic_block *get_parent() { return parent_; }
55 void erase_from_parent();
56 // helpers
57 bool has_tile_result_or_op();
58 // repr
59 std::string repr() const { return repr_impl(); }
60 // metadata
61 void set_metadata(ir::metadata::kind_t kind,
62 std::vector<unsigned> value) { metadatas_[kind] = value;}
63 std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];}
64 // cloning
65 ir::instruction* clone() {
66 ir::instruction* res = clone_impl();
67// for(auto it = op_begin(); it != op_end(); it++)
68// (*it)->add_use(res);
69 res->parent_ = nullptr;
70 res->users_.clear();
71 return res;
72 }
73 // instruction id
74 value_id_t get_id() const { return id_; }
75
76 void print(std::ostream &os);
77
78private:
79 basic_block *parent_;
80 std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_;
81 value_id_t id_;
82};
83
84//===----------------------------------------------------------------------===//
85// call_inst classes
86//===----------------------------------------------------------------------===//
87
88class call_inst: public instruction {
89private:
90 std::string repr_impl() const;
91 call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next);
92
93public:
94 static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "", instruction *next = nullptr);
95 ir::function* get_fn() { return fn_; }
96
97 _TRITON_DEFINE_CLONE(call_inst)
98 _TRITON_DEFINE_ACCEPT(call_inst)
99
100private:
101 ir::function* fn_;
102};
103
104class launch_inst: public instruction {
105private:
106 std::string repr_impl() const { return "launch"; }
107 launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
108 const std::string &name = "", instruction *next = nullptr);
109
110public:
111 static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps,
112 const std::string& name = "", instruction* next = nullptr);
113
114 ir::function* get_fn();
115 std::vector<ir::value*> get_values();
116 std::vector<ir::value*> get_grid();
117 ir::value* get_num_warps();
118
119
120 _TRITON_DEFINE_CLONE(launch_inst)
121 _TRITON_DEFINE_ACCEPT(launch_inst)
122
123private:
124 unsigned val_begin;
125 unsigned val_end;
126 unsigned grid_begin;
127 unsigned grid_end;
128};
129
130//===----------------------------------------------------------------------===//
131// phi_node classes
132//===----------------------------------------------------------------------===//
133
134class phi_node: public instruction {
135private:
136 phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next);
137 std::string repr_impl() const { return "phi"; }
138
139public:
140 void set_incoming_value(unsigned i, value *v);
141 void set_incoming_block(unsigned i, basic_block *block);
142 value *get_value_for_block(basic_block *block);
143 value *get_incoming_value(unsigned i) { return get_operand(i); }
144 basic_block *get_incoming_block(unsigned i) { return blocks_[i]; }
145 unsigned get_num_incoming() { return get_num_operands(); }
146 void add_incoming(value *v, basic_block *block);
147
148 // Type
149 void set_type(type *ty) { ty_ = ty; }
150
151 // Factory methods
152 static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "", instruction *next = nullptr);
153
154 _TRITON_DEFINE_CLONE(phi_node)
155 _TRITON_DEFINE_ACCEPT(phi_node)
156
157private:
158 unsigned num_reserved_;
159 std::vector<basic_block*> blocks_;
160};
161
162//===----------------------------------------------------------------------===//
163// binary_operator classes
164//===----------------------------------------------------------------------===//
165
166class binary_operator: public instruction {
167public:
168 typedef binary_op_t op_t;
169
170private:
171 std::string repr_impl() const;
172
173protected:
174 // Constructors
175 binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next);
176
177public:
178 // Get operand
179 binary_op_t get_op() const { return op_; }
180
181 // Bool
182 bool is_terminator() const;
183 bool is_binary_op() const;
184 bool is_int_div_rem() const;
185 bool is_shift() const;
186 bool is_cast() const;
187 bool is_int_mult() const;
188 bool is_int_add_sub() const;
189 bool is_int_div() const;
190 bool is_int_rem() const;
191 bool is_shl() const;
192 bool is_shr() const;
193
194 // Approx
195 void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; }
196 bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; }
197
198 // Wraps
199 void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; }
200 void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; }
201
202 // Factory methods
203 static binary_operator *create(binary_op_t op, value *lhs, value *rhs,
204 const std::string &name = "", instruction *next = nullptr);
205// static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr);
206// static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr);
207// static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr);
208
209 _TRITON_DEFINE_CLONE(binary_operator)
210 _TRITON_DEFINE_ACCEPT(binary_operator)
211
212public:
213 binary_op_t op_;
214 bool has_no_unsigned_wrap_;
215 bool has_no_signed_wrap_;
216
217 bool fdiv_ieee_rnd_;
218};
219
220
221//===----------------------------------------------------------------------===//
222// cmp_inst classes
223//===----------------------------------------------------------------------===//
224
225class cmp_inst: public instruction{
226public:
227 typedef cmp_pred_t pred_t;
228
229private:
230 std::string repr_impl() const;
231
232protected:
233 cmp_inst(type *ty, value_id_t id, cmp_pred_t pred,
234 value *lhs, value *rhs, const std::string &name, instruction *next);
235 static bool is_fp_predicate(cmp_pred_t pred);
236 static bool is_int_predicate(cmp_pred_t pred);
237 static type* make_cmp_result_type(type *ty);
238
239public:
240 cmp_pred_t get_pred() const { return pred_; }
241
242private:
243 cmp_pred_t pred_;
244};
245
246class icmp_inst: public cmp_inst {
247 icmp_inst(type *ty, cmp_pred_t pred,
248 value *lhs, value *rhs, const std::string &name, instruction *next);
249
250public:
251 static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
252 const std::string &name = "", instruction *next = nullptr);
253 _TRITON_DEFINE_CLONE(icmp_inst)
254 _TRITON_DEFINE_ACCEPT(icmp_inst)
255};
256
257class fcmp_inst: public cmp_inst {
258 fcmp_inst(type *ty, cmp_pred_t pred,
259 value *lhs, value *rhs, const std::string &name, instruction *next);
260
261public:
262 static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs,
263 const std::string &name = "", instruction *next = nullptr);
264 _TRITON_DEFINE_CLONE(fcmp_inst)
265 _TRITON_DEFINE_ACCEPT(fcmp_inst)
266};
267
268//===----------------------------------------------------------------------===//
269// unary_inst classes
270//===----------------------------------------------------------------------===//
271
272class unary_inst: public instruction {
273protected:
274 unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next);
275};
276
277//===----------------------------------------------------------------------===//
278// dequantize_inst classes
279//===----------------------------------------------------------------------===//
280
281class dequantize_inst: public instruction{
282private:
283 std::string repr_impl() const override { return "dequantize"; }
284
285protected:
286 dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next);
287
288public:
289 static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty,
290 const std::string &name = "", instruction *next = nullptr);
291
292 _TRITON_DEFINE_CLONE(dequantize_inst)
293 _TRITON_DEFINE_ACCEPT(dequantize_inst)
294};
295
296//===----------------------------------------------------------------------===//
297// cast_inst classes
298//===----------------------------------------------------------------------===//
299
300class cast_inst: public unary_inst{
301private:
302 std::string repr_impl() const;
303
304protected:
305 cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op)
306 : unary_inst(ty, id, v, name, next), op_(op) { }
307
308private:
309 static bool is_valid(cast_op_t op, value *arg, type *ty);
310
311public:
312 // accessors
313 cast_op_t get_op() const { return op_; }
314
315 // factory methods
316 static cast_inst *create(cast_op_t op, value *arg, type *ty,
317 const std::string &name = "", instruction *next = nullptr);
318 static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed,
319 const std::string &name = "", instruction *next = nullptr);
320
321 _TRITON_DEFINE_ACCEPT(cast_inst)
322
323private:
324 cast_op_t op_;
325};
326
327#define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \
328class name : public cast_inst { \
329 _TRITON_DEFINE_CLONE(name) \
330 friend class cast_inst; \
331 name(type *ty, value *v, const std::string &name, instruction *next) \
332 : cast_inst(ty, id, v, name, next, op){ } \
333};
334
335TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc)
336TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt)
337TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt)
338TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc)
339TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt)
340TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP)
341TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP)
342TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI)
343TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI)
344TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt)
345TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr)
346TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast)
347TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast)
348
349//===----------------------------------------------------------------------===//
350// terminator_inst classes
351//===----------------------------------------------------------------------===//
352
353class terminator_inst: public instruction{
354 using instruction::instruction;
355};
356
357// return instruction
358class return_inst: public terminator_inst {
359private:
360 std::string repr_impl() const { return "ret"; }
361 return_inst(context &ctx, value *ret_val, instruction *next);
362
363public:
364 // accessors
365 value *get_return_value()
366 { return get_num_operands() ? get_operand(0) : nullptr; }
367
368 unsigned get_num_successors() const { return 0; }
369
370 // factory methods
371 static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr);
372
373 _TRITON_DEFINE_CLONE(return_inst)
374 _TRITON_DEFINE_ACCEPT(return_inst)
375};
376
377// base branch instruction
378class branch_inst: public terminator_inst{
379private:
380 std::string repr_impl() const { return "br"; }
381
382protected:
383 using terminator_inst::terminator_inst;
384
385public:
386 static branch_inst* create(basic_block *dest,
387 instruction *next = nullptr);
388 static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest,
389 instruction *next = nullptr);
390};
391
392// conditional branch
393class cond_branch_inst: public branch_inst {
394private:
395 friend class branch_inst;
396 cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next);
397
398public:
399 basic_block *get_true_dest() { return (basic_block*)get_operand(0); }
400 basic_block *get_false_dest() { return (basic_block*)get_operand(1); }
401 value *get_cond() { return get_operand(2); }
402 _TRITON_DEFINE_CLONE(cond_branch_inst)
403 _TRITON_DEFINE_ACCEPT(cond_branch_inst)
404};
405
406// unconditional branch
407class uncond_branch_inst: public branch_inst {
408private:
409 friend class branch_inst;
410 uncond_branch_inst(basic_block *dst, instruction *next);
411
412public:
413 basic_block *get_dest() { return (basic_block*)get_operand(0); }
414 _TRITON_DEFINE_CLONE(uncond_branch_inst)
415 _TRITON_DEFINE_ACCEPT(uncond_branch_inst)
416};
417
418
419//===----------------------------------------------------------------------===//
420// getelementptr_inst classes
421//===----------------------------------------------------------------------===//
422
423class getelementptr_inst: public instruction {
424private:
425 std::string repr_impl() const { return "getelementptr"; }
426 getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next);
427
428private:
429 static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx);
430 static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx);
431 static type *get_indexed_type(type *ty, const std::vector<value*> &idx);
432
433public:
434 // accessors
435 type *get_source_elt_ty() { return source_elt_ty; }
436 op_iterator idx_begin() { return op_begin() + 1; }
437 op_iterator idx_end() { return op_end(); }
438 value *get_pointer_operand() { return *op_begin(); }
439
440 // factory methods
441 static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx,
442 const std::string &name = "", instruction *next = nullptr);
443 _TRITON_DEFINE_CLONE(getelementptr_inst)
444 _TRITON_DEFINE_ACCEPT(getelementptr_inst)
445
446private:
447 type *source_elt_ty;
448 type *res_elt_ty;
449};
450
451//===----------------------------------------------------------------------===//
452// load_inst/store_inst classes
453//===----------------------------------------------------------------------===//
454
455class io_inst: public instruction {
456public:
457
458 enum EVICTION_POLICY : uint32_t {
459 NORMAL=0,
460 EVICT_FIRST,
461 EVICT_LAST,
462 };
463
464protected:
465 io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
466 const std::string &name = "", instruction *next = nullptr);
467
468 std::string get_eviction_policy_repr() const {
469 if (eviction_ == EVICT_FIRST) return ".L1::evict_first";
470 if (eviction_ == EVICT_LAST) return ".L2::evict_last";
471 return "";
472 }
473
474public:
475 // accessors
476 value *get_pointer_operand() { return get_operand(0); }
477 EVICTION_POLICY get_eviction_policy() const { return eviction_; }
478
479protected:
480 EVICTION_POLICY eviction_;
481};
482
483// load
484class load_inst: public io_inst {
485public:
486 enum CACHE_MODIFIER : uint32_t {
487 NONE=0,
488 CA,
489 CG,
490 };
491
492
493 CACHE_MODIFIER get_cache_modifier() const { return cache_; }
494 bool get_is_volatile() const { return is_volatile_; }
495
496protected:
497 load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction,
498 bool is_volatile,
499 const std::string &name = "", instruction *next = nullptr);
500 std::string get_cache_modifier_repr() const {
501 if (cache_ == CA) return ".ca";
502 if (cache_ == CG) return ".cg";
503 return "";
504 }
505 CACHE_MODIFIER cache_;
506
507 std::string get_volatile_repr() {
508 return is_volatile_ ? ".volatile" : "";
509 }
510 bool is_volatile_;
511
512private:
513 static type *get_pointee_type(type *ty);
514};
515
516// unmasked load
517class unmasked_load_inst: public load_inst {
518private:
519 std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); }
520 unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next);
521
522public:
523 static unmasked_load_inst* create(value *ptr,
524 CACHE_MODIFIER cache, EVICTION_POLICY eviction,
525 bool is_volatile,
526 const std::string &name = "",
527 instruction *next = nullptr);
528 _TRITON_DEFINE_CLONE(unmasked_load_inst)
529 _TRITON_DEFINE_ACCEPT(unmasked_load_inst)
530};
531
532// masked load
533class masked_load_inst: public load_inst {
534private:
535 std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); }
536 masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile,
537 const std::string &name, instruction *next);
538
539public:
540 // accessors
541 value *get_mask_operand() { return get_operand(1); }
542 value *get_false_value_operand() { return get_operand(2); }
543 // factory method
544 static masked_load_inst* create(value *ptr, value *mask, value *false_value,
545 CACHE_MODIFIER cache, EVICTION_POLICY eviction,
546 bool is_volatile,
547 const std::string &name = "",
548 instruction *next = nullptr);
549 _TRITON_DEFINE_CLONE(masked_load_inst)
550 _TRITON_DEFINE_ACCEPT(masked_load_inst)
551};
552
553// masked load async
554class masked_load_async_inst: public load_inst {
555private:
556 std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); }
557 masked_load_async_inst(value *ptr, value *mask, value *false_value,
558 CACHE_MODIFIER cache, EVICTION_POLICY eviction,
559 const std::string &name, instruction *next);
560
561public:
562 // accessors
563 value *get_mask_operand() { return get_operand(1); }
564 value *get_false_value_operand() { return get_operand(2); }
565 // factory method
566 static masked_load_async_inst* create(value *ptr, value *mask, value *false_value,
567 load_inst::CACHE_MODIFIER cache,
568 EVICTION_POLICY eviction,
569 const std::string &name = "",
570 instruction *next = nullptr);
571 _TRITON_DEFINE_CLONE(masked_load_async_inst)
572 _TRITON_DEFINE_ACCEPT(masked_load_async_inst)
573};
574
575
576
577// store
578class store_inst: public io_inst {
579protected:
580 store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction,
581 const std::string &name = "", instruction *next = nullptr);
582
583public:
584 value *get_value_operand() { return get_operand(1); }
585};
586
587// unmasked_store
588class unmasked_store_inst: public store_inst{
589private:
590 std::string repr_impl() const { return "unmasked_store"; }
591 unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next);
592
593public:
594 // factory method
595 static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction,
596 const std::string &name = "",
597 instruction *next = nullptr);
598 _TRITON_DEFINE_CLONE(unmasked_store_inst)
599 _TRITON_DEFINE_ACCEPT(unmasked_store_inst)
600};
601
602class masked_store_inst: public store_inst{
603private:
604 std::string repr_impl() const { return "masked_store"; }
605 masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
606 const std::string &name, instruction *next);
607
608public:
609 // accessors
610 value *get_mask_operand() { return get_operand(2); }
611 // factory method
612 static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction,
613 const std::string &name = "",
614 instruction *next = nullptr);
615 _TRITON_DEFINE_CLONE(masked_store_inst)
616 _TRITON_DEFINE_ACCEPT(masked_store_inst)
617};
618
619//===----------------------------------------------------------------------===//
620// struct classes
621//===----------------------------------------------------------------------===//
622
623// insert_value
624
625class insert_value_inst: public instruction {
626private:
627 std::string repr_impl() const { return "insertvalue"; }
628 insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next);
629
630public:
631 static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "", instruction *next = nullptr);
632 size_t get_idx() { return idx_; }
633 _TRITON_DEFINE_CLONE(insert_value_inst)
634 _TRITON_DEFINE_ACCEPT(insert_value_inst)
635
636private:
637 size_t idx_;
638};
639
640// extract_value
641
642class extract_value_inst: public instruction {
643private:
644 std::string repr_impl() const { return "extractvalue"; }
645 extract_value_inst(value *val, size_t idx, const std::string &name, instruction *next);
646
647public:
648 static extract_value_inst* create(value *val, size_t idx, const std::string &name = "", instruction *next = nullptr);
649 size_t get_idx() { return idx_; }
650 _TRITON_DEFINE_CLONE(extract_value_inst)
651 _TRITON_DEFINE_ACCEPT(extract_value_inst)
652
653private:
654 size_t idx_;
655};
656
657//===----------------------------------------------------------------------===//
658// retile_inst classes
659//===----------------------------------------------------------------------===//
660
661// cat
662
663class cat_inst: public instruction {
664private:
665 std::string repr_impl() const { return "cat"; }
666 cat_inst(value *x, value *y, const std::string &name, instruction *next);
667
668public:
669 static instruction* create(value *lhs, value *rhs,
670 const std::string &name = "",
671 instruction *next = nullptr);
672 _TRITON_DEFINE_CLONE(cat_inst)
673 _TRITON_DEFINE_ACCEPT(cat_inst)
674};
675
676// retile
677
678class retile_inst: public unary_inst {
679protected:
680 retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next);
681};
682
683// reshape
684
685class reshape_inst: public retile_inst {
686private:
687 using retile_inst::retile_inst;
688 std::string repr_impl() const { return "reshape"; }
689
690public:
691 static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
692 const std::string &name = "", instruction *next = nullptr);
693 _TRITON_DEFINE_CLONE(reshape_inst)
694 _TRITON_DEFINE_ACCEPT(reshape_inst)
695};
696
697// splat
698
699class splat_inst: public retile_inst {
700private:
701 using retile_inst::retile_inst;
702 std::string repr_impl() const { return "splat"; }
703
704public:
705 static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
706 const std::string &name = "", instruction *next = nullptr);
707 _TRITON_DEFINE_CLONE(splat_inst)
708 _TRITON_DEFINE_ACCEPT(splat_inst)
709};
710
711// broadcast
712
713class broadcast_inst: public retile_inst {
714private:
715 using retile_inst::retile_inst;
716 std::string repr_impl() const { return "broadcast"; }
717
718public:
719 static instruction* create(value *arg, const type::block_shapes_t &shape_suffix,
720 const std::string &name = "", instruction *next = nullptr);
721 _TRITON_DEFINE_CLONE(broadcast_inst)
722 _TRITON_DEFINE_ACCEPT(broadcast_inst)
723};
724
725
726// downcast
727
728class downcast_inst: public unary_inst {
729private:
730 using unary_inst::unary_inst;
731 std::string repr_impl() const { return "downcast"; }
732
733public:
734 static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
735 _TRITON_DEFINE_CLONE(downcast_inst)
736 _TRITON_DEFINE_ACCEPT(downcast_inst)
737};
738
739//===----------------------------------------------------------------------===//
740// builtin_inst classes
741//===----------------------------------------------------------------------===//
742
743class builtin_inst: public instruction{
744protected:
745 using instruction::instruction;
746};
747
748class get_program_id_inst: public builtin_inst {
749private:
750 get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
751 std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")"; }
752
753public:
754 static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
755 unsigned get_axis() const { return axis_; }
756 _TRITON_DEFINE_CLONE(get_program_id_inst)
757 _TRITON_DEFINE_ACCEPT(get_program_id_inst)
758
759private:
760 unsigned axis_;
761};
762
763class get_num_programs_inst: public builtin_inst {
764private:
765 get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next);
766 std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")"; }
767
768public:
769 static instruction* create(context &ctx, unsigned axis, const std::string &name = "", instruction *next = nullptr);
770 unsigned get_axis() const { return axis_; }
771 _TRITON_DEFINE_CLONE(get_num_programs_inst)
772 _TRITON_DEFINE_ACCEPT(get_num_programs_inst)
773
774private:
775 unsigned axis_;
776};
777
778
779class atomic_inst: public io_inst {
780public:
781 using io_inst::io_inst;
782 atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next):
783 io_inst(ty, id, num_ops, NORMAL, name, next) {}
784};
785
786class atomic_rmw_inst: public atomic_inst {
787private:
788 atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
789 std::string repr_impl() const { return "atomic_rmw"; }
790 _TRITON_DEFINE_CLONE(atomic_rmw_inst)
791 _TRITON_DEFINE_ACCEPT(atomic_rmw_inst)
792
793public:
794 static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "", instruction *next = nullptr);
795 atomic_rmw_op_t get_op() { return op_; }
796
797private:
798 atomic_rmw_op_t op_;
799};
800
801class atomic_cas_inst: public atomic_inst {
802private:
803 atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next);
804 std::string repr_impl() const { return "atomic_cas"; }
805 _TRITON_DEFINE_CLONE(atomic_cas_inst)
806 _TRITON_DEFINE_ACCEPT(atomic_cas_inst)
807
808public:
809 static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "", instruction *next = nullptr);
810};
811
812class umulhi_inst: public builtin_inst {
813private:
814 umulhi_inst(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
815 std::string repr_impl() const { return "umulhi"; }
816 _TRITON_DEFINE_CLONE(umulhi_inst)
817 _TRITON_DEFINE_ACCEPT(umulhi_inst)
818
819public:
820 static instruction* create(value *lhs, value *rhs, const std::string &name = "", instruction *next = nullptr);
821};
822
823class exp_inst: public builtin_inst {
824private:
825 exp_inst(value *val, const std::string &name = "", instruction *next = nullptr);
826 std::string repr_impl() const { return "exp"; }
827 _TRITON_DEFINE_CLONE(exp_inst)
828 _TRITON_DEFINE_ACCEPT(exp_inst)
829
830public:
831 static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
832};
833
834class cos_inst: public builtin_inst {
835private:
836 cos_inst(value *val, const std::string &name = "", instruction *next = nullptr);
837 std::string repr_impl() const { return "cos"; }
838 _TRITON_DEFINE_CLONE(cos_inst)
839 _TRITON_DEFINE_ACCEPT(cos_inst)
840
841public:
842 static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
843};
844
845class sin_inst: public builtin_inst {
846private:
847 sin_inst(value *val, const std::string &name = "", instruction *next = nullptr);
848 std::string repr_impl() const { return "sin"; }
849 _TRITON_DEFINE_CLONE(sin_inst)
850 _TRITON_DEFINE_ACCEPT(sin_inst)
851
852public:
853 static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
854};
855
856class log_inst: public builtin_inst {
857private:
858 log_inst(value *val, const std::string &name = "", instruction *next = nullptr);
859 std::string repr_impl() const { return "log"; }
860 _TRITON_DEFINE_CLONE(log_inst)
861 _TRITON_DEFINE_ACCEPT(log_inst)
862
863public:
864 static instruction* create(value *val, const std::string &name = "", instruction *next = nullptr);
865};
866
867
868class dot_inst: public builtin_inst {
869public:
870 enum TransT { NoTrans, Trans };
871 enum DataType {
872 FP8, FP16, BF16, TF32, FP32,
873 INT1, INT4, INT8, INT32,
874 UNKNOWN,
875 };
876
877private:
878 dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next);
879 std::string repr_impl() const { return "dot"; }
880
881public:
882 bool is_prefetched() const { return is_prefetched_; }
883 void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; }
884 bool allow_tf32() const { return allow_tf32_; }
885 bool is_trans_a() const { return AT_ == Trans; }
886 bool is_trans_b() const { return BT_ == Trans; }
887
888public:
889 static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
890 static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
891 static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
892 static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
893 static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "", instruction *next = nullptr);
894 _TRITON_DEFINE_CLONE(dot_inst)
895 _TRITON_DEFINE_ACCEPT(dot_inst)
896
897private:
898 bool is_prefetched_ = false;
899 bool allow_tf32_ = false;
900 DataType C_type_ = DataType::FP32;
901 DataType A_type_ = DataType::FP16;
902 DataType B_type_ = DataType::FP16;
903 TransT AT_;
904 TransT BT_;
905};
906
907//class outer_inst: public builtin_inst {
908//private:
909// outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next);
910//public:
911// static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr);
912//};
913
914class trans_inst: public builtin_inst {
915public:
916 ir::type* get_res_ty(ir::type* in, std::vector<int> perm);
917 std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm);
918
919private:
920 trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next);
921 std::string repr_impl() const { return "trans"; }
922
923public:
924 static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "", instruction *next = nullptr);
925 const std::vector<int> get_perm() const;
926 _TRITON_DEFINE_CLONE(trans_inst)
927 _TRITON_DEFINE_ACCEPT(trans_inst)
928
929private:
930 std::vector<int> perm_;
931};
932
933class sqrt_inst: public builtin_inst {
934private:
935 sqrt_inst(value *arg, const std::string& name, instruction* next);
936 std::string repr_impl() const { return "sqrt"; }
937public:
938 static instruction* create(value *arg, const std::string &name = "", instruction *next = nullptr);
939 _TRITON_DEFINE_CLONE(sqrt_inst)
940 _TRITON_DEFINE_ACCEPT(sqrt_inst)
941};
942
943class reduce_inst: public builtin_inst {
944public:
945 enum op_t{
946 ADD, SUB, MAX, MIN, UMAX, UMIN,
947 ARGMAX, ARGMIN, ARGUMAX, ARGUMIN,
948 FADD, FSUB, FMAX, FMIN,
949 ARGFMAX, ARGFMIN,
950 XOR
951 };
952
953private:
954 static type* get_res_type(value *arg, unsigned axis);
955 static std::string to_str(op_t op);
956
957private:
958 reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next);
959 std::string repr_impl() const { return "reduce"; }
960 _TRITON_DEFINE_CLONE(reduce_inst)
961 _TRITON_DEFINE_ACCEPT(reduce_inst)
962
963public:
964 static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "", instruction *next = nullptr);
965 unsigned get_axis() const { return axis_; }
966 op_t get_op() const { return op_; }
967 bool with_index() const {
968 return with_index_ops_.find(op_) != with_index_ops_.end();
969 }
970
971private:
972 const static inline std::set<op_t> with_index_ops_ = {
973 op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX,
974 op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN};
975 unsigned axis_;
976 op_t op_;
977};
978
979
980class select_inst: public builtin_inst {
981private:
982 select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next);
983 std::string repr_impl() const { return "select"; }
984 _TRITON_DEFINE_CLONE(select_inst)
985 _TRITON_DEFINE_ACCEPT(select_inst)
986
987public:
988 static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "", instruction *next = nullptr);
989 value* get_pred_op() { return get_operand(0); }
990 value* get_if_value_op() { return get_operand(1); }
991 value* get_else_value_op() { return get_operand(2); }
992};
993
994//===----------------------------------------------------------------------===//
995// intrinsics classes
996//===----------------------------------------------------------------------===//
997
998
999class copy_to_shared_inst: public unary_inst{
1000private:
1001 using unary_inst::unary_inst;
1002 std::string repr_impl() const { return "copy_to_shared"; }
1003
1004public:
1005 static copy_to_shared_inst* create(value *arg, const std::string &name = "",
1006 instruction *next = nullptr);
1007 _TRITON_DEFINE_CLONE(copy_to_shared_inst)
1008 _TRITON_DEFINE_ACCEPT(copy_to_shared_inst)
1009};
1010
1011class copy_from_shared_inst: public unary_inst{
1012private:
1013 using unary_inst::unary_inst;
1014 std::string repr_impl() const { return "copy_from_shared"; }
1015
1016public:
1017 static copy_from_shared_inst* create(value *arg, const std::string &name = "",
1018 instruction *next = nullptr);
1019 _TRITON_DEFINE_CLONE(copy_from_shared_inst)
1020 _TRITON_DEFINE_ACCEPT(copy_from_shared_inst)
1021};
1022
1023class cvt_layout_inst: public unary_inst {
1024private:
1025 using unary_inst::unary_inst;
1026 std::string repr_impl() const { return "cvt_layout_inst"; }
1027
1028public:
1029 static cvt_layout_inst* create(value *arg, const std::string &name = "", instruction *next = nullptr);
1030 _TRITON_DEFINE_CLONE(cvt_layout_inst)
1031 _TRITON_DEFINE_ACCEPT(cvt_layout_inst)
1032};
1033
1034class barrier_inst: public instruction{
1035private:
1036 barrier_inst(context &ctx, const std::string &name, instruction *next);
1037 std::string repr_impl() const { return "barrier"; }
1038 _TRITON_DEFINE_CLONE(barrier_inst)
1039 _TRITON_DEFINE_ACCEPT(barrier_inst)
1040
1041public:
1042 static barrier_inst* create(context &ctx, const std::string &name = "",
1043 instruction *next = nullptr);
1044};
1045
1046class async_wait_inst: public instruction{
1047private:
1048 async_wait_inst(context &ctx, int N, const std::string &name, instruction *next);
1049 std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; }
1050 _TRITON_DEFINE_CLONE(async_wait_inst)
1051 _TRITON_DEFINE_ACCEPT(async_wait_inst)
1052
1053public:
1054 static async_wait_inst* create(context &ctx, int N,
1055 const std::string &name = "", instruction *next = nullptr);
1056 int get_N() { return N_; }
1057 void set_N(int n) { N_ = n; }
1058
1059private:
1060 int N_;
1061};
1062
1063class prefetch_s_inst : public instruction {
1064 std::string repr_impl() const { return "prefetch_s"; }
1065 _TRITON_DEFINE_CLONE(prefetch_s_inst)
1066 _TRITON_DEFINE_ACCEPT(prefetch_s_inst)
1067
1068 /// inc_: 0->first, 1->latch
1069 int inc_ = 0;
1070public:
1071 prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next)
1072 : instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) {
1073 set_operand(0, arg);
1074 }
1075 int get_inc() const { return inc_; }
1076 static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "",
1077 instruction *next=nullptr);
1078};
1079
1080/* constant range */
1081class make_range: public instruction{
1082 make_range(type *ty, constant_int* first, constant_int* last);
1083 std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]"; }
1084 _TRITON_DEFINE_CLONE(make_range)
1085 _TRITON_DEFINE_ACCEPT(make_range)
1086
1087public:
1088 static make_range *create(constant_int *first, constant_int *last);
1089 const constant_int* get_first() const;
1090 const constant_int* get_last() const;
1091
1092private:
1093 constant_int* first_;
1094 constant_int* last_;
1095};
1096
1097/* timing utilities */
1098class clock_inst: public instruction{
1099 clock_inst(context &ctx, const std::string &name, instruction *next);
1100 std::string repr_impl() const { return "clock"; }
1101 _TRITON_DEFINE_CLONE(clock_inst)
1102 _TRITON_DEFINE_ACCEPT(clock_inst)
1103
1104public:
1105 static clock_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
1106};
1107
1108class globaltimer_inst: public instruction{
1109 globaltimer_inst(context &ctx, const std::string &name, instruction *next);
1110 std::string repr_impl() const { return "globaltimer"; }
1111 _TRITON_DEFINE_CLONE(globaltimer_inst)
1112 _TRITON_DEFINE_ACCEPT(globaltimer_inst)
1113
1114public:
1115 static globaltimer_inst* create(context &ctx, const std::string &name = "", instruction *next = nullptr);
1116};
1117
1118class extern_elementwise_inst : public instruction {
1119 extern_elementwise_inst(context &ctx, const std::vector<value *> &args,
1120 type *dst_ty, const std::string &lib_name,
1121 const std::string &extern_lib_path,
1122 const std::string &symbol_name,
1123 const std::string &name, instruction *next);
1124 std::string repr_impl() const { return "extern_elementwise"; }
1125 _TRITON_DEFINE_CLONE(extern_elementwise_inst)
1126 _TRITON_DEFINE_ACCEPT(extern_elementwise_inst)
1127
1128 public:
1129 static extern_elementwise_inst *create(
1130 context &ctx, const std::vector<value *> &args, type *dst_ty,
1131 const std::string &lib_name = "", const std::string &lib_path = "",
1132 const std::string &symbol_name = "", const std::string &name = "",
1133 instruction *next = nullptr);
1134
1135 const std::string &get_lib_name() const { return lib_name_; }
1136 const std::string &get_lib_path() const { return lib_path_; }
1137 const std::string &get_symbol_name() const { return symbol_name_; }
1138
1139 private:
1140 std::string lib_name_;
1141 std::string lib_path_;
1142 std::string symbol_name_;
1143};
1144}
1145}
1146
1147#endif
1148