1#pragma once
2#include <c10/util/Exception.h>
3#include <c10/util/Optional.h>
4
5#include <algorithm>
6#include <iostream>
7#include <iterator>
8#include <memory>
9#include <numeric>
10#include <unordered_map>
11
12namespace torch {
13namespace jit {
14
15class SourceRangeUnpickler;
16struct SourceRange;
17
18// A stringlike class backed by a vector of string_view
19// the string represented are logically the concatenation of the string_views
20// This has advantage of not needing continues memory.
21struct TORCH_API StringCordView {
22 StringCordView();
23 StringCordView(const StringCordView&) = default;
24 StringCordView(
25 std::vector<c10::string_view> inputs,
26 std::vector<std::shared_ptr<std::string>> ownerships);
27
28 StringCordView& operator=(const StringCordView&) = default;
29
30 size_t size() const {
31 return accumulated_sizes_.back();
32 }
33
34 size_t find(const std::string& tok, size_t start) const;
35 StringCordView substr(size_t start, size_t size) const;
36
37 char at(size_t index) const {
38 return *iter_for_pos(index);
39 }
40 char operator[](size_t index) const {
41 return at(index);
42 }
43
44 std::string str() const {
45 std::stringstream ss;
46 for (auto s : pieces_) {
47 ss << std::string(s);
48 }
49 return ss.str();
50 }
51
52 bool operator==(const std::string& rhs) const;
53
54 bool operator==(const StringCordView& rhs) const;
55
56 c10::string_view piece(size_t index) const {
57 return pieces_[index];
58 }
59
60 struct Iterator {
61 Iterator(
62 const StringCordView* str,
63 size_t start_line,
64 size_t start_pos,
65 size_t size)
66 : line_(start_line), pos_(start_pos), str_(str), size_(size) {}
67 explicit Iterator(const StringCordView* str)
68 : Iterator(str, 0, 0, str->size()) {}
69
70 Iterator() : Iterator(nullptr, 0, 0, 0) {}
71
72 Iterator(const Iterator&) = default;
73 Iterator(Iterator&&) = default;
74 Iterator& operator=(const Iterator&) = default;
75 Iterator& operator=(Iterator&&) = default;
76
77 Iterator operator++() {
78 if (size_ == 0) {
79 return *this;
80 }
81 if ((pos_ + 1) < str_->pieces_[line_].size()) {
82 pos_++;
83 } else {
84 line_++;
85 pos_ = 0;
86 }
87 return *this;
88 }
89
90 Iterator operator++(int) {
91 Iterator prev(*this);
92 ++(*this);
93 return prev;
94 }
95
96 Iterator next_iter() const {
97 Iterator next(*this);
98 ++next;
99 return next;
100 }
101
102 Iterator& operator+=(size_t num) {
103 if (!has_next()) {
104 return *this;
105 }
106 size_t target_pos = pos_ + num;
107 if (target_pos >= str_->accumulated_sizes_[line_] &&
108 (line_ + 1) < str_->accumulated_sizes_.size() &&
109 target_pos < str_->accumulated_sizes_[line_ + 1]) {
110 pos_ = target_pos;
111 return *this;
112 }
113
114 size_t target_abs_pos = pos() + num;
115 *this = str_->iter_for_pos(target_abs_pos);
116 return *this;
117 }
118
119 bool operator==(const Iterator& rhs) const {
120 if (!has_next() && !rhs.has_next()) {
121 return true;
122 }
123 return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_);
124 }
125 bool operator!=(const Iterator& rhs) {
126 return !((*this) == rhs);
127 }
128 bool has_next() const {
129 return size_ > 0 && (line_ < str_->pieces_.size());
130 }
131
132 char operator*() const {
133 TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size());
134 TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size());
135 return str_->pieces_[line_].at(pos_);
136 }
137
138 // returns rest of the line of the current iterator
139 c10::string_view rest_line() const {
140 if (line_ >= str_->pieces_.size()) {
141 return "";
142 }
143
144 c10::string_view cur_line = str_->pieces_[line_];
145 return cur_line.substr(pos_, std::string::npos);
146 }
147
148 size_t pos() const {
149 if (size_ == 0) {
150 return 0;
151 }
152 return str_->accumulated_sizes_[line_] + pos_;
153 }
154
155 private:
156 size_t line_;
157 size_t pos_;
158 const StringCordView* str_;
159 size_t size_;
160 friend struct StringCordView;
161 };
162
163 Iterator begin() const {
164 return Iterator(this, 0, 0, size());
165 }
166 Iterator end() const {
167 return Iterator(this, pieces_.size(), 0, 0);
168 }
169 Iterator iter_for_pos(size_t pos) const;
170
171 private:
172 std::vector<c10::string_view> pieces_;
173 std::vector<size_t> accumulated_sizes_;
174 std::vector<std::shared_ptr<std::string>> owned_strings_;
175};
176
177// Source represents a code segment. It keeps track of:
178// - text_view : the view into text of the code segment
179// - filename (optional) : if present, represents the name of the file from
180// which the code segment originated.
181// - starting_line_no : represents the line in the original file where the
182// code segment started.
183struct TORCH_API Source {
184 // Whether or not Source should copy the string passed in the constructor.
185 enum CopiesString { COPIES_STRING, DONT_COPY };
186
187 explicit Source(
188 c10::string_view text_view,
189 c10::optional<std::string> filename = c10::nullopt,
190 size_t starting_line_no = 0,
191 std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr,
192 CopiesString copies_str = COPIES_STRING)
193 : filename_(std::move(filename)),
194 starting_line_no_(starting_line_no),
195 gen_ranges_(std::move(gen_ranges)) {
196 if (copies_str == COPIES_STRING) {
197 std::shared_ptr<std::string> allocated_str =
198 std::make_shared<std::string>(text_view.data(), text_view.size());
199 text_view_ = StringCordView({*allocated_str}, {allocated_str});
200 } else {
201 text_view_ = StringCordView({text_view}, {});
202 }
203
204 calc_line_start_offsets();
205 }
206
207 explicit Source(
208 StringCordView str,
209 c10::optional<std::string> filename = c10::nullopt,
210 size_t starting_line_no = 0,
211 std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr)
212 : text_view_(str),
213 filename_(std::move(filename)),
214 starting_line_no_(starting_line_no),
215 gen_ranges_(std::move(gen_ranges)) {
216 calc_line_start_offsets();
217 }
218 // Given a line number (within source_), return the byte offset of the
219 // beginning of that line.
220 size_t offset_for_line(size_t line) const {
221 return line_starting_offsets_.at(line);
222 }
223
224 // Returns number of lines present.
225 size_t num_lines() const {
226 return line_starting_offsets_.size();
227 }
228
229 // Calculate the line (within the code segment) on which `offset` resides.
230 size_t lineno_for_offset(size_t offset) const {
231 auto iter = std::upper_bound(
232 line_starting_offsets_.begin(), line_starting_offsets_.end(), offset);
233 return iter - line_starting_offsets_.begin() - 1;
234 }
235
236 // Calculate the line (within the original source file, if present) on which
237 // `lineno` resides.
238 size_t lineno_to_source_lineno(size_t lineno) const {
239 if (filename_) {
240 return lineno + starting_line_no_;
241 } else {
242 return lineno;
243 }
244 }
245
246 StringCordView get_line(size_t lineno) const {
247 auto start = offset_for_line(lineno);
248 auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start
249 : text_view_.size() - start;
250 return text_view_.substr(start, size);
251 }
252
253 const StringCordView& text_str() const {
254 return text_view_;
255 }
256
257 char char_at(size_t index) const {
258 return text_view_.at(index);
259 }
260
261 size_t size() const {
262 return text_view_.size();
263 }
264
265 c10::optional<std::string>& filename() {
266 return filename_;
267 }
268
269 size_t starting_line_no() const {
270 return starting_line_no_;
271 }
272
273 c10::optional<SourceRange> findSourceRangeThatGenerated(
274 const SourceRange& range);
275
276 ~Source() = default;
277
278 private:
279 void calc_line_start_offsets() {
280 line_starting_offsets_.clear();
281 line_starting_offsets_.push_back(0);
282 size_t pos = 0;
283 while ((pos = text_view_.find("\n", pos)) != std::string::npos) {
284 line_starting_offsets_.push_back(++pos);
285 }
286 }
287
288 StringCordView text_view_;
289
290 c10::optional<std::string> filename_;
291 // If filename_ is not present, starting_line_no_ is don't care
292 size_t starting_line_no_;
293 // Starting offsets for lines into the source. e.g. line 0 starts at
294 // line_starting_offsets_[0], etc.
295 std::vector<size_t> line_starting_offsets_;
296
297 std::shared_ptr<SourceRangeUnpickler> gen_ranges_;
298};
299
300// A SourceRange is a reference to subset of a Source, specified by `start` and
301// `end` byte offsets into the source text.
302struct TORCH_API SourceRange {
303 SourceRange(std::shared_ptr<Source> source_view, size_t start_, size_t end_)
304 : source_view_(std::move(source_view)), start_(start_), end_(end_) {
305 if (source_view_) {
306 start_iter_ = source_view_->text_str().iter_for_pos(start_);
307 }
308 }
309
310 SourceRange() : source_view_(nullptr), start_(0), end_(0) {}
311
312 SourceRange(
313 std::shared_ptr<Source> source_view_,
314 StringCordView::Iterator start_iter,
315 size_t end_)
316 : source_view_(std::move(source_view_)),
317 start_(start_iter.pos()),
318 end_(end_),
319 start_iter_(start_iter) {}
320
321 const c10::string_view token_text() const {
322 size_t size = end() - start();
323 return start_iter_.rest_line().substr(0, size);
324 }
325
326 const StringCordView text() const {
327 return source_view_->text_str().substr(start(), end() - start());
328 }
329 size_t size() const {
330 return end() - start();
331 }
332 static const size_t CONTEXT = 3;
333 void highlight(std::ostream& out) const;
334
335 // Customizable version of 'highlight' method.
336 void print_with_context(
337 std::ostream& out,
338 size_t context,
339 bool highlight,
340 const std::string& funcname) const;
341
342 const std::shared_ptr<Source>& source() const {
343 return source_view_;
344 }
345 size_t start() const {
346 return start_;
347 }
348 size_t end() const {
349 return end_;
350 }
351 std::string str() const {
352 std::stringstream ss;
353 highlight(ss);
354 return ss.str();
355 }
356
357 c10::optional<std::tuple<std::string, size_t, size_t>> file_line_col() const {
358 if (!source_view_ || !source()->filename()) {
359 return c10::nullopt;
360 }
361
362 auto lineno = source_view_->lineno_for_offset(start_);
363 auto col_offset = (int)start_ - (int)source_view_->offset_for_line(lineno);
364 // TODO: c10::optional<>::value returns an rvalue ref so can't use it here??
365 return std::make_tuple<std::string, size_t, size_t>(
366 source_view_->filename().value_or(""),
367 source_view_->lineno_to_source_lineno(lineno),
368 (size_t)col_offset);
369 }
370
371 bool operator==(const SourceRange& rhs) const {
372 return start() == rhs.start() && end() == rhs.end() &&
373 source() == rhs.source();
374 }
375
376 bool operator!=(const SourceRange& rhs) const {
377 return !(*this == rhs);
378 }
379
380 c10::optional<SourceRange> findSourceRangeThatGenerated() const {
381 if (!source_view_) {
382 return c10::nullopt;
383 }
384 return source_view_->findSourceRangeThatGenerated(*this);
385 }
386
387 protected:
388 std::shared_ptr<Source> source_view_;
389
390 private:
391 size_t start_;
392 size_t end_;
393 StringCordView::Iterator start_iter_;
394};
395
396// OwnedSourceRange is just like a SourceRange except that it owns a `Source`
397// instead of `Source`. Thus OwnedSourceRange owns a copy of source text.
398struct OwnedSourceRange : public SourceRange {
399 explicit OwnedSourceRange(const SourceRange& source_range)
400 : SourceRange(source_range) {
401 const auto& source = source_range.source();
402 if (source) {
403 source_view_ = std::make_shared<Source>(
404 source->text_str().str(),
405 source->filename(),
406 source->starting_line_no());
407 }
408 }
409};
410
411struct TORCH_API SourceRangeHasher {
412 public:
413 size_t operator()(const torch::jit::SourceRange& key) const;
414};
415
416struct StackEntry {
417 std::string filename;
418 SourceRange range;
419};
420
421TORCH_API void format_stack_trace(
422 std::ostream& out,
423 const std::vector<StackEntry>& entries);
424
425inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) {
426 range.highlight(out);
427 return out;
428}
429
430// A pair of (byte offset, SourceRange) describing a specific segment
431// of the output stream
432struct TaggedRange {
433 TaggedRange(size_t bytes, SourceRange range)
434 : bytes(bytes), range(std::move(range)) {}
435 size_t bytes;
436 SourceRange range;
437};
438using SourceRangeRecords = std::vector<TaggedRange>;
439using SourceRangeTagMap =
440 std::unordered_map<SourceRange, int64_t, SourceRangeHasher>;
441
442} // namespace jit
443} // namespace torch
444
445namespace std {
446template <>
447struct iterator_traits<torch::jit::StringCordView::Iterator> {
448 using value_type = char;
449 using difference_type = ptrdiff_t;
450 using pointer = char*;
451 using reference = char&;
452 using iterator_category = std::forward_iterator_tag;
453};
454} // namespace std
455