1 | #pragma once |
2 | |
3 | #ifndef _TRITON_IR_INSTRUCTIONS_H_ |
4 | #define _TRITON_IR_INSTRUCTIONS_H_ |
5 | |
6 | #include <vector> |
7 | #include <map> |
8 | #include "triton/ir/enums.h" |
9 | #include "triton/ir/constant.h" |
10 | #include "triton/ir/value.h" |
11 | #include "triton/ir/type.h" |
12 | #include "triton/ir/metadata.h" |
13 | #include "triton/ir/visitor.h" |
14 | |
15 | #define _TRITON_DEFINE_CLONE(name) \ |
16 | ir::instruction* () const { return new name(*this); } |
17 | |
18 | #define _TRITON_DEFINE_ACCEPT(name) \ |
19 | void (visitor* v) { v->visit_ ## name (this); } |
20 | |
21 | namespace triton{ |
22 | namespace ir{ |
23 | |
24 | class constant_int; |
25 | class constant; |
26 | class make_range; |
27 | class basic_block; |
28 | class context; |
29 | class visitor; |
30 | |
31 | //===----------------------------------------------------------------------===// |
32 | // instruction classes |
33 | //===----------------------------------------------------------------------===// |
34 | |
35 | class result_reference; |
36 | |
37 | |
38 | class instruction: public user{ |
39 | public: |
40 | virtual std::string repr_impl() const = 0; |
41 | |
42 | private: |
43 | virtual ir::instruction* clone_impl() const = 0; |
44 | |
45 | protected: |
46 | // constructors |
47 | instruction(type *ty, value_id_t ity, unsigned num_ops, |
48 | const std::string &name = "" , instruction *next = nullptr); |
49 | |
50 | public: |
51 | // parent |
52 | void set_parent(basic_block *block) { parent_ = block; } |
53 | const basic_block *get_parent() const { return parent_; } |
54 | basic_block *get_parent() { return parent_; } |
55 | void erase_from_parent(); |
56 | // helpers |
57 | bool has_tile_result_or_op(); |
58 | // repr |
59 | std::string repr() const { return repr_impl(); } |
60 | // metadata |
61 | void set_metadata(ir::metadata::kind_t kind, |
62 | std::vector<unsigned> value) { metadatas_[kind] = value;} |
63 | std::vector<unsigned> get_metadata(ir::metadata::kind_t kind) { return metadatas_[kind];} |
64 | // cloning |
65 | ir::instruction* clone() { |
66 | ir::instruction* res = clone_impl(); |
67 | // for(auto it = op_begin(); it != op_end(); it++) |
68 | // (*it)->add_use(res); |
69 | res->parent_ = nullptr; |
70 | res->users_.clear(); |
71 | return res; |
72 | } |
73 | // instruction id |
74 | value_id_t get_id() const { return id_; } |
75 | |
76 | void print(std::ostream &os); |
77 | |
78 | private: |
79 | basic_block *parent_; |
80 | std::map<ir::metadata::kind_t, std::vector<unsigned>> metadatas_; |
81 | value_id_t id_; |
82 | }; |
83 | |
84 | //===----------------------------------------------------------------------===// |
85 | // call_inst classes |
86 | //===----------------------------------------------------------------------===// |
87 | |
88 | class call_inst: public instruction { |
89 | private: |
90 | std::string repr_impl() const; |
91 | call_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::string& name, instruction* next); |
92 | |
93 | public: |
94 | static call_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::string &name = "" , instruction *next = nullptr); |
95 | ir::function* get_fn() { return fn_; } |
96 | |
97 | _TRITON_DEFINE_CLONE(call_inst) |
98 | _TRITON_DEFINE_ACCEPT(call_inst) |
99 | |
100 | private: |
101 | ir::function* fn_; |
102 | }; |
103 | |
104 | class launch_inst: public instruction { |
105 | private: |
106 | std::string repr_impl() const { return "launch" ; } |
107 | launch_inst(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, |
108 | const std::string &name = "" , instruction *next = nullptr); |
109 | |
110 | public: |
111 | static launch_inst* create(ir::function* fn, const std::vector<ir::value*>& values, const std::vector<ir::value*>& grid, ir::value* num_warps, |
112 | const std::string& name = "" , instruction* next = nullptr); |
113 | |
114 | ir::function* get_fn(); |
115 | std::vector<ir::value*> get_values(); |
116 | std::vector<ir::value*> get_grid(); |
117 | ir::value* get_num_warps(); |
118 | |
119 | |
120 | _TRITON_DEFINE_CLONE(launch_inst) |
121 | _TRITON_DEFINE_ACCEPT(launch_inst) |
122 | |
123 | private: |
124 | unsigned val_begin; |
125 | unsigned val_end; |
126 | unsigned grid_begin; |
127 | unsigned grid_end; |
128 | }; |
129 | |
130 | //===----------------------------------------------------------------------===// |
131 | // phi_node classes |
132 | //===----------------------------------------------------------------------===// |
133 | |
134 | class phi_node: public instruction { |
135 | private: |
136 | phi_node(type *ty, unsigned num_reserved, const std::string &name, instruction *next); |
137 | std::string repr_impl() const { return "phi" ; } |
138 | |
139 | public: |
140 | void set_incoming_value(unsigned i, value *v); |
141 | void set_incoming_block(unsigned i, basic_block *block); |
142 | value *get_value_for_block(basic_block *block); |
143 | value *get_incoming_value(unsigned i) { return get_operand(i); } |
144 | basic_block *get_incoming_block(unsigned i) { return blocks_[i]; } |
145 | unsigned get_num_incoming() { return get_num_operands(); } |
146 | void add_incoming(value *v, basic_block *block); |
147 | |
148 | // Type |
149 | void set_type(type *ty) { ty_ = ty; } |
150 | |
151 | // Factory methods |
152 | static phi_node* create(type *ty, unsigned num_reserved, const std::string &name = "" , instruction *next = nullptr); |
153 | |
154 | _TRITON_DEFINE_CLONE(phi_node) |
155 | _TRITON_DEFINE_ACCEPT(phi_node) |
156 | |
157 | private: |
158 | unsigned num_reserved_; |
159 | std::vector<basic_block*> blocks_; |
160 | }; |
161 | |
162 | //===----------------------------------------------------------------------===// |
163 | // binary_operator classes |
164 | //===----------------------------------------------------------------------===// |
165 | |
166 | class binary_operator: public instruction { |
167 | public: |
168 | typedef binary_op_t op_t; |
169 | |
170 | private: |
171 | std::string repr_impl() const; |
172 | |
173 | protected: |
174 | // Constructors |
175 | binary_operator(binary_op_t op, value *lhs, value *rhs, type *ty, const std::string &name, instruction *next); |
176 | |
177 | public: |
178 | // Get operand |
179 | binary_op_t get_op() const { return op_; } |
180 | |
181 | // Bool |
182 | bool is_terminator() const; |
183 | bool is_binary_op() const; |
184 | bool is_int_div_rem() const; |
185 | bool is_shift() const; |
186 | bool is_cast() const; |
187 | bool is_int_mult() const; |
188 | bool is_int_add_sub() const; |
189 | bool is_int_div() const; |
190 | bool is_int_rem() const; |
191 | bool is_shl() const; |
192 | bool is_shr() const; |
193 | |
194 | // Approx |
195 | void set_fdiv_ieee_rounding(bool rnd) { fdiv_ieee_rnd_ = rnd; } |
196 | bool get_fdiv_ieee_rounding() { return fdiv_ieee_rnd_; } |
197 | |
198 | // Wraps |
199 | void set_has_no_unsigned_wrap(bool b = true) { has_no_unsigned_wrap_ = b; } |
200 | void set_has_no_signed_wrap(bool b = true) { has_no_signed_wrap_ = b; } |
201 | |
202 | // Factory methods |
203 | static binary_operator *create(binary_op_t op, value *lhs, value *rhs, |
204 | const std::string &name = "" , instruction *next = nullptr); |
205 | // static binary_operator *create_fneg(value *arg, const std::string &name = "", instruction *next = nullptr); |
206 | // static binary_operator *create_neg(value *arg, const std::string &name = "", instruction *next = nullptr); |
207 | // static binary_operator *create_not(value *arg, const std::string &name = "", instruction *next = nullptr); |
208 | |
209 | _TRITON_DEFINE_CLONE(binary_operator) |
210 | _TRITON_DEFINE_ACCEPT(binary_operator) |
211 | |
212 | public: |
213 | binary_op_t op_; |
214 | bool has_no_unsigned_wrap_; |
215 | bool has_no_signed_wrap_; |
216 | |
217 | bool fdiv_ieee_rnd_; |
218 | }; |
219 | |
220 | |
221 | //===----------------------------------------------------------------------===// |
222 | // cmp_inst classes |
223 | //===----------------------------------------------------------------------===// |
224 | |
225 | class cmp_inst: public instruction{ |
226 | public: |
227 | typedef cmp_pred_t pred_t; |
228 | |
229 | private: |
230 | std::string repr_impl() const; |
231 | |
232 | protected: |
233 | cmp_inst(type *ty, value_id_t id, cmp_pred_t pred, |
234 | value *lhs, value *rhs, const std::string &name, instruction *next); |
235 | static bool is_fp_predicate(cmp_pred_t pred); |
236 | static bool is_int_predicate(cmp_pred_t pred); |
237 | static type* make_cmp_result_type(type *ty); |
238 | |
239 | public: |
240 | cmp_pred_t get_pred() const { return pred_; } |
241 | |
242 | private: |
243 | cmp_pred_t pred_; |
244 | }; |
245 | |
246 | class icmp_inst: public cmp_inst { |
247 | icmp_inst(type *ty, cmp_pred_t pred, |
248 | value *lhs, value *rhs, const std::string &name, instruction *next); |
249 | |
250 | public: |
251 | static icmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, |
252 | const std::string &name = "" , instruction *next = nullptr); |
253 | _TRITON_DEFINE_CLONE(icmp_inst) |
254 | _TRITON_DEFINE_ACCEPT(icmp_inst) |
255 | }; |
256 | |
257 | class fcmp_inst: public cmp_inst { |
258 | fcmp_inst(type *ty, cmp_pred_t pred, |
259 | value *lhs, value *rhs, const std::string &name, instruction *next); |
260 | |
261 | public: |
262 | static fcmp_inst* create(cmp_pred_t pred, value *lhs, value *rhs, |
263 | const std::string &name = "" , instruction *next = nullptr); |
264 | _TRITON_DEFINE_CLONE(fcmp_inst) |
265 | _TRITON_DEFINE_ACCEPT(fcmp_inst) |
266 | }; |
267 | |
268 | //===----------------------------------------------------------------------===// |
269 | // unary_inst classes |
270 | //===----------------------------------------------------------------------===// |
271 | |
272 | class unary_inst: public instruction { |
273 | protected: |
274 | unary_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next); |
275 | }; |
276 | |
277 | //===----------------------------------------------------------------------===// |
278 | // dequantize_inst classes |
279 | //===----------------------------------------------------------------------===// |
280 | |
281 | class dequantize_inst: public instruction{ |
282 | private: |
283 | std::string repr_impl() const override { return "dequantize" ; } |
284 | |
285 | protected: |
286 | dequantize_inst(type *ty, value *v, value *scale, value *shift, const std::string &name, instruction *next); |
287 | |
288 | public: |
289 | static dequantize_inst *create(value *arg, value *scale, value *shift, type *ty, |
290 | const std::string &name = "" , instruction *next = nullptr); |
291 | |
292 | _TRITON_DEFINE_CLONE(dequantize_inst) |
293 | _TRITON_DEFINE_ACCEPT(dequantize_inst) |
294 | }; |
295 | |
296 | //===----------------------------------------------------------------------===// |
297 | // cast_inst classes |
298 | //===----------------------------------------------------------------------===// |
299 | |
300 | class cast_inst: public unary_inst{ |
301 | private: |
302 | std::string repr_impl() const; |
303 | |
304 | protected: |
305 | cast_inst(type *ty, value_id_t id, value *v, const std::string &name, instruction *next, cast_op_t op) |
306 | : unary_inst(ty, id, v, name, next), op_(op) { } |
307 | |
308 | private: |
309 | static bool is_valid(cast_op_t op, value *arg, type *ty); |
310 | |
311 | public: |
312 | // accessors |
313 | cast_op_t get_op() const { return op_; } |
314 | |
315 | // factory methods |
316 | static cast_inst *create(cast_op_t op, value *arg, type *ty, |
317 | const std::string &name = "" , instruction *next = nullptr); |
318 | static cast_inst *create_integer_cast(value *arg, type *ty, bool is_signed, |
319 | const std::string &name = "" , instruction *next = nullptr); |
320 | |
321 | _TRITON_DEFINE_ACCEPT(cast_inst) |
322 | |
323 | private: |
324 | cast_op_t op_; |
325 | }; |
326 | |
327 | #define TRITON_IR_DECLARE_CAST_INST_SIMPL(name, id, op) \ |
328 | class name : public cast_inst { \ |
329 | _TRITON_DEFINE_CLONE(name) \ |
330 | friend class cast_inst; \ |
331 | name(type *ty, value *v, const std::string &name, instruction *next) \ |
332 | : cast_inst(ty, id, v, name, next, op){ } \ |
333 | }; |
334 | |
335 | TRITON_IR_DECLARE_CAST_INST_SIMPL(trunc_inst, INST_CAST_TRUNC, cast_op_t::Trunc) |
336 | TRITON_IR_DECLARE_CAST_INST_SIMPL(z_ext_inst, INST_CAST_ZEXT, cast_op_t::ZExt) |
337 | TRITON_IR_DECLARE_CAST_INST_SIMPL(s_ext_inst, INST_CAST_SEXT, cast_op_t::SExt) |
338 | TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_trunc_inst, INST_CAST_FP_TRUNC, cast_op_t::FPTrunc) |
339 | TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_ext_inst, INST_CAST_FP_EXT, cast_op_t::FPExt) |
340 | TRITON_IR_DECLARE_CAST_INST_SIMPL(ui_to_fp_inst, INST_CAST_UI_TO_FP, cast_op_t::UIToFP) |
341 | TRITON_IR_DECLARE_CAST_INST_SIMPL(si_to_fp_inst, INST_CAST_SI_TO_FP, cast_op_t::SIToFP) |
342 | TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_ui_inst, INST_CAST_FP_TO_UI, cast_op_t::FPToUI) |
343 | TRITON_IR_DECLARE_CAST_INST_SIMPL(fp_to_si_inst, INST_CAST_FP_TO_SI, cast_op_t::FPToSI) |
344 | TRITON_IR_DECLARE_CAST_INST_SIMPL(ptr_to_int_inst, INST_CAST_PTR_TO_INT, cast_op_t::PtrToInt) |
345 | TRITON_IR_DECLARE_CAST_INST_SIMPL(int_to_ptr_inst, INST_CAST_INT_TO_PTR, cast_op_t::IntToPtr) |
346 | TRITON_IR_DECLARE_CAST_INST_SIMPL(bit_cast_inst, INST_CAST_BIT_CAST, cast_op_t::BitCast) |
347 | TRITON_IR_DECLARE_CAST_INST_SIMPL(addr_space_cast_inst, INST_CAST_ADDR_SPACE_CAST, cast_op_t::AddrSpaceCast) |
348 | |
349 | //===----------------------------------------------------------------------===// |
350 | // terminator_inst classes |
351 | //===----------------------------------------------------------------------===// |
352 | |
353 | class terminator_inst: public instruction{ |
354 | using instruction::instruction; |
355 | }; |
356 | |
357 | // return instruction |
358 | class return_inst: public terminator_inst { |
359 | private: |
360 | std::string repr_impl() const { return "ret" ; } |
361 | return_inst(context &ctx, value *ret_val, instruction *next); |
362 | |
363 | public: |
364 | // accessors |
365 | value *get_return_value() |
366 | { return get_num_operands() ? get_operand(0) : nullptr; } |
367 | |
368 | unsigned get_num_successors() const { return 0; } |
369 | |
370 | // factory methods |
371 | static return_inst* create(context &ctx, value *ret_val = nullptr, instruction *next = nullptr); |
372 | |
373 | _TRITON_DEFINE_CLONE(return_inst) |
374 | _TRITON_DEFINE_ACCEPT(return_inst) |
375 | }; |
376 | |
377 | // base branch instruction |
378 | class branch_inst: public terminator_inst{ |
379 | private: |
380 | std::string repr_impl() const { return "br" ; } |
381 | |
382 | protected: |
383 | using terminator_inst::terminator_inst; |
384 | |
385 | public: |
386 | static branch_inst* create(basic_block *dest, |
387 | instruction *next = nullptr); |
388 | static branch_inst* create(value *cond, basic_block *if_dest, basic_block *else_dest, |
389 | instruction *next = nullptr); |
390 | }; |
391 | |
392 | // conditional branch |
393 | class cond_branch_inst: public branch_inst { |
394 | private: |
395 | friend class branch_inst; |
396 | cond_branch_inst(basic_block *if_dst, basic_block *else_dst, value *cond, instruction *next); |
397 | |
398 | public: |
399 | basic_block *get_true_dest() { return (basic_block*)get_operand(0); } |
400 | basic_block *get_false_dest() { return (basic_block*)get_operand(1); } |
401 | value *get_cond() { return get_operand(2); } |
402 | _TRITON_DEFINE_CLONE(cond_branch_inst) |
403 | _TRITON_DEFINE_ACCEPT(cond_branch_inst) |
404 | }; |
405 | |
406 | // unconditional branch |
407 | class uncond_branch_inst: public branch_inst { |
408 | private: |
409 | friend class branch_inst; |
410 | uncond_branch_inst(basic_block *dst, instruction *next); |
411 | |
412 | public: |
413 | basic_block *get_dest() { return (basic_block*)get_operand(0); } |
414 | _TRITON_DEFINE_CLONE(uncond_branch_inst) |
415 | _TRITON_DEFINE_ACCEPT(uncond_branch_inst) |
416 | }; |
417 | |
418 | |
419 | //===----------------------------------------------------------------------===// |
420 | // getelementptr_inst classes |
421 | //===----------------------------------------------------------------------===// |
422 | |
423 | class getelementptr_inst: public instruction { |
424 | private: |
425 | std::string repr_impl() const { return "getelementptr" ; } |
426 | getelementptr_inst(type *pointee_ty, value *ptr, const std::vector<value*> &idx, const std::string &name, instruction *next); |
427 | |
428 | private: |
429 | static type *get_return_type(type *ty, value *ptr, const std::vector<value*> &idx); |
430 | static type *get_indexed_type_impl(type *ty, const std::vector<value *> &idx); |
431 | static type *get_indexed_type(type *ty, const std::vector<value*> &idx); |
432 | |
433 | public: |
434 | // accessors |
435 | type *get_source_elt_ty() { return source_elt_ty; } |
436 | op_iterator idx_begin() { return op_begin() + 1; } |
437 | op_iterator idx_end() { return op_end(); } |
438 | value *get_pointer_operand() { return *op_begin(); } |
439 | |
440 | // factory methods |
441 | static getelementptr_inst* create(value *ptr, const std::vector<value*> &idx, |
442 | const std::string &name = "" , instruction *next = nullptr); |
443 | _TRITON_DEFINE_CLONE(getelementptr_inst) |
444 | _TRITON_DEFINE_ACCEPT(getelementptr_inst) |
445 | |
446 | private: |
447 | type *source_elt_ty; |
448 | type *res_elt_ty; |
449 | }; |
450 | |
451 | //===----------------------------------------------------------------------===// |
452 | // load_inst/store_inst classes |
453 | //===----------------------------------------------------------------------===// |
454 | |
455 | class io_inst: public instruction { |
456 | public: |
457 | |
458 | enum EVICTION_POLICY : uint32_t { |
459 | NORMAL=0, |
460 | EVICT_FIRST, |
461 | EVICT_LAST, |
462 | }; |
463 | |
464 | protected: |
465 | io_inst(type *ty, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, |
466 | const std::string &name = "" , instruction *next = nullptr); |
467 | |
468 | std::string get_eviction_policy_repr() const { |
469 | if (eviction_ == EVICT_FIRST) return ".L1::evict_first" ; |
470 | if (eviction_ == EVICT_LAST) return ".L2::evict_last" ; |
471 | return "" ; |
472 | } |
473 | |
474 | public: |
475 | // accessors |
476 | value *get_pointer_operand() { return get_operand(0); } |
477 | EVICTION_POLICY get_eviction_policy() const { return eviction_; } |
478 | |
479 | protected: |
480 | EVICTION_POLICY eviction_; |
481 | }; |
482 | |
483 | // load |
484 | class load_inst: public io_inst { |
485 | public: |
486 | enum CACHE_MODIFIER : uint32_t { |
487 | NONE=0, |
488 | CA, |
489 | CG, |
490 | }; |
491 | |
492 | |
493 | CACHE_MODIFIER get_cache_modifier() const { return cache_; } |
494 | bool get_is_volatile() const { return is_volatile_; } |
495 | |
496 | protected: |
497 | load_inst(value *ptr, value_id_t id, unsigned num_ops, CACHE_MODIFIER cache, EVICTION_POLICY eviction, |
498 | bool is_volatile, |
499 | const std::string &name = "" , instruction *next = nullptr); |
500 | std::string get_cache_modifier_repr() const { |
501 | if (cache_ == CA) return ".ca" ; |
502 | if (cache_ == CG) return ".cg" ; |
503 | return "" ; |
504 | } |
505 | CACHE_MODIFIER cache_; |
506 | |
507 | std::string get_volatile_repr() { |
508 | return is_volatile_ ? ".volatile" : "" ; |
509 | } |
510 | bool is_volatile_; |
511 | |
512 | private: |
513 | static type *get_pointee_type(type *ty); |
514 | }; |
515 | |
516 | // unmasked load |
517 | class unmasked_load_inst: public load_inst { |
518 | private: |
519 | std::string repr_impl() const { return "unmasked_load" + get_cache_modifier_repr(); } |
520 | unmasked_load_inst(value *ptr, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, const std::string &name, instruction *next); |
521 | |
522 | public: |
523 | static unmasked_load_inst* create(value *ptr, |
524 | CACHE_MODIFIER cache, EVICTION_POLICY eviction, |
525 | bool is_volatile, |
526 | const std::string &name = "" , |
527 | instruction *next = nullptr); |
528 | _TRITON_DEFINE_CLONE(unmasked_load_inst) |
529 | _TRITON_DEFINE_ACCEPT(unmasked_load_inst) |
530 | }; |
531 | |
532 | // masked load |
533 | class masked_load_inst: public load_inst { |
534 | private: |
535 | std::string repr_impl() const { return "masked_load" + get_cache_modifier_repr(); } |
536 | masked_load_inst(value *ptr, value *mask, value *false_value, load_inst::CACHE_MODIFIER cache, load_inst::EVICTION_POLICY eviction, bool is_volatile, |
537 | const std::string &name, instruction *next); |
538 | |
539 | public: |
540 | // accessors |
541 | value *get_mask_operand() { return get_operand(1); } |
542 | value *get_false_value_operand() { return get_operand(2); } |
543 | // factory method |
544 | static masked_load_inst* create(value *ptr, value *mask, value *false_value, |
545 | CACHE_MODIFIER cache, EVICTION_POLICY eviction, |
546 | bool is_volatile, |
547 | const std::string &name = "" , |
548 | instruction *next = nullptr); |
549 | _TRITON_DEFINE_CLONE(masked_load_inst) |
550 | _TRITON_DEFINE_ACCEPT(masked_load_inst) |
551 | }; |
552 | |
553 | // masked load async |
554 | class masked_load_async_inst: public load_inst { |
555 | private: |
556 | std::string repr_impl() const { return "masked_load_async" + get_cache_modifier_repr(); } |
557 | masked_load_async_inst(value *ptr, value *mask, value *false_value, |
558 | CACHE_MODIFIER cache, EVICTION_POLICY eviction, |
559 | const std::string &name, instruction *next); |
560 | |
561 | public: |
562 | // accessors |
563 | value *get_mask_operand() { return get_operand(1); } |
564 | value *get_false_value_operand() { return get_operand(2); } |
565 | // factory method |
566 | static masked_load_async_inst* create(value *ptr, value *mask, value *false_value, |
567 | load_inst::CACHE_MODIFIER cache, |
568 | EVICTION_POLICY eviction, |
569 | const std::string &name = "" , |
570 | instruction *next = nullptr); |
571 | _TRITON_DEFINE_CLONE(masked_load_async_inst) |
572 | _TRITON_DEFINE_ACCEPT(masked_load_async_inst) |
573 | }; |
574 | |
575 | |
576 | |
577 | // store |
578 | class store_inst: public io_inst { |
579 | protected: |
580 | store_inst(value *ptr, value_id_t id, unsigned num_ops, EVICTION_POLICY eviction, |
581 | const std::string &name = "" , instruction *next = nullptr); |
582 | |
583 | public: |
584 | value *get_value_operand() { return get_operand(1); } |
585 | }; |
586 | |
587 | // unmasked_store |
588 | class unmasked_store_inst: public store_inst{ |
589 | private: |
590 | std::string repr_impl() const { return "unmasked_store" ; } |
591 | unmasked_store_inst(value *ptr, value *v, EVICTION_POLICY eviction, const std::string &name, instruction *next); |
592 | |
593 | public: |
594 | // factory method |
595 | static unmasked_store_inst* create(value* ptr, value *v, EVICTION_POLICY eviction, |
596 | const std::string &name = "" , |
597 | instruction *next = nullptr); |
598 | _TRITON_DEFINE_CLONE(unmasked_store_inst) |
599 | _TRITON_DEFINE_ACCEPT(unmasked_store_inst) |
600 | }; |
601 | |
602 | class masked_store_inst: public store_inst{ |
603 | private: |
604 | std::string repr_impl() const { return "masked_store" ; } |
605 | masked_store_inst(value *ptr, value *v, value *mask, EVICTION_POLICY eviction, |
606 | const std::string &name, instruction *next); |
607 | |
608 | public: |
609 | // accessors |
610 | value *get_mask_operand() { return get_operand(2); } |
611 | // factory method |
612 | static masked_store_inst* create(value *ptr, value *v, value *mask, EVICTION_POLICY eviction, |
613 | const std::string &name = "" , |
614 | instruction *next = nullptr); |
615 | _TRITON_DEFINE_CLONE(masked_store_inst) |
616 | _TRITON_DEFINE_ACCEPT(masked_store_inst) |
617 | }; |
618 | |
619 | //===----------------------------------------------------------------------===// |
620 | // struct classes |
621 | //===----------------------------------------------------------------------===// |
622 | |
623 | // insert_value |
624 | |
625 | class insert_value_inst: public instruction { |
626 | private: |
627 | std::string repr_impl() const { return "insertvalue" ; } |
628 | insert_value_inst(value *val, value *elt, size_t idx, const std::string &name, instruction *next); |
629 | |
630 | public: |
631 | static insert_value_inst* create(value *val, value* elt, size_t idx, const std::string &name = "" , instruction *next = nullptr); |
632 | size_t get_idx() { return idx_; } |
633 | _TRITON_DEFINE_CLONE(insert_value_inst) |
634 | _TRITON_DEFINE_ACCEPT(insert_value_inst) |
635 | |
636 | private: |
637 | size_t idx_; |
638 | }; |
639 | |
640 | // extract_value |
641 | |
642 | class : public instruction { |
643 | private: |
644 | std::string () const { return "extractvalue" ; } |
645 | (value *val, size_t idx, const std::string &name, instruction *next); |
646 | |
647 | public: |
648 | static extract_value_inst* (value *val, size_t idx, const std::string &name = "" , instruction *next = nullptr); |
649 | size_t () { return idx_; } |
650 | _TRITON_DEFINE_CLONE(extract_value_inst) |
651 | _TRITON_DEFINE_ACCEPT(extract_value_inst) |
652 | |
653 | private: |
654 | size_t ; |
655 | }; |
656 | |
657 | //===----------------------------------------------------------------------===// |
658 | // retile_inst classes |
659 | //===----------------------------------------------------------------------===// |
660 | |
661 | // cat |
662 | |
663 | class cat_inst: public instruction { |
664 | private: |
665 | std::string repr_impl() const { return "cat" ; } |
666 | cat_inst(value *x, value *y, const std::string &name, instruction *next); |
667 | |
668 | public: |
669 | static instruction* create(value *lhs, value *rhs, |
670 | const std::string &name = "" , |
671 | instruction *next = nullptr); |
672 | _TRITON_DEFINE_CLONE(cat_inst) |
673 | _TRITON_DEFINE_ACCEPT(cat_inst) |
674 | }; |
675 | |
676 | // retile |
677 | |
678 | class retile_inst: public unary_inst { |
679 | protected: |
680 | retile_inst(value *arg, value_id_t id, const type::block_shapes_t &shapes, const std::string &name, instruction *next); |
681 | }; |
682 | |
683 | // reshape |
684 | |
685 | class reshape_inst: public retile_inst { |
686 | private: |
687 | using retile_inst::retile_inst; |
688 | std::string repr_impl() const { return "reshape" ; } |
689 | |
690 | public: |
691 | static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, |
692 | const std::string &name = "" , instruction *next = nullptr); |
693 | _TRITON_DEFINE_CLONE(reshape_inst) |
694 | _TRITON_DEFINE_ACCEPT(reshape_inst) |
695 | }; |
696 | |
697 | // splat |
698 | |
699 | class splat_inst: public retile_inst { |
700 | private: |
701 | using retile_inst::retile_inst; |
702 | std::string repr_impl() const { return "splat" ; } |
703 | |
704 | public: |
705 | static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, |
706 | const std::string &name = "" , instruction *next = nullptr); |
707 | _TRITON_DEFINE_CLONE(splat_inst) |
708 | _TRITON_DEFINE_ACCEPT(splat_inst) |
709 | }; |
710 | |
711 | // broadcast |
712 | |
713 | class broadcast_inst: public retile_inst { |
714 | private: |
715 | using retile_inst::retile_inst; |
716 | std::string repr_impl() const { return "broadcast" ; } |
717 | |
718 | public: |
719 | static instruction* create(value *arg, const type::block_shapes_t &shape_suffix, |
720 | const std::string &name = "" , instruction *next = nullptr); |
721 | _TRITON_DEFINE_CLONE(broadcast_inst) |
722 | _TRITON_DEFINE_ACCEPT(broadcast_inst) |
723 | }; |
724 | |
725 | |
726 | // downcast |
727 | |
728 | class downcast_inst: public unary_inst { |
729 | private: |
730 | using unary_inst::unary_inst; |
731 | std::string repr_impl() const { return "downcast" ; } |
732 | |
733 | public: |
734 | static instruction* create(value *arg, const std::string &name = "" , instruction *next = nullptr); |
735 | _TRITON_DEFINE_CLONE(downcast_inst) |
736 | _TRITON_DEFINE_ACCEPT(downcast_inst) |
737 | }; |
738 | |
739 | //===----------------------------------------------------------------------===// |
740 | // builtin_inst classes |
741 | //===----------------------------------------------------------------------===// |
742 | |
743 | class builtin_inst: public instruction{ |
744 | protected: |
745 | using instruction::instruction; |
746 | }; |
747 | |
748 | class get_program_id_inst: public builtin_inst { |
749 | private: |
750 | get_program_id_inst(type *ty, unsigned axis, const std::string &name, instruction *next); |
751 | std::string repr_impl() const { return "get_program_id(" + std::to_string(axis_) + ")" ; } |
752 | |
753 | public: |
754 | static instruction* create(context &ctx, unsigned axis, const std::string &name = "" , instruction *next = nullptr); |
755 | unsigned get_axis() const { return axis_; } |
756 | _TRITON_DEFINE_CLONE(get_program_id_inst) |
757 | _TRITON_DEFINE_ACCEPT(get_program_id_inst) |
758 | |
759 | private: |
760 | unsigned axis_; |
761 | }; |
762 | |
763 | class get_num_programs_inst: public builtin_inst { |
764 | private: |
765 | get_num_programs_inst(type *ty, unsigned axis, const std::string &name, instruction *next); |
766 | std::string repr_impl() const { return "get_num_programs(" + std::to_string(axis_) + ")" ; } |
767 | |
768 | public: |
769 | static instruction* create(context &ctx, unsigned axis, const std::string &name = "" , instruction *next = nullptr); |
770 | unsigned get_axis() const { return axis_; } |
771 | _TRITON_DEFINE_CLONE(get_num_programs_inst) |
772 | _TRITON_DEFINE_ACCEPT(get_num_programs_inst) |
773 | |
774 | private: |
775 | unsigned axis_; |
776 | }; |
777 | |
778 | |
779 | class atomic_inst: public io_inst { |
780 | public: |
781 | using io_inst::io_inst; |
782 | atomic_inst(type *ty, value_id_t id, unsigned num_ops, const std::string &name, instruction *next): |
783 | io_inst(ty, id, num_ops, NORMAL, name, next) {} |
784 | }; |
785 | |
786 | class atomic_rmw_inst: public atomic_inst { |
787 | private: |
788 | atomic_rmw_inst(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "" , instruction *next = nullptr); |
789 | std::string repr_impl() const { return "atomic_rmw" ; } |
790 | _TRITON_DEFINE_CLONE(atomic_rmw_inst) |
791 | _TRITON_DEFINE_ACCEPT(atomic_rmw_inst) |
792 | |
793 | public: |
794 | static instruction* create(atomic_rmw_op_t op, value *ptr, value *val, value *msk, const std::string &name = "" , instruction *next = nullptr); |
795 | atomic_rmw_op_t get_op() { return op_; } |
796 | |
797 | private: |
798 | atomic_rmw_op_t op_; |
799 | }; |
800 | |
801 | class atomic_cas_inst: public atomic_inst { |
802 | private: |
803 | atomic_cas_inst(value *ptr, value *cmp, value *val, const std::string &name, instruction *next); |
804 | std::string repr_impl() const { return "atomic_cas" ; } |
805 | _TRITON_DEFINE_CLONE(atomic_cas_inst) |
806 | _TRITON_DEFINE_ACCEPT(atomic_cas_inst) |
807 | |
808 | public: |
809 | static instruction* create(value *ptr, value *cmp, value *val, const std::string &name = "" , instruction *next = nullptr); |
810 | }; |
811 | |
812 | class umulhi_inst: public builtin_inst { |
813 | private: |
814 | umulhi_inst(value *lhs, value *rhs, const std::string &name = "" , instruction *next = nullptr); |
815 | std::string repr_impl() const { return "umulhi" ; } |
816 | _TRITON_DEFINE_CLONE(umulhi_inst) |
817 | _TRITON_DEFINE_ACCEPT(umulhi_inst) |
818 | |
819 | public: |
820 | static instruction* create(value *lhs, value *rhs, const std::string &name = "" , instruction *next = nullptr); |
821 | }; |
822 | |
823 | class exp_inst: public builtin_inst { |
824 | private: |
825 | exp_inst(value *val, const std::string &name = "" , instruction *next = nullptr); |
826 | std::string repr_impl() const { return "exp" ; } |
827 | _TRITON_DEFINE_CLONE(exp_inst) |
828 | _TRITON_DEFINE_ACCEPT(exp_inst) |
829 | |
830 | public: |
831 | static instruction* create(value *val, const std::string &name = "" , instruction *next = nullptr); |
832 | }; |
833 | |
834 | class cos_inst: public builtin_inst { |
835 | private: |
836 | cos_inst(value *val, const std::string &name = "" , instruction *next = nullptr); |
837 | std::string repr_impl() const { return "cos" ; } |
838 | _TRITON_DEFINE_CLONE(cos_inst) |
839 | _TRITON_DEFINE_ACCEPT(cos_inst) |
840 | |
841 | public: |
842 | static instruction* create(value *val, const std::string &name = "" , instruction *next = nullptr); |
843 | }; |
844 | |
845 | class sin_inst: public builtin_inst { |
846 | private: |
847 | sin_inst(value *val, const std::string &name = "" , instruction *next = nullptr); |
848 | std::string repr_impl() const { return "sin" ; } |
849 | _TRITON_DEFINE_CLONE(sin_inst) |
850 | _TRITON_DEFINE_ACCEPT(sin_inst) |
851 | |
852 | public: |
853 | static instruction* create(value *val, const std::string &name = "" , instruction *next = nullptr); |
854 | }; |
855 | |
856 | class log_inst: public builtin_inst { |
857 | private: |
858 | log_inst(value *val, const std::string &name = "" , instruction *next = nullptr); |
859 | std::string repr_impl() const { return "log" ; } |
860 | _TRITON_DEFINE_CLONE(log_inst) |
861 | _TRITON_DEFINE_ACCEPT(log_inst) |
862 | |
863 | public: |
864 | static instruction* create(value *val, const std::string &name = "" , instruction *next = nullptr); |
865 | }; |
866 | |
867 | |
868 | class dot_inst: public builtin_inst { |
869 | public: |
870 | enum TransT { NoTrans, Trans }; |
871 | enum DataType { |
872 | FP8, FP16, BF16, TF32, FP32, |
873 | INT1, INT4, INT8, INT32, |
874 | UNKNOWN, |
875 | }; |
876 | |
877 | private: |
878 | dot_inst(value *A, value *B, value *C, TransT AT, TransT BT, bool allow_tf32, const std::string &name, instruction *next); |
879 | std::string repr_impl() const { return "dot" ; } |
880 | |
881 | public: |
882 | bool is_prefetched() const { return is_prefetched_; } |
883 | void set_prefetched(bool is_prefetched) { is_prefetched_ = is_prefetched; } |
884 | bool allow_tf32() const { return allow_tf32_; } |
885 | bool is_trans_a() const { return AT_ == Trans; } |
886 | bool is_trans_b() const { return BT_ == Trans; } |
887 | |
888 | public: |
889 | static instruction *create(value *A, value *B, value *C, bool AT, bool BT, bool allow_tf32, const std::string &name = "" , instruction *next = nullptr); |
890 | static instruction* create_nn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "" , instruction *next = nullptr); |
891 | static instruction* create_nt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "" , instruction *next = nullptr); |
892 | static instruction* create_tn(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "" , instruction *next = nullptr); |
893 | static instruction* create_tt(value *A, value *B, value *C, bool allow_tf32, const std::string &name = "" , instruction *next = nullptr); |
894 | _TRITON_DEFINE_CLONE(dot_inst) |
895 | _TRITON_DEFINE_ACCEPT(dot_inst) |
896 | |
897 | private: |
898 | bool is_prefetched_ = false; |
899 | bool allow_tf32_ = false; |
900 | DataType C_type_ = DataType::FP32; |
901 | DataType A_type_ = DataType::FP16; |
902 | DataType B_type_ = DataType::FP16; |
903 | TransT AT_; |
904 | TransT BT_; |
905 | }; |
906 | |
907 | //class outer_inst: public builtin_inst { |
908 | //private: |
909 | // outer_inst(value *A, value *B, value *C, const std::string &name, instruction *next); |
910 | //public: |
911 | // static instruction* create(value *A, value *B, value *C, const std::string &name = "", instruction *next = nullptr); |
912 | //}; |
913 | |
914 | class trans_inst: public builtin_inst { |
915 | public: |
916 | ir::type* get_res_ty(ir::type* in, std::vector<int> perm); |
917 | std::vector<int> init_perm(ir::type* ty, const std::vector<int>& perm); |
918 | |
919 | private: |
920 | trans_inst(value *arg, const std::vector<int>& perm, const std::string& name, instruction* next); |
921 | std::string repr_impl() const { return "trans" ; } |
922 | |
923 | public: |
924 | static instruction* create(value *arg, const std::vector<int> &perm = {}, const std::string &name = "" , instruction *next = nullptr); |
925 | const std::vector<int> get_perm() const; |
926 | _TRITON_DEFINE_CLONE(trans_inst) |
927 | _TRITON_DEFINE_ACCEPT(trans_inst) |
928 | |
929 | private: |
930 | std::vector<int> perm_; |
931 | }; |
932 | |
933 | class sqrt_inst: public builtin_inst { |
934 | private: |
935 | sqrt_inst(value *arg, const std::string& name, instruction* next); |
936 | std::string repr_impl() const { return "sqrt" ; } |
937 | public: |
938 | static instruction* create(value *arg, const std::string &name = "" , instruction *next = nullptr); |
939 | _TRITON_DEFINE_CLONE(sqrt_inst) |
940 | _TRITON_DEFINE_ACCEPT(sqrt_inst) |
941 | }; |
942 | |
943 | class reduce_inst: public builtin_inst { |
944 | public: |
945 | enum op_t{ |
946 | ADD, SUB, MAX, MIN, UMAX, UMIN, |
947 | ARGMAX, ARGMIN, ARGUMAX, ARGUMIN, |
948 | FADD, FSUB, FMAX, FMIN, |
949 | ARGFMAX, ARGFMIN, |
950 | XOR |
951 | }; |
952 | |
953 | private: |
954 | static type* get_res_type(value *arg, unsigned axis); |
955 | static std::string to_str(op_t op); |
956 | |
957 | private: |
958 | reduce_inst(value* arg, op_t op, unsigned axis, const std::string& name, instruction* next); |
959 | std::string repr_impl() const { return "reduce" ; } |
960 | _TRITON_DEFINE_CLONE(reduce_inst) |
961 | _TRITON_DEFINE_ACCEPT(reduce_inst) |
962 | |
963 | public: |
964 | static instruction* create(value *arg, op_t op, unsigned axis, const std::string &name = "" , instruction *next = nullptr); |
965 | unsigned get_axis() const { return axis_; } |
966 | op_t get_op() const { return op_; } |
967 | bool with_index() const { |
968 | return with_index_ops_.find(op_) != with_index_ops_.end(); |
969 | } |
970 | |
971 | private: |
972 | const static inline std::set<op_t> with_index_ops_ = { |
973 | op_t::ARGMAX, op_t::ARGMIN, op_t::ARGUMAX, |
974 | op_t::ARGUMIN, op_t::ARGFMAX, op_t::ARGFMIN}; |
975 | unsigned axis_; |
976 | op_t op_; |
977 | }; |
978 | |
979 | |
980 | class select_inst: public builtin_inst { |
981 | private: |
982 | select_inst(value *pred, value *if_value, value *else_value, const std::string& name, instruction* next); |
983 | std::string repr_impl() const { return "select" ; } |
984 | _TRITON_DEFINE_CLONE(select_inst) |
985 | _TRITON_DEFINE_ACCEPT(select_inst) |
986 | |
987 | public: |
988 | static instruction* create(value *pred, value *if_value, value *else_value, const std::string &name = "" , instruction *next = nullptr); |
989 | value* get_pred_op() { return get_operand(0); } |
990 | value* get_if_value_op() { return get_operand(1); } |
991 | value* get_else_value_op() { return get_operand(2); } |
992 | }; |
993 | |
994 | //===----------------------------------------------------------------------===// |
995 | // intrinsics classes |
996 | //===----------------------------------------------------------------------===// |
997 | |
998 | |
999 | class copy_to_shared_inst: public unary_inst{ |
1000 | private: |
1001 | using unary_inst::unary_inst; |
1002 | std::string repr_impl() const { return "copy_to_shared" ; } |
1003 | |
1004 | public: |
1005 | static copy_to_shared_inst* create(value *arg, const std::string &name = "" , |
1006 | instruction *next = nullptr); |
1007 | _TRITON_DEFINE_CLONE(copy_to_shared_inst) |
1008 | _TRITON_DEFINE_ACCEPT(copy_to_shared_inst) |
1009 | }; |
1010 | |
1011 | class copy_from_shared_inst: public unary_inst{ |
1012 | private: |
1013 | using unary_inst::unary_inst; |
1014 | std::string repr_impl() const { return "copy_from_shared" ; } |
1015 | |
1016 | public: |
1017 | static copy_from_shared_inst* create(value *arg, const std::string &name = "" , |
1018 | instruction *next = nullptr); |
1019 | _TRITON_DEFINE_CLONE(copy_from_shared_inst) |
1020 | _TRITON_DEFINE_ACCEPT(copy_from_shared_inst) |
1021 | }; |
1022 | |
1023 | class cvt_layout_inst: public unary_inst { |
1024 | private: |
1025 | using unary_inst::unary_inst; |
1026 | std::string repr_impl() const { return "cvt_layout_inst" ; } |
1027 | |
1028 | public: |
1029 | static cvt_layout_inst* create(value *arg, const std::string &name = "" , instruction *next = nullptr); |
1030 | _TRITON_DEFINE_CLONE(cvt_layout_inst) |
1031 | _TRITON_DEFINE_ACCEPT(cvt_layout_inst) |
1032 | }; |
1033 | |
1034 | class barrier_inst: public instruction{ |
1035 | private: |
1036 | barrier_inst(context &ctx, const std::string &name, instruction *next); |
1037 | std::string repr_impl() const { return "barrier" ; } |
1038 | _TRITON_DEFINE_CLONE(barrier_inst) |
1039 | _TRITON_DEFINE_ACCEPT(barrier_inst) |
1040 | |
1041 | public: |
1042 | static barrier_inst* create(context &ctx, const std::string &name = "" , |
1043 | instruction *next = nullptr); |
1044 | }; |
1045 | |
1046 | class async_wait_inst: public instruction{ |
1047 | private: |
1048 | async_wait_inst(context &ctx, int N, const std::string &name, instruction *next); |
1049 | std::string repr_impl() const { return "async_wait_group " + std::to_string(N_) ; } |
1050 | _TRITON_DEFINE_CLONE(async_wait_inst) |
1051 | _TRITON_DEFINE_ACCEPT(async_wait_inst) |
1052 | |
1053 | public: |
1054 | static async_wait_inst* create(context &ctx, int N, |
1055 | const std::string &name = "" , instruction *next = nullptr); |
1056 | int get_N() { return N_; } |
1057 | void set_N(int n) { N_ = n; } |
1058 | |
1059 | private: |
1060 | int N_; |
1061 | }; |
1062 | |
1063 | class prefetch_s_inst : public instruction { |
1064 | std::string repr_impl() const { return "prefetch_s" ; } |
1065 | _TRITON_DEFINE_CLONE(prefetch_s_inst) |
1066 | _TRITON_DEFINE_ACCEPT(prefetch_s_inst) |
1067 | |
1068 | /// inc_: 0->first, 1->latch |
1069 | int inc_ = 0; |
1070 | public: |
1071 | prefetch_s_inst(context &ctx, value *arg, int inc, const std::string &name, instruction *next) |
1072 | : instruction(type::get_void_ty(ctx), INST_PREFETCH_S, 1, name, next), inc_(inc) { |
1073 | set_operand(0, arg); |
1074 | } |
1075 | int get_inc() const { return inc_; } |
1076 | static prefetch_s_inst *create(context &ctx, value *arg, int inc, const std::string &name = "" , |
1077 | instruction *next=nullptr); |
1078 | }; |
1079 | |
1080 | /* constant range */ |
1081 | class make_range: public instruction{ |
1082 | make_range(type *ty, constant_int* first, constant_int* last); |
1083 | std::string repr_impl() const { return "make_range[" + first_->repr() + " : " + last_->repr() + "]" ; } |
1084 | _TRITON_DEFINE_CLONE(make_range) |
1085 | _TRITON_DEFINE_ACCEPT(make_range) |
1086 | |
1087 | public: |
1088 | static make_range *create(constant_int *first, constant_int *last); |
1089 | const constant_int* get_first() const; |
1090 | const constant_int* get_last() const; |
1091 | |
1092 | private: |
1093 | constant_int* first_; |
1094 | constant_int* last_; |
1095 | }; |
1096 | |
1097 | /* timing utilities */ |
1098 | class clock_inst: public instruction{ |
1099 | clock_inst(context &ctx, const std::string &name, instruction *next); |
1100 | std::string repr_impl() const { return "clock" ; } |
1101 | _TRITON_DEFINE_CLONE(clock_inst) |
1102 | _TRITON_DEFINE_ACCEPT(clock_inst) |
1103 | |
1104 | public: |
1105 | static clock_inst* create(context &ctx, const std::string &name = "" , instruction *next = nullptr); |
1106 | }; |
1107 | |
1108 | class globaltimer_inst: public instruction{ |
1109 | globaltimer_inst(context &ctx, const std::string &name, instruction *next); |
1110 | std::string repr_impl() const { return "globaltimer" ; } |
1111 | _TRITON_DEFINE_CLONE(globaltimer_inst) |
1112 | _TRITON_DEFINE_ACCEPT(globaltimer_inst) |
1113 | |
1114 | public: |
1115 | static globaltimer_inst* create(context &ctx, const std::string &name = "" , instruction *next = nullptr); |
1116 | }; |
1117 | |
1118 | class extern_elementwise_inst : public instruction { |
1119 | extern_elementwise_inst(context &ctx, const std::vector<value *> &args, |
1120 | type *dst_ty, const std::string &lib_name, |
1121 | const std::string &extern_lib_path, |
1122 | const std::string &symbol_name, |
1123 | const std::string &name, instruction *next); |
1124 | std::string repr_impl() const { return "extern_elementwise" ; } |
1125 | _TRITON_DEFINE_CLONE(extern_elementwise_inst) |
1126 | _TRITON_DEFINE_ACCEPT(extern_elementwise_inst) |
1127 | |
1128 | public: |
1129 | static extern_elementwise_inst *create( |
1130 | context &ctx, const std::vector<value *> &args, type *dst_ty, |
1131 | const std::string &lib_name = "" , const std::string &lib_path = "" , |
1132 | const std::string &symbol_name = "" , const std::string &name = "" , |
1133 | instruction *next = nullptr); |
1134 | |
1135 | const std::string &get_lib_name() const { return lib_name_; } |
1136 | const std::string &get_lib_path() const { return lib_path_; } |
1137 | const std::string &get_symbol_name() const { return symbol_name_; } |
1138 | |
1139 | private: |
1140 | std::string lib_name_; |
1141 | std::string lib_path_; |
1142 | std::string symbol_name_; |
1143 | }; |
1144 | } |
1145 | } |
1146 | |
1147 | #endif |
1148 | |