1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_params.cc |
22 | */ |
23 | |
24 | #include "codegen_params.h" |
25 | |
26 | #include <dlpack/dlpack.h> |
27 | |
28 | #include <cmath> |
29 | #include <iomanip> |
30 | #include <memory> |
31 | #include <string> |
32 | |
33 | namespace tvm { |
34 | namespace codegen { |
35 | |
36 | /*! \brief maximum line length of generated parameters, including indent. */ |
37 | static constexpr const int kMaxLineLength = 80; |
38 | |
39 | static int ComputeNumElementsPerRow(int one_element_size_bytes, int indent_chars) { |
40 | if (one_element_size_bytes > kMaxLineLength - indent_chars) { |
41 | return 1; |
42 | } |
43 | // When multiple elements fit per line, divide the available space by the size of one element, |
44 | // and return the largest power of 2 less than the result. Using power-of-2-sized elements allows |
45 | // for easily traversing the generated code. |
46 | int elements_per_row = (kMaxLineLength - indent_chars) / one_element_size_bytes; |
47 | |
48 | // Implementation of fls. Iteratively clear the LSB until one bit remains. |
49 | while ((elements_per_row & (elements_per_row - 1)) > 0) { |
50 | elements_per_row &= elements_per_row - 1; |
51 | } |
52 | return elements_per_row; |
53 | } |
54 | |
55 | template <typename T, typename Enable = std::enable_if<std::is_integral<T>::value>> |
56 | void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os, |
57 | const std::string& eol) { |
58 | int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */); |
59 | if (std::is_signed<T>::value) { |
60 | one_element_size_bytes += 1; // sign character |
61 | if (sizeof(T) == 64 / 8) { |
62 | one_element_size_bytes += 2; // "LL" |
63 | } |
64 | } else { |
65 | if (sizeof(T) == 64 / 8) { |
66 | one_element_size_bytes += 3; // "ULL" |
67 | } |
68 | } |
69 | |
70 | size_t elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); |
71 | std::string indent_str(indent_chars, ' '); |
72 | |
73 | for (size_t i = 0; i < num_elements; i++) { |
74 | if ((i % elements_per_row) == 0) { |
75 | os << indent_str; |
76 | } |
77 | int64_t elem = static_cast<T*>(data)[i]; |
78 | if (std::is_signed<T>::value) { |
79 | uint64_t to_print; |
80 | if (elem < 0) { |
81 | os << "-" ; |
82 | to_print = -elem; |
83 | } else { |
84 | os << "+" ; |
85 | to_print = elem; |
86 | } |
87 | os << "0x" << std::setw(sizeof(T) * 8 / 4) << static_cast<std::uint64_t>(to_print); |
88 | if (sizeof(T) == 64 / 8) { |
89 | os << "LL" ; |
90 | } |
91 | } else { |
92 | os << "0x" << std::setw(sizeof(T) * 8 / 4) << static_cast<std::uint64_t>(elem); |
93 | if (sizeof(T) == 64 / 8) { |
94 | os << "ULL" ; |
95 | } |
96 | } |
97 | if (i < num_elements - 1) { |
98 | os << ", " ; |
99 | } |
100 | if ((i % elements_per_row) == elements_per_row - 1) { |
101 | os << eol; |
102 | } |
103 | } |
104 | |
105 | if ((num_elements % elements_per_row) != 0) { |
106 | os << eol; |
107 | } |
108 | } |
109 | |
110 | template <typename T, typename Enable = std::enable_if<std::is_floating_point<T>::value>> |
111 | void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os, |
112 | const std::string& eol) { |
113 | // Floats and doubles are printed as hex but casted. |
114 | int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */) + 1 /* sign */ + |
115 | 1 /* decimal point */ + 1 /* exponent sign */; |
116 | if (sizeof(T) == 64 / 8) { |
117 | one_element_size_bytes += 2; /* 4 decimal digits in exponent, relative to bits / 4 */ |
118 | } else if (sizeof(T) == 32 / 8) { |
119 | one_element_size_bytes += 1; /* extra decimal digit in exponent, relative to bits / 4 */ |
120 | } |
121 | |
122 | size_t elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars); |
123 | std::string indent_str(indent_chars, ' '); |
124 | |
125 | std::stringstream ss; |
126 | if (std::is_signed<T>::value) { |
127 | ss.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific, |
128 | std::ios::basefield | std::ios::showbase | std::ios::floatfield); |
129 | } else { |
130 | ss.setf(std::ios::hex | std::ios::fixed | std::ios::scientific, |
131 | std::ios::basefield | std::ios::showbase | std::ios::floatfield); |
132 | } |
133 | for (size_t i = 0; i < num_elements; i++) { |
134 | if ((i % elements_per_row) == 0) { |
135 | os << indent_str; |
136 | } |
137 | |
138 | T elem = static_cast<T*>(data)[i]; |
139 | if (std::isinf(elem)) { |
140 | // C99 standard. |
141 | os << (elem < 0 ? "-" : " " ) << std::setw(one_element_size_bytes - 1) << "INFINITY" ; |
142 | } else if (std::isnan(elem)) { |
143 | // GNU extension, implemenatation-dependent. |
144 | os << std::setw(one_element_size_bytes) << "NAN" ; |
145 | } else { |
146 | ss << elem; |
147 | os << std::setw(one_element_size_bytes) << ss.str(); |
148 | ss.str("" ); |
149 | } |
150 | if (i < num_elements - 1) { |
151 | os << ", " ; |
152 | } |
153 | if ((i % elements_per_row) == elements_per_row - 1) { |
154 | os << eol; |
155 | } |
156 | } |
157 | |
158 | if ((num_elements % elements_per_row) != 0) { |
159 | os << eol; |
160 | } |
161 | } |
162 | |
163 | void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os, |
164 | const std::string& eol) { |
165 | auto arr_type = arr.DataType(); |
166 | CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw " |
167 | << arr_type.lanes(); |
168 | |
169 | auto shape = arr.Shape(); |
170 | int num_elements = 1; |
171 | for (auto shape_elem : shape) { |
172 | num_elements *= shape_elem; |
173 | } |
174 | |
175 | auto old_fmtflags = os.flags(); |
176 | os.setf(std::ios::internal | std::ios::hex, |
177 | std::ios::adjustfield | std::ios::basefield | std::ios::showbase); |
178 | os.fill('0'); |
179 | switch (arr_type.code()) { |
180 | case runtime::DataType::kInt: |
181 | CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || |
182 | arr_type.bits() == 64) |
183 | << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " |
184 | << arr_type.bits() << "-bit array" ; |
185 | if (arr_type.bits() == 8) { |
186 | PrintIntegralArray<int8_t>(arr->data, num_elements, indent_chars, os, eol); |
187 | } else if (arr_type.bits() == 16) { |
188 | PrintIntegralArray<int16_t>(arr->data, num_elements, indent_chars, os, eol); |
189 | } else if (arr_type.bits() == 32) { |
190 | PrintIntegralArray<int32_t>(arr->data, num_elements, indent_chars, os, eol); |
191 | } else if (arr_type.bits() == 64) { |
192 | PrintIntegralArray<int64_t>(arr->data, num_elements, indent_chars, os, eol); |
193 | } else { |
194 | CHECK(false) << "should not get here" ; |
195 | } |
196 | break; |
197 | |
198 | case runtime::DataType::TypeCode::kUInt: |
199 | CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 || |
200 | arr_type.bits() == 64) |
201 | << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw " |
202 | << arr_type.bits() << "-bit array" ; |
203 | |
204 | if (arr_type.bits() == 8) { |
205 | PrintIntegralArray<uint8_t>(arr->data, num_elements, indent_chars, os, eol); |
206 | } else if (arr_type.bits() == 16) { |
207 | PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol); |
208 | } else if (arr_type.bits() == 32) { |
209 | PrintIntegralArray<uint32_t>(arr->data, num_elements, indent_chars, os, eol); |
210 | } else if (arr_type.bits() == 64) { |
211 | PrintIntegralArray<uint64_t>(arr->data, num_elements, indent_chars, os, eol); |
212 | } else { |
213 | CHECK(false) << "should not get here" ; |
214 | } |
215 | break; |
216 | |
217 | case runtime::DataType::TypeCode::kFloat: { |
218 | os.fill(' '); |
219 | os.setf(std::ios::left, std::ios::adjustfield); |
220 | if (arr_type.bits() == 16) { |
221 | // NOTE: print types not widely supported by C as uint16_t. |
222 | PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol); |
223 | } else if (arr_type.bits() == 32) { |
224 | PrintFloatingPointArray<float>(arr->data, num_elements, indent_chars, os, eol); |
225 | } else if (arr_type.bits() == 64) { |
226 | PrintFloatingPointArray<double>(arr->data, num_elements, indent_chars, os, eol); |
227 | } else { |
228 | CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw " |
229 | << arr_type.bits() << "-bit array" ; |
230 | } |
231 | break; |
232 | } |
233 | |
234 | case runtime::DataType::TypeCode::kBFloat: { |
235 | // NOTE: print types not widely supported by C as uint16_t. |
236 | CHECK(arr_type.bits() == 16) |
237 | << "CodegenParams: only support generating 16-bit bfloat params; saw " << arr_type.bits() |
238 | << "-bit array" ; |
239 | PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol); |
240 | break; |
241 | } |
242 | |
243 | default: |
244 | CHECK(false) << "Data type '" << arr_type << "' not supported" ; |
245 | } |
246 | |
247 | os.flags(old_fmtflags); |
248 | } |
249 | |
250 | } // namespace codegen |
251 | } // namespace tvm |
252 | |