1/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14*/
15
16// We extract stack traces in Python using the logic in tf_stack.cc, which
17// stores a list of PyCodeObject*. Such stack trace extraction is really fast.
18//
19// We store the retrieved stack trace within the Node object directly. Then
20// whenever the graph is instantiated/copies, we copy the stack trace with it.
21// Since the graph instantiation goes through the protobuf roundtrip, we store
22// the original stack traces mapping attached in FunctionLibraryDefinition.
23
24// clang-format off
25// These headers must be at the top, before including Python.h header
26// Otherwise, we get C2039 on MSVC due to 'copysign'
27#include "pybind11/complex.h"
28#include "pybind11/pybind11.h"
29#include "pybind11/stl.h"
30#include "pybind11/stl_bind.h"
31// clang-format on
32
33#include <frameobject.h>
34
35#include <algorithm>
36#include <vector>
37
38#include "Python.h"
39#include "absl/algorithm/container.h"
40#include "absl/container/flat_hash_set.h"
41#include "absl/hash/hash.h"
42#include "absl/strings/str_format.h"
43#include "absl/strings/str_join.h"
44#include "absl/types/span.h"
45#include "tensorflow/c/c_api_internal.h"
46#include "tensorflow/core/graph/graph.h"
47#include "tensorflow/core/platform/mutex.h"
48#include "tensorflow/core/platform/path.h"
49#include "tensorflow/python/util/stack_trace.h"
50
51struct StackFrame; // Forward declaration.
52struct StackTrace;
53
54PYBIND11_MAKE_OPAQUE(std::vector<StackFrame>);
55PYBIND11_MAKE_OPAQUE(StackTrace);
56
57namespace tensorflow {
58
59namespace {
60
61namespace py = pybind11;
62
63using StringSet = absl::flat_hash_set<std::string>;
64
65// Python wrapper for a SourceMap.
66class PyBindSourceMap {
67 public:
68 PyBindSourceMap() : source_map_(std::make_shared<SourceMap>()) {}
69
70 // Shares ownership with whoever captures traces in the scope of this map.
71 std::shared_ptr<SourceMap> source_map_;
72};
73
74// Python wrapper for a FileSet.
75class PyBindFileSet {
76 public:
77 PyBindFileSet() : file_set_(std::make_shared<StringSet>()) {}
78
79 // Shares ownership with whoever captures traces in the scope of this set.
80 std::shared_ptr<StringSet> file_set_;
81};
82
83// Returns contents of the line corresponding to the given frame.
84//
85// Precondition: must be holding Python GIL.
86py::str LineContents(const StackFrame& frame) {
87 DCheckPyGilStateForStackTrace();
88 // Pointers are to avoid static destruction of pybind::object, which
89 // occurs in uncontrollable states.
90 static const auto* inspect = new py::module(py::module::import("inspect"));
91 static const auto* getmodule = new py::function(inspect->attr("getmodule"));
92 static const auto* linecache =
93 new py::module(py::module::import("linecache"));
94 static const auto* checkcache =
95 new py::function(linecache->attr("checkcache"));
96 static const auto* getline = new py::function(linecache->attr("getline"));
97 (*checkcache)(py::str(frame.file_name));
98
99 // Here we use the undocumented second argument of inspect.getmodule to look
100 // up a module from a filename. It has been unchanged since 2015.
101 const auto& module = (*getmodule)(py::none(), py::str(frame.file_name));
102 py::object dict = py::none();
103 if (!module.is_none()) {
104 // module dict is used by getline to resolve import hooks; see the
105 // stdlib's inspect module.
106 dict = module.attr("__dict__");
107 }
108 return py::cast<py::str>(
109 (*getline)(py::str(frame.file_name), py::int_(frame.line_number), dict)
110 .attr("strip")());
111}
112
113// Ignores the frames containing this substring for common prefix calculation.
114static const char* kFilenameToIgnorePrefix = "<embedded";
115
116// Converts the given stack frame to string, according to options defined in
117// `opts`.
118std::string StackFrameToString(
119 const StackFrame& frame,
120 const AbstractStackTrace::TracePrintingOptions& opts,
121 int shared_prefix_size = 0) {
122 std::string out = absl::StrFormat(
123 "File \"%s\", line %d, in %s",
124 absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)
125 ? frame.file_name
126 : frame.file_name.substr(shared_prefix_size),
127 frame.line_number, frame.function_name);
128
129 if (opts.show_line_contents) {
130 PyGILState_STATE state = PyGILState_Ensure();
131 std::string line_contents = std::string(LineContents(frame));
132 PyGILState_Release(state);
133 if (!line_contents.empty()) {
134 absl::StrAppend(&out, "\n ", line_contents);
135 }
136 }
137 return out;
138}
139
140class StackTraceWrapper : public AbstractStackTrace {
141 public:
142 explicit StackTraceWrapper(absl::Span<const StackFrame> stack_frames)
143 : stack_frames_cache_(std::vector<StackFrame>(stack_frames.begin(),
144 stack_frames.end())) {}
145
146 StackTraceWrapper(StackTraceWrapper&& rhs) {
147 captured_ = std::move(rhs.captured_);
148 source_map_ = std::move(rhs.source_map_);
149 filter_ = std::move(rhs.filter_);
150 stacklevel_ = rhs.stacklevel_;
151 tensorflow::mutex_lock lock(rhs.mu_);
152 stack_frames_cache_ = std::move(rhs.stack_frames_cache_);
153 last_stack_frame_cache_ = std::move(rhs.last_stack_frame_cache_);
154 }
155
156 StackTraceWrapper& operator=(StackTraceWrapper&& rhs) {
157 if (&rhs == this) return *this;
158
159 captured_ = std::move(rhs.captured_);
160 source_map_ = std::move(rhs.source_map_);
161 filter_ = std::move(rhs.filter_);
162 stacklevel_ = rhs.stacklevel_;
163
164 tensorflow::mutex_lock self_lock(mu_);
165 tensorflow::mutex_lock rhs_lock(rhs.mu_);
166
167 stack_frames_cache_ = std::move(rhs.stack_frames_cache_);
168 last_stack_frame_cache_ = std::move(rhs.last_stack_frame_cache_);
169 return *this;
170 }
171
172 static StackTraceWrapper ExtractStack(
173 const std::shared_ptr<SourceMap>& source_map,
174 const std::shared_ptr<StringSet>& filter, int stacklevel) {
175 return StackTraceWrapper{StackTrace::Capture(-1), source_map, filter,
176 stacklevel};
177 }
178
179 absl::Span<const StackFrame> ToFrames() const override {
180 tensorflow::mutex_lock lock(mu_);
181 if (stack_frames_cache_) {
182 return *stack_frames_cache_;
183 }
184
185 // Grabbing the GIL solves two purposes: 1) makes the class thread-safe,
186 // and 2) ToStackFrames and LineContents actually need it.
187 PyGILState_STATE state = PyGILState_Ensure();
188
189 stack_frames_cache_ = captured_.ToStackFrames(
190 *source_map_, [&](const char* f) { return StackTraceFiltering(f); });
191
192 // Drop last stack frames.
193 int newsize = stack_frames_cache_->size() - stacklevel_;
194 if (newsize < 0) {
195 newsize = 0;
196 }
197 stack_frames_cache_->resize(newsize);
198
199 PyGILState_Release(state);
200 return *stack_frames_cache_;
201 }
202
203 int get_stacklevel() const { return stacklevel_; }
204
205 void set_stacklevel(int stacklevel) { stacklevel_ = stacklevel; }
206
207 std::vector<StackFrame> GetUserFrames(int limit = -1) const {
208 PyGILState_STATE state = PyGILState_Ensure();
209 std::vector<StackFrame> user_frames = captured_.ToStackFrames(
210 *source_map_,
211 [&](const char* file_name) {
212 return StackTraceFiltering(file_name) ||
213 IsInternalFrameForFilename(file_name);
214 },
215 /*reverse_traversal=*/true,
216 /*limit=*/limit);
217 PyGILState_Release(state);
218 // ensure we use the original (outermost first) ordering.
219 absl::c_reverse(user_frames);
220 return user_frames;
221 }
222
223 StackFrame LastUserFrame() const override {
224 tensorflow::mutex_lock lock(mu_);
225 if (last_stack_frame_cache_) {
226 return *last_stack_frame_cache_;
227 }
228
229 PyGILState_STATE state = PyGILState_Ensure();
230 std::vector<StackFrame> last_frame = GetUserFrames(1);
231
232 if (last_frame.empty()) {
233 last_stack_frame_cache_ = StackFrame{"", -1, ""};
234 } else {
235 DCHECK_EQ(last_frame.size(), 1);
236 last_stack_frame_cache_ = last_frame[0];
237 }
238 PyGILState_Release(state);
239 return *last_stack_frame_cache_;
240 }
241
242 // Erases a section of the stack trace.
243 void Erase(int first, int last) {
244 tensorflow::mutex_lock lock(mu_);
245 if (!stack_frames_cache_) {
246 ToFrames();
247 }
248 DCHECK_GE(first, 0);
249 DCHECK_LT(first, stack_frames_cache_->size());
250 DCHECK_GE(last, 0);
251 DCHECK_LE(last, stack_frames_cache_->size());
252 auto it = stack_frames_cache_->begin();
253 stack_frames_cache_->erase(it + first, it + last);
254 }
255
256 std::string ToString(const TracePrintingOptions& opts) const override {
257 std::vector<std::string> files_to_find_prefix;
258 for (const StackFrame& frame : ToFrames()) {
259 if (!absl::StrContains(frame.file_name, kFilenameToIgnorePrefix)) {
260 files_to_find_prefix.push_back(frame.file_name);
261 }
262 }
263 int shared_prefix_size =
264 opts.filter_common_prefix
265 ? io::CommonPathPrefix(files_to_find_prefix).size()
266 : 0;
267
268 tensorflow::mutex_lock lock(mu_);
269 if (!opts.drop_internal_frames) {
270 return ToStringHelper(*stack_frames_cache_, opts, shared_prefix_size);
271 }
272
273 std::vector<StackFrame> filtered_frames;
274 for (const StackFrame& frame : *stack_frames_cache_) {
275 if (!IsInternalFrameForFilename(frame.file_name)) {
276 filtered_frames.push_back(frame);
277 }
278 }
279 return ToStringHelper(filtered_frames, opts, shared_prefix_size);
280 }
281
282 ~StackTraceWrapper() override {
283 PyGILState_STATE state = PyGILState_Ensure();
284 captured_.Clear();
285 source_map_.reset();
286 filter_.reset();
287 PyGILState_Release(state);
288 }
289
290 private:
291 StackTraceWrapper(StackTrace&& captured,
292 const std::shared_ptr<SourceMap>& source_map,
293 const std::shared_ptr<StringSet>& filter, int stacklevel)
294 : captured_(std::move(captured)),
295 source_map_(source_map),
296 filter_(filter),
297 stacklevel_(stacklevel) {}
298
299 static std::string ToStringHelper(absl::Span<const StackFrame> stack_frames,
300 const TracePrintingOptions& opts,
301 int shared_prefix_size) {
302 return absl::StrJoin(
303 stack_frames, "\n", [&](std::string* out, const StackFrame& frame) {
304 absl::StrAppend(out,
305 StackFrameToString(frame, opts, shared_prefix_size));
306 });
307 }
308
309 bool StackTraceFiltering(const char* file_name) const {
310 return filter_->contains(file_name);
311 }
312
313 // Note: Make sure to update move constructor while adding new member
314 // variables.
315 StackTrace captured_;
316 std::shared_ptr<SourceMap> source_map_;
317 std::shared_ptr<StringSet> filter_;
318 int stacklevel_;
319
320 // Using optional to force destruction while we hold a GIL.
321 mutable absl::optional<std::vector<StackFrame>> stack_frames_cache_
322 TF_GUARDED_BY(mu_);
323 mutable absl::optional<StackFrame> last_stack_frame_cache_ TF_GUARDED_BY(mu_);
324 mutable mutex mu_;
325};
326
327} // namespace
328
329PYBIND11_MODULE(_tf_stack, m) {
330 py::class_<PyBindSourceMap>(m, "PyBindSourceMap")
331 .def(py::init())
332 .def("update_to",
333 [](const PyBindSourceMap& self, const py::tuple& source_map) {
334 self.source_map_->clear();
335 for (const auto& item : source_map) {
336 const auto& tuple_item = py::cast<py::tuple>(item);
337
338 const auto& key = py::cast<py::tuple>(tuple_item[0]);
339 std::string&& k_filename = py::cast<std::string>(key[0]);
340 int k_lineno = py::cast<int>(key[1]);
341
342 const auto& value = py::cast<py::tuple>(tuple_item[1]);
343 std::string&& v_filename = py::cast<std::string>(value[0]);
344 int v_lineno = py::cast<int>(value[1]);
345 const auto& function_name_val = value[2];
346 std::string&& v_function_name =
347 function_name_val.is_none()
348 ? ""
349 : py::cast<std::string>(function_name_val);
350
351 self.source_map_->emplace(
352 SourceLoc{k_filename, k_lineno},
353 StackFrame({v_filename, v_lineno, v_function_name}));
354 }
355 });
356
357 py::class_<PyBindFileSet>(m, "PyBindFileSet")
358 .def(py::init())
359 .def("update_to", [](const PyBindFileSet& self, const py::set& file_set) {
360 self.file_set_->clear();
361 for (const auto& item : file_set) {
362 self.file_set_->insert(py::cast<std::string>(item));
363 }
364 });
365
366 py::class_<StackFrame>(m, "StackFrame")
367 .def_property_readonly(
368 "filename",
369 [](const StackFrame& self) { return py::str(self.file_name); })
370 .def_property_readonly(
371 "lineno",
372 [](const StackFrame& self) { return py::int_(self.line_number); })
373 .def_property_readonly(
374 "name",
375 [](const StackFrame& self) { return py::str(self.function_name); })
376 .def_property_readonly(
377 "line", [](const StackFrame& self) { return LineContents(self); })
378
379 // For compatibility with the traceback module.
380 .def("__eq__", &StackFrame::operator==)
381 .def("__ne__", &StackFrame::operator!=)
382 .def("__hash__",
383 [](const StackFrame& self) {
384 return absl::Hash<std::tuple<std::string, int, std::string>>()(
385 std::make_tuple(self.file_name, self.line_number,
386 self.function_name));
387 })
388 .def("__getitem__",
389 [](const StackFrame& self, const py::object& index) -> py::object {
390 return py::make_tuple(
391 py::str(self.file_name), py::int_(self.line_number),
392 py::str(self.function_name), LineContents(self))[index];
393 })
394 .def("__iter__",
395 [](const StackFrame& self) {
396 return py::iter(py::make_tuple(
397 py::str(self.file_name), py::int_(self.line_number),
398 py::str(self.function_name), LineContents(self))
399
400 );
401 })
402 .def("__repr__",
403 [](const StackFrame& self) { return StackFrameToString(self, {}); })
404 .def("__len__", [](const StackFrame&) { return 4; });
405
406 py::class_<StackTraceWrapper>(m, "StackTraceWrapper")
407 // TODO(slebedev): upstream negative indexing support into pybind11.
408 .def(
409 "__getitem__",
410 [](const StackTraceWrapper& self, py::ssize_t index) {
411 absl::Span<const StackFrame> frames = self.ToFrames();
412 const size_t eff_index =
413 index < 0 ? frames.size() + index : static_cast<size_t>(index);
414 if (eff_index >= frames.size()) {
415 throw py::index_error();
416 }
417 return frames[eff_index];
418 },
419 py::return_value_policy::reference_internal)
420 .def(
421 "__getitem__",
422 [](const StackTraceWrapper& self, py::slice slice) {
423 absl::Span<const StackFrame> frames = self.ToFrames();
424 py::ssize_t start, stop, step, slicelength;
425 if (!slice.compute(frames.size(), &start, &stop, &step,
426 &slicelength)) {
427 throw py::error_already_set();
428 }
429 if (step == 1) {
430 return StackTraceWrapper{frames.subspan(start, slicelength)};
431 }
432 // TODO(cheshire): Cleanup, use Python slicing logic directly
433 // instead.
434 std::vector<StackFrame> out;
435 out.reserve(slicelength);
436 // Python slices allow negative indexing.
437 for (int i = start; i != stop; i += step) {
438 out.push_back(frames[i]);
439 }
440 return StackTraceWrapper{out};
441 },
442 py::return_value_policy::reference_internal)
443 .def("__delitem__",
444 [](StackTraceWrapper& self, py::ssize_t index) {
445 absl::Span<const StackFrame> frames = self.ToFrames();
446 const size_t eff_index =
447 index < 0 ? frames.size() + index : static_cast<size_t>(index);
448 if (eff_index >= frames.size()) {
449 throw py::index_error();
450 }
451 self.Erase(eff_index, eff_index + 1);
452 })
453 .def("__delitem__",
454 [](StackTraceWrapper& self, py::slice slice) {
455 absl::Span<const StackFrame> frames = self.ToFrames();
456 py::ssize_t start, stop, step, slicelength;
457 if (!slice.compute(frames.size(), &start, &stop, &step,
458 &slicelength)) {
459 throw py::error_already_set();
460 }
461 if (step != 1) {
462 throw py::index_error();
463 }
464 if (stop > start) {
465 self.Erase(start, stop);
466 }
467 })
468 .def("__len__",
469 [](const StackTraceWrapper& self) { return self.ToFrames().size(); })
470 .def("__eq__",
471 [](const StackTraceWrapper& self, const StackTraceWrapper& other) {
472 return self.ToFrames() == other.ToFrames();
473 })
474 .def("__hash__",
475 [](const StackTraceWrapper& self) {
476 return py::hash(py::str(self.ToString({})));
477 })
478 // NOTE(feyu): consider remove this and use traceback.format_list(tb)
479 // to format the trace.
480 .def("__repr__",
481 [](const StackTraceWrapper& self) {
482 return py::str(self.ToString({}));
483 })
484 .def_property(
485 "_stacklevel", &StackTraceWrapper::get_stacklevel,
486 &StackTraceWrapper::set_stacklevel,
487 "Adjusts stacklevel; no effects after ToFrames() is called.")
488 .def(
489 "get_user_frames",
490 [](const StackTraceWrapper& self) {
491 return StackTraceWrapper{self.GetUserFrames()};
492 },
493 "Returns the non-framework frames as a new trace object.")
494 .def(
495 "last_user_frame",
496 [](const StackTraceWrapper& self) { return self.LastUserFrame(); },
497 "Returns the last non-framework frame.");
498
499 m.def("extract_stack_for_op", [](const PyBindSourceMap& source_map,
500 const PyBindFileSet& file_set,
501 TF_Operation* op, int stacklevel) {
502 DCHECK(!op->node.GetStackTrace()) << "Should not reset the stack trace";
503 op->node.SetStackTrace(
504 std::make_shared<StackTraceWrapper>(StackTraceWrapper::ExtractStack(
505 source_map.source_map_, file_set.file_set_, stacklevel)));
506 });
507
508 m.def(
509 "extract_stack",
510 [](const PyBindSourceMap& source_map, const PyBindFileSet& file_set) {
511 return StackTraceWrapper::ExtractStack(source_map.source_map_,
512 file_set.file_set_, 1);
513 },
514 py::return_value_policy::move);
515}
516
517} // namespace tensorflow
518