1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_params.cc
22 */
23
24#include "codegen_params.h"
25
26#include <dlpack/dlpack.h>
27
28#include <cmath>
29#include <iomanip>
30#include <memory>
31#include <string>
32
33namespace tvm {
34namespace codegen {
35
36/*! \brief maximum line length of generated parameters, including indent. */
37static constexpr const int kMaxLineLength = 80;
38
39static int ComputeNumElementsPerRow(int one_element_size_bytes, int indent_chars) {
40 if (one_element_size_bytes > kMaxLineLength - indent_chars) {
41 return 1;
42 }
43 // When multiple elements fit per line, divide the available space by the size of one element,
44 // and return the largest power of 2 less than the result. Using power-of-2-sized elements allows
45 // for easily traversing the generated code.
46 int elements_per_row = (kMaxLineLength - indent_chars) / one_element_size_bytes;
47
48 // Implementation of fls. Iteratively clear the LSB until one bit remains.
49 while ((elements_per_row & (elements_per_row - 1)) > 0) {
50 elements_per_row &= elements_per_row - 1;
51 }
52 return elements_per_row;
53}
54
55template <typename T, typename Enable = std::enable_if<std::is_integral<T>::value>>
56void PrintIntegralArray(void* data, size_t num_elements, int indent_chars, std::ostream& os,
57 const std::string& eol) {
58 int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */);
59 if (std::is_signed<T>::value) {
60 one_element_size_bytes += 1; // sign character
61 if (sizeof(T) == 64 / 8) {
62 one_element_size_bytes += 2; // "LL"
63 }
64 } else {
65 if (sizeof(T) == 64 / 8) {
66 one_element_size_bytes += 3; // "ULL"
67 }
68 }
69
70 size_t elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars);
71 std::string indent_str(indent_chars, ' ');
72
73 for (size_t i = 0; i < num_elements; i++) {
74 if ((i % elements_per_row) == 0) {
75 os << indent_str;
76 }
77 int64_t elem = static_cast<T*>(data)[i];
78 if (std::is_signed<T>::value) {
79 uint64_t to_print;
80 if (elem < 0) {
81 os << "-";
82 to_print = -elem;
83 } else {
84 os << "+";
85 to_print = elem;
86 }
87 os << "0x" << std::setw(sizeof(T) * 8 / 4) << static_cast<std::uint64_t>(to_print);
88 if (sizeof(T) == 64 / 8) {
89 os << "LL";
90 }
91 } else {
92 os << "0x" << std::setw(sizeof(T) * 8 / 4) << static_cast<std::uint64_t>(elem);
93 if (sizeof(T) == 64 / 8) {
94 os << "ULL";
95 }
96 }
97 if (i < num_elements - 1) {
98 os << ", ";
99 }
100 if ((i % elements_per_row) == elements_per_row - 1) {
101 os << eol;
102 }
103 }
104
105 if ((num_elements % elements_per_row) != 0) {
106 os << eol;
107 }
108}
109
110template <typename T, typename Enable = std::enable_if<std::is_floating_point<T>::value>>
111void PrintFloatingPointArray(void* data, size_t num_elements, int indent_chars, std::ostream& os,
112 const std::string& eol) {
113 // Floats and doubles are printed as hex but casted.
114 int one_element_size_bytes = (sizeof(T) / 4) + (2 /* "0x" */) + (2 /* ", " */) + 1 /* sign */ +
115 1 /* decimal point */ + 1 /* exponent sign */;
116 if (sizeof(T) == 64 / 8) {
117 one_element_size_bytes += 2; /* 4 decimal digits in exponent, relative to bits / 4 */
118 } else if (sizeof(T) == 32 / 8) {
119 one_element_size_bytes += 1; /* extra decimal digit in exponent, relative to bits / 4 */
120 }
121
122 size_t elements_per_row = ComputeNumElementsPerRow(one_element_size_bytes, indent_chars);
123 std::string indent_str(indent_chars, ' ');
124
125 std::stringstream ss;
126 if (std::is_signed<T>::value) {
127 ss.setf(std::ios::hex | std::ios::showbase | std::ios::fixed | std::ios::scientific,
128 std::ios::basefield | std::ios::showbase | std::ios::floatfield);
129 } else {
130 ss.setf(std::ios::hex | std::ios::fixed | std::ios::scientific,
131 std::ios::basefield | std::ios::showbase | std::ios::floatfield);
132 }
133 for (size_t i = 0; i < num_elements; i++) {
134 if ((i % elements_per_row) == 0) {
135 os << indent_str;
136 }
137
138 T elem = static_cast<T*>(data)[i];
139 if (std::isinf(elem)) {
140 // C99 standard.
141 os << (elem < 0 ? "-" : " ") << std::setw(one_element_size_bytes - 1) << "INFINITY";
142 } else if (std::isnan(elem)) {
143 // GNU extension, implemenatation-dependent.
144 os << std::setw(one_element_size_bytes) << "NAN";
145 } else {
146 ss << elem;
147 os << std::setw(one_element_size_bytes) << ss.str();
148 ss.str("");
149 }
150 if (i < num_elements - 1) {
151 os << ", ";
152 }
153 if ((i % elements_per_row) == elements_per_row - 1) {
154 os << eol;
155 }
156 }
157
158 if ((num_elements % elements_per_row) != 0) {
159 os << eol;
160 }
161}
162
163void NDArrayDataToC(::tvm::runtime::NDArray arr, int indent_chars, std::ostream& os,
164 const std::string& eol) {
165 auto arr_type = arr.DataType();
166 CHECK_EQ(arr_type.lanes(), 1) << "CodegenParams: only support generating 1-lane parameters; saw "
167 << arr_type.lanes();
168
169 auto shape = arr.Shape();
170 int num_elements = 1;
171 for (auto shape_elem : shape) {
172 num_elements *= shape_elem;
173 }
174
175 auto old_fmtflags = os.flags();
176 os.setf(std::ios::internal | std::ios::hex,
177 std::ios::adjustfield | std::ios::basefield | std::ios::showbase);
178 os.fill('0');
179 switch (arr_type.code()) {
180 case runtime::DataType::kInt:
181 CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 ||
182 arr_type.bits() == 64)
183 << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
184 << arr_type.bits() << "-bit array";
185 if (arr_type.bits() == 8) {
186 PrintIntegralArray<int8_t>(arr->data, num_elements, indent_chars, os, eol);
187 } else if (arr_type.bits() == 16) {
188 PrintIntegralArray<int16_t>(arr->data, num_elements, indent_chars, os, eol);
189 } else if (arr_type.bits() == 32) {
190 PrintIntegralArray<int32_t>(arr->data, num_elements, indent_chars, os, eol);
191 } else if (arr_type.bits() == 64) {
192 PrintIntegralArray<int64_t>(arr->data, num_elements, indent_chars, os, eol);
193 } else {
194 CHECK(false) << "should not get here";
195 }
196 break;
197
198 case runtime::DataType::TypeCode::kUInt:
199 CHECK(arr_type.bits() == 8 || arr_type.bits() == 16 || arr_type.bits() == 32 ||
200 arr_type.bits() == 64)
201 << "CodegenParams: only support generating 8-, 16-, 32-, or 64-bit integer params; saw "
202 << arr_type.bits() << "-bit array";
203
204 if (arr_type.bits() == 8) {
205 PrintIntegralArray<uint8_t>(arr->data, num_elements, indent_chars, os, eol);
206 } else if (arr_type.bits() == 16) {
207 PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
208 } else if (arr_type.bits() == 32) {
209 PrintIntegralArray<uint32_t>(arr->data, num_elements, indent_chars, os, eol);
210 } else if (arr_type.bits() == 64) {
211 PrintIntegralArray<uint64_t>(arr->data, num_elements, indent_chars, os, eol);
212 } else {
213 CHECK(false) << "should not get here";
214 }
215 break;
216
217 case runtime::DataType::TypeCode::kFloat: {
218 os.fill(' ');
219 os.setf(std::ios::left, std::ios::adjustfield);
220 if (arr_type.bits() == 16) {
221 // NOTE: print types not widely supported by C as uint16_t.
222 PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
223 } else if (arr_type.bits() == 32) {
224 PrintFloatingPointArray<float>(arr->data, num_elements, indent_chars, os, eol);
225 } else if (arr_type.bits() == 64) {
226 PrintFloatingPointArray<double>(arr->data, num_elements, indent_chars, os, eol);
227 } else {
228 CHECK(false) << "CodegenParams: only support 32- or 64-bit floating point; saw "
229 << arr_type.bits() << "-bit array";
230 }
231 break;
232 }
233
234 case runtime::DataType::TypeCode::kBFloat: {
235 // NOTE: print types not widely supported by C as uint16_t.
236 CHECK(arr_type.bits() == 16)
237 << "CodegenParams: only support generating 16-bit bfloat params; saw " << arr_type.bits()
238 << "-bit array";
239 PrintIntegralArray<uint16_t>(arr->data, num_elements, indent_chars, os, eol);
240 break;
241 }
242
243 default:
244 CHECK(false) << "Data type '" << arr_type << "' not supported";
245 }
246
247 os.flags(old_fmtflags);
248}
249
250} // namespace codegen
251} // namespace tvm
252