1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20#include "./base_doc_printer.h"
21
22namespace tvm {
23namespace script {
24namespace printer {
25
26namespace {
27
28std::vector<ByteSpan> MergeAndExemptSpans(const std::vector<ByteSpan>& spans,
29 const std::vector<ByteSpan>& spans_exempted) {
30 // use prefix sum to merge and exempt spans
31 std::vector<ByteSpan> res;
32 std::vector<std::pair<size_t, int>> prefix_stamp;
33 for (ByteSpan span : spans) {
34 prefix_stamp.push_back({span.first, 1});
35 prefix_stamp.push_back({span.second, -1});
36 }
37 // at most spans.size() spans accumulated in prefix sum
38 // use spans.size() + 1 as stamp unit to exempt all positive spans
39 // with only one negative span
40 int max_n = spans.size() + 1;
41 for (ByteSpan span : spans_exempted) {
42 prefix_stamp.push_back({span.first, -max_n});
43 prefix_stamp.push_back({span.second, max_n});
44 }
45 std::sort(prefix_stamp.begin(), prefix_stamp.end());
46 int prefix_sum = 0;
47 int n = prefix_stamp.size();
48 for (int i = 0; i < n - 1; ++i) {
49 prefix_sum += prefix_stamp[i].second;
50 // positive prefix sum leads to spans without exemption
51 // different stamp positions guarantee the stamps in same position accumulated
52 if (prefix_sum > 0 && prefix_stamp[i].first < prefix_stamp[i + 1].first) {
53 if (res.size() && res.back().second == prefix_stamp[i].first) {
54 // merge to the last spans if it is successive
55 res.back().second = prefix_stamp[i + 1].first;
56 } else {
57 // add a new independent span
58 res.push_back({prefix_stamp[i].first, prefix_stamp[i + 1].first});
59 }
60 }
61 }
62 return res;
63}
64
65size_t GetTextWidth(const std::string& text, const ByteSpan& span) {
66 // FIXME: this only works for ASCII characters.
67 // To do this "correctly", we need to parse UTF-8 into codepoints
68 // and call wcwidth() or equivalent for every codepoint.
69 size_t ret = 0;
70 for (size_t i = span.first; i != span.second; ++i) {
71 if (isprint(text[i])) {
72 ret += 1;
73 }
74 }
75 return ret;
76}
77
78size_t MoveBack(size_t pos, size_t distance) { return distance > pos ? 0 : pos - distance; }
79
80size_t MoveForward(size_t pos, size_t distance, size_t max) {
81 return distance > max - pos ? max : pos + distance;
82}
83
84size_t GetLineIndex(size_t byte_pos, const std::vector<size_t>& line_starts) {
85 auto it = std::upper_bound(line_starts.begin(), line_starts.end(), byte_pos);
86 return (it - line_starts.begin()) - 1;
87}
88
89using UnderlineIter = typename std::vector<ByteSpan>::const_iterator;
90
91ByteSpan PopNextUnderline(UnderlineIter* next_underline, UnderlineIter end_underline) {
92 if (*next_underline == end_underline) {
93 return {std::numeric_limits<size_t>::max(), std::numeric_limits<size_t>::max()};
94 } else {
95 return *(*next_underline)++;
96 }
97}
98
99void PrintChunk(const std::pair<size_t, size_t>& lines_range,
100 const std::pair<UnderlineIter, UnderlineIter>& underlines, const std::string& text,
101 const std::vector<size_t>& line_starts, const PrinterConfig& options,
102 size_t line_number_width, std::string* out) {
103 UnderlineIter next_underline = underlines.first;
104 ByteSpan current_underline = PopNextUnderline(&next_underline, underlines.second);
105
106 for (size_t line_idx = lines_range.first; line_idx < lines_range.second; ++line_idx) {
107 if (options->print_line_numbers) {
108 std::string line_num_str = std::to_string(line_idx + 1);
109 line_num_str.push_back(' ');
110 for (size_t i = line_num_str.size(); i < line_number_width; ++i) {
111 out->push_back(' ');
112 }
113 *out += line_num_str;
114 }
115
116 size_t line_start = line_starts.at(line_idx);
117 size_t line_end =
118 line_idx + 1 == line_starts.size() ? text.size() : line_starts.at(line_idx + 1);
119 out->append(text.begin() + line_start, text.begin() + line_end);
120
121 bool printed_underline = false;
122 size_t line_pos = line_start;
123 bool printed_extra_caret = 0;
124 while (current_underline.first < line_end) {
125 if (!printed_underline) {
126 *out += std::string(line_number_width, ' ');
127 printed_underline = true;
128 }
129
130 size_t underline_end_for_line = std::min(line_end, current_underline.second);
131 size_t num_spaces = GetTextWidth(text, {line_pos, current_underline.first});
132 if (num_spaces > 0 && printed_extra_caret) {
133 num_spaces -= 1;
134 printed_extra_caret = false;
135 }
136 *out += std::string(num_spaces, ' ');
137
138 size_t num_carets = GetTextWidth(text, {current_underline.first, underline_end_for_line});
139 if (num_carets == 0 && !printed_extra_caret) {
140 // Special case: when underlineing an empty or unprintable string, make sure to print
141 // at least one caret still.
142 num_carets = 1;
143 printed_extra_caret = true;
144 } else if (num_carets > 0 && printed_extra_caret) {
145 num_carets -= 1;
146 printed_extra_caret = false;
147 }
148 *out += std::string(num_carets, '^');
149
150 line_pos = current_underline.first = underline_end_for_line;
151 if (current_underline.first == current_underline.second) {
152 current_underline = PopNextUnderline(&next_underline, underlines.second);
153 }
154 }
155
156 if (printed_underline) {
157 out->push_back('\n');
158 }
159 }
160}
161
162void PrintCut(size_t num_lines_skipped, std::string* out) {
163 if (num_lines_skipped != 0) {
164 std::ostringstream s;
165 s << "(... " << num_lines_skipped << " lines skipped ...)\n";
166 *out += s.str();
167 }
168}
169
170std::pair<size_t, size_t> GetLinesForUnderline(const ByteSpan& underline,
171 const std::vector<size_t>& line_starts,
172 size_t num_lines, const PrinterConfig& options) {
173 size_t first_line_of_underline = GetLineIndex(underline.first, line_starts);
174 size_t first_line_of_chunk = MoveBack(first_line_of_underline, options->num_context_lines);
175 size_t end_line_of_underline = GetLineIndex(underline.second - 1, line_starts) + 1;
176 size_t end_line_of_chunk =
177 MoveForward(end_line_of_underline, options->num_context_lines, num_lines);
178
179 return {first_line_of_chunk, end_line_of_chunk};
180}
181
182// If there is only one line between the chunks, it is better to print it as is,
183// rather than something like "(... 1 line skipped ...)".
184constexpr const size_t kMinLinesToCutOut = 2;
185
186bool TryMergeChunks(std::pair<size_t, size_t>* cur_chunk,
187 const std::pair<size_t, size_t>& new_chunk) {
188 if (new_chunk.first < cur_chunk->second + kMinLinesToCutOut) {
189 cur_chunk->second = new_chunk.second;
190 return true;
191 } else {
192 return false;
193 }
194}
195
196size_t GetNumLines(const std::string& text, const std::vector<size_t>& line_starts) {
197 if (line_starts.back() == text.size()) {
198 // Final empty line doesn't count as a line
199 return line_starts.size() - 1;
200 } else {
201 return line_starts.size();
202 }
203}
204
205size_t GetLineNumberWidth(size_t num_lines, const PrinterConfig& options) {
206 if (options->print_line_numbers) {
207 return std::to_string(num_lines).size() + 1;
208 } else {
209 return 0;
210 }
211}
212
213std::string DecorateText(const std::string& text, const std::vector<size_t>& line_starts,
214 const PrinterConfig& options, const std::vector<ByteSpan>& underlines) {
215 size_t num_lines = GetNumLines(text, line_starts);
216 size_t line_number_width = GetLineNumberWidth(num_lines, options);
217
218 std::string ret;
219 if (underlines.empty()) {
220 PrintChunk({0, num_lines}, {underlines.begin(), underlines.begin()}, text, line_starts, options,
221 line_number_width, &ret);
222 return ret;
223 }
224
225 size_t last_end_line = 0;
226 std::pair<size_t, size_t> cur_chunk =
227 GetLinesForUnderline(underlines[0], line_starts, num_lines, options);
228 if (cur_chunk.first < kMinLinesToCutOut) {
229 cur_chunk.first = 0;
230 }
231
232 auto first_underline_in_cur_chunk = underlines.begin();
233 for (auto underline_it = underlines.begin() + 1; underline_it != underlines.end();
234 ++underline_it) {
235 std::pair<size_t, size_t> new_chunk =
236 GetLinesForUnderline(*underline_it, line_starts, num_lines, options);
237
238 if (!TryMergeChunks(&cur_chunk, new_chunk)) {
239 PrintCut(cur_chunk.first - last_end_line, &ret);
240 PrintChunk(cur_chunk, {first_underline_in_cur_chunk, underline_it}, text, line_starts,
241 options, line_number_width, &ret);
242 last_end_line = cur_chunk.second;
243 cur_chunk = new_chunk;
244 first_underline_in_cur_chunk = underline_it;
245 }
246 }
247
248 PrintCut(cur_chunk.first - last_end_line, &ret);
249 if (num_lines - cur_chunk.second < kMinLinesToCutOut) {
250 cur_chunk.second = num_lines;
251 }
252 PrintChunk(cur_chunk, {first_underline_in_cur_chunk, underlines.end()}, text, line_starts,
253 options, line_number_width, &ret);
254 PrintCut(num_lines - cur_chunk.second, &ret);
255 return ret;
256}
257
258} // namespace
259
260DocPrinter::DocPrinter(const PrinterConfig& options) : options_(options) {
261 line_starts_.push_back(0);
262}
263
264void DocPrinter::Append(const Doc& doc) { Append(doc, PrinterConfig()); }
265
266void DocPrinter::Append(const Doc& doc, const PrinterConfig& cfg) {
267 for (const ObjectPath& p : cfg->path_to_underline) {
268 path_to_underline_.push_back(p);
269 current_max_path_length_.push_back(0);
270 current_underline_candidates_.push_back(std::vector<ByteSpan>());
271 }
272 PrintDoc(doc);
273 for (const auto& c : current_underline_candidates_) {
274 underlines_.insert(underlines_.end(), c.begin(), c.end());
275 }
276}
277
278String DocPrinter::GetString() const {
279 std::string text = output_.str();
280
281 // Remove any trailing indentation
282 while (!text.empty() && text.back() == ' ') {
283 text.pop_back();
284 }
285
286 if (!text.empty() && text.back() != '\n') {
287 text.push_back('\n');
288 }
289
290 return DecorateText(text, line_starts_, options_,
291 MergeAndExemptSpans(underlines_, underlines_exempted_));
292}
293
294void DocPrinter::PrintDoc(const Doc& doc) {
295 size_t start_pos = output_.tellp();
296
297 if (const auto* doc_node = doc.as<LiteralDocNode>()) {
298 PrintTypedDoc(GetRef<LiteralDoc>(doc_node));
299 } else if (const auto* doc_node = doc.as<IdDocNode>()) {
300 PrintTypedDoc(GetRef<IdDoc>(doc_node));
301 } else if (const auto* doc_node = doc.as<AttrAccessDocNode>()) {
302 PrintTypedDoc(GetRef<AttrAccessDoc>(doc_node));
303 } else if (const auto* doc_node = doc.as<IndexDocNode>()) {
304 PrintTypedDoc(GetRef<IndexDoc>(doc_node));
305 } else if (const auto* doc_node = doc.as<OperationDocNode>()) {
306 PrintTypedDoc(GetRef<OperationDoc>(doc_node));
307 } else if (const auto* doc_node = doc.as<CallDocNode>()) {
308 PrintTypedDoc(GetRef<CallDoc>(doc_node));
309 } else if (const auto* doc_node = doc.as<LambdaDocNode>()) {
310 PrintTypedDoc(GetRef<LambdaDoc>(doc_node));
311 } else if (const auto* doc_node = doc.as<ListDocNode>()) {
312 PrintTypedDoc(GetRef<ListDoc>(doc_node));
313 } else if (const auto* doc_node = doc.as<TupleDocNode>()) {
314 PrintTypedDoc(GetRef<TupleDoc>(doc_node));
315 } else if (const auto* doc_node = doc.as<DictDocNode>()) {
316 PrintTypedDoc(GetRef<DictDoc>(doc_node));
317 } else if (const auto* doc_node = doc.as<SliceDocNode>()) {
318 PrintTypedDoc(GetRef<SliceDoc>(doc_node));
319 } else if (const auto* doc_node = doc.as<StmtBlockDocNode>()) {
320 PrintTypedDoc(GetRef<StmtBlockDoc>(doc_node));
321 } else if (const auto* doc_node = doc.as<AssignDocNode>()) {
322 PrintTypedDoc(GetRef<AssignDoc>(doc_node));
323 } else if (const auto* doc_node = doc.as<IfDocNode>()) {
324 PrintTypedDoc(GetRef<IfDoc>(doc_node));
325 } else if (const auto* doc_node = doc.as<WhileDocNode>()) {
326 PrintTypedDoc(GetRef<WhileDoc>(doc_node));
327 } else if (const auto* doc_node = doc.as<ForDocNode>()) {
328 PrintTypedDoc(GetRef<ForDoc>(doc_node));
329 } else if (const auto* doc_node = doc.as<ScopeDocNode>()) {
330 PrintTypedDoc(GetRef<ScopeDoc>(doc_node));
331 } else if (const auto* doc_node = doc.as<ExprStmtDocNode>()) {
332 PrintTypedDoc(GetRef<ExprStmtDoc>(doc_node));
333 } else if (const auto* doc_node = doc.as<AssertDocNode>()) {
334 PrintTypedDoc(GetRef<AssertDoc>(doc_node));
335 } else if (const auto* doc_node = doc.as<ReturnDocNode>()) {
336 PrintTypedDoc(GetRef<ReturnDoc>(doc_node));
337 } else if (const auto* doc_node = doc.as<FunctionDocNode>()) {
338 PrintTypedDoc(GetRef<FunctionDoc>(doc_node));
339 } else if (const auto* doc_node = doc.as<ClassDocNode>()) {
340 PrintTypedDoc(GetRef<ClassDoc>(doc_node));
341 } else if (const auto* doc_node = doc.as<CommentDocNode>()) {
342 PrintTypedDoc(GetRef<CommentDoc>(doc_node));
343 } else if (const auto* doc_node = doc.as<DocStringDocNode>()) {
344 PrintTypedDoc(GetRef<DocStringDoc>(doc_node));
345 } else {
346 LOG(FATAL) << "Do not know how to print " << doc->GetTypeKey();
347 throw;
348 }
349
350 size_t end_pos = output_.tellp();
351 for (const ObjectPath& path : doc->source_paths) {
352 MarkSpan({start_pos, end_pos}, path);
353 }
354}
355
356void DocPrinter::MarkSpan(const ByteSpan& span, const ObjectPath& path) {
357 int n = path_to_underline_.size();
358 for (int i = 0; i < n; ++i) {
359 ObjectPath p = path_to_underline_[i];
360 if (path->Length() >= current_max_path_length_[i] && path->IsPrefixOf(p)) {
361 if (path->Length() > current_max_path_length_[i]) {
362 current_max_path_length_[i] = path->Length();
363 current_underline_candidates_[i].clear();
364 }
365 current_underline_candidates_[i].push_back(span);
366 }
367 }
368}
369
370} // namespace printer
371} // namespace script
372} // namespace tvm
373