1#pragma once
2
3#ifndef _TRITON_IR_TYPE_H_
4#define _TRITON_IR_TYPE_H_
5
6#include <cassert>
7#include <vector>
8#include <string>
9#include <stdexcept>
10
11namespace triton{
12namespace ir{
13
14class context;
15class value;
16class integer_type;
17class constant_int;
18
19/* Type */
20class type {
21public:
22 typedef std::vector<unsigned> block_shapes_t;
23
24 typedef std::vector<type*> contained_tys_vec_t;
25 typedef contained_tys_vec_t::iterator ty_iterator;
26 typedef contained_tys_vec_t::const_iterator const_ty_iterator;
27
28public:
29 enum id_t {
30 // primitive types
31 VoidTyID = 0, ///< type with no size
32 FP8TyID, ///< 8-bit floating point type (3 bits mantissa)
33 FP16TyID, ///< 16-bit floating point type (10 bits mantissa)
34 BF16TyID, ///< 16-bit floating point type (7 bits mantissa)
35 FP32TyID, ///< 32-bit floating point type
36 FP64TyID, ///< 64-bit floating point type
37 LabelTyID, ///< Labels
38 MetadataTyID, ///< Metadata
39 TokenTyID, ///< Token
40 // derived types
41 IntegerTyID, ///< Arbitrary bit width integers
42 FunctionTyID, ///< Functions
43 PointerTyID, ///< Pointers
44 StructTyID, ///< Struct
45 BlockTyID, ///< Block
46 };
47
48public:
49 //constructors
50 type(context &ctx, id_t id) : ctx_(ctx), id_(id) { }
51
52 //destructor
53 virtual ~type(){}
54
55 // accessors
56 context &get_context() const { return ctx_; }
57 id_t get_type_id() const { return id_; }
58 // type attributes
59 unsigned get_fp_mantissa_width() const;
60 unsigned get_integer_bitwidth() const;
61 unsigned get_tile_bitwidth() const;
62 unsigned get_primitive_size_in_bits() const;
63 type *get_scalar_ty() const;
64 block_shapes_t get_block_shapes() const;
65 const size_t get_tile_rank() const;
66 const size_t get_tile_ranks1() const;
67 unsigned get_tile_num_elements() const;
68 type *get_tile_element_ty() const;
69 unsigned get_pointer_address_space() const;
70 type *get_pointer_element_ty() const;
71 unsigned get_struct_numel() const { return contained_tys_.size(); }
72 type *get_struct_type(unsigned int i) const { return contained_tys_[i]; }
73
74 // primitive predicates
75 bool is_void_ty() const { return id_ == VoidTyID; }
76 bool is_fp8_ty() const { return id_ == FP8TyID; }
77 bool is_fp16_ty() const { return id_ == FP16TyID; }
78 bool is_bf16_ty() const { return id_ == BF16TyID; }
79 bool is_fp32_ty() const { return id_ == FP32TyID; }
80 bool is_fp64_ty() const { return id_ == FP64TyID; }
81 bool is_label_ty() const { return id_ == LabelTyID;}
82 bool is_metadata_ty() const { return id_ == MetadataTyID; }
83 bool is_token_ty() const { return id_ == TokenTyID; }
84 bool is_integer_ty() const { return id_ == IntegerTyID; }
85 bool is_bool_ty() const { return is_integer_ty(1); }
86 bool is_pointer_ty() const { return id_ == PointerTyID; }
87 bool is_block_ty() const { return id_ == BlockTyID; }
88 bool is_struct_ty() const { return id_ == StructTyID; }
89
90 // Composite predicates
91 bool is_int_or_tileint_ty();
92 bool is_integer_ty(unsigned width) const;
93 bool is_floating_point_ty() const;
94 bool is_sized() const ;
95
96 // Factory methods
97 // primitive types
98 static type *get_void_ty(context &ctx);
99 static type *get_label_ty(context &ctx);
100 // half
101 static type *get_fp8_ty(context &ctx);
102 static type *get_fp16_ty(context &ctx);
103 static type *get_bf16_ty(context &ctx);
104 static type *get_fp32_ty(context &ctx);
105 static type *get_fp64_ty(context &ctx);
106 // integer types
107 static integer_type *get_int1_ty(context &ctx);
108 static integer_type *get_int8_ty(context &ctx);
109 static integer_type *get_int16_ty(context &ctx);
110 static integer_type *get_int32_ty(context &ctx);
111 static integer_type *get_int64_ty(context &ctx);
112 static integer_type *get_int128_ty(context &ctx);
113
114 // repr
115 std::string tile_repr() const {
116 std::string res = get_tile_element_ty()->repr();
117 auto shapes = get_block_shapes();
118 res += "<";
119 for(size_t i = 0; i < shapes.size(); i++){
120 if(i > 0)
121 res += ", ";
122 res += std::to_string(shapes[i]);
123 }
124 res+= ">";
125 return res;
126 }
127
128 std::string repr() const {
129 switch(id_) {
130 case VoidTyID: return "void";
131 case FP8TyID: return "fp8";
132 case BF16TyID: return "bf16";
133 case FP16TyID: return "f16";
134 case FP32TyID: return "f32";
135 case FP64TyID: return "f64";
136 case LabelTyID: return "label";
137 case MetadataTyID: return "md";
138 case TokenTyID: return "tok";
139 case IntegerTyID: return ("i") + std::to_string(get_integer_bitwidth());
140 case FunctionTyID: return "fn";
141 case PointerTyID: return get_pointer_element_ty()->repr() + "*";
142 case StructTyID: return "struct";
143 case BlockTyID: return tile_repr();
144 default: break;
145 }
146 throw std::logic_error("unknown type id '" + std::to_string(id_) + "'");
147 };
148
149private:
150 context &ctx_;
151 id_t id_;
152
153protected:
154 contained_tys_vec_t contained_tys_;
155};
156
157class integer_type: public type {
158 friend class context_impl;
159
160private:
161 // constructors
162 integer_type(context &ctx, unsigned bitwidth)
163 : type(ctx, IntegerTyID), bitwidth_(bitwidth) {}
164
165public:
166 // accessors
167 unsigned get_bitwidth() const { return bitwidth_; }
168
169 // factory methods
170 static integer_type* get(context &ctx, unsigned width);
171
172private:
173 unsigned bitwidth_;
174};
175
176class composite_type: public type{
177protected:
178 using type::type;
179
180public:
181 bool index_valid(value *idx) const;
182 type* get_type_at_index(value *idx) const;
183};
184
185class struct_type: public composite_type {
186public:
187 struct_type(const contained_tys_vec_t& tys, bool is_packed);
188 unsigned get_num_types() const { return contained_tys_.size(); }
189 static struct_type* get(const contained_tys_vec_t& tys, bool is_packed);
190
191private:
192 bool is_packed_;
193};
194
195class block_type: public composite_type {
196private:
197 block_type(type *ty, const block_shapes_t &shapes);
198 static bool is_valid_elt_ty(type *ty);
199
200public:
201 // accessors
202 const block_shapes_t& get_shapes() const { return shapes_; }
203 unsigned get_num_elements() const;
204 unsigned get_bitwidth() const;
205
206 // factory methods
207 static block_type* get(type *ty, const block_shapes_t &shapes);
208 static block_type* get_same_shapes(type *ty, type *ref);
209
210private:
211 block_shapes_t shapes_;
212};
213
214class pointer_type: public type {
215private:
216 pointer_type(type *ty, unsigned address_space);
217 static bool is_valid_elt_ty(type *ty);
218
219public:
220 // accessors
221 unsigned get_address_space() const { return address_space_; }
222 type *get_element_ty() const { return contained_tys_[0]; }
223 // factory methods
224 static pointer_type* get(type *ty, unsigned address_space);
225
226private:
227 unsigned address_space_;
228};
229
230class function_type: public type {
231private:
232 function_type(type *ret_ty, const std::vector<type *> &param_tys);
233
234public:
235 // accessors
236 unsigned get_num_params() const { return contained_tys_.size() - 1; }
237 const_ty_iterator params_begin() const { return contained_tys_.begin() + 1; }
238 const_ty_iterator params_end() const { return contained_tys_.end(); }
239 ty_iterator params_begin() { return contained_tys_.begin() + 1; }
240 ty_iterator params_end() { return contained_tys_.end(); }
241 type* get_param_ty(unsigned i) const { return contained_tys_.at(1 + i); }
242 type* get_return_ty() const { return contained_tys_.at(0); }
243 void reset_ret_ty(type* ty) { contained_tys_[0] = ty;}
244 // factory methods
245 static function_type* get(type *ret_ty, const std::vector<type*>& param_tys);
246};
247
248
249}
250}
251
252#endif
253