1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file install_debug_spans.cc |
22 | * \brief Prints TIR code in memory and replaces all spans in the module with |
23 | the location to which the ops would be printed |
24 | */ |
25 | |
26 | #include "./install_debug_spans.h" |
27 | |
28 | #include <tvm/tir/transform.h> |
29 | |
30 | #include <string> |
31 | #include <utility> |
32 | |
33 | #include "../../relay/printer/tir_text_printer_debug.h" |
34 | |
35 | namespace tvm { |
36 | namespace tir { |
37 | |
38 | Stmt DebugInfoInstaller::InstallInfo(const std::string& name, const Stmt& stmt) { |
39 | DebugInfoInstaller installer(stmt, name + ".tir" ); |
40 | return installer.VisitStmt(stmt); |
41 | } |
42 | |
43 | DebugInfoInstaller::DebugInfoInstaller(const Stmt& stmt, const std::string& filename) { |
44 | // Determine the line that each stmt/expr will be printed on |
45 | tvm::relay::TIRTextPrinterDebug printer(false); |
46 | |
47 | // Fill in the stmts and exprs' line info |
48 | auto result = printer.Print(stmt).str(); |
49 | |
50 | // Create map of the stmt/expr -> its line number in the output to later |
51 | // create new spans for each stmt/expr |
52 | const auto& stmts = printer.GetStmtsByLine(); |
53 | VLOG(0) << "Debug printer found " << stmts.size() << " stmts after printing" ; |
54 | for (const auto& line : stmts) { |
55 | stmt_lines_[std::get<0>(line)] = std::get<1>(line); |
56 | } |
57 | |
58 | const auto& exprs = printer.GetExprsByLine(); |
59 | VLOG(0) << "Debug printer found " << exprs.size() << " exprs after printing" ; |
60 | for (const auto& line : exprs) { |
61 | expr_lines_[std::get<0>(line)] = std::get<1>(line); |
62 | } |
63 | |
64 | // Output the printed TIR to the specified file |
65 | VLOG(0) << "Outputting TIR to " << filename; |
66 | filename_ = std::move(filename); |
67 | std::ofstream out(filename_); |
68 | out << result; |
69 | out.close(); |
70 | } |
71 | |
72 | PrimExpr DebugInfoInstaller::VisitExpr(const PrimExpr& expr) { |
73 | PrimExpr result = expr; |
74 | result = StmtExprMutator::VisitExpr(result); |
75 | return result; |
76 | } |
77 | |
78 | Stmt DebugInfoInstaller::VisitStmt(const Stmt& stmt) { |
79 | Stmt result = stmt; |
80 | result = StmtExprMutator::VisitStmt(result); |
81 | return result; |
82 | } |
83 | |
84 | Span DebugInfoInstaller::MaybeSpan(const StmtNode* op) { |
85 | auto entry = stmt_lines_.find(op); |
86 | if (entry == stmt_lines_.end()) { |
87 | return Span(); |
88 | } else { |
89 | size_t column = 0; |
90 | size_t line = entry->second; |
91 | return Span(SourceName::Get(filename_), line, line, column, column); |
92 | } |
93 | } |
94 | |
95 | Span DebugInfoInstaller::MaybeSpan(const PrimExprNode* op) { |
96 | auto entry = expr_lines_.find(op); |
97 | if (entry == expr_lines_.end()) { |
98 | return Span(); |
99 | } else { |
100 | size_t column = 0; |
101 | size_t line = entry->second; |
102 | return Span(SourceName::Get(filename_), line, line, column, column); |
103 | } |
104 | } |
105 | |
106 | #define X(TypeName) \ |
107 | PrimExpr DebugInfoInstaller::VisitExpr_(const TypeName##Node* op) { \ |
108 | auto new_expr = StmtExprMutator::VisitExpr_(op); \ |
109 | auto new_type = Downcast<TypeName>(new_expr); \ |
110 | auto new_node = new_type.CopyOnWrite(); \ |
111 | new_node->span = MaybeSpan(op); \ |
112 | return new_type; \ |
113 | } |
114 | TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_EXPRS |
115 | #undef X |
116 | |
117 | #define X(TypeName) \ |
118 | Stmt DebugInfoInstaller::VisitStmt_(const TypeName##Node* op) { \ |
119 | Stmt new_stmt = StmtExprMutator::VisitStmt_(op); \ |
120 | auto new_type = Downcast<TypeName>(new_stmt); \ |
121 | auto new_node = new_type.CopyOnWrite(); \ |
122 | new_node->span = MaybeSpan(op); \ |
123 | return new_type; \ |
124 | } |
125 | TVM_TIR_TRANSFORMS_INSTALL_DEBUG_SPANS_SUPPORTED_STMTS |
126 | #undef X |
127 | |
128 | namespace transform { |
129 | |
130 | Pass InstallDebugSpans() { |
131 | auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) { |
132 | ICHECK(m->functions.size() == 1) |
133 | << "Debug info can only be added to IRModules with a single function" ; |
134 | // There is known to be only 1 function in the module at this point |
135 | auto entry = m->functions.begin(); |
136 | auto name = std::get<0>(*entry)->name_hint; |
137 | auto* n = f.CopyOnWrite(); |
138 | |
139 | n->body = DebugInfoInstaller::InstallInfo(std::move(name), std::move(f->body)); |
140 | |
141 | return f; |
142 | }; |
143 | return CreatePrimFuncPass(pass_func, 0, "tir.InstallDebugSpans" , {}); |
144 | } |
145 | |
146 | TVM_REGISTER_GLOBAL("tir.transform.InstallDebugSpans" ).set_body_typed(InstallDebugSpans); |
147 | |
148 | } // namespace transform |
149 | } // namespace tir |
150 | } // namespace tvm |
151 | |