1#pragma once
2
3#include "lib_tiny_ir.h"
4#include "taichi/ir/type.h"
5
6namespace taichi::lang {
7namespace spirv {
8
9class STD140LayoutContext : public tinyir::LayoutContext {};
10class STD430LayoutContext : public tinyir::LayoutContext {};
11
12class IntType : public tinyir::Type, public tinyir::MemRefElementTypeInterface {
13 public:
14 IntType(int num_bits, bool is_signed)
15 : num_bits_(num_bits), is_signed_(is_signed) {
16 }
17
18 int num_bits() const {
19 return num_bits_;
20 }
21
22 bool is_signed() const {
23 return is_signed_;
24 }
25
26 size_t memory_size(tinyir::LayoutContext &ctx) const override {
27 return tinyir::ceil_div(num_bits(), 8);
28 }
29
30 size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override {
31 return tinyir::ceil_div(num_bits(), 8);
32 }
33
34 private:
35 bool is_equal(const Polymorphic &other) const override {
36 const IntType &t = (const IntType &)other;
37 return t.num_bits_ == num_bits_ && t.is_signed_ == is_signed_;
38 }
39
40 int num_bits_{0};
41 bool is_signed_{false};
42};
43
44class FloatType : public tinyir::Type,
45 public tinyir::MemRefElementTypeInterface {
46 public:
47 explicit FloatType(int num_bits) : num_bits_(num_bits) {
48 }
49
50 int num_bits() const {
51 return num_bits_;
52 }
53
54 size_t memory_size(tinyir::LayoutContext &ctx) const override {
55 return tinyir::ceil_div(num_bits(), 8);
56 }
57
58 size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override {
59 return tinyir::ceil_div(num_bits(), 8);
60 }
61
62 private:
63 int num_bits_{0};
64
65 bool is_equal(const Polymorphic &other) const override {
66 const FloatType &t = (const FloatType &)other;
67 return t.num_bits_ == num_bits_;
68 }
69};
70
71class PhysicalPointerType : public IntType,
72 public tinyir::PointerTypeInterface {
73 public:
74 explicit PhysicalPointerType(const tinyir::Type *pointed_type)
75 : IntType(/*num_bits=*/64, /*is_signed=*/false),
76 pointed_type_(pointed_type) {
77 }
78
79 const tinyir::Type *get_pointed_type() const override {
80 return pointed_type_;
81 }
82
83 private:
84 const tinyir::Type *pointed_type_;
85
86 bool is_equal(const Polymorphic &other) const override {
87 const PhysicalPointerType &pt = (const PhysicalPointerType &)other;
88 return IntType::operator==((const IntType &)other) &&
89 pointed_type_->equals(pt.pointed_type_);
90 }
91};
92
93class StructType : public tinyir::Type,
94 public tinyir::AggregateTypeInterface,
95 public tinyir::MemRefAggregateTypeInterface {
96 public:
97 explicit StructType(std::vector<const tinyir::Type *> &elements)
98 : elements_(elements) {
99 }
100
101 const tinyir::Type *nth_element_type(int n) const override {
102 return elements_[n];
103 }
104
105 int get_num_elements() const override {
106 return elements_.size();
107 }
108
109 size_t memory_size(tinyir::LayoutContext &ctx) const override;
110
111 size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override;
112
113 size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override;
114
115 private:
116 std::vector<const tinyir::Type *> elements_;
117
118 bool is_equal(const Polymorphic &other) const override {
119 const StructType &t = (const StructType &)other;
120 if (t.get_num_elements() != get_num_elements()) {
121 return false;
122 }
123 for (int i = 0; i < get_num_elements(); i++) {
124 if (!elements_[i]->equals(t.elements_[i])) {
125 return false;
126 }
127 }
128 return true;
129 }
130};
131
132class SmallVectorType : public tinyir::Type,
133 public tinyir::ShapedTypeInterface,
134 public tinyir::MemRefElementTypeInterface {
135 public:
136 SmallVectorType(const tinyir::Type *element_type, int num_elements);
137
138 const tinyir::Type *element_type() const override {
139 return element_type_;
140 }
141
142 bool is_constant_shape() const override {
143 return true;
144 }
145
146 std::vector<size_t> get_constant_shape() const override {
147 return {size_t(num_elements_)};
148 }
149
150 size_t memory_size(tinyir::LayoutContext &ctx) const override;
151
152 size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override;
153
154 private:
155 bool is_equal(const Polymorphic &other) const override {
156 const SmallVectorType &t = (const SmallVectorType &)other;
157 return num_elements_ == t.num_elements_ &&
158 element_type_->equals(t.element_type_);
159 }
160
161 const tinyir::Type *element_type_{nullptr};
162 int num_elements_{0};
163};
164
165class ArrayType : public tinyir::Type,
166 public tinyir::ShapedTypeInterface,
167 public tinyir::MemRefAggregateTypeInterface {
168 public:
169 ArrayType(const tinyir::Type *element_type, size_t size)
170 : element_type_(element_type), size_(size) {
171 }
172
173 const tinyir::Type *element_type() const override {
174 return element_type_;
175 }
176
177 bool is_constant_shape() const override {
178 return true;
179 }
180
181 std::vector<size_t> get_constant_shape() const override {
182 return {size_};
183 }
184
185 size_t memory_size(tinyir::LayoutContext &ctx) const override;
186
187 size_t memory_alignment_size(tinyir::LayoutContext &ctx) const override;
188
189 size_t nth_element_offset(int n, tinyir::LayoutContext &ctx) const override;
190
191 private:
192 bool is_equal(const Polymorphic &other) const override {
193 const ArrayType &t = (const ArrayType &)other;
194 return size_ == t.size_ && element_type_->equals(t.element_type_);
195 }
196
197 const tinyir::Type *element_type_{nullptr};
198 size_t size_{0};
199};
200
201bool bitcast_possible(tinyir::Type *a, tinyir::Type *b, bool _inverted = false);
202
203class TypeVisitor : public tinyir::Visitor {
204 public:
205 void visit_type(const tinyir::Type *type) override;
206
207 virtual void visit_int_type(const IntType *type) {
208 }
209
210 virtual void visit_float_type(const FloatType *type) {
211 }
212
213 virtual void visit_physical_pointer_type(const PhysicalPointerType *type) {
214 }
215
216 virtual void visit_struct_type(const StructType *type) {
217 }
218
219 virtual void visit_small_vector_type(const SmallVectorType *type) {
220 }
221
222 virtual void visit_array_type(const ArrayType *type) {
223 }
224};
225
226const tinyir::Type *translate_ti_primitive(tinyir::Block &ir_module,
227 const DataType t);
228
229std::string ir_print_types(const tinyir::Block *block);
230
231std::unique_ptr<tinyir::Block> ir_reduce_types(
232 tinyir::Block *blk,
233 std::unordered_map<const tinyir::Type *, const tinyir::Type *> &old2new);
234
235class IRBuilder;
236
237std::unordered_map<const tinyir::Node *, uint32_t> ir_translate_to_spirv(
238 const tinyir::Block *blk,
239 tinyir::LayoutContext &layout_ctx,
240 IRBuilder *spir_builder);
241
242} // namespace spirv
243} // namespace taichi::lang
244