1#include "taichi/ir/type_utils.h"
2
3namespace taichi::lang {
4
5std::string data_type_name(DataType t) {
6 if (!t->is<PrimitiveType>()) {
7 return t->to_string();
8 }
9
10 // Handle primitive types below.
11
12 if (false) {
13 }
14#define PER_TYPE(i) else if (t->is_primitive(PrimitiveTypeID::i)) return #i;
15#include "taichi/inc/data_type.inc.h"
16#undef PER_TYPE
17 else
18 TI_NOT_IMPLEMENTED
19}
20
21std::vector<int> data_type_shape(DataType t) {
22 if (t->is<TensorType>()) {
23 auto tensor_type = t->cast<TensorType>();
24 return tensor_type->get_shape();
25 }
26
27 return {};
28}
29
30int data_type_size(DataType t) {
31 // TODO:
32 // 1. Ensure in the old code, pointer attributes of t are correct (by
33 // setting a loud failure on pointers);
34 // 2. Support pointer types here.
35 t.set_is_pointer(false);
36 if (false) {
37 } else if (t->is_primitive(PrimitiveTypeID::f16))
38 return 2;
39 else if (t->is_primitive(PrimitiveTypeID::gen))
40 return 0;
41 else if (t->is_primitive(PrimitiveTypeID::unknown))
42 return -1;
43
44 if (t->is<TensorType>()) {
45 auto tensor_type = t->cast<TensorType>();
46 TI_ASSERT(tensor_type->get_element_type());
47 return tensor_type->get_num_elements() *
48 data_type_size(tensor_type->get_element_type());
49 }
50
51#define REGISTER_DATA_TYPE(i, j) \
52 else if (t->is_primitive(PrimitiveTypeID::i)) return sizeof(j)
53
54 REGISTER_DATA_TYPE(f32, float32);
55 REGISTER_DATA_TYPE(f64, float64);
56 REGISTER_DATA_TYPE(i8, int8);
57 REGISTER_DATA_TYPE(i16, int16);
58 REGISTER_DATA_TYPE(i32, int32);
59 REGISTER_DATA_TYPE(i64, int64);
60 REGISTER_DATA_TYPE(u8, uint8);
61 REGISTER_DATA_TYPE(u16, uint16);
62 REGISTER_DATA_TYPE(u32, uint32);
63 REGISTER_DATA_TYPE(u64, uint64);
64
65#undef REGISTER_DATA_TYPE
66 else {
67 TI_NOT_IMPLEMENTED
68 }
69}
70
71std::string tensor_type_format_helper(const std::vector<int> &shape,
72 std::string format_str,
73 int dim) {
74 std::string fmt = "[";
75 for (int i = 0; i < shape[dim]; ++i) {
76 if (dim != shape.size() - 1) {
77 fmt += tensor_type_format_helper(shape, format_str, dim + 1);
78 } else {
79 fmt += format_str;
80 }
81 if (i != shape[dim] - 1) {
82 fmt += ", ";
83 if (dim == 0 && dim != shape.size() - 1) {
84 fmt += "\n";
85 }
86 }
87 }
88 fmt += "]";
89 return fmt;
90}
91
92std::string tensor_type_format(DataType t, Arch arch) {
93 TI_ASSERT(t->is<TensorType>());
94 auto tensor_type = t->as<TensorType>();
95 auto shape = tensor_type->get_shape();
96 auto element_type = tensor_type->get_element_type();
97 auto element_type_format = data_type_format(element_type, arch);
98 return tensor_type_format_helper(shape, element_type_format, 0);
99}
100
101std::string data_type_format(DataType dt, Arch arch) {
102 if (dt->is_primitive(PrimitiveTypeID::i8)) {
103 // i8/u8 is converted to i16/u16 before printing, because CUDA doesn't
104 // support the "%hhd"/"%hhu" specifiers.
105 return "%hd";
106 } else if (dt->is_primitive(PrimitiveTypeID::u8)) {
107 return "%hu";
108 } else if (dt->is_primitive(PrimitiveTypeID::i16)) {
109 return "%hd";
110 } else if (dt->is_primitive(PrimitiveTypeID::u16)) {
111 return "%hu";
112 } else if (dt->is_primitive(PrimitiveTypeID::i32)) {
113 return "%d";
114 } else if (dt->is_primitive(PrimitiveTypeID::u32)) {
115 return "%u";
116 } else if (dt->is_primitive(PrimitiveTypeID::i64)) {
117 // Use %lld on Windows.
118 // Discussion: https://github.com/taichi-dev/taichi/issues/2522
119 // Vulkan does not support printing 64-bit signed integer
120 return "%lld";
121 } else if (dt->is_primitive(PrimitiveTypeID::u64)) {
122 // Vulkan requires %lu to print 64-bit unsigned integer
123 return arch == Arch::vulkan ? "%lu" : "%llu";
124 } else if (dt->is_primitive(PrimitiveTypeID::f32)) {
125 return "%f";
126 } else if (dt->is_primitive(PrimitiveTypeID::f64)) {
127 return "%.12f";
128 } else if (dt->is<QuantIntType>()) {
129 return "%d";
130 } else if (dt->is_primitive(PrimitiveTypeID::f16)) {
131 // f16 (and f32) is converted to f64 before printing, see
132 // TaskCodeGenLLVM::visit(PrintStmt *stmt) and
133 // TaskCodeGenCUDA::visit(PrintStmt *stmt) for more details.
134 return "%f";
135 } else if (dt->is<TensorType>()) {
136 return tensor_type_format(dt, arch);
137 } else {
138 TI_NOT_IMPLEMENTED
139 }
140}
141
142} // namespace taichi::lang
143