1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/ir/module.h> |
20 | #include <tvm/relay/base.h> |
21 | #include <tvm/relay/error.h> |
22 | |
23 | // clang-format off |
24 | #include <string> |
25 | #include <vector> |
26 | #include <rang.hpp> |
27 | // clang-format on |
28 | |
29 | namespace tvm { |
30 | namespace relay { |
31 | |
32 | template <typename T, typename U> |
33 | using NodeMap = std::unordered_map<T, U, ObjectPtrHash, ObjectPtrEqual>; |
34 | |
35 | void ErrorReporter::RenderErrors(const IRModule& module, bool use_color) { |
36 | // First we pick an error reporting strategy for each error. |
37 | // TODO(@jroesch): Spanned errors are currently not supported. |
38 | for (auto err : this->errors_) { |
39 | ICHECK(!err.span.defined()) << "attempting to use spanned errors, currently not supported" ; |
40 | } |
41 | |
42 | NodeMap<GlobalVar, NodeMap<ObjectRef, std::string>> error_maps; |
43 | |
44 | // Set control mode in order to produce colors; |
45 | if (use_color) { |
46 | rang::setControlMode(rang::control::Force); |
47 | } |
48 | |
49 | for (auto pair : this->node_to_gv_) { |
50 | auto node = pair.first; |
51 | auto global = Downcast<GlobalVar>(pair.second); |
52 | |
53 | auto has_errs = this->node_to_error_.find(node); |
54 | |
55 | ICHECK(has_errs != this->node_to_error_.end()); |
56 | |
57 | const auto& error_indices = has_errs->second; |
58 | |
59 | std::stringstream err_msg; |
60 | |
61 | err_msg << rang::fg::red; |
62 | err_msg << " " ; |
63 | for (auto index : error_indices) { |
64 | err_msg << this->errors_[index].what() << "; " ; |
65 | } |
66 | err_msg << rang::fg::reset; |
67 | |
68 | // Setup error map. |
69 | auto it = error_maps.find(global); |
70 | if (it != error_maps.end()) { |
71 | it->second.insert({node, err_msg.str()}); |
72 | } else { |
73 | error_maps.insert({global, {{node, err_msg.str()}}}); |
74 | } |
75 | } |
76 | |
77 | // Now we will construct the fully-annotated program to display to |
78 | // the user. |
79 | std::stringstream annotated_prog; |
80 | |
81 | // First we output a header for the errors. |
82 | annotated_prog << rang::style::bold << std::endl |
83 | << "Error(s) have occurred. The program has been annotated with them:" << std::endl |
84 | << std::endl |
85 | << rang::style::reset; |
86 | |
87 | // For each global function which contains errors, we will |
88 | // construct an annotated function. |
89 | for (auto pair : error_maps) { |
90 | auto global = pair.first; |
91 | auto err_map = pair.second; |
92 | auto func = module->Lookup(global); |
93 | |
94 | // We output the name of the function before displaying |
95 | // the annotated program. |
96 | annotated_prog << rang::style::bold << "In `" << global->name_hint << "`: " << std::endl |
97 | << rang::style::reset; |
98 | |
99 | // We then call into the Relay printer to generate the program. |
100 | // |
101 | // The annotation callback will annotate the error messages |
102 | // contained in the map. |
103 | annotated_prog << AsText(func, false, [&err_map](const ObjectRef& expr) { |
104 | auto it = err_map.find(expr); |
105 | if (it != err_map.end()) { |
106 | ICHECK_NE(it->second.size(), 0); |
107 | return it->second; |
108 | } else { |
109 | return std::string("" ); |
110 | } |
111 | }); |
112 | } |
113 | |
114 | auto msg = annotated_prog.str(); |
115 | |
116 | if (use_color) { |
117 | rang::setControlMode(rang::control::Auto); |
118 | } |
119 | |
120 | // Finally we report the error, currently we do so to LOG(FATAL), |
121 | // it may be good to instead report it to std::cout. |
122 | LOG(FATAL) << annotated_prog.str() << std::endl; |
123 | } |
124 | |
125 | void ErrorReporter::ReportAt(const GlobalVar& global, const ObjectRef& node, |
126 | const CompileError& err) { |
127 | size_t index_to_insert = this->errors_.size(); |
128 | this->errors_.push_back(err); |
129 | auto it = this->node_to_error_.find(node); |
130 | if (it != this->node_to_error_.end()) { |
131 | it->second.push_back(index_to_insert); |
132 | } else { |
133 | this->node_to_error_.insert({node, {index_to_insert}}); |
134 | } |
135 | this->node_to_gv_.insert({node, global}); |
136 | } |
137 | } // namespace relay |
138 | } // namespace tvm |
139 | |