1// Copyright (c) 2015-2016 The Khronos Group Inc.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "source/assembly_grammar.h"
16
17#include <algorithm>
18#include <cassert>
19#include <cstring>
20
21#include "source/ext_inst.h"
22#include "source/opcode.h"
23#include "source/operand.h"
24#include "source/table.h"
25
26namespace spvtools {
27namespace {
28
29/// @brief Parses a mask expression string for the given operand type.
30///
31/// A mask expression is a sequence of one or more terms separated by '|',
32/// where each term a named enum value for the given type. No whitespace
33/// is permitted.
34///
35/// On success, the value is written to pValue.
36///
37/// @param[in] operandTable operand lookup table
38/// @param[in] type of the operand
39/// @param[in] textValue word of text to be parsed
40/// @param[out] pValue where the resulting value is written
41///
42/// @return result code
43spv_result_t spvTextParseMaskOperand(spv_target_env env,
44 const spv_operand_table operandTable,
45 const spv_operand_type_t type,
46 const char* textValue, uint32_t* pValue) {
47 if (textValue == nullptr) return SPV_ERROR_INVALID_TEXT;
48 size_t text_length = strlen(textValue);
49 if (text_length == 0) return SPV_ERROR_INVALID_TEXT;
50 const char* text_end = textValue + text_length;
51
52 // We only support mask expressions in ASCII, so the separator value is a
53 // char.
54 const char separator = '|';
55
56 // Accumulate the result by interpreting one word at a time, scanning
57 // from left to right.
58 uint32_t value = 0;
59 const char* begin = textValue; // The left end of the current word.
60 const char* end = nullptr; // One character past the end of the current word.
61 do {
62 end = std::find(begin, text_end, separator);
63
64 spv_operand_desc entry = nullptr;
65 if (auto error = spvOperandTableNameLookup(env, operandTable, type, begin,
66 end - begin, &entry)) {
67 return error;
68 }
69 value |= entry->value;
70
71 // Advance to the next word by skipping over the separator.
72 begin = end + 1;
73 } while (end != text_end);
74
75 *pValue = value;
76 return SPV_SUCCESS;
77}
78
79// Associates an opcode with its name.
80struct SpecConstantOpcodeEntry {
81 SpvOp opcode;
82 const char* name;
83};
84
85// All the opcodes allowed as the operation for OpSpecConstantOp.
86// The name does not have the usual "Op" prefix. For example opcode SpvOpIAdd
87// is associated with the name "IAdd".
88//
89// clang-format off
90#define CASE(NAME) { SpvOp##NAME, #NAME }
91const SpecConstantOpcodeEntry kOpSpecConstantOpcodes[] = {
92 // Conversion
93 CASE(SConvert),
94 CASE(FConvert),
95 CASE(ConvertFToS),
96 CASE(ConvertSToF),
97 CASE(ConvertFToU),
98 CASE(ConvertUToF),
99 CASE(UConvert),
100 CASE(ConvertPtrToU),
101 CASE(ConvertUToPtr),
102 CASE(GenericCastToPtr),
103 CASE(PtrCastToGeneric),
104 CASE(Bitcast),
105 CASE(QuantizeToF16),
106 // Arithmetic
107 CASE(SNegate),
108 CASE(Not),
109 CASE(IAdd),
110 CASE(ISub),
111 CASE(IMul),
112 CASE(UDiv),
113 CASE(SDiv),
114 CASE(UMod),
115 CASE(SRem),
116 CASE(SMod),
117 CASE(ShiftRightLogical),
118 CASE(ShiftRightArithmetic),
119 CASE(ShiftLeftLogical),
120 CASE(BitwiseOr),
121 CASE(BitwiseAnd),
122 CASE(BitwiseXor),
123 CASE(FNegate),
124 CASE(FAdd),
125 CASE(FSub),
126 CASE(FMul),
127 CASE(FDiv),
128 CASE(FRem),
129 CASE(FMod),
130 // Composite
131 CASE(VectorShuffle),
132 CASE(CompositeExtract),
133 CASE(CompositeInsert),
134 // Logical
135 CASE(LogicalOr),
136 CASE(LogicalAnd),
137 CASE(LogicalNot),
138 CASE(LogicalEqual),
139 CASE(LogicalNotEqual),
140 CASE(Select),
141 // Comparison
142 CASE(IEqual),
143 CASE(INotEqual),
144 CASE(ULessThan),
145 CASE(SLessThan),
146 CASE(UGreaterThan),
147 CASE(SGreaterThan),
148 CASE(ULessThanEqual),
149 CASE(SLessThanEqual),
150 CASE(UGreaterThanEqual),
151 CASE(SGreaterThanEqual),
152 // Memory
153 CASE(AccessChain),
154 CASE(InBoundsAccessChain),
155 CASE(PtrAccessChain),
156 CASE(InBoundsPtrAccessChain),
157 CASE(CooperativeMatrixLengthNV)
158};
159
160// The 60 is determined by counting the opcodes listed in the spec.
161static_assert(60 == sizeof(kOpSpecConstantOpcodes)/sizeof(kOpSpecConstantOpcodes[0]),
162 "OpSpecConstantOp opcode table is incomplete");
163#undef CASE
164// clang-format on
165
166const size_t kNumOpSpecConstantOpcodes =
167 sizeof(kOpSpecConstantOpcodes) / sizeof(kOpSpecConstantOpcodes[0]);
168
169} // namespace
170
171bool AssemblyGrammar::isValid() const {
172 return operandTable_ && opcodeTable_ && extInstTable_;
173}
174
175CapabilitySet AssemblyGrammar::filterCapsAgainstTargetEnv(
176 const SpvCapability* cap_array, uint32_t count) const {
177 CapabilitySet cap_set;
178 for (uint32_t i = 0; i < count; ++i) {
179 spv_operand_desc cap_desc = {};
180 if (SPV_SUCCESS == lookupOperand(SPV_OPERAND_TYPE_CAPABILITY,
181 static_cast<uint32_t>(cap_array[i]),
182 &cap_desc)) {
183 // spvOperandTableValueLookup() filters capabilities internally
184 // according to the current target environment by itself. So we
185 // should be safe to add this capability if the lookup succeeds.
186 cap_set.Add(cap_array[i]);
187 }
188 }
189 return cap_set;
190}
191
192spv_result_t AssemblyGrammar::lookupOpcode(const char* name,
193 spv_opcode_desc* desc) const {
194 return spvOpcodeTableNameLookup(target_env_, opcodeTable_, name, desc);
195}
196
197spv_result_t AssemblyGrammar::lookupOpcode(SpvOp opcode,
198 spv_opcode_desc* desc) const {
199 return spvOpcodeTableValueLookup(target_env_, opcodeTable_, opcode, desc);
200}
201
202spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type,
203 const char* name, size_t name_len,
204 spv_operand_desc* desc) const {
205 return spvOperandTableNameLookup(target_env_, operandTable_, type, name,
206 name_len, desc);
207}
208
209spv_result_t AssemblyGrammar::lookupOperand(spv_operand_type_t type,
210 uint32_t operand,
211 spv_operand_desc* desc) const {
212 return spvOperandTableValueLookup(target_env_, operandTable_, type, operand,
213 desc);
214}
215
216spv_result_t AssemblyGrammar::lookupSpecConstantOpcode(const char* name,
217 SpvOp* opcode) const {
218 const auto* last = kOpSpecConstantOpcodes + kNumOpSpecConstantOpcodes;
219 const auto* found =
220 std::find_if(kOpSpecConstantOpcodes, last,
221 [name](const SpecConstantOpcodeEntry& entry) {
222 return 0 == strcmp(name, entry.name);
223 });
224 if (found == last) return SPV_ERROR_INVALID_LOOKUP;
225 *opcode = found->opcode;
226 return SPV_SUCCESS;
227}
228
229spv_result_t AssemblyGrammar::lookupSpecConstantOpcode(SpvOp opcode) const {
230 const auto* last = kOpSpecConstantOpcodes + kNumOpSpecConstantOpcodes;
231 const auto* found =
232 std::find_if(kOpSpecConstantOpcodes, last,
233 [opcode](const SpecConstantOpcodeEntry& entry) {
234 return opcode == entry.opcode;
235 });
236 if (found == last) return SPV_ERROR_INVALID_LOOKUP;
237 return SPV_SUCCESS;
238}
239
240spv_result_t AssemblyGrammar::parseMaskOperand(const spv_operand_type_t type,
241 const char* textValue,
242 uint32_t* pValue) const {
243 return spvTextParseMaskOperand(target_env_, operandTable_, type, textValue,
244 pValue);
245}
246spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type,
247 const char* textValue,
248 spv_ext_inst_desc* extInst) const {
249 return spvExtInstTableNameLookup(extInstTable_, type, textValue, extInst);
250}
251
252spv_result_t AssemblyGrammar::lookupExtInst(spv_ext_inst_type_t type,
253 uint32_t firstWord,
254 spv_ext_inst_desc* extInst) const {
255 return spvExtInstTableValueLookup(extInstTable_, type, firstWord, extInst);
256}
257
258void AssemblyGrammar::pushOperandTypesForMask(
259 const spv_operand_type_t type, const uint32_t mask,
260 spv_operand_pattern_t* pattern) const {
261 spvPushOperandTypesForMask(target_env_, operandTable_, type, mask, pattern);
262}
263
264} // namespace spvtools
265