1 | #pragma once |
2 | |
3 | #include "taichi/ir/type.h" |
4 | #include "taichi/util/hash.h" |
5 | |
6 | #include <mutex> |
7 | |
8 | namespace taichi::lang { |
9 | |
10 | class TypeFactory { |
11 | public: |
12 | static TypeFactory &get_instance(); |
13 | |
14 | // TODO(type): maybe it makes sense to let each get_X function return X* |
15 | // instead of generic Type* |
16 | |
17 | Type *get_primitive_type(PrimitiveTypeID id); |
18 | |
19 | PrimitiveType *get_primitive_int_type(int bits, bool is_signed = true); |
20 | |
21 | PrimitiveType *get_primitive_real_type(int bits); |
22 | |
23 | Type *get_tensor_type(std::vector<int> shape, Type *element); |
24 | |
25 | const Type *get_struct_type(const std::vector<StructMember> &elements, |
26 | const std::string &layout = "none" ); |
27 | |
28 | Type *get_pointer_type(Type *element, bool is_bit_pointer = false); |
29 | |
30 | Type *get_quant_int_type(int num_bits, bool is_signed, Type *compute_type); |
31 | |
32 | Type *get_quant_fixed_type(Type *digits_type, |
33 | Type *compute_type, |
34 | float64 scale); |
35 | |
36 | Type *get_quant_float_type(Type *digits_type, |
37 | Type *exponent_type, |
38 | Type *compute_type); |
39 | |
40 | BitStructType *get_bit_struct_type( |
41 | PrimitiveType *physical_type, |
42 | const std::vector<Type *> &member_types, |
43 | const std::vector<int> &member_bit_offsets, |
44 | const std::vector<int> &member_exponents, |
45 | const std::vector<std::vector<int>> &member_exponent_users); |
46 | |
47 | Type *get_quant_array_type(PrimitiveType *physical_type, |
48 | Type *element_type, |
49 | int num_elements); |
50 | |
51 | static DataType create_tensor_type(std::vector<int> shape, DataType element); |
52 | |
53 | private: |
54 | TypeFactory(); |
55 | |
56 | std::unordered_map<PrimitiveTypeID, std::unique_ptr<Type>> primitive_types_; |
57 | std::mutex primitive_mut_; |
58 | |
59 | std::unordered_map<std::pair<std::string, Type *>, |
60 | std::unique_ptr<Type>, |
61 | hashing::Hasher<std::pair<std::string, Type *>>> |
62 | tensor_types_; |
63 | std::mutex tensor_mut_; |
64 | |
65 | std::unordered_map< |
66 | std::pair<std::vector<StructMember>, std::string>, |
67 | std::unique_ptr<Type>, |
68 | hashing::Hasher<std::pair<std::vector<StructMember>, std::string>>> |
69 | struct_types_; |
70 | std::mutex struct_mut_; |
71 | |
72 | // TODO: is_bit_ptr? |
73 | std::unordered_map<std::pair<Type *, bool>, |
74 | std::unique_ptr<Type>, |
75 | hashing::Hasher<std::pair<Type *, bool>>> |
76 | pointer_types_; |
77 | std::mutex pointer_mut_; |
78 | |
79 | std::unordered_map<std::tuple<int, bool, Type *>, |
80 | std::unique_ptr<Type>, |
81 | hashing::Hasher<std::tuple<int, bool, Type *>>> |
82 | quant_int_types_; |
83 | std::mutex quant_int_mut_; |
84 | |
85 | std::unordered_map<std::tuple<Type *, Type *, float64>, |
86 | std::unique_ptr<Type>, |
87 | hashing::Hasher<std::tuple<Type *, Type *, float64>>> |
88 | quant_fixed_types_; |
89 | std::mutex quant_fixed_mut_; |
90 | |
91 | std::unordered_map<std::tuple<Type *, Type *, Type *>, |
92 | std::unique_ptr<Type>, |
93 | hashing::Hasher<std::tuple<Type *, Type *, Type *>>> |
94 | quant_float_types_; |
95 | std::mutex quant_float_mut_; |
96 | |
97 | // TODO: avoid duplication |
98 | std::vector<std::unique_ptr<BitStructType>> bit_struct_types_; |
99 | std::mutex bit_struct_mut_; |
100 | |
101 | // TODO: avoid duplication |
102 | std::vector<std::unique_ptr<Type>> quant_array_types_; |
103 | std::mutex quant_array_mut_; |
104 | }; |
105 | |
106 | DataType promoted_type(DataType a, DataType b); |
107 | |
108 | } // namespace taichi::lang |
109 | |