1#include <cassert>
2#include <stdexcept>
3#include "triton/ir/type.h"
4#include "triton/ir/context.h"
5#include "triton/ir/context_impl.h"
6#include "triton/ir/value.h"
7#include "triton/ir/constant.h"
8
9namespace triton{
10namespace ir{
11
12//===----------------------------------------------------------------------===//
13// type class
14//===----------------------------------------------------------------------===//
15
16// attributes
17type *type::get_scalar_ty() const {
18 if(is_block_ty())
19 return get_tile_element_ty();
20 return const_cast<type*>(this);
21}
22
23unsigned type::get_primitive_size_in_bits() const {
24 switch (id_) {
25 case FP8TyID: return 8;
26 case FP16TyID: return 16;
27 case BF16TyID: return 16;
28 case FP32TyID: return 32;
29 case FP64TyID: return 64;
30 case IntegerTyID: return std::max<int>(8, ((integer_type*)(this))->get_bitwidth());
31 case BlockTyID: return ((block_type*)(this))->get_bitwidth();
32 default: return 0;
33 }
34}
35
36unsigned type::get_integer_bitwidth() const
37{ assert(id_ == IntegerTyID); return ((integer_type*)(this))->get_bitwidth(); }
38
39unsigned type::get_tile_bitwidth() const
40{ return ((block_type*)(this))->get_bitwidth(); }
41
42unsigned type::get_fp_mantissa_width() const {
43 id_t id = get_scalar_ty()->id_;
44 assert(is_floating_point_ty() && "Not a floating point type!");
45 if (id == FP8TyID) return 3;
46 if (id == FP16TyID) return 10;
47 if (id == BF16TyID) return 7;
48 if (id == FP32TyID) return 23;
49 if (id == FP64TyID) return 53;
50 throw std::runtime_error("unreachable");
51}
52
53type* type::get_tile_element_ty() const {
54 assert(is_block_ty());
55 return contained_tys_[0];
56}
57
58unsigned type::get_pointer_address_space() const {
59 assert(is_pointer_ty());
60 return ((pointer_type*)this)->get_address_space();
61}
62
63type * type::get_pointer_element_ty() const {
64 type *ptr_ty = get_scalar_ty();
65 assert(ptr_ty->is_pointer_ty());
66 type *scalar_ty = ((pointer_type*)ptr_ty)->get_element_ty();
67 if(is_block_ty())
68 return block_type::get_same_shapes(scalar_ty, (type*)this);
69 return scalar_ty;
70}
71
72
73type::block_shapes_t type::get_block_shapes() const {
74 assert(is_block_ty());
75 return ((block_type*)this)->get_shapes();
76}
77
78const size_t type::get_tile_rank() const {
79 return get_block_shapes().size();
80}
81
82const size_t type::get_tile_ranks1() const {
83 int ret = 0;
84 for(int s: get_block_shapes())
85 ret += s > 1;
86 return ret;
87}
88
89
90unsigned type::get_tile_num_elements() const {
91 const block_shapes_t& shapes = get_block_shapes();
92 unsigned result = 1;
93 for(auto shape: shapes)
94 result *= shape;
95 return result;
96}
97
98
99// composite predicates
100bool type::is_int_or_tileint_ty()
101{ return get_scalar_ty()->is_integer_ty(); }
102
103bool type::is_integer_ty(unsigned width) const
104{ return is_integer_ty() && get_integer_bitwidth()== width; }
105
106
107bool type::is_floating_point_ty() const
108{ return is_fp8_ty() || is_fp16_ty() || is_bf16_ty() || is_fp32_ty() || is_fp64_ty(); }
109
110bool type::is_sized() const {
111 // primitive types are sized
112 if(is_integer_ty() || is_floating_point_ty() ||
113 is_pointer_ty()){
114 return true;
115 }
116 // tile types are sizes
117 if(is_block_ty())
118 return get_scalar_ty()->is_sized();
119 return false;
120}
121
122// primitive types
123type *type::get_void_ty(context &ctx) { return &ctx.p_impl->void_ty; }
124type *type::get_label_ty(context &ctx) { return &ctx.p_impl->label_ty; }
125// floating point
126type *type::get_fp8_ty(context &ctx) { return &ctx.p_impl->fp8_ty; }
127type *type::get_fp16_ty(context &ctx) { return &ctx.p_impl->fp16_ty; }
128type *type::get_bf16_ty(context &ctx) { return &ctx.p_impl->bf16_ty; }
129type *type::get_fp32_ty(context &ctx) { return &ctx.p_impl->fp32_ty; }
130type *type::get_fp64_ty(context &ctx) { return &ctx.p_impl->fp64_ty; }
131// integer types
132integer_type *type::get_int1_ty(context &ctx) { return &ctx.p_impl->int1_ty; }
133integer_type *type::get_int8_ty(context &ctx) { return &ctx.p_impl->int8_ty; }
134integer_type *type::get_int16_ty(context &ctx) { return &ctx.p_impl->int16_ty; }
135integer_type *type::get_int32_ty(context &ctx) { return &ctx.p_impl->int32_ty; }
136integer_type *type::get_int64_ty(context &ctx) { return &ctx.p_impl->int64_ty; }
137integer_type *type::get_int128_ty(context &ctx) { return &ctx.p_impl->int128_ty; }
138
139
140
141pointer_type::pointer_type(type *ty, unsigned address_space)
142 : type(ty->get_context(), PointerTyID), address_space_(address_space){
143 contained_tys_.push_back(ty);
144}
145
146bool pointer_type::is_valid_elt_ty(type *ty){
147 return !ty->is_void_ty() && !ty->is_label_ty() &&
148 !ty->is_metadata_ty() && !ty->is_token_ty();
149}
150
151pointer_type* pointer_type::get(type *elt_ty, unsigned address_space){
152 assert(elt_ty && "Can't get a pointer to <null> type!");
153 assert(is_valid_elt_ty(elt_ty) && "Invalid type for pointer element!");
154 // look-up
155 context_impl *impl = elt_ty->get_context().p_impl.get();
156 std::unique_ptr<pointer_type> &entry = impl->ptr_tys[std::make_pair(elt_ty, address_space)];
157 if(!entry)
158 entry.reset(new pointer_type(elt_ty, address_space));
159 return entry.get();
160}
161
162//===----------------------------------------------------------------------===//
163// composite_type class
164//===----------------------------------------------------------------------===//
165
166type* composite_type::get_type_at_index(value *) const{
167 assert(is_block_ty());
168 return get_scalar_ty();
169}
170
171bool composite_type::index_valid(value *idx) const{
172 assert(is_block_ty());
173 return idx->get_type()->is_int_or_tileint_ty();
174}
175
176//===----------------------------------------------------------------------===//
177// struct_type class
178//===----------------------------------------------------------------------===//
179
180struct_type::struct_type(const contained_tys_vec_t& tys, bool is_packed)
181 : composite_type(tys[0]->get_context(), StructTyID), is_packed_(is_packed) {
182 contained_tys_ = tys;
183}
184
185struct_type* struct_type::get(const contained_tys_vec_t& tys, bool is_packed) {
186 assert(tys.size());
187 context_impl* impl = tys[0]->get_context().p_impl.get();
188 struct_type *& entry = impl->struct_tys[tys];
189 if(!entry)
190 entry = new struct_type(tys, is_packed);
191 return entry;
192}
193
194
195//===----------------------------------------------------------------------===//
196// block_type class
197//===----------------------------------------------------------------------===//
198
199block_type::block_type(type *ty, const block_shapes_t &shapes)
200 : composite_type(ty->get_context(), BlockTyID), shapes_(shapes) {
201 contained_tys_.push_back(ty);
202}
203
204bool block_type::is_valid_elt_ty(type *ty) {
205 return ty->is_pointer_ty() || ty->is_floating_point_ty() || ty->is_integer_ty();
206}
207
208unsigned block_type::get_num_elements() const {
209 unsigned res = 1;
210 for(auto shape: shapes_)
211 res *= shape;
212 return res;
213}
214
215unsigned block_type::get_bitwidth() const {
216 return get_num_elements() * get_tile_element_ty()->get_primitive_size_in_bits();
217}
218
219block_type* block_type::get(type *elt_ty, const block_shapes_t &shapes) {
220 assert(elt_ty && "Can't get a tile of <null> type!");
221 assert(shapes.size() && "Can't create a tile with empty shapes!");
222 assert(is_valid_elt_ty(elt_ty) && "Invalid type for tile element!");
223 // look-up
224 context_impl *impl = elt_ty->get_context().p_impl.get();
225 std::unique_ptr<block_type> &entry = impl->block_tys[std::make_pair(elt_ty, shapes)];
226 if(!entry)
227 entry.reset(new block_type(elt_ty, shapes));
228 return entry.get();
229}
230
231block_type* block_type::get_same_shapes(type *ty, type *ref){
232 assert(ref->is_block_ty());
233 return get(ty, ref->get_block_shapes());
234}
235
236//===----------------------------------------------------------------------===//
237// function_type class
238//===----------------------------------------------------------------------===//
239
240function_type::function_type(type *ret_ty, const std::vector<type*> &param_tys):
241 type(ret_ty->get_context(), FunctionTyID) {
242 contained_tys_.push_back(ret_ty);
243 for(type *ty: param_tys)
244 contained_tys_.push_back(ty);
245}
246
247function_type* function_type::get(type *ret_ty, const std::vector<type *> &param_tys) {
248 return new function_type(ret_ty, param_tys);
249}
250
251}
252}
253