1 | #pragma once |
2 | #include <c10/util/Exception.h> |
3 | #include <c10/util/Optional.h> |
4 | |
5 | #include <algorithm> |
6 | #include <iostream> |
7 | #include <iterator> |
8 | #include <memory> |
9 | #include <numeric> |
10 | #include <unordered_map> |
11 | |
12 | namespace torch { |
13 | namespace jit { |
14 | |
15 | class SourceRangeUnpickler; |
16 | struct SourceRange; |
17 | |
18 | // A stringlike class backed by a vector of string_view |
19 | // the string represented are logically the concatenation of the string_views |
20 | // This has advantage of not needing continues memory. |
21 | struct TORCH_API StringCordView { |
22 | StringCordView(); |
23 | StringCordView(const StringCordView&) = default; |
24 | StringCordView( |
25 | std::vector<c10::string_view> inputs, |
26 | std::vector<std::shared_ptr<std::string>> ownerships); |
27 | |
28 | StringCordView& operator=(const StringCordView&) = default; |
29 | |
30 | size_t size() const { |
31 | return accumulated_sizes_.back(); |
32 | } |
33 | |
34 | size_t find(const std::string& tok, size_t start) const; |
35 | StringCordView substr(size_t start, size_t size) const; |
36 | |
37 | char at(size_t index) const { |
38 | return *iter_for_pos(index); |
39 | } |
40 | char operator[](size_t index) const { |
41 | return at(index); |
42 | } |
43 | |
44 | std::string str() const { |
45 | std::stringstream ss; |
46 | for (auto s : pieces_) { |
47 | ss << std::string(s); |
48 | } |
49 | return ss.str(); |
50 | } |
51 | |
52 | bool operator==(const std::string& rhs) const; |
53 | |
54 | bool operator==(const StringCordView& rhs) const; |
55 | |
56 | c10::string_view piece(size_t index) const { |
57 | return pieces_[index]; |
58 | } |
59 | |
60 | struct Iterator { |
61 | Iterator( |
62 | const StringCordView* str, |
63 | size_t start_line, |
64 | size_t start_pos, |
65 | size_t size) |
66 | : line_(start_line), pos_(start_pos), str_(str), size_(size) {} |
67 | explicit Iterator(const StringCordView* str) |
68 | : Iterator(str, 0, 0, str->size()) {} |
69 | |
70 | Iterator() : Iterator(nullptr, 0, 0, 0) {} |
71 | |
72 | Iterator(const Iterator&) = default; |
73 | Iterator(Iterator&&) = default; |
74 | Iterator& operator=(const Iterator&) = default; |
75 | Iterator& operator=(Iterator&&) = default; |
76 | |
77 | Iterator operator++() { |
78 | if (size_ == 0) { |
79 | return *this; |
80 | } |
81 | if ((pos_ + 1) < str_->pieces_[line_].size()) { |
82 | pos_++; |
83 | } else { |
84 | line_++; |
85 | pos_ = 0; |
86 | } |
87 | return *this; |
88 | } |
89 | |
90 | Iterator operator++(int) { |
91 | Iterator prev(*this); |
92 | ++(*this); |
93 | return prev; |
94 | } |
95 | |
96 | Iterator next_iter() const { |
97 | Iterator next(*this); |
98 | ++next; |
99 | return next; |
100 | } |
101 | |
102 | Iterator& operator+=(size_t num) { |
103 | if (!has_next()) { |
104 | return *this; |
105 | } |
106 | size_t target_pos = pos_ + num; |
107 | if (target_pos >= str_->accumulated_sizes_[line_] && |
108 | (line_ + 1) < str_->accumulated_sizes_.size() && |
109 | target_pos < str_->accumulated_sizes_[line_ + 1]) { |
110 | pos_ = target_pos; |
111 | return *this; |
112 | } |
113 | |
114 | size_t target_abs_pos = pos() + num; |
115 | *this = str_->iter_for_pos(target_abs_pos); |
116 | return *this; |
117 | } |
118 | |
119 | bool operator==(const Iterator& rhs) const { |
120 | if (!has_next() && !rhs.has_next()) { |
121 | return true; |
122 | } |
123 | return (str_ == rhs.str_) && (line_ == rhs.line_) && (pos_ == rhs.pos_); |
124 | } |
125 | bool operator!=(const Iterator& rhs) { |
126 | return !((*this) == rhs); |
127 | } |
128 | bool has_next() const { |
129 | return size_ > 0 && (line_ < str_->pieces_.size()); |
130 | } |
131 | |
132 | char operator*() const { |
133 | TORCH_INTERNAL_ASSERT(line_ < str_->pieces_.size()); |
134 | TORCH_INTERNAL_ASSERT(pos_ < str_->pieces_[line_].size()); |
135 | return str_->pieces_[line_].at(pos_); |
136 | } |
137 | |
138 | // returns rest of the line of the current iterator |
139 | c10::string_view rest_line() const { |
140 | if (line_ >= str_->pieces_.size()) { |
141 | return "" ; |
142 | } |
143 | |
144 | c10::string_view cur_line = str_->pieces_[line_]; |
145 | return cur_line.substr(pos_, std::string::npos); |
146 | } |
147 | |
148 | size_t pos() const { |
149 | if (size_ == 0) { |
150 | return 0; |
151 | } |
152 | return str_->accumulated_sizes_[line_] + pos_; |
153 | } |
154 | |
155 | private: |
156 | size_t line_; |
157 | size_t pos_; |
158 | const StringCordView* str_; |
159 | size_t size_; |
160 | friend struct StringCordView; |
161 | }; |
162 | |
163 | Iterator begin() const { |
164 | return Iterator(this, 0, 0, size()); |
165 | } |
166 | Iterator end() const { |
167 | return Iterator(this, pieces_.size(), 0, 0); |
168 | } |
169 | Iterator iter_for_pos(size_t pos) const; |
170 | |
171 | private: |
172 | std::vector<c10::string_view> pieces_; |
173 | std::vector<size_t> accumulated_sizes_; |
174 | std::vector<std::shared_ptr<std::string>> owned_strings_; |
175 | }; |
176 | |
177 | // Source represents a code segment. It keeps track of: |
178 | // - text_view : the view into text of the code segment |
179 | // - filename (optional) : if present, represents the name of the file from |
180 | // which the code segment originated. |
181 | // - starting_line_no : represents the line in the original file where the |
182 | // code segment started. |
183 | struct TORCH_API Source { |
184 | // Whether or not Source should copy the string passed in the constructor. |
185 | enum CopiesString { COPIES_STRING, DONT_COPY }; |
186 | |
187 | explicit Source( |
188 | c10::string_view text_view, |
189 | c10::optional<std::string> filename = c10::nullopt, |
190 | size_t starting_line_no = 0, |
191 | std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr, |
192 | CopiesString copies_str = COPIES_STRING) |
193 | : filename_(std::move(filename)), |
194 | starting_line_no_(starting_line_no), |
195 | gen_ranges_(std::move(gen_ranges)) { |
196 | if (copies_str == COPIES_STRING) { |
197 | std::shared_ptr<std::string> allocated_str = |
198 | std::make_shared<std::string>(text_view.data(), text_view.size()); |
199 | text_view_ = StringCordView({*allocated_str}, {allocated_str}); |
200 | } else { |
201 | text_view_ = StringCordView({text_view}, {}); |
202 | } |
203 | |
204 | calc_line_start_offsets(); |
205 | } |
206 | |
207 | explicit Source( |
208 | StringCordView str, |
209 | c10::optional<std::string> filename = c10::nullopt, |
210 | size_t starting_line_no = 0, |
211 | std::shared_ptr<SourceRangeUnpickler> gen_ranges = nullptr) |
212 | : text_view_(str), |
213 | filename_(std::move(filename)), |
214 | starting_line_no_(starting_line_no), |
215 | gen_ranges_(std::move(gen_ranges)) { |
216 | calc_line_start_offsets(); |
217 | } |
218 | // Given a line number (within source_), return the byte offset of the |
219 | // beginning of that line. |
220 | size_t offset_for_line(size_t line) const { |
221 | return line_starting_offsets_.at(line); |
222 | } |
223 | |
224 | // Returns number of lines present. |
225 | size_t num_lines() const { |
226 | return line_starting_offsets_.size(); |
227 | } |
228 | |
229 | // Calculate the line (within the code segment) on which `offset` resides. |
230 | size_t lineno_for_offset(size_t offset) const { |
231 | auto iter = std::upper_bound( |
232 | line_starting_offsets_.begin(), line_starting_offsets_.end(), offset); |
233 | return iter - line_starting_offsets_.begin() - 1; |
234 | } |
235 | |
236 | // Calculate the line (within the original source file, if present) on which |
237 | // `lineno` resides. |
238 | size_t lineno_to_source_lineno(size_t lineno) const { |
239 | if (filename_) { |
240 | return lineno + starting_line_no_; |
241 | } else { |
242 | return lineno; |
243 | } |
244 | } |
245 | |
246 | StringCordView get_line(size_t lineno) const { |
247 | auto start = offset_for_line(lineno); |
248 | auto size = (lineno + 1) < num_lines() ? offset_for_line(lineno + 1) - start |
249 | : text_view_.size() - start; |
250 | return text_view_.substr(start, size); |
251 | } |
252 | |
253 | const StringCordView& text_str() const { |
254 | return text_view_; |
255 | } |
256 | |
257 | char char_at(size_t index) const { |
258 | return text_view_.at(index); |
259 | } |
260 | |
261 | size_t size() const { |
262 | return text_view_.size(); |
263 | } |
264 | |
265 | c10::optional<std::string>& filename() { |
266 | return filename_; |
267 | } |
268 | |
269 | size_t starting_line_no() const { |
270 | return starting_line_no_; |
271 | } |
272 | |
273 | c10::optional<SourceRange> findSourceRangeThatGenerated( |
274 | const SourceRange& range); |
275 | |
276 | ~Source() = default; |
277 | |
278 | private: |
279 | void calc_line_start_offsets() { |
280 | line_starting_offsets_.clear(); |
281 | line_starting_offsets_.push_back(0); |
282 | size_t pos = 0; |
283 | while ((pos = text_view_.find("\n" , pos)) != std::string::npos) { |
284 | line_starting_offsets_.push_back(++pos); |
285 | } |
286 | } |
287 | |
288 | StringCordView text_view_; |
289 | |
290 | c10::optional<std::string> filename_; |
291 | // If filename_ is not present, starting_line_no_ is don't care |
292 | size_t starting_line_no_; |
293 | // Starting offsets for lines into the source. e.g. line 0 starts at |
294 | // line_starting_offsets_[0], etc. |
295 | std::vector<size_t> line_starting_offsets_; |
296 | |
297 | std::shared_ptr<SourceRangeUnpickler> gen_ranges_; |
298 | }; |
299 | |
300 | // A SourceRange is a reference to subset of a Source, specified by `start` and |
301 | // `end` byte offsets into the source text. |
302 | struct TORCH_API SourceRange { |
303 | SourceRange(std::shared_ptr<Source> source_view, size_t start_, size_t end_) |
304 | : source_view_(std::move(source_view)), start_(start_), end_(end_) { |
305 | if (source_view_) { |
306 | start_iter_ = source_view_->text_str().iter_for_pos(start_); |
307 | } |
308 | } |
309 | |
310 | SourceRange() : source_view_(nullptr), start_(0), end_(0) {} |
311 | |
312 | SourceRange( |
313 | std::shared_ptr<Source> source_view_, |
314 | StringCordView::Iterator start_iter, |
315 | size_t end_) |
316 | : source_view_(std::move(source_view_)), |
317 | start_(start_iter.pos()), |
318 | end_(end_), |
319 | start_iter_(start_iter) {} |
320 | |
321 | const c10::string_view token_text() const { |
322 | size_t size = end() - start(); |
323 | return start_iter_.rest_line().substr(0, size); |
324 | } |
325 | |
326 | const StringCordView text() const { |
327 | return source_view_->text_str().substr(start(), end() - start()); |
328 | } |
329 | size_t size() const { |
330 | return end() - start(); |
331 | } |
332 | static const size_t CONTEXT = 3; |
333 | void highlight(std::ostream& out) const; |
334 | |
335 | // Customizable version of 'highlight' method. |
336 | void print_with_context( |
337 | std::ostream& out, |
338 | size_t context, |
339 | bool highlight, |
340 | const std::string& funcname) const; |
341 | |
342 | const std::shared_ptr<Source>& source() const { |
343 | return source_view_; |
344 | } |
345 | size_t start() const { |
346 | return start_; |
347 | } |
348 | size_t end() const { |
349 | return end_; |
350 | } |
351 | std::string str() const { |
352 | std::stringstream ss; |
353 | highlight(ss); |
354 | return ss.str(); |
355 | } |
356 | |
357 | c10::optional<std::tuple<std::string, size_t, size_t>> file_line_col() const { |
358 | if (!source_view_ || !source()->filename()) { |
359 | return c10::nullopt; |
360 | } |
361 | |
362 | auto lineno = source_view_->lineno_for_offset(start_); |
363 | auto col_offset = (int)start_ - (int)source_view_->offset_for_line(lineno); |
364 | // TODO: c10::optional<>::value returns an rvalue ref so can't use it here?? |
365 | return std::make_tuple<std::string, size_t, size_t>( |
366 | source_view_->filename().value_or("" ), |
367 | source_view_->lineno_to_source_lineno(lineno), |
368 | (size_t)col_offset); |
369 | } |
370 | |
371 | bool operator==(const SourceRange& rhs) const { |
372 | return start() == rhs.start() && end() == rhs.end() && |
373 | source() == rhs.source(); |
374 | } |
375 | |
376 | bool operator!=(const SourceRange& rhs) const { |
377 | return !(*this == rhs); |
378 | } |
379 | |
380 | c10::optional<SourceRange> findSourceRangeThatGenerated() const { |
381 | if (!source_view_) { |
382 | return c10::nullopt; |
383 | } |
384 | return source_view_->findSourceRangeThatGenerated(*this); |
385 | } |
386 | |
387 | protected: |
388 | std::shared_ptr<Source> source_view_; |
389 | |
390 | private: |
391 | size_t start_; |
392 | size_t end_; |
393 | StringCordView::Iterator start_iter_; |
394 | }; |
395 | |
396 | // OwnedSourceRange is just like a SourceRange except that it owns a `Source` |
397 | // instead of `Source`. Thus OwnedSourceRange owns a copy of source text. |
398 | struct OwnedSourceRange : public SourceRange { |
399 | explicit OwnedSourceRange(const SourceRange& source_range) |
400 | : SourceRange(source_range) { |
401 | const auto& source = source_range.source(); |
402 | if (source) { |
403 | source_view_ = std::make_shared<Source>( |
404 | source->text_str().str(), |
405 | source->filename(), |
406 | source->starting_line_no()); |
407 | } |
408 | } |
409 | }; |
410 | |
411 | struct TORCH_API SourceRangeHasher { |
412 | public: |
413 | size_t operator()(const torch::jit::SourceRange& key) const; |
414 | }; |
415 | |
416 | struct StackEntry { |
417 | std::string filename; |
418 | SourceRange range; |
419 | }; |
420 | |
421 | TORCH_API void format_stack_trace( |
422 | std::ostream& out, |
423 | const std::vector<StackEntry>& entries); |
424 | |
425 | inline std::ostream& operator<<(std::ostream& out, const SourceRange& range) { |
426 | range.highlight(out); |
427 | return out; |
428 | } |
429 | |
430 | // A pair of (byte offset, SourceRange) describing a specific segment |
431 | // of the output stream |
432 | struct TaggedRange { |
433 | TaggedRange(size_t bytes, SourceRange range) |
434 | : bytes(bytes), range(std::move(range)) {} |
435 | size_t bytes; |
436 | SourceRange range; |
437 | }; |
438 | using SourceRangeRecords = std::vector<TaggedRange>; |
439 | using SourceRangeTagMap = |
440 | std::unordered_map<SourceRange, int64_t, SourceRangeHasher>; |
441 | |
442 | } // namespace jit |
443 | } // namespace torch |
444 | |
445 | namespace std { |
446 | template <> |
447 | struct iterator_traits<torch::jit::StringCordView::Iterator> { |
448 | using value_type = char; |
449 | using difference_type = ptrdiff_t; |
450 | using pointer = char*; |
451 | using reference = char&; |
452 | using iterator_category = std::forward_iterator_tag; |
453 | }; |
454 | } // namespace std |
455 | |