1#include <cstdlib>
2#include <iomanip>
3#include <sstream>
4#include <string>
5#include <unordered_map>
6#include <vector>
7
8#include <ATen/core/function.h>
9#include <c10/util/Exception.h>
10#include <c10/util/StringUtil.h>
11#include <torch/csrc/jit/api/function_impl.h>
12#include <torch/csrc/jit/frontend/error_report.h>
13#include <torch/csrc/jit/ir/ir.h>
14#include <torch/csrc/jit/jit_log.h>
15#include <torch/csrc/jit/serialization/python_print.h>
16
17namespace torch {
18namespace jit {
19
20class JitLoggingConfig {
21 public:
22 static JitLoggingConfig& getInstance() {
23 static JitLoggingConfig instance;
24 return instance;
25 }
26 JitLoggingConfig(JitLoggingConfig const&) = delete;
27 void operator=(JitLoggingConfig const&) = delete;
28
29 private:
30 std::string logging_levels;
31 std::unordered_map<std::string, size_t> files_to_levels;
32 std::ostream* out;
33
34 // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init)
35 JitLoggingConfig() {
36 const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL");
37 logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level);
38 out = &std::cerr;
39 parse();
40 }
41 void parse();
42
43 public:
44 std::string getLoggingLevels() const {
45 return this->logging_levels;
46 }
47 void setLoggingLevels(std::string levels) {
48 this->logging_levels = std::move(levels);
49 parse();
50 }
51
52 const std::unordered_map<std::string, size_t>& getFilesToLevels() const {
53 return this->files_to_levels;
54 }
55
56 void setOutputStream(std::ostream& out_stream) {
57 this->out = &out_stream;
58 }
59
60 std::ostream& getOutputStream() {
61 return *(this->out);
62 }
63};
64
65std::string get_jit_logging_levels() {
66 return JitLoggingConfig::getInstance().getLoggingLevels();
67}
68
69void set_jit_logging_levels(std::string level) {
70 JitLoggingConfig::getInstance().setLoggingLevels(std::move(level));
71}
72
73void set_jit_logging_output_stream(std::ostream& stream) {
74 JitLoggingConfig::getInstance().setOutputStream(stream);
75}
76
77std::ostream& get_jit_logging_output_stream() {
78 return JitLoggingConfig::getInstance().getOutputStream();
79}
80
81// gets a string representation of a node header
82// (e.g. outputs, a node kind and outputs)
83std::string getHeader(const Node* node) {
84 std::stringstream ss;
85 node->print(ss, 0, {}, false, false, false, false);
86 return ss.str();
87}
88
89void JitLoggingConfig::parse() {
90 std::stringstream in_ss;
91 in_ss << "function:" << this->logging_levels;
92
93 files_to_levels.clear();
94 std::string line;
95 while (std::getline(in_ss, line, ':')) {
96 if (line.empty()) {
97 continue;
98 }
99
100 auto index_at = line.find_last_of('>');
101 auto begin_index = index_at == std::string::npos ? 0 : index_at + 1;
102 size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1;
103 auto end_index = line.find_last_of('.') == std::string::npos
104 ? line.size()
105 : line.find_last_of('.');
106 auto filename = line.substr(begin_index, end_index - begin_index);
107 files_to_levels.insert({filename, logging_level});
108 }
109}
110
111bool is_enabled(const char* cfname, JitLoggingLevels level) {
112 const auto& files_to_levels =
113 JitLoggingConfig::getInstance().getFilesToLevels();
114 std::string fname{cfname};
115 fname = c10::detail::StripBasename(fname);
116 const auto end_index = fname.find_last_of('.') == std::string::npos
117 ? fname.size()
118 : fname.find_last_of('.');
119 const auto fname_no_ext = fname.substr(0, end_index);
120
121 const auto it = files_to_levels.find(fname_no_ext);
122 if (it == files_to_levels.end()) {
123 return false;
124 }
125
126 return level <= static_cast<JitLoggingLevels>(it->second);
127}
128
129// Unfortunately, in `GraphExecutor` where `log_function` is invoked
130// we won't have access to an original function, so we have to construct
131// a dummy function to give to PythonPrint
132std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) {
133 torch::jit::GraphFunction func("source_dump", graph, nullptr);
134 std::vector<at::IValue> constants;
135 PrintDepsTable deps;
136 PythonPrint pp(constants, deps);
137 pp.printFunction(func);
138 return pp.str();
139}
140
141std::string jit_log_prefix(
142 const std::string& prefix,
143 const std::string& in_str) {
144 std::stringstream in_ss(in_str);
145 std::stringstream out_ss;
146 std::string line;
147 while (std::getline(in_ss, line)) {
148 out_ss << prefix << line << std::endl;
149 }
150
151 return out_ss.str();
152}
153
154std::string jit_log_prefix(
155 JitLoggingLevels level,
156 const char* fn,
157 int l,
158 const std::string& in_str) {
159 std::stringstream prefix_ss;
160 prefix_ss << "[";
161 prefix_ss << level << " ";
162 prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":";
163 prefix_ss << std::setfill('0') << std::setw(3) << l;
164 prefix_ss << "] ";
165
166 return jit_log_prefix(prefix_ss.str(), in_str);
167}
168
169std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) {
170 switch (level) {
171 case JitLoggingLevels::GRAPH_DUMP:
172 out << "DUMP";
173 break;
174 case JitLoggingLevels::GRAPH_UPDATE:
175 out << "UPDATE";
176 break;
177 case JitLoggingLevels::GRAPH_DEBUG:
178 out << "DEBUG";
179 break;
180 default:
181 TORCH_INTERNAL_ASSERT(false, "Invalid level");
182 }
183
184 return out;
185}
186
187} // namespace jit
188} // namespace torch
189