1 | #include "taichi/ir/type_utils.h" |
2 | |
3 | namespace taichi::lang { |
4 | |
5 | std::string data_type_name(DataType t) { |
6 | if (!t->is<PrimitiveType>()) { |
7 | return t->to_string(); |
8 | } |
9 | |
10 | // Handle primitive types below. |
11 | |
12 | if (false) { |
13 | } |
14 | #define PER_TYPE(i) else if (t->is_primitive(PrimitiveTypeID::i)) return #i; |
15 | #include "taichi/inc/data_type.inc.h" |
16 | #undef PER_TYPE |
17 | else |
18 | TI_NOT_IMPLEMENTED |
19 | } |
20 | |
21 | std::vector<int> data_type_shape(DataType t) { |
22 | if (t->is<TensorType>()) { |
23 | auto tensor_type = t->cast<TensorType>(); |
24 | return tensor_type->get_shape(); |
25 | } |
26 | |
27 | return {}; |
28 | } |
29 | |
30 | int data_type_size(DataType t) { |
31 | // TODO: |
32 | // 1. Ensure in the old code, pointer attributes of t are correct (by |
33 | // setting a loud failure on pointers); |
34 | // 2. Support pointer types here. |
35 | t.set_is_pointer(false); |
36 | if (false) { |
37 | } else if (t->is_primitive(PrimitiveTypeID::f16)) |
38 | return 2; |
39 | else if (t->is_primitive(PrimitiveTypeID::gen)) |
40 | return 0; |
41 | else if (t->is_primitive(PrimitiveTypeID::unknown)) |
42 | return -1; |
43 | |
44 | if (t->is<TensorType>()) { |
45 | auto tensor_type = t->cast<TensorType>(); |
46 | TI_ASSERT(tensor_type->get_element_type()); |
47 | return tensor_type->get_num_elements() * |
48 | data_type_size(tensor_type->get_element_type()); |
49 | } |
50 | |
51 | #define REGISTER_DATA_TYPE(i, j) \ |
52 | else if (t->is_primitive(PrimitiveTypeID::i)) return sizeof(j) |
53 | |
54 | REGISTER_DATA_TYPE(f32, float32); |
55 | REGISTER_DATA_TYPE(f64, float64); |
56 | REGISTER_DATA_TYPE(i8, int8); |
57 | REGISTER_DATA_TYPE(i16, int16); |
58 | REGISTER_DATA_TYPE(i32, int32); |
59 | REGISTER_DATA_TYPE(i64, int64); |
60 | REGISTER_DATA_TYPE(u8, uint8); |
61 | REGISTER_DATA_TYPE(u16, uint16); |
62 | REGISTER_DATA_TYPE(u32, uint32); |
63 | REGISTER_DATA_TYPE(u64, uint64); |
64 | |
65 | #undef REGISTER_DATA_TYPE |
66 | else { |
67 | TI_NOT_IMPLEMENTED |
68 | } |
69 | } |
70 | |
71 | std::string tensor_type_format_helper(const std::vector<int> &shape, |
72 | std::string format_str, |
73 | int dim) { |
74 | std::string fmt = "[" ; |
75 | for (int i = 0; i < shape[dim]; ++i) { |
76 | if (dim != shape.size() - 1) { |
77 | fmt += tensor_type_format_helper(shape, format_str, dim + 1); |
78 | } else { |
79 | fmt += format_str; |
80 | } |
81 | if (i != shape[dim] - 1) { |
82 | fmt += ", " ; |
83 | if (dim == 0 && dim != shape.size() - 1) { |
84 | fmt += "\n" ; |
85 | } |
86 | } |
87 | } |
88 | fmt += "]" ; |
89 | return fmt; |
90 | } |
91 | |
92 | std::string tensor_type_format(DataType t, Arch arch) { |
93 | TI_ASSERT(t->is<TensorType>()); |
94 | auto tensor_type = t->as<TensorType>(); |
95 | auto shape = tensor_type->get_shape(); |
96 | auto element_type = tensor_type->get_element_type(); |
97 | auto element_type_format = data_type_format(element_type, arch); |
98 | return tensor_type_format_helper(shape, element_type_format, 0); |
99 | } |
100 | |
101 | std::string data_type_format(DataType dt, Arch arch) { |
102 | if (dt->is_primitive(PrimitiveTypeID::i8)) { |
103 | // i8/u8 is converted to i16/u16 before printing, because CUDA doesn't |
104 | // support the "%hhd"/"%hhu" specifiers. |
105 | return "%hd" ; |
106 | } else if (dt->is_primitive(PrimitiveTypeID::u8)) { |
107 | return "%hu" ; |
108 | } else if (dt->is_primitive(PrimitiveTypeID::i16)) { |
109 | return "%hd" ; |
110 | } else if (dt->is_primitive(PrimitiveTypeID::u16)) { |
111 | return "%hu" ; |
112 | } else if (dt->is_primitive(PrimitiveTypeID::i32)) { |
113 | return "%d" ; |
114 | } else if (dt->is_primitive(PrimitiveTypeID::u32)) { |
115 | return "%u" ; |
116 | } else if (dt->is_primitive(PrimitiveTypeID::i64)) { |
117 | // Use %lld on Windows. |
118 | // Discussion: https://github.com/taichi-dev/taichi/issues/2522 |
119 | // Vulkan does not support printing 64-bit signed integer |
120 | return "%lld" ; |
121 | } else if (dt->is_primitive(PrimitiveTypeID::u64)) { |
122 | // Vulkan requires %lu to print 64-bit unsigned integer |
123 | return arch == Arch::vulkan ? "%lu" : "%llu" ; |
124 | } else if (dt->is_primitive(PrimitiveTypeID::f32)) { |
125 | return "%f" ; |
126 | } else if (dt->is_primitive(PrimitiveTypeID::f64)) { |
127 | return "%.12f" ; |
128 | } else if (dt->is<QuantIntType>()) { |
129 | return "%d" ; |
130 | } else if (dt->is_primitive(PrimitiveTypeID::f16)) { |
131 | // f16 (and f32) is converted to f64 before printing, see |
132 | // TaskCodeGenLLVM::visit(PrintStmt *stmt) and |
133 | // TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details. |
134 | return "%f" ; |
135 | } else if (dt->is<TensorType>()) { |
136 | return tensor_type_format(dt, arch); |
137 | } else { |
138 | TI_NOT_IMPLEMENTED |
139 | } |
140 | } |
141 | |
142 | } // namespace taichi::lang |
143 | |