1 | #include <cstdlib> |
2 | #include <iomanip> |
3 | #include <sstream> |
4 | #include <string> |
5 | #include <unordered_map> |
6 | #include <vector> |
7 | |
8 | #include <ATen/core/function.h> |
9 | #include <c10/util/Exception.h> |
10 | #include <c10/util/StringUtil.h> |
11 | #include <torch/csrc/jit/api/function_impl.h> |
12 | #include <torch/csrc/jit/frontend/error_report.h> |
13 | #include <torch/csrc/jit/ir/ir.h> |
14 | #include <torch/csrc/jit/jit_log.h> |
15 | #include <torch/csrc/jit/serialization/python_print.h> |
16 | |
17 | namespace torch { |
18 | namespace jit { |
19 | |
20 | class JitLoggingConfig { |
21 | public: |
22 | static JitLoggingConfig& getInstance() { |
23 | static JitLoggingConfig instance; |
24 | return instance; |
25 | } |
26 | JitLoggingConfig(JitLoggingConfig const&) = delete; |
27 | void operator=(JitLoggingConfig const&) = delete; |
28 | |
29 | private: |
30 | std::string logging_levels; |
31 | std::unordered_map<std::string, size_t> files_to_levels; |
32 | std::ostream* out; |
33 | |
34 | // NOLINTNEXTLINE(cppcoreguidelines-pro-type-member-init) |
35 | JitLoggingConfig() { |
36 | const char* jit_log_level = std::getenv("PYTORCH_JIT_LOG_LEVEL" ); |
37 | logging_levels.assign(jit_log_level == nullptr ? "" : jit_log_level); |
38 | out = &std::cerr; |
39 | parse(); |
40 | } |
41 | void parse(); |
42 | |
43 | public: |
44 | std::string getLoggingLevels() const { |
45 | return this->logging_levels; |
46 | } |
47 | void setLoggingLevels(std::string levels) { |
48 | this->logging_levels = std::move(levels); |
49 | parse(); |
50 | } |
51 | |
52 | const std::unordered_map<std::string, size_t>& getFilesToLevels() const { |
53 | return this->files_to_levels; |
54 | } |
55 | |
56 | void setOutputStream(std::ostream& out_stream) { |
57 | this->out = &out_stream; |
58 | } |
59 | |
60 | std::ostream& getOutputStream() { |
61 | return *(this->out); |
62 | } |
63 | }; |
64 | |
65 | std::string get_jit_logging_levels() { |
66 | return JitLoggingConfig::getInstance().getLoggingLevels(); |
67 | } |
68 | |
69 | void set_jit_logging_levels(std::string level) { |
70 | JitLoggingConfig::getInstance().setLoggingLevels(std::move(level)); |
71 | } |
72 | |
73 | void set_jit_logging_output_stream(std::ostream& stream) { |
74 | JitLoggingConfig::getInstance().setOutputStream(stream); |
75 | } |
76 | |
77 | std::ostream& get_jit_logging_output_stream() { |
78 | return JitLoggingConfig::getInstance().getOutputStream(); |
79 | } |
80 | |
81 | // gets a string representation of a node header |
82 | // (e.g. outputs, a node kind and outputs) |
83 | std::string (const Node* node) { |
84 | std::stringstream ss; |
85 | node->print(ss, 0, {}, false, false, false, false); |
86 | return ss.str(); |
87 | } |
88 | |
89 | void JitLoggingConfig::parse() { |
90 | std::stringstream in_ss; |
91 | in_ss << "function:" << this->logging_levels; |
92 | |
93 | files_to_levels.clear(); |
94 | std::string line; |
95 | while (std::getline(in_ss, line, ':')) { |
96 | if (line.empty()) { |
97 | continue; |
98 | } |
99 | |
100 | auto index_at = line.find_last_of('>'); |
101 | auto begin_index = index_at == std::string::npos ? 0 : index_at + 1; |
102 | size_t logging_level = index_at == std::string::npos ? 0 : index_at + 1; |
103 | auto end_index = line.find_last_of('.') == std::string::npos |
104 | ? line.size() |
105 | : line.find_last_of('.'); |
106 | auto filename = line.substr(begin_index, end_index - begin_index); |
107 | files_to_levels.insert({filename, logging_level}); |
108 | } |
109 | } |
110 | |
111 | bool is_enabled(const char* cfname, JitLoggingLevels level) { |
112 | const auto& files_to_levels = |
113 | JitLoggingConfig::getInstance().getFilesToLevels(); |
114 | std::string fname{cfname}; |
115 | fname = c10::detail::StripBasename(fname); |
116 | const auto end_index = fname.find_last_of('.') == std::string::npos |
117 | ? fname.size() |
118 | : fname.find_last_of('.'); |
119 | const auto fname_no_ext = fname.substr(0, end_index); |
120 | |
121 | const auto it = files_to_levels.find(fname_no_ext); |
122 | if (it == files_to_levels.end()) { |
123 | return false; |
124 | } |
125 | |
126 | return level <= static_cast<JitLoggingLevels>(it->second); |
127 | } |
128 | |
129 | // Unfortunately, in `GraphExecutor` where `log_function` is invoked |
130 | // we won't have access to an original function, so we have to construct |
131 | // a dummy function to give to PythonPrint |
132 | std::string log_function(const std::shared_ptr<torch::jit::Graph>& graph) { |
133 | torch::jit::GraphFunction func("source_dump" , graph, nullptr); |
134 | std::vector<at::IValue> constants; |
135 | PrintDepsTable deps; |
136 | PythonPrint pp(constants, deps); |
137 | pp.printFunction(func); |
138 | return pp.str(); |
139 | } |
140 | |
141 | std::string jit_log_prefix( |
142 | const std::string& prefix, |
143 | const std::string& in_str) { |
144 | std::stringstream in_ss(in_str); |
145 | std::stringstream out_ss; |
146 | std::string line; |
147 | while (std::getline(in_ss, line)) { |
148 | out_ss << prefix << line << std::endl; |
149 | } |
150 | |
151 | return out_ss.str(); |
152 | } |
153 | |
154 | std::string jit_log_prefix( |
155 | JitLoggingLevels level, |
156 | const char* fn, |
157 | int l, |
158 | const std::string& in_str) { |
159 | std::stringstream prefix_ss; |
160 | prefix_ss << "[" ; |
161 | prefix_ss << level << " " ; |
162 | prefix_ss << c10::detail::StripBasename(std::string(fn)) << ":" ; |
163 | prefix_ss << std::setfill('0') << std::setw(3) << l; |
164 | prefix_ss << "] " ; |
165 | |
166 | return jit_log_prefix(prefix_ss.str(), in_str); |
167 | } |
168 | |
169 | std::ostream& operator<<(std::ostream& out, JitLoggingLevels level) { |
170 | switch (level) { |
171 | case JitLoggingLevels::GRAPH_DUMP: |
172 | out << "DUMP" ; |
173 | break; |
174 | case JitLoggingLevels::GRAPH_UPDATE: |
175 | out << "UPDATE" ; |
176 | break; |
177 | case JitLoggingLevels::GRAPH_DEBUG: |
178 | out << "DEBUG" ; |
179 | break; |
180 | default: |
181 | TORCH_INTERNAL_ASSERT(false, "Invalid level" ); |
182 | } |
183 | |
184 | return out; |
185 | } |
186 | |
187 | } // namespace jit |
188 | } // namespace torch |
189 | |