1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/ir/module.h>
20#include <tvm/relay/base.h>
21#include <tvm/relay/error.h>
22
23// clang-format off
24#include <string>
25#include <vector>
26#include <rang.hpp>
27// clang-format on
28
29namespace tvm {
30namespace relay {
31
32template <typename T, typename U>
33using NodeMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>;
34
35void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) {
36 // First we pick an error reporting strategy for each error.
37 // TODO(@jroesch): Spanned errors are currently not supported.
38 for (auto err : this->errors_) {
39 ICHECK(!err.span.defined()) << "attempting to use spanned errors, currently not supported";
40 }
41
42 NodeMap<GlobalVar, NodeMap<ObjectRef, std::string>> error_maps;
43
44 // Set control mode in order to produce colors;
45 if (use_color) {
46 rang::setControlMode(rang::control::Force);
47 }
48
49 for (auto pair : this->node_to_gv_) {
50 auto node = pair.first;
51 auto global = Downcast<GlobalVar>(pair.second);
52
53 auto has_errs = this->node_to_error_.find(node);
54
55 ICHECK(has_errs != this->node_to_error_.end());
56
57 const auto& error_indices = has_errs->second;
58
59 std::stringstream err_msg;
60
61 err_msg << rang::fg::red;
62 err_msg << " ";
63 for (auto index : error_indices) {
64 err_msg << this->errors_[index].what() << "; ";
65 }
66 err_msg << rang::fg::reset;
67
68 // Setup error map.
69 auto it = error_maps.find(global);
70 if (it != error_maps.end()) {
71 it->second.insert({node, err_msg.str()});
72 } else {
73 error_maps.insert({global, {{node, err_msg.str()}}});
74 }
75 }
76
77 // Now we will construct the fully-annotated program to display to
78 // the user.
79 std::stringstream annotated_prog;
80
81 // First we output a header for the errors.
82 annotated_prog << rang::style::bold << std::endl
83 << "Error(s) have occurred. The program has been annotated with them:" << std::endl
84 << std::endl
85 << rang::style::reset;
86
87 // For each global function which contains errors, we will
88 // construct an annotated function.
89 for (auto pair : error_maps) {
90 auto global = pair.first;
91 auto err_map = pair.second;
92 auto func = module->Lookup(global);
93
94 // We output the name of the function before displaying
95 // the annotated program.
96 annotated_prog << rang::style::bold << "In `" << global->name_hint << "`: " << std::endl
97 << rang::style::reset;
98
99 // We then call into the Relay printer to generate the program.
100 //
101 // The annotation callback will annotate the error messages
102 // contained in the map.
103 annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) {
104 auto it = err_map.find(expr);
105 if (it != err_map.end()) {
106 ICHECK_NE(it->second.size(), 0);
107 return it->second;
108 } else {
109 return std::string("");
110 }
111 });
112 }
113
114 auto msg = annotated_prog.str();
115
116 if (use_color) {
117 rang::setControlMode(rang::control::Auto);
118 }
119
120 // Finally we report the error, currently we do so to LOG(FATAL),
121 // it may be good to instead report it to std::cout.
122 LOG(FATAL) << annotated_prog.str() << std::endl;
123}
124
125void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node,
126 const CompileError& err) {
127 size_t index_to_insert = this->errors_.size();
128 this->errors_.push_back(err);
129 auto it = this->node_to_error_.find(node);
130 if (it != this->node_to_error_.end()) {
131 it->second.push_back(index_to_insert);
132 } else {
133 this->node_to_error_.insert({node, {index_to_insert}});
134 }
135 this->node_to_gv_.insert({node, global});
136}
137} // namespace relay
138} // namespace tvm
139