1 | #include <cassert> |
2 | #include <stdexcept> |
3 | #include "triton/ir/type.h" |
4 | #include "triton/ir/context.h" |
5 | #include "triton/ir/context_impl.h" |
6 | #include "triton/ir/value.h" |
7 | #include "triton/ir/constant.h" |
8 | |
9 | namespace triton{ |
10 | namespace ir{ |
11 | |
12 | //===----------------------------------------------------------------------===// |
13 | // type class |
14 | //===----------------------------------------------------------------------===// |
15 | |
16 | // attributes |
17 | type *type::get_scalar_ty() const { |
18 | if(is_block_ty()) |
19 | return get_tile_element_ty(); |
20 | return const_cast<type*>(this); |
21 | } |
22 | |
23 | unsigned type::get_primitive_size_in_bits() const { |
24 | switch (id_) { |
25 | case FP8TyID: return 8; |
26 | case FP16TyID: return 16; |
27 | case BF16TyID: return 16; |
28 | case FP32TyID: return 32; |
29 | case FP64TyID: return 64; |
30 | case IntegerTyID: return std::max<int>(8, ((integer_type*)(this))->get_bitwidth()); |
31 | case BlockTyID: return ((block_type*)(this))->get_bitwidth(); |
32 | default: return 0; |
33 | } |
34 | } |
35 | |
36 | unsigned type::get_integer_bitwidth() const |
37 | { assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); } |
38 | |
39 | unsigned type::get_tile_bitwidth() const |
40 | { return ((block_type*)(this))->get_bitwidth(); } |
41 | |
42 | unsigned type::get_fp_mantissa_width() const { |
43 | id_t id = get_scalar_ty()->id_; |
44 | assert(is_floating_point_ty() && "Not a floating point type!" ); |
45 | if (id == FP8TyID) return 3; |
46 | if (id == FP16TyID) return 10; |
47 | if (id == BF16TyID) return 7; |
48 | if (id == FP32TyID) return 23; |
49 | if (id == FP64TyID) return 53; |
50 | throw std::runtime_error("unreachable" ); |
51 | } |
52 | |
53 | type* type::get_tile_element_ty() const { |
54 | assert(is_block_ty()); |
55 | return contained_tys_[0]; |
56 | } |
57 | |
58 | unsigned type::get_pointer_address_space() const { |
59 | assert(is_pointer_ty()); |
60 | return ((pointer_type*)this)->get_address_space(); |
61 | } |
62 | |
63 | type * type::get_pointer_element_ty() const { |
64 | type *ptr_ty = get_scalar_ty(); |
65 | assert(ptr_ty->is_pointer_ty()); |
66 | type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty(); |
67 | if(is_block_ty()) |
68 | return block_type::get_same_shapes(scalar_ty, (type*)this); |
69 | return scalar_ty; |
70 | } |
71 | |
72 | |
73 | type::block_shapes_t type::get_block_shapes() const { |
74 | assert(is_block_ty()); |
75 | return ((block_type*)this)->get_shapes(); |
76 | } |
77 | |
78 | const size_t type::get_tile_rank() const { |
79 | return get_block_shapes().size(); |
80 | } |
81 | |
82 | const size_t type::get_tile_ranks1() const { |
83 | int ret = 0; |
84 | for(int s: get_block_shapes()) |
85 | ret += s > 1; |
86 | return ret; |
87 | } |
88 | |
89 | |
90 | unsigned type::get_tile_num_elements() const { |
91 | const block_shapes_t& shapes = get_block_shapes(); |
92 | unsigned result = 1; |
93 | for(auto shape: shapes) |
94 | result *= shape; |
95 | return result; |
96 | } |
97 | |
98 | |
99 | // composite predicates |
100 | bool type::is_int_or_tileint_ty() |
101 | { return get_scalar_ty()->is_integer_ty(); } |
102 | |
103 | bool type::is_integer_ty(unsigned width) const |
104 | { return is_integer_ty() && get_integer_bitwidth()== width; } |
105 | |
106 | |
107 | bool type::is_floating_point_ty() const |
108 | { return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); } |
109 | |
110 | bool type::is_sized() const { |
111 | // primitive types are sized |
112 | if(is_integer_ty() || is_floating_point_ty() || |
113 | is_pointer_ty()){ |
114 | return true; |
115 | } |
116 | // tile types are sizes |
117 | if(is_block_ty()) |
118 | return get_scalar_ty()->is_sized(); |
119 | return false; |
120 | } |
121 | |
122 | // primitive types |
123 | type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; } |
124 | type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; } |
125 | // floating point |
126 | type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; } |
127 | type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; } |
128 | type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; } |
129 | type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; } |
130 | type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; } |
131 | // integer types |
132 | integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; } |
133 | integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; } |
134 | integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; } |
135 | integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; } |
136 | integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; } |
137 | integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; } |
138 | |
139 | |
140 | |
141 | pointer_type::pointer_type(type *ty, unsigned address_space) |
142 | : type(ty->get_context(), PointerTyID), address_space_(address_space){ |
143 | contained_tys_.push_back(ty); |
144 | } |
145 | |
146 | bool pointer_type::is_valid_elt_ty(type *ty){ |
147 | return !ty->is_void_ty() && !ty->is_label_ty() && |
148 | !ty->is_metadata_ty() && !ty->is_token_ty(); |
149 | } |
150 | |
151 | pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){ |
152 | assert(elt_ty && "Can't get a pointer to <null> type!" ); |
153 | assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!" ); |
154 | // look-up |
155 | context_impl *impl = elt_ty->get_context().p_impl.get(); |
156 | std::unique_ptr<pointer_type> &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)]; |
157 | if(!entry) |
158 | entry.reset(new pointer_type(elt_ty, address_space)); |
159 | return entry.get(); |
160 | } |
161 | |
162 | //===----------------------------------------------------------------------===// |
163 | // composite_type class |
164 | //===----------------------------------------------------------------------===// |
165 | |
166 | type* composite_type::get_type_at_index(value *) const{ |
167 | assert(is_block_ty()); |
168 | return get_scalar_ty(); |
169 | } |
170 | |
171 | bool composite_type::index_valid(value *idx) const{ |
172 | assert(is_block_ty()); |
173 | return idx->get_type()->is_int_or_tileint_ty(); |
174 | } |
175 | |
176 | //===----------------------------------------------------------------------===// |
177 | // struct_type class |
178 | //===----------------------------------------------------------------------===// |
179 | |
180 | struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed) |
181 | : composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) { |
182 | contained_tys_ = tys; |
183 | } |
184 | |
185 | struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) { |
186 | assert(tys.size()); |
187 | context_impl* impl = tys[0]->get_context().p_impl.get(); |
188 | struct_type *& entry = impl->struct_tys[tys]; |
189 | if(!entry) |
190 | entry = new struct_type(tys, is_packed); |
191 | return entry; |
192 | } |
193 | |
194 | |
195 | //===----------------------------------------------------------------------===// |
196 | // block_type class |
197 | //===----------------------------------------------------------------------===// |
198 | |
199 | block_type::block_type(type *ty, const block_shapes_t &shapes) |
200 | : composite_type(ty->get_context(), BlockTyID), shapes_(shapes) { |
201 | contained_tys_.push_back(ty); |
202 | } |
203 | |
204 | bool block_type::is_valid_elt_ty(type *ty) { |
205 | return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty(); |
206 | } |
207 | |
208 | unsigned block_type::get_num_elements() const { |
209 | unsigned res = 1; |
210 | for(auto shape: shapes_) |
211 | res *= shape; |
212 | return res; |
213 | } |
214 | |
215 | unsigned block_type::get_bitwidth() const { |
216 | return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits(); |
217 | } |
218 | |
219 | block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) { |
220 | assert(elt_ty && "Can't get a tile of <null> type!" ); |
221 | assert(shapes.size() && "Can't create a tile with empty shapes!" ); |
222 | assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!" ); |
223 | // look-up |
224 | context_impl *impl = elt_ty->get_context().p_impl.get(); |
225 | std::unique_ptr<block_type> &entry = impl->block_tys[std::make_pair(elt_ty, shapes)]; |
226 | if(!entry) |
227 | entry.reset(new block_type(elt_ty, shapes)); |
228 | return entry.get(); |
229 | } |
230 | |
231 | block_type* block_type::get_same_shapes(type *ty, type *ref){ |
232 | assert(ref->is_block_ty()); |
233 | return get(ty, ref->get_block_shapes()); |
234 | } |
235 | |
236 | //===----------------------------------------------------------------------===// |
237 | // function_type class |
238 | //===----------------------------------------------------------------------===// |
239 | |
240 | function_type::function_type(type *ret_ty, const std::vector<type*> ¶m_tys): |
241 | type(ret_ty->get_context(), FunctionTyID) { |
242 | contained_tys_.push_back(ret_ty); |
243 | for(type *ty: param_tys) |
244 | contained_tys_.push_back(ty); |
245 | } |
246 | |
247 | function_type* function_type::get(type *ret_ty, const std::vector<type *> ¶m_tys) { |
248 | return new function_type(ret_ty, param_tys); |
249 | } |
250 | |
251 | } |
252 | } |
253 | |