1 | // Copyright (c) 2020 The Khronos Group Inc. |
2 | // Copyright (c) 2020 Valve Corporation |
3 | // Copyright (c) 2020 LunarG Inc. |
4 | // |
5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
6 | // you may not use this file except in compliance with the License. |
7 | // You may obtain a copy of the License at |
8 | // |
9 | // http://www.apache.org/licenses/LICENSE-2.0 |
10 | // |
11 | // Unless required by applicable law or agreed to in writing, software |
12 | // distributed under the License is distributed on an "AS IS" BASIS, |
13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
14 | // See the License for the specific language governing permissions and |
15 | // limitations under the License. |
16 | |
17 | #include "inst_debug_printf_pass.h" |
18 | |
19 | #include "source/util/string_utils.h" |
20 | #include "spirv/unified1/NonSemanticDebugPrintf.h" |
21 | |
22 | namespace spvtools { |
23 | namespace opt { |
24 | |
25 | void InstDebugPrintfPass::GenOutputValues(Instruction* val_inst, |
26 | std::vector<uint32_t>* val_ids, |
27 | InstructionBuilder* builder) { |
28 | uint32_t val_ty_id = val_inst->type_id(); |
29 | analysis::TypeManager* type_mgr = context()->get_type_mgr(); |
30 | analysis::Type* val_ty = type_mgr->GetType(val_ty_id); |
31 | switch (val_ty->kind()) { |
32 | case analysis::Type::kVector: { |
33 | analysis::Vector* v_ty = val_ty->AsVector(); |
34 | const analysis::Type* c_ty = v_ty->element_type(); |
35 | uint32_t c_ty_id = type_mgr->GetId(c_ty); |
36 | for (uint32_t c = 0; c < v_ty->element_count(); ++c) { |
37 | Instruction* c_inst = builder->AddIdLiteralOp( |
38 | c_ty_id, SpvOpCompositeExtract, val_inst->result_id(), c); |
39 | GenOutputValues(c_inst, val_ids, builder); |
40 | } |
41 | return; |
42 | } |
43 | case analysis::Type::kBool: { |
44 | // Select between uint32 zero or one |
45 | uint32_t zero_id = builder->GetUintConstantId(0); |
46 | uint32_t one_id = builder->GetUintConstantId(1); |
47 | Instruction* sel_inst = builder->AddTernaryOp( |
48 | GetUintId(), SpvOpSelect, val_inst->result_id(), one_id, zero_id); |
49 | val_ids->push_back(sel_inst->result_id()); |
50 | return; |
51 | } |
52 | case analysis::Type::kFloat: { |
53 | analysis::Float* f_ty = val_ty->AsFloat(); |
54 | switch (f_ty->width()) { |
55 | case 16: { |
56 | // Convert float16 to float32 and recurse |
57 | Instruction* f32_inst = builder->AddUnaryOp( |
58 | GetFloatId(), SpvOpFConvert, val_inst->result_id()); |
59 | GenOutputValues(f32_inst, val_ids, builder); |
60 | return; |
61 | } |
62 | case 64: { |
63 | // Bitcast float64 to uint64 and recurse |
64 | Instruction* ui64_inst = builder->AddUnaryOp( |
65 | GetUint64Id(), SpvOpBitcast, val_inst->result_id()); |
66 | GenOutputValues(ui64_inst, val_ids, builder); |
67 | return; |
68 | } |
69 | case 32: { |
70 | // Bitcase float32 to uint32 |
71 | Instruction* bc_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast, |
72 | val_inst->result_id()); |
73 | val_ids->push_back(bc_inst->result_id()); |
74 | return; |
75 | } |
76 | default: |
77 | assert(false && "unsupported float width" ); |
78 | return; |
79 | } |
80 | } |
81 | case analysis::Type::kInteger: { |
82 | analysis::Integer* i_ty = val_ty->AsInteger(); |
83 | switch (i_ty->width()) { |
84 | case 64: { |
85 | Instruction* ui64_inst = val_inst; |
86 | if (i_ty->IsSigned()) { |
87 | // Bitcast sint64 to uint64 |
88 | ui64_inst = builder->AddUnaryOp(GetUint64Id(), SpvOpBitcast, |
89 | val_inst->result_id()); |
90 | } |
91 | // Break uint64 into 2x uint32 |
92 | Instruction* lo_ui64_inst = builder->AddUnaryOp( |
93 | GetUintId(), SpvOpUConvert, ui64_inst->result_id()); |
94 | Instruction* rshift_ui64_inst = builder->AddBinaryOp( |
95 | GetUint64Id(), SpvOpShiftRightLogical, ui64_inst->result_id(), |
96 | builder->GetUintConstantId(32)); |
97 | Instruction* hi_ui64_inst = builder->AddUnaryOp( |
98 | GetUintId(), SpvOpUConvert, rshift_ui64_inst->result_id()); |
99 | val_ids->push_back(lo_ui64_inst->result_id()); |
100 | val_ids->push_back(hi_ui64_inst->result_id()); |
101 | return; |
102 | } |
103 | case 8: { |
104 | Instruction* ui8_inst = val_inst; |
105 | if (i_ty->IsSigned()) { |
106 | // Bitcast sint8 to uint8 |
107 | ui8_inst = builder->AddUnaryOp(GetUint8Id(), SpvOpBitcast, |
108 | val_inst->result_id()); |
109 | } |
110 | // Convert uint8 to uint32 |
111 | Instruction* ui32_inst = builder->AddUnaryOp( |
112 | GetUintId(), SpvOpUConvert, ui8_inst->result_id()); |
113 | val_ids->push_back(ui32_inst->result_id()); |
114 | return; |
115 | } |
116 | case 32: { |
117 | Instruction* ui32_inst = val_inst; |
118 | if (i_ty->IsSigned()) { |
119 | // Bitcast sint32 to uint32 |
120 | ui32_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast, |
121 | val_inst->result_id()); |
122 | } |
123 | // uint32 needs no further processing |
124 | val_ids->push_back(ui32_inst->result_id()); |
125 | return; |
126 | } |
127 | default: |
128 | // TODO(greg-lunarg): Support non-32-bit int |
129 | assert(false && "unsupported int width" ); |
130 | return; |
131 | } |
132 | } |
133 | default: |
134 | assert(false && "unsupported type" ); |
135 | return; |
136 | } |
137 | } |
138 | |
139 | void InstDebugPrintfPass::GenOutputCode( |
140 | Instruction* printf_inst, uint32_t stage_idx, |
141 | std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { |
142 | BasicBlock* back_blk_ptr = &*new_blocks->back(); |
143 | InstructionBuilder builder( |
144 | context(), back_blk_ptr, |
145 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
146 | // Gen debug printf record validation-specific values. The format string |
147 | // will have its id written. Vectors will need to be broken down into |
148 | // component values. float16 will need to be converted to float32. Pointer |
149 | // and uint64 will need to be converted to two uint32 values. float32 will |
150 | // need to be bitcast to uint32. int32 will need to be bitcast to uint32. |
151 | std::vector<uint32_t> val_ids; |
152 | bool is_first_operand = false; |
153 | printf_inst->ForEachInId( |
154 | [&is_first_operand, &val_ids, &builder, this](const uint32_t* iid) { |
155 | // skip set operand |
156 | if (!is_first_operand) { |
157 | is_first_operand = true; |
158 | return; |
159 | } |
160 | Instruction* opnd_inst = get_def_use_mgr()->GetDef(*iid); |
161 | if (opnd_inst->opcode() == SpvOpString) { |
162 | uint32_t string_id_id = builder.GetUintConstantId(*iid); |
163 | val_ids.push_back(string_id_id); |
164 | } else { |
165 | GenOutputValues(opnd_inst, &val_ids, &builder); |
166 | } |
167 | }); |
168 | GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids, |
169 | &builder); |
170 | context()->KillInst(printf_inst); |
171 | } |
172 | |
173 | void InstDebugPrintfPass::GenDebugPrintfCode( |
174 | BasicBlock::iterator ref_inst_itr, |
175 | UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx, |
176 | std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { |
177 | // If not DebugPrintf OpExtInst, return. |
178 | Instruction* printf_inst = &*ref_inst_itr; |
179 | if (printf_inst->opcode() != SpvOpExtInst) return; |
180 | if (printf_inst->GetSingleWordInOperand(0) != ext_inst_printf_id_) return; |
181 | if (printf_inst->GetSingleWordInOperand(1) != |
182 | NonSemanticDebugPrintfDebugPrintf) |
183 | return; |
184 | // Initialize DefUse manager before dismantling module |
185 | (void)get_def_use_mgr(); |
186 | // Move original block's preceding instructions into first new block |
187 | std::unique_ptr<BasicBlock> new_blk_ptr; |
188 | MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr); |
189 | new_blocks->push_back(std::move(new_blk_ptr)); |
190 | // Generate instructions to output printf args to printf buffer |
191 | GenOutputCode(printf_inst, stage_idx, new_blocks); |
192 | // Caller expects at least two blocks with last block containing remaining |
193 | // code, so end block after instrumentation, create remainder block, and |
194 | // branch to it |
195 | uint32_t rem_blk_id = TakeNextId(); |
196 | std::unique_ptr<Instruction> rem_label(NewLabel(rem_blk_id)); |
197 | BasicBlock* back_blk_ptr = &*new_blocks->back(); |
198 | InstructionBuilder builder( |
199 | context(), back_blk_ptr, |
200 | IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping); |
201 | (void)builder.AddBranch(rem_blk_id); |
202 | // Gen remainder block |
203 | new_blk_ptr.reset(new BasicBlock(std::move(rem_label))); |
204 | builder.SetInsertPoint(&*new_blk_ptr); |
205 | // Move original block's remaining code into remainder block and add |
206 | // to new blocks |
207 | MovePostludeCode(ref_block_itr, &*new_blk_ptr); |
208 | new_blocks->push_back(std::move(new_blk_ptr)); |
209 | } |
210 | |
211 | void InstDebugPrintfPass::InitializeInstDebugPrintf() { |
212 | // Initialize base class |
213 | InitializeInstrument(); |
214 | } |
215 | |
216 | Pass::Status InstDebugPrintfPass::ProcessImpl() { |
217 | // Perform printf instrumentation on each entry point function in module |
218 | InstProcessFunction pfn = |
219 | [this](BasicBlock::iterator ref_inst_itr, |
220 | UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx, |
221 | std::vector<std::unique_ptr<BasicBlock>>* new_blocks) { |
222 | return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx, |
223 | new_blocks); |
224 | }; |
225 | (void)InstProcessEntryPointCallTree(pfn); |
226 | // Remove DebugPrintf OpExtInstImport instruction |
227 | Instruction* ext_inst_import_inst = |
228 | get_def_use_mgr()->GetDef(ext_inst_printf_id_); |
229 | context()->KillInst(ext_inst_import_inst); |
230 | // If no remaining non-semantic instruction sets, remove non-semantic debug |
231 | // info extension from module and feature manager |
232 | bool non_sem_set_seen = false; |
233 | for (auto c_itr = context()->module()->ext_inst_import_begin(); |
234 | c_itr != context()->module()->ext_inst_import_end(); ++c_itr) { |
235 | const std::string set_name = c_itr->GetInOperand(0).AsString(); |
236 | if (spvtools::utils::starts_with(set_name, "NonSemantic." )) { |
237 | non_sem_set_seen = true; |
238 | break; |
239 | } |
240 | } |
241 | if (!non_sem_set_seen) { |
242 | for (auto c_itr = context()->module()->extension_begin(); |
243 | c_itr != context()->module()->extension_end(); ++c_itr) { |
244 | const std::string ext_name = c_itr->GetInOperand(0).AsString(); |
245 | if (ext_name == "SPV_KHR_non_semantic_info" ) { |
246 | context()->KillInst(&*c_itr); |
247 | break; |
248 | } |
249 | } |
250 | context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info); |
251 | } |
252 | return Status::SuccessWithChange; |
253 | } |
254 | |
255 | Pass::Status InstDebugPrintfPass::Process() { |
256 | ext_inst_printf_id_ = |
257 | get_module()->GetExtInstImportId("NonSemantic.DebugPrintf" ); |
258 | if (ext_inst_printf_id_ == 0) return Status::SuccessWithoutChange; |
259 | InitializeInstDebugPrintf(); |
260 | return ProcessImpl(); |
261 | } |
262 | |
263 | } // namespace opt |
264 | } // namespace spvtools |
265 | |