1// Copyright (c) 2020 The Khronos Group Inc.
2// Copyright (c) 2020 Valve Corporation
3// Copyright (c) 2020 LunarG Inc.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16
17#include "inst_debug_printf_pass.h"
18
19#include "source/util/string_utils.h"
20#include "spirv/unified1/NonSemanticDebugPrintf.h"
21
22namespace spvtools {
23namespace opt {
24
25void InstDebugPrintfPass::GenOutputValues(Instruction* val_inst,
26 std::vector<uint32_t>* val_ids,
27 InstructionBuilder* builder) {
28 uint32_t val_ty_id = val_inst->type_id();
29 analysis::TypeManager* type_mgr = context()->get_type_mgr();
30 analysis::Type* val_ty = type_mgr->GetType(val_ty_id);
31 switch (val_ty->kind()) {
32 case analysis::Type::kVector: {
33 analysis::Vector* v_ty = val_ty->AsVector();
34 const analysis::Type* c_ty = v_ty->element_type();
35 uint32_t c_ty_id = type_mgr->GetId(c_ty);
36 for (uint32_t c = 0; c < v_ty->element_count(); ++c) {
37 Instruction* c_inst = builder->AddIdLiteralOp(
38 c_ty_id, SpvOpCompositeExtract, val_inst->result_id(), c);
39 GenOutputValues(c_inst, val_ids, builder);
40 }
41 return;
42 }
43 case analysis::Type::kBool: {
44 // Select between uint32 zero or one
45 uint32_t zero_id = builder->GetUintConstantId(0);
46 uint32_t one_id = builder->GetUintConstantId(1);
47 Instruction* sel_inst = builder->AddTernaryOp(
48 GetUintId(), SpvOpSelect, val_inst->result_id(), one_id, zero_id);
49 val_ids->push_back(sel_inst->result_id());
50 return;
51 }
52 case analysis::Type::kFloat: {
53 analysis::Float* f_ty = val_ty->AsFloat();
54 switch (f_ty->width()) {
55 case 16: {
56 // Convert float16 to float32 and recurse
57 Instruction* f32_inst = builder->AddUnaryOp(
58 GetFloatId(), SpvOpFConvert, val_inst->result_id());
59 GenOutputValues(f32_inst, val_ids, builder);
60 return;
61 }
62 case 64: {
63 // Bitcast float64 to uint64 and recurse
64 Instruction* ui64_inst = builder->AddUnaryOp(
65 GetUint64Id(), SpvOpBitcast, val_inst->result_id());
66 GenOutputValues(ui64_inst, val_ids, builder);
67 return;
68 }
69 case 32: {
70 // Bitcase float32 to uint32
71 Instruction* bc_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast,
72 val_inst->result_id());
73 val_ids->push_back(bc_inst->result_id());
74 return;
75 }
76 default:
77 assert(false && "unsupported float width");
78 return;
79 }
80 }
81 case analysis::Type::kInteger: {
82 analysis::Integer* i_ty = val_ty->AsInteger();
83 switch (i_ty->width()) {
84 case 64: {
85 Instruction* ui64_inst = val_inst;
86 if (i_ty->IsSigned()) {
87 // Bitcast sint64 to uint64
88 ui64_inst = builder->AddUnaryOp(GetUint64Id(), SpvOpBitcast,
89 val_inst->result_id());
90 }
91 // Break uint64 into 2x uint32
92 Instruction* lo_ui64_inst = builder->AddUnaryOp(
93 GetUintId(), SpvOpUConvert, ui64_inst->result_id());
94 Instruction* rshift_ui64_inst = builder->AddBinaryOp(
95 GetUint64Id(), SpvOpShiftRightLogical, ui64_inst->result_id(),
96 builder->GetUintConstantId(32));
97 Instruction* hi_ui64_inst = builder->AddUnaryOp(
98 GetUintId(), SpvOpUConvert, rshift_ui64_inst->result_id());
99 val_ids->push_back(lo_ui64_inst->result_id());
100 val_ids->push_back(hi_ui64_inst->result_id());
101 return;
102 }
103 case 8: {
104 Instruction* ui8_inst = val_inst;
105 if (i_ty->IsSigned()) {
106 // Bitcast sint8 to uint8
107 ui8_inst = builder->AddUnaryOp(GetUint8Id(), SpvOpBitcast,
108 val_inst->result_id());
109 }
110 // Convert uint8 to uint32
111 Instruction* ui32_inst = builder->AddUnaryOp(
112 GetUintId(), SpvOpUConvert, ui8_inst->result_id());
113 val_ids->push_back(ui32_inst->result_id());
114 return;
115 }
116 case 32: {
117 Instruction* ui32_inst = val_inst;
118 if (i_ty->IsSigned()) {
119 // Bitcast sint32 to uint32
120 ui32_inst = builder->AddUnaryOp(GetUintId(), SpvOpBitcast,
121 val_inst->result_id());
122 }
123 // uint32 needs no further processing
124 val_ids->push_back(ui32_inst->result_id());
125 return;
126 }
127 default:
128 // TODO(greg-lunarg): Support non-32-bit int
129 assert(false && "unsupported int width");
130 return;
131 }
132 }
133 default:
134 assert(false && "unsupported type");
135 return;
136 }
137}
138
139void InstDebugPrintfPass::GenOutputCode(
140 Instruction* printf_inst, uint32_t stage_idx,
141 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
142 BasicBlock* back_blk_ptr = &*new_blocks->back();
143 InstructionBuilder builder(
144 context(), back_blk_ptr,
145 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
146 // Gen debug printf record validation-specific values. The format string
147 // will have its id written. Vectors will need to be broken down into
148 // component values. float16 will need to be converted to float32. Pointer
149 // and uint64 will need to be converted to two uint32 values. float32 will
150 // need to be bitcast to uint32. int32 will need to be bitcast to uint32.
151 std::vector<uint32_t> val_ids;
152 bool is_first_operand = false;
153 printf_inst->ForEachInId(
154 [&is_first_operand, &val_ids, &builder, this](const uint32_t* iid) {
155 // skip set operand
156 if (!is_first_operand) {
157 is_first_operand = true;
158 return;
159 }
160 Instruction* opnd_inst = get_def_use_mgr()->GetDef(*iid);
161 if (opnd_inst->opcode() == SpvOpString) {
162 uint32_t string_id_id = builder.GetUintConstantId(*iid);
163 val_ids.push_back(string_id_id);
164 } else {
165 GenOutputValues(opnd_inst, &val_ids, &builder);
166 }
167 });
168 GenDebugStreamWrite(uid2offset_[printf_inst->unique_id()], stage_idx, val_ids,
169 &builder);
170 context()->KillInst(printf_inst);
171}
172
173void InstDebugPrintfPass::GenDebugPrintfCode(
174 BasicBlock::iterator ref_inst_itr,
175 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
176 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
177 // If not DebugPrintf OpExtInst, return.
178 Instruction* printf_inst = &*ref_inst_itr;
179 if (printf_inst->opcode() != SpvOpExtInst) return;
180 if (printf_inst->GetSingleWordInOperand(0) != ext_inst_printf_id_) return;
181 if (printf_inst->GetSingleWordInOperand(1) !=
182 NonSemanticDebugPrintfDebugPrintf)
183 return;
184 // Initialize DefUse manager before dismantling module
185 (void)get_def_use_mgr();
186 // Move original block's preceding instructions into first new block
187 std::unique_ptr<BasicBlock> new_blk_ptr;
188 MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
189 new_blocks->push_back(std::move(new_blk_ptr));
190 // Generate instructions to output printf args to printf buffer
191 GenOutputCode(printf_inst, stage_idx, new_blocks);
192 // Caller expects at least two blocks with last block containing remaining
193 // code, so end block after instrumentation, create remainder block, and
194 // branch to it
195 uint32_t rem_blk_id = TakeNextId();
196 std::unique_ptr<Instruction> rem_label(NewLabel(rem_blk_id));
197 BasicBlock* back_blk_ptr = &*new_blocks->back();
198 InstructionBuilder builder(
199 context(), back_blk_ptr,
200 IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
201 (void)builder.AddBranch(rem_blk_id);
202 // Gen remainder block
203 new_blk_ptr.reset(new BasicBlock(std::move(rem_label)));
204 builder.SetInsertPoint(&*new_blk_ptr);
205 // Move original block's remaining code into remainder block and add
206 // to new blocks
207 MovePostludeCode(ref_block_itr, &*new_blk_ptr);
208 new_blocks->push_back(std::move(new_blk_ptr));
209}
210
211void InstDebugPrintfPass::InitializeInstDebugPrintf() {
212 // Initialize base class
213 InitializeInstrument();
214}
215
216Pass::Status InstDebugPrintfPass::ProcessImpl() {
217 // Perform printf instrumentation on each entry point function in module
218 InstProcessFunction pfn =
219 [this](BasicBlock::iterator ref_inst_itr,
220 UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t stage_idx,
221 std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
222 return GenDebugPrintfCode(ref_inst_itr, ref_block_itr, stage_idx,
223 new_blocks);
224 };
225 (void)InstProcessEntryPointCallTree(pfn);
226 // Remove DebugPrintf OpExtInstImport instruction
227 Instruction* ext_inst_import_inst =
228 get_def_use_mgr()->GetDef(ext_inst_printf_id_);
229 context()->KillInst(ext_inst_import_inst);
230 // If no remaining non-semantic instruction sets, remove non-semantic debug
231 // info extension from module and feature manager
232 bool non_sem_set_seen = false;
233 for (auto c_itr = context()->module()->ext_inst_import_begin();
234 c_itr != context()->module()->ext_inst_import_end(); ++c_itr) {
235 const std::string set_name = c_itr->GetInOperand(0).AsString();
236 if (spvtools::utils::starts_with(set_name, "NonSemantic.")) {
237 non_sem_set_seen = true;
238 break;
239 }
240 }
241 if (!non_sem_set_seen) {
242 for (auto c_itr = context()->module()->extension_begin();
243 c_itr != context()->module()->extension_end(); ++c_itr) {
244 const std::string ext_name = c_itr->GetInOperand(0).AsString();
245 if (ext_name == "SPV_KHR_non_semantic_info") {
246 context()->KillInst(&*c_itr);
247 break;
248 }
249 }
250 context()->get_feature_mgr()->RemoveExtension(kSPV_KHR_non_semantic_info);
251 }
252 return Status::SuccessWithChange;
253}
254
255Pass::Status InstDebugPrintfPass::Process() {
256 ext_inst_printf_id_ =
257 get_module()->GetExtInstImportId("NonSemantic.DebugPrintf");
258 if (ext_inst_printf_id_ == 0) return Status::SuccessWithoutChange;
259 InitializeInstDebugPrintf();
260 return ProcessImpl();
261}
262
263} // namespace opt
264} // namespace spvtools
265