1 | #pragma once |
2 | |
3 | #ifndef _TRITON_IR_TYPE_H_ |
4 | #define _TRITON_IR_TYPE_H_ |
5 | |
6 | #include <cassert> |
7 | #include <vector> |
8 | #include <string> |
9 | #include <stdexcept> |
10 | |
11 | namespace triton{ |
12 | namespace ir{ |
13 | |
14 | class context; |
15 | class value; |
16 | class integer_type; |
17 | class constant_int; |
18 | |
19 | /* Type */ |
20 | class type { |
21 | public: |
22 | typedef std::vector<unsigned> block_shapes_t; |
23 | |
24 | typedef std::vector<type*> contained_tys_vec_t; |
25 | typedef contained_tys_vec_t::iterator ty_iterator; |
26 | typedef contained_tys_vec_t::const_iterator const_ty_iterator; |
27 | |
28 | public: |
29 | enum id_t { |
30 | // primitive types |
31 | VoidTyID = 0, ///< type with no size |
32 | FP8TyID, ///< 8-bit floating point type (3 bits mantissa) |
33 | FP16TyID, ///< 16-bit floating point type (10 bits mantissa) |
34 | BF16TyID, ///< 16-bit floating point type (7 bits mantissa) |
35 | FP32TyID, ///< 32-bit floating point type |
36 | FP64TyID, ///< 64-bit floating point type |
37 | LabelTyID, ///< Labels |
38 | MetadataTyID, ///< Metadata |
39 | TokenTyID, ///< Token |
40 | // derived types |
41 | IntegerTyID, ///< Arbitrary bit width integers |
42 | FunctionTyID, ///< Functions |
43 | PointerTyID, ///< Pointers |
44 | StructTyID, ///< Struct |
45 | BlockTyID, ///< Block |
46 | }; |
47 | |
48 | public: |
49 | //constructors |
50 | type(context &ctx, id_t id) : ctx_(ctx), id_(id) { } |
51 | |
52 | //destructor |
53 | virtual ~type(){} |
54 | |
55 | // accessors |
56 | context &get_context() const { return ctx_; } |
57 | id_t get_type_id() const { return id_; } |
58 | // type attributes |
59 | unsigned get_fp_mantissa_width() const; |
60 | unsigned get_integer_bitwidth() const; |
61 | unsigned get_tile_bitwidth() const; |
62 | unsigned get_primitive_size_in_bits() const; |
63 | type *get_scalar_ty() const; |
64 | block_shapes_t get_block_shapes() const; |
65 | const size_t get_tile_rank() const; |
66 | const size_t get_tile_ranks1() const; |
67 | unsigned get_tile_num_elements() const; |
68 | type *get_tile_element_ty() const; |
69 | unsigned get_pointer_address_space() const; |
70 | type *get_pointer_element_ty() const; |
71 | unsigned get_struct_numel() const { return contained_tys_.size(); } |
72 | type *get_struct_type(unsigned int i) const { return contained_tys_[i]; } |
73 | |
74 | // primitive predicates |
75 | bool is_void_ty() const { return id_ == VoidTyID; } |
76 | bool is_fp8_ty() const { return id_ == FP8TyID; } |
77 | bool is_fp16_ty() const { return id_ == FP16TyID; } |
78 | bool is_bf16_ty() const { return id_ == BF16TyID; } |
79 | bool is_fp32_ty() const { return id_ == FP32TyID; } |
80 | bool is_fp64_ty() const { return id_ == FP64TyID; } |
81 | bool is_label_ty() const { return id_ == LabelTyID;} |
82 | bool is_metadata_ty() const { return id_ == MetadataTyID; } |
83 | bool is_token_ty() const { return id_ == TokenTyID; } |
84 | bool is_integer_ty() const { return id_ == IntegerTyID; } |
85 | bool is_bool_ty() const { return is_integer_ty(1); } |
86 | bool is_pointer_ty() const { return id_ == PointerTyID; } |
87 | bool is_block_ty() const { return id_ == BlockTyID; } |
88 | bool is_struct_ty() const { return id_ == StructTyID; } |
89 | |
90 | // Composite predicates |
91 | bool is_int_or_tileint_ty(); |
92 | bool is_integer_ty(unsigned width) const; |
93 | bool is_floating_point_ty() const; |
94 | bool is_sized() const ; |
95 | |
96 | // Factory methods |
97 | // primitive types |
98 | static type *get_void_ty(context &ctx); |
99 | static type *get_label_ty(context &ctx); |
100 | // half |
101 | static type *get_fp8_ty(context &ctx); |
102 | static type *get_fp16_ty(context &ctx); |
103 | static type *get_bf16_ty(context &ctx); |
104 | static type *get_fp32_ty(context &ctx); |
105 | static type *get_fp64_ty(context &ctx); |
106 | // integer types |
107 | static integer_type *get_int1_ty(context &ctx); |
108 | static integer_type *get_int8_ty(context &ctx); |
109 | static integer_type *get_int16_ty(context &ctx); |
110 | static integer_type *get_int32_ty(context &ctx); |
111 | static integer_type *get_int64_ty(context &ctx); |
112 | static integer_type *get_int128_ty(context &ctx); |
113 | |
114 | // repr |
115 | std::string tile_repr() const { |
116 | std::string res = get_tile_element_ty()->repr(); |
117 | auto shapes = get_block_shapes(); |
118 | res += "<" ; |
119 | for(size_t i = 0; i < shapes.size(); i++){ |
120 | if(i > 0) |
121 | res += ", " ; |
122 | res += std::to_string(shapes[i]); |
123 | } |
124 | res+= ">" ; |
125 | return res; |
126 | } |
127 | |
128 | std::string repr() const { |
129 | switch(id_) { |
130 | case VoidTyID: return "void" ; |
131 | case FP8TyID: return "fp8" ; |
132 | case BF16TyID: return "bf16" ; |
133 | case FP16TyID: return "f16" ; |
134 | case FP32TyID: return "f32" ; |
135 | case FP64TyID: return "f64" ; |
136 | case LabelTyID: return "label" ; |
137 | case MetadataTyID: return "md" ; |
138 | case TokenTyID: return "tok" ; |
139 | case IntegerTyID: return ("i" ) + std::to_string(get_integer_bitwidth()); |
140 | case FunctionTyID: return "fn" ; |
141 | case PointerTyID: return get_pointer_element_ty()->repr() + "*" ; |
142 | case StructTyID: return "struct" ; |
143 | case BlockTyID: return tile_repr(); |
144 | default: break; |
145 | } |
146 | throw std::logic_error("unknown type id '" + std::to_string(id_) + "'" ); |
147 | }; |
148 | |
149 | private: |
150 | context &ctx_; |
151 | id_t id_; |
152 | |
153 | protected: |
154 | contained_tys_vec_t contained_tys_; |
155 | }; |
156 | |
157 | class integer_type: public type { |
158 | friend class context_impl; |
159 | |
160 | private: |
161 | // constructors |
162 | integer_type(context &ctx, unsigned bitwidth) |
163 | : type(ctx, IntegerTyID), bitwidth_(bitwidth) {} |
164 | |
165 | public: |
166 | // accessors |
167 | unsigned get_bitwidth() const { return bitwidth_; } |
168 | |
169 | // factory methods |
170 | static integer_type* get(context &ctx, unsigned width); |
171 | |
172 | private: |
173 | unsigned bitwidth_; |
174 | }; |
175 | |
176 | class composite_type: public type{ |
177 | protected: |
178 | using type::type; |
179 | |
180 | public: |
181 | bool index_valid(value *idx) const; |
182 | type* get_type_at_index(value *idx) const; |
183 | }; |
184 | |
185 | class struct_type: public composite_type { |
186 | public: |
187 | struct_type(const contained_tys_vec_t& tys, bool is_packed); |
188 | unsigned get_num_types() const { return contained_tys_.size(); } |
189 | static struct_type* get(const contained_tys_vec_t& tys, bool is_packed); |
190 | |
191 | private: |
192 | bool is_packed_; |
193 | }; |
194 | |
195 | class block_type: public composite_type { |
196 | private: |
197 | block_type(type *ty, const block_shapes_t &shapes); |
198 | static bool is_valid_elt_ty(type *ty); |
199 | |
200 | public: |
201 | // accessors |
202 | const block_shapes_t& get_shapes() const { return shapes_; } |
203 | unsigned get_num_elements() const; |
204 | unsigned get_bitwidth() const; |
205 | |
206 | // factory methods |
207 | static block_type* get(type *ty, const block_shapes_t &shapes); |
208 | static block_type* get_same_shapes(type *ty, type *ref); |
209 | |
210 | private: |
211 | block_shapes_t shapes_; |
212 | }; |
213 | |
214 | class pointer_type: public type { |
215 | private: |
216 | pointer_type(type *ty, unsigned address_space); |
217 | static bool is_valid_elt_ty(type *ty); |
218 | |
219 | public: |
220 | // accessors |
221 | unsigned get_address_space() const { return address_space_; } |
222 | type *get_element_ty() const { return contained_tys_[0]; } |
223 | // factory methods |
224 | static pointer_type* get(type *ty, unsigned address_space); |
225 | |
226 | private: |
227 | unsigned address_space_; |
228 | }; |
229 | |
230 | class function_type: public type { |
231 | private: |
232 | function_type(type *ret_ty, const std::vector<type *> ¶m_tys); |
233 | |
234 | public: |
235 | // accessors |
236 | unsigned get_num_params() const { return contained_tys_.size() - 1; } |
237 | const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; } |
238 | const_ty_iterator params_end() const { return contained_tys_.end(); } |
239 | ty_iterator params_begin() { return contained_tys_.begin() + 1; } |
240 | ty_iterator params_end() { return contained_tys_.end(); } |
241 | type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); } |
242 | type* get_return_ty() const { return contained_tys_.at(0); } |
243 | void reset_ret_ty(type* ty) { contained_tys_[0] = ty;} |
244 | // factory methods |
245 | static function_type* get(type *ret_ty, const std::vector<type*>& param_tys); |
246 | }; |
247 | |
248 | |
249 | } |
250 | } |
251 | |
252 | #endif |
253 | |