1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/ir/diagnostic.cc
22 * \brief Implementation of DiagnosticContext and friends.
23 */
24#include <tvm/ir/diagnostic.h>
25#include <tvm/ir/source_map.h>
26
27#include <rang.hpp>
28
29namespace tvm {
30
31// failed to check to argument arg0.dims[0] != 0
32
33/* Diagnostic */
34TVM_REGISTER_NODE_TYPE(DiagnosticNode);
35
36TVM_REGISTER_GLOBAL("diagnostics.Diagnostic")
37 .set_body_typed([](int level, Span span, String message) {
38 return Diagnostic(static_cast<DiagnosticLevel>(level), span, message);
39 });
40
41Diagnostic::Diagnostic(DiagnosticLevel level, Span span, const std::string& message) {
42 auto n = make_object<DiagnosticNode>();
43 n->level = level;
44 n->span = span;
45 n->message = message;
46 data_ = std::move(n);
47}
48
49DiagnosticBuilder Diagnostic::Bug(Span span) {
50 return DiagnosticBuilder(DiagnosticLevel::kBug, span);
51}
52
53DiagnosticBuilder Diagnostic::Error(Span span) {
54 return DiagnosticBuilder(DiagnosticLevel::kError, span);
55}
56
57DiagnosticBuilder Diagnostic::Warning(Span span) {
58 return DiagnosticBuilder(DiagnosticLevel::kWarning, span);
59}
60
61DiagnosticBuilder Diagnostic::Note(Span span) {
62 return DiagnosticBuilder(DiagnosticLevel::kNote, span);
63}
64
65DiagnosticBuilder Diagnostic::Help(Span span) {
66 return DiagnosticBuilder(DiagnosticLevel::kHelp, span);
67}
68
69/* Diagnostic Renderer */
70TVM_REGISTER_NODE_TYPE(DiagnosticRendererNode);
71
72void DiagnosticRenderer::Render(const DiagnosticContext& ctx) { (*this)->renderer(ctx); }
73
74TVM_DLL DiagnosticRenderer::DiagnosticRenderer(
75 TypedPackedFunc<void(DiagnosticContext ctx)> renderer) {
76 auto n = make_object<DiagnosticRendererNode>();
77 n->renderer = renderer;
78 data_ = std::move(n);
79}
80
81TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRenderer")
82 .set_body_typed([](TypedPackedFunc<void(DiagnosticContext ctx)> renderer) {
83 return DiagnosticRenderer(renderer);
84 });
85
86/* Diagnostic Context */
87TVM_REGISTER_NODE_TYPE(DiagnosticContextNode);
88
89void DiagnosticContext::Render() {
90 (*this)->renderer.Render(*this);
91
92 int errs = 0;
93 if ((*this)->diagnostics.size()) {
94 for (auto diagnostic : (*this)->diagnostics) {
95 if (diagnostic->level == DiagnosticLevel::kError) {
96 errs += 1;
97 }
98 }
99 }
100
101 if (errs) {
102 (*this)->renderer = DiagnosticRenderer();
103 LOG(FATAL) << "DiagnosticError: one or more error diagnostics were "
104 << "emitted, please check diagnostic render for output.";
105 }
106}
107
108TVM_REGISTER_GLOBAL("diagnostics.DiagnosticRendererRender")
109 .set_body_typed([](DiagnosticRenderer renderer, DiagnosticContext ctx) {
110 renderer.Render(ctx);
111 });
112
113DiagnosticContext::DiagnosticContext(const IRModule& module, const DiagnosticRenderer& renderer) {
114 CHECK(renderer.defined()) << "can not initialize a diagnostic renderer with a null function";
115 auto n = make_object<DiagnosticContextNode>();
116 n->module = module;
117 n->renderer = renderer;
118 data_ = std::move(n);
119}
120
121TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContext")
122 .set_body_typed([](const IRModule& module, const DiagnosticRenderer& renderer) {
123 return DiagnosticContext(module, renderer);
124 });
125
126/*! \brief Emit a diagnostic. */
127void DiagnosticContext::Emit(const Diagnostic& diagnostic) {
128 (*this)->diagnostics.push_back(diagnostic);
129}
130
131TVM_REGISTER_GLOBAL("diagnostics.Emit")
132 .set_body_typed([](DiagnosticContext ctx, const Diagnostic& diagnostic) {
133 return ctx.Emit(diagnostic);
134 });
135
136TVM_REGISTER_GLOBAL("diagnostics.DiagnosticContextRender")
137 .set_body_typed([](DiagnosticContext context) { return context.Render(); });
138
139/*! \brief Emit a diagnostic. */
140void DiagnosticContext::EmitFatal(const Diagnostic& diagnostic) {
141 Emit(diagnostic);
142 Render();
143}
144
145/* Default Terminal Renderer. */
146static const char* DEFAULT_RENDERER = "diagnostics.DefaultRenderer";
147static const char* OVERRIDE_RENDERER = "diagnostics.OverrideRenderer";
148
149DiagnosticRenderer GetRenderer() {
150 auto override_pf = tvm::runtime::Registry::Get(OVERRIDE_RENDERER);
151 tvm::runtime::TypedPackedFunc<ObjectRef()> pf;
152 if (override_pf) {
153 pf = tvm::runtime::TypedPackedFunc<ObjectRef()>(*override_pf);
154 } else {
155 auto default_pf = tvm::runtime::Registry::Get(DEFAULT_RENDERER);
156 ICHECK(default_pf != nullptr)
157 << "Can not find registered function for " << DEFAULT_RENDERER << "." << std::endl
158 << "Either this is an internal error or the default function was overloaded incorrectly.";
159 pf = tvm::runtime::TypedPackedFunc<ObjectRef()>(*default_pf);
160 }
161 return Downcast<DiagnosticRenderer>(pf());
162}
163
164DiagnosticContext DiagnosticContext::Default(const IRModule& module) {
165 auto renderer = GetRenderer();
166 return DiagnosticContext(module, renderer);
167}
168
169TVM_REGISTER_GLOBAL("diagnostics.Default").set_body_typed([](const IRModule& module) {
170 return DiagnosticContext::Default(module);
171});
172
173std::ostream& EmitDiagnosticHeader(std::ostream& out, const Span& span, DiagnosticLevel level,
174 std::string msg) {
175 rang::fg diagnostic_color = rang::fg::reset;
176 std::string diagnostic_type;
177
178 switch (level) {
179 case DiagnosticLevel::kWarning: {
180 diagnostic_color = rang::fg::yellow;
181 diagnostic_type = "warning";
182 break;
183 }
184 case DiagnosticLevel::kError: {
185 diagnostic_color = rang::fg::red;
186 diagnostic_type = "error";
187 break;
188 }
189 case DiagnosticLevel::kBug: {
190 diagnostic_color = rang::fg::blue;
191 diagnostic_type = "bug";
192 break;
193 }
194 case DiagnosticLevel::kNote: {
195 diagnostic_color = rang::fg::reset;
196 diagnostic_type = "note";
197 break;
198 }
199 case DiagnosticLevel::kHelp: {
200 diagnostic_color = rang::fg::reset;
201 diagnostic_type = "help";
202 break;
203 }
204 }
205
206 out << rang::style::bold << diagnostic_color << diagnostic_type << ": " << rang::fg::reset << msg
207 << std::endl
208 << rang::fg::blue << " --> " << rang::fg::reset << rang::style::reset
209 << span->source_name->name << ":" << span->line << ":" << span->column << std::endl;
210
211 return out;
212}
213
214/*! \brief Generate an error message at a specific line and column with the
215 * annotated message.
216 *
217 * The error is written directly to the `out` std::ostream.
218 *
219 * \param out The output ostream.
220 * \param line The line at which to report a diagnostic.
221 * \param line The column at which to report a diagnostic.
222 * \param msg The message to attach.
223 */
224void ReportAt(const DiagnosticContext& context, std::ostream& out, const Span& span,
225 const Diagnostic& diagnostic) {
226 if (!span.defined()) {
227 out << diagnostic->message << std::endl;
228 return;
229 }
230
231 ICHECK(context->module->source_map.defined());
232 auto it = context->module->source_map->source_map.find(span->source_name);
233
234 // If the source name is not in the current source map, sources were not annotated.
235 if (it == context->module->source_map->source_map.end()) {
236 LOG(FATAL) << "The source maps are not populated for this module. "
237 << "Please use `tvm.relay.transform.AnnotateSpans` to attach source maps for error "
238 "reporting.\n"
239 << "Error: " << diagnostic->message;
240 }
241
242 auto source = (*it).second;
243 VLOG(1) << "Source: " << std::endl << source->source;
244
245 VLOG(1) << "ReportAt "
246 << "span = " << span << " msg = " << diagnostic->message;
247
248 auto line_text = source.GetLine(span->line);
249
250 std::stringstream line_header_s;
251 line_header_s << " " << span->line << " ";
252 auto line_header = line_header_s.str();
253
254 std::stringstream no_line_header_s;
255 for (size_t i = 0; i < line_header.size(); i++) {
256 no_line_header_s << " ";
257 }
258 auto no_line_header = no_line_header_s.str();
259
260 EmitDiagnosticHeader(out, span, diagnostic->level, diagnostic->message)
261 << no_line_header << "| " << std::endl
262 << line_header << "| " << line_text << std::endl
263 << no_line_header << "| ";
264
265 std::stringstream marker;
266 for (size_t i = 1; i <= line_text.size(); i++) {
267 if (static_cast<int>(i) >= span->column && static_cast<int>(i) < span->end_column) {
268 marker << "^";
269 } else {
270 marker << " ";
271 }
272 }
273 out << marker.str();
274 out << std::endl;
275}
276
277// TODO(@jroesch): eventually modularize the rendering interface to provide control of how to
278// format errors.
279DiagnosticRenderer TerminalRenderer(std::ostream& out) {
280 return DiagnosticRenderer([&](const DiagnosticContext& ctx) {
281 for (auto diagnostic : ctx->diagnostics) {
282 ReportAt(ctx, out, diagnostic->span, diagnostic);
283 }
284 });
285}
286
287TVM_REGISTER_GLOBAL(DEFAULT_RENDERER).set_body_typed([]() { return TerminalRenderer(std::cout); });
288
289TVM_REGISTER_GLOBAL("diagnostics.GetRenderer").set_body_typed([]() { return GetRenderer(); });
290
291TVM_REGISTER_GLOBAL("diagnostics.ClearRenderer").set_body_typed([]() {
292 tvm::runtime::Registry::Remove(OVERRIDE_RENDERER);
293});
294
295} // namespace tvm
296