1#pragma once
2
3#include "taichi/ir/type.h"
4#include "taichi/util/hash.h"
5
6#include <mutex>
7
8namespace taichi::lang {
9
10class TypeFactory {
11 public:
12 static TypeFactory &get_instance();
13
14 // TODO(type): maybe it makes sense to let each get_X function return X*
15 // instead of generic Type*
16
17 Type *get_primitive_type(PrimitiveTypeID id);
18
19 PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true);
20
21 PrimitiveType *get_primitive_real_type(int bits);
22
23 Type *get_tensor_type(std::vector<int> shape, Type *element);
24
25 const Type *get_struct_type(const std::vector<StructMember> &elements,
26 const std::string &layout = "none");
27
28 Type *get_pointer_type(Type *element, bool is_bit_pointer = false);
29
30 Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type);
31
32 Type *get_quant_fixed_type(Type *digits_type,
33 Type *compute_type,
34 float64 scale);
35
36 Type *get_quant_float_type(Type *digits_type,
37 Type *exponent_type,
38 Type *compute_type);
39
40 BitStructType *get_bit_struct_type(
41 PrimitiveType *physical_type,
42 const std::vector<Type *> &member_types,
43 const std::vector<int> &member_bit_offsets,
44 const std::vector<int> &member_exponents,
45 const std::vector<std::vector<int>> &member_exponent_users);
46
47 Type *get_quant_array_type(PrimitiveType *physical_type,
48 Type *element_type,
49 int num_elements);
50
51 static DataType create_tensor_type(std::vector<int> shape, DataType element);
52
53 private:
54 TypeFactory();
55
56 std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_;
57 std::mutex primitive_mut_;
58
59 std::unordered_map<std::pair<std::string, Type *>,
60 std::unique_ptr<Type>,
61 hashing::Hasher<std::pair<std::string, Type *>>>
62 tensor_types_;
63 std::mutex tensor_mut_;
64
65 std::unordered_map<
66 std::pair<std::vector<StructMember>, std::string>,
67 std::unique_ptr<Type>,
68 hashing::Hasher<std::pair<std::vector<StructMember>, std::string>>>
69 struct_types_;
70 std::mutex struct_mut_;
71
72 // TODO: is_bit_ptr?
73 std::unordered_map<std::pair<Type *, bool>,
74 std::unique_ptr<Type>,
75 hashing::Hasher<std::pair<Type *, bool>>>
76 pointer_types_;
77 std::mutex pointer_mut_;
78
79 std::unordered_map<std::tuple<int, bool, Type *>,
80 std::unique_ptr<Type>,
81 hashing::Hasher<std::tuple<int, bool, Type *>>>
82 quant_int_types_;
83 std::mutex quant_int_mut_;
84
85 std::unordered_map<std::tuple<Type *, Type *, float64>,
86 std::unique_ptr<Type>,
87 hashing::Hasher<std::tuple<Type *, Type *, float64>>>
88 quant_fixed_types_;
89 std::mutex quant_fixed_mut_;
90
91 std::unordered_map<std::tuple<Type *, Type *, Type *>,
92 std::unique_ptr<Type>,
93 hashing::Hasher<std::tuple<Type *, Type *, Type *>>>
94 quant_float_types_;
95 std::mutex quant_float_mut_;
96
97 // TODO: avoid duplication
98 std::vector<std::unique_ptr<BitStructType>> bit_struct_types_;
99 std::mutex bit_struct_mut_;
100
101 // TODO: avoid duplication
102 std::vector<std::unique_ptr<Type>> quant_array_types_;
103 std::mutex quant_array_mut_;
104};
105
106DataType promoted_type(DataType a, DataType b);
107
108} // namespace taichi::lang
109