1/*
2 * Copyright 2016-2021 The Brenwill Workshop Ltd.
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 */
17
18/*
19 * At your option, you may choose to accept this material under either:
20 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
21 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
22 */
23
24#include "spirv_msl.hpp"
25#include "GLSL.std.450.h"
26
27#include <algorithm>
28#include <assert.h>
29#include <numeric>
30
31using namespace spv;
32using namespace SPIRV_CROSS_NAMESPACE;
33using namespace std;
34
35static const uint32_t k_unknown_location = ~0u;
36static const uint32_t k_unknown_component = ~0u;
37static const char *force_inline = "static inline __attribute__((always_inline))";
38
39CompilerMSL::CompilerMSL(std::vector<uint32_t> spirv_)
40 : CompilerGLSL(move(spirv_))
41{
42}
43
44CompilerMSL::CompilerMSL(const uint32_t *ir_, size_t word_count)
45 : CompilerGLSL(ir_, word_count)
46{
47}
48
49CompilerMSL::CompilerMSL(const ParsedIR &ir_)
50 : CompilerGLSL(ir_)
51{
52}
53
54CompilerMSL::CompilerMSL(ParsedIR &&ir_)
55 : CompilerGLSL(std::move(ir_))
56{
57}
58
59void CompilerMSL::add_msl_shader_input(const MSLShaderInput &si)
60{
61 inputs_by_location[{si.location, si.component}] = si;
62 if (si.builtin != BuiltInMax && !inputs_by_builtin.count(si.builtin))
63 inputs_by_builtin[si.builtin] = si;
64}
65
66void CompilerMSL::add_msl_resource_binding(const MSLResourceBinding &binding)
67{
68 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
69 resource_bindings[tuple] = { binding, false };
70
71 // If we might need to pad argument buffer members to positionally align
72 // arg buffer indexes, also maintain a lookup by argument buffer index.
73 if (msl_options.pad_argument_buffer_resources)
74 {
75 StageSetBinding arg_idx_tuple = { binding.stage, binding.desc_set, k_unknown_component };
76
77#define ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(rez) \
78 arg_idx_tuple.binding = binding.msl_##rez; \
79 resource_arg_buff_idx_to_binding_number[arg_idx_tuple] = binding.binding
80
81 switch (binding.basetype)
82 {
83 case SPIRType::Void:
84 case SPIRType::Boolean:
85 case SPIRType::SByte:
86 case SPIRType::UByte:
87 case SPIRType::Short:
88 case SPIRType::UShort:
89 case SPIRType::Int:
90 case SPIRType::UInt:
91 case SPIRType::Int64:
92 case SPIRType::UInt64:
93 case SPIRType::AtomicCounter:
94 case SPIRType::Half:
95 case SPIRType::Float:
96 case SPIRType::Double:
97 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(buffer);
98 break;
99 case SPIRType::Image:
100 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
101 break;
102 case SPIRType::Sampler:
103 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
104 break;
105 case SPIRType::SampledImage:
106 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(texture);
107 ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP(sampler);
108 break;
109 default:
110 SPIRV_CROSS_THROW("Unexpected argument buffer resource base type. When padding argument buffer elements, "
111 "all descriptor set resources must be supplied with a base type by the app.");
112 }
113#undef ADD_ARG_IDX_TO_BINDING_NUM_LOOKUP
114 }
115}
116
117void CompilerMSL::add_dynamic_buffer(uint32_t desc_set, uint32_t binding, uint32_t index)
118{
119 SetBindingPair pair = { desc_set, binding };
120 buffers_requiring_dynamic_offset[pair] = { index, 0 };
121}
122
123void CompilerMSL::add_inline_uniform_block(uint32_t desc_set, uint32_t binding)
124{
125 SetBindingPair pair = { desc_set, binding };
126 inline_uniform_blocks.insert(pair);
127}
128
129void CompilerMSL::add_discrete_descriptor_set(uint32_t desc_set)
130{
131 if (desc_set < kMaxArgumentBuffers)
132 argument_buffer_discrete_mask |= 1u << desc_set;
133}
134
135void CompilerMSL::set_argument_buffer_device_address_space(uint32_t desc_set, bool device_storage)
136{
137 if (desc_set < kMaxArgumentBuffers)
138 {
139 if (device_storage)
140 argument_buffer_device_storage_mask |= 1u << desc_set;
141 else
142 argument_buffer_device_storage_mask &= ~(1u << desc_set);
143 }
144}
145
146bool CompilerMSL::is_msl_shader_input_used(uint32_t location)
147{
148 // Don't report internal location allocations to app.
149 return location_inputs_in_use.count(location) != 0 &&
150 location_inputs_in_use_fallback.count(location) == 0;
151}
152
153uint32_t CompilerMSL::get_automatic_builtin_input_location(spv::BuiltIn builtin) const
154{
155 auto itr = builtin_to_automatic_input_location.find(builtin);
156 if (itr == builtin_to_automatic_input_location.end())
157 return k_unknown_location;
158 else
159 return itr->second;
160}
161
162bool CompilerMSL::is_msl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
163{
164 StageSetBinding tuple = { model, desc_set, binding };
165 auto itr = resource_bindings.find(tuple);
166 return itr != end(resource_bindings) && itr->second.second;
167}
168
169// Returns the size of the array of resources used by the variable with the specified id.
170// The returned value is retrieved from the resource binding added using add_msl_resource_binding().
171uint32_t CompilerMSL::get_resource_array_size(uint32_t id) const
172{
173 StageSetBinding tuple = { get_entry_point().model, get_decoration(id, DecorationDescriptorSet),
174 get_decoration(id, DecorationBinding) };
175 auto itr = resource_bindings.find(tuple);
176 return itr != end(resource_bindings) ? itr->second.first.count : 0;
177}
178
179uint32_t CompilerMSL::get_automatic_msl_resource_binding(uint32_t id) const
180{
181 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexPrimary);
182}
183
184uint32_t CompilerMSL::get_automatic_msl_resource_binding_secondary(uint32_t id) const
185{
186 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexSecondary);
187}
188
189uint32_t CompilerMSL::get_automatic_msl_resource_binding_tertiary(uint32_t id) const
190{
191 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexTertiary);
192}
193
194uint32_t CompilerMSL::get_automatic_msl_resource_binding_quaternary(uint32_t id) const
195{
196 return get_extended_decoration(id, SPIRVCrossDecorationResourceIndexQuaternary);
197}
198
199void CompilerMSL::set_fragment_output_components(uint32_t location, uint32_t components)
200{
201 fragment_output_components[location] = components;
202}
203
204bool CompilerMSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
205{
206 return (builtin == BuiltInSampleMask);
207}
208
209void CompilerMSL::build_implicit_builtins()
210{
211 bool need_sample_pos = active_input_builtins.get(BuiltInSamplePosition);
212 bool need_vertex_params = capture_output_to_buffer && get_execution_model() == ExecutionModelVertex &&
213 !msl_options.vertex_for_tessellation;
214 bool need_tesc_params = get_execution_model() == ExecutionModelTessellationControl;
215 bool need_subgroup_mask =
216 active_input_builtins.get(BuiltInSubgroupEqMask) || active_input_builtins.get(BuiltInSubgroupGeMask) ||
217 active_input_builtins.get(BuiltInSubgroupGtMask) || active_input_builtins.get(BuiltInSubgroupLeMask) ||
218 active_input_builtins.get(BuiltInSubgroupLtMask);
219 bool need_subgroup_ge_mask = !msl_options.is_ios() && (active_input_builtins.get(BuiltInSubgroupGeMask) ||
220 active_input_builtins.get(BuiltInSubgroupGtMask));
221 bool need_multiview = get_execution_model() == ExecutionModelVertex && !msl_options.view_index_from_device_index &&
222 msl_options.multiview_layered_rendering &&
223 (msl_options.multiview || active_input_builtins.get(BuiltInViewIndex));
224 bool need_dispatch_base =
225 msl_options.dispatch_base && get_execution_model() == ExecutionModelGLCompute &&
226 (active_input_builtins.get(BuiltInWorkgroupId) || active_input_builtins.get(BuiltInGlobalInvocationId));
227 bool need_grid_params = get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation;
228 bool need_vertex_base_params =
229 need_grid_params &&
230 (active_input_builtins.get(BuiltInVertexId) || active_input_builtins.get(BuiltInVertexIndex) ||
231 active_input_builtins.get(BuiltInBaseVertex) || active_input_builtins.get(BuiltInInstanceId) ||
232 active_input_builtins.get(BuiltInInstanceIndex) || active_input_builtins.get(BuiltInBaseInstance));
233 bool need_local_invocation_index = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInSubgroupId);
234 bool need_workgroup_size = msl_options.emulate_subgroups && active_input_builtins.get(BuiltInNumSubgroups);
235
236 if (need_subpass_input || need_sample_pos || need_subgroup_mask || need_vertex_params || need_tesc_params ||
237 need_multiview || need_dispatch_base || need_vertex_base_params || need_grid_params || needs_sample_id ||
238 needs_subgroup_invocation_id || needs_subgroup_size || has_additional_fixed_sample_mask() || need_local_invocation_index ||
239 need_workgroup_size)
240 {
241 bool has_frag_coord = false;
242 bool has_sample_id = false;
243 bool has_vertex_idx = false;
244 bool has_base_vertex = false;
245 bool has_instance_idx = false;
246 bool has_base_instance = false;
247 bool has_invocation_id = false;
248 bool has_primitive_id = false;
249 bool has_subgroup_invocation_id = false;
250 bool has_subgroup_size = false;
251 bool has_view_idx = false;
252 bool has_layer = false;
253 bool has_local_invocation_index = false;
254 bool has_workgroup_size = false;
255 uint32_t workgroup_id_type = 0;
256
257 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
258 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
259 return;
260 if (!interface_variable_exists_in_entry_point(var.self))
261 return;
262 if (!has_decoration(var.self, DecorationBuiltIn))
263 return;
264
265 BuiltIn builtin = ir.meta[var.self].decoration.builtin_type;
266
267 if (var.storage == StorageClassOutput)
268 {
269 if (has_additional_fixed_sample_mask() && builtin == BuiltInSampleMask)
270 {
271 builtin_sample_mask_id = var.self;
272 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var.self);
273 does_shader_write_sample_mask = true;
274 }
275 }
276
277 if (var.storage != StorageClassInput)
278 return;
279
280 // Use Metal's native frame-buffer fetch API for subpass inputs.
281 if (need_subpass_input && (!msl_options.use_framebuffer_fetch_subpasses))
282 {
283 switch (builtin)
284 {
285 case BuiltInFragCoord:
286 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var.self);
287 builtin_frag_coord_id = var.self;
288 has_frag_coord = true;
289 break;
290 case BuiltInLayer:
291 if (!msl_options.arrayed_subpass_input || msl_options.multiview)
292 break;
293 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var.self);
294 builtin_layer_id = var.self;
295 has_layer = true;
296 break;
297 case BuiltInViewIndex:
298 if (!msl_options.multiview)
299 break;
300 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
301 builtin_view_idx_id = var.self;
302 has_view_idx = true;
303 break;
304 default:
305 break;
306 }
307 }
308
309 if ((need_sample_pos || needs_sample_id) && builtin == BuiltInSampleId)
310 {
311 builtin_sample_id_id = var.self;
312 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var.self);
313 has_sample_id = true;
314 }
315
316 if (need_vertex_params)
317 {
318 switch (builtin)
319 {
320 case BuiltInVertexIndex:
321 builtin_vertex_idx_id = var.self;
322 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var.self);
323 has_vertex_idx = true;
324 break;
325 case BuiltInBaseVertex:
326 builtin_base_vertex_id = var.self;
327 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var.self);
328 has_base_vertex = true;
329 break;
330 case BuiltInInstanceIndex:
331 builtin_instance_idx_id = var.self;
332 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
333 has_instance_idx = true;
334 break;
335 case BuiltInBaseInstance:
336 builtin_base_instance_id = var.self;
337 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
338 has_base_instance = true;
339 break;
340 default:
341 break;
342 }
343 }
344
345 if (need_tesc_params)
346 {
347 switch (builtin)
348 {
349 case BuiltInInvocationId:
350 builtin_invocation_id_id = var.self;
351 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var.self);
352 has_invocation_id = true;
353 break;
354 case BuiltInPrimitiveId:
355 builtin_primitive_id_id = var.self;
356 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var.self);
357 has_primitive_id = true;
358 break;
359 default:
360 break;
361 }
362 }
363
364 if ((need_subgroup_mask || needs_subgroup_invocation_id) && builtin == BuiltInSubgroupLocalInvocationId)
365 {
366 builtin_subgroup_invocation_id_id = var.self;
367 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var.self);
368 has_subgroup_invocation_id = true;
369 }
370
371 if ((need_subgroup_ge_mask || needs_subgroup_size) && builtin == BuiltInSubgroupSize)
372 {
373 builtin_subgroup_size_id = var.self;
374 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var.self);
375 has_subgroup_size = true;
376 }
377
378 if (need_multiview)
379 {
380 switch (builtin)
381 {
382 case BuiltInInstanceIndex:
383 // The view index here is derived from the instance index.
384 builtin_instance_idx_id = var.self;
385 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var.self);
386 has_instance_idx = true;
387 break;
388 case BuiltInBaseInstance:
389 // If a non-zero base instance is used, we need to adjust for it when calculating the view index.
390 builtin_base_instance_id = var.self;
391 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var.self);
392 has_base_instance = true;
393 break;
394 case BuiltInViewIndex:
395 builtin_view_idx_id = var.self;
396 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var.self);
397 has_view_idx = true;
398 break;
399 default:
400 break;
401 }
402 }
403
404 if (need_local_invocation_index && builtin == BuiltInLocalInvocationIndex)
405 {
406 builtin_local_invocation_index_id = var.self;
407 mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var.self);
408 has_local_invocation_index = true;
409 }
410
411 if (need_workgroup_size && builtin == BuiltInLocalInvocationId)
412 {
413 builtin_workgroup_size_id = var.self;
414 mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var.self);
415 has_workgroup_size = true;
416 }
417
418 // The base workgroup needs to have the same type and vector size
419 // as the workgroup or invocation ID, so keep track of the type that
420 // was used.
421 if (need_dispatch_base && workgroup_id_type == 0 &&
422 (builtin == BuiltInWorkgroupId || builtin == BuiltInGlobalInvocationId))
423 workgroup_id_type = var.basetype;
424 });
425
426 // Use Metal's native frame-buffer fetch API for subpass inputs.
427 if ((!has_frag_coord || (msl_options.multiview && !has_view_idx) ||
428 (msl_options.arrayed_subpass_input && !msl_options.multiview && !has_layer)) &&
429 (!msl_options.use_framebuffer_fetch_subpasses) && need_subpass_input)
430 {
431 if (!has_frag_coord)
432 {
433 uint32_t offset = ir.increase_bound_by(3);
434 uint32_t type_id = offset;
435 uint32_t type_ptr_id = offset + 1;
436 uint32_t var_id = offset + 2;
437
438 // Create gl_FragCoord.
439 SPIRType vec4_type;
440 vec4_type.basetype = SPIRType::Float;
441 vec4_type.width = 32;
442 vec4_type.vecsize = 4;
443 set<SPIRType>(type_id, vec4_type);
444
445 SPIRType vec4_type_ptr;
446 vec4_type_ptr = vec4_type;
447 vec4_type_ptr.pointer = true;
448 vec4_type_ptr.pointer_depth++;
449 vec4_type_ptr.parent_type = type_id;
450 vec4_type_ptr.storage = StorageClassInput;
451 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
452 ptr_type.self = type_id;
453
454 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
455 set_decoration(var_id, DecorationBuiltIn, BuiltInFragCoord);
456 builtin_frag_coord_id = var_id;
457 mark_implicit_builtin(StorageClassInput, BuiltInFragCoord, var_id);
458 }
459
460 if (!has_layer && msl_options.arrayed_subpass_input && !msl_options.multiview)
461 {
462 uint32_t offset = ir.increase_bound_by(2);
463 uint32_t type_ptr_id = offset;
464 uint32_t var_id = offset + 1;
465
466 // Create gl_Layer.
467 SPIRType uint_type_ptr;
468 uint_type_ptr = get_uint_type();
469 uint_type_ptr.pointer = true;
470 uint_type_ptr.pointer_depth++;
471 uint_type_ptr.parent_type = get_uint_type_id();
472 uint_type_ptr.storage = StorageClassInput;
473 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
474 ptr_type.self = get_uint_type_id();
475
476 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
477 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
478 builtin_layer_id = var_id;
479 mark_implicit_builtin(StorageClassInput, BuiltInLayer, var_id);
480 }
481
482 if (!has_view_idx && msl_options.multiview)
483 {
484 uint32_t offset = ir.increase_bound_by(2);
485 uint32_t type_ptr_id = offset;
486 uint32_t var_id = offset + 1;
487
488 // Create gl_ViewIndex.
489 SPIRType uint_type_ptr;
490 uint_type_ptr = get_uint_type();
491 uint_type_ptr.pointer = true;
492 uint_type_ptr.pointer_depth++;
493 uint_type_ptr.parent_type = get_uint_type_id();
494 uint_type_ptr.storage = StorageClassInput;
495 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
496 ptr_type.self = get_uint_type_id();
497
498 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
499 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
500 builtin_view_idx_id = var_id;
501 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
502 }
503 }
504
505 if (!has_sample_id && (need_sample_pos || needs_sample_id))
506 {
507 uint32_t offset = ir.increase_bound_by(2);
508 uint32_t type_ptr_id = offset;
509 uint32_t var_id = offset + 1;
510
511 // Create gl_SampleID.
512 SPIRType uint_type_ptr;
513 uint_type_ptr = get_uint_type();
514 uint_type_ptr.pointer = true;
515 uint_type_ptr.pointer_depth++;
516 uint_type_ptr.parent_type = get_uint_type_id();
517 uint_type_ptr.storage = StorageClassInput;
518 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
519 ptr_type.self = get_uint_type_id();
520
521 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
522 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleId);
523 builtin_sample_id_id = var_id;
524 mark_implicit_builtin(StorageClassInput, BuiltInSampleId, var_id);
525 }
526
527 if ((need_vertex_params && (!has_vertex_idx || !has_base_vertex || !has_instance_idx || !has_base_instance)) ||
528 (need_multiview && (!has_instance_idx || !has_base_instance || !has_view_idx)))
529 {
530 uint32_t type_ptr_id = ir.increase_bound_by(1);
531
532 SPIRType uint_type_ptr;
533 uint_type_ptr = get_uint_type();
534 uint_type_ptr.pointer = true;
535 uint_type_ptr.pointer_depth++;
536 uint_type_ptr.parent_type = get_uint_type_id();
537 uint_type_ptr.storage = StorageClassInput;
538 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
539 ptr_type.self = get_uint_type_id();
540
541 if (need_vertex_params && !has_vertex_idx)
542 {
543 uint32_t var_id = ir.increase_bound_by(1);
544
545 // Create gl_VertexIndex.
546 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
547 set_decoration(var_id, DecorationBuiltIn, BuiltInVertexIndex);
548 builtin_vertex_idx_id = var_id;
549 mark_implicit_builtin(StorageClassInput, BuiltInVertexIndex, var_id);
550 }
551
552 if (need_vertex_params && !has_base_vertex)
553 {
554 uint32_t var_id = ir.increase_bound_by(1);
555
556 // Create gl_BaseVertex.
557 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
558 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseVertex);
559 builtin_base_vertex_id = var_id;
560 mark_implicit_builtin(StorageClassInput, BuiltInBaseVertex, var_id);
561 }
562
563 if (!has_instance_idx) // Needed by both multiview and tessellation
564 {
565 uint32_t var_id = ir.increase_bound_by(1);
566
567 // Create gl_InstanceIndex.
568 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
569 set_decoration(var_id, DecorationBuiltIn, BuiltInInstanceIndex);
570 builtin_instance_idx_id = var_id;
571 mark_implicit_builtin(StorageClassInput, BuiltInInstanceIndex, var_id);
572 }
573
574 if (!has_base_instance) // Needed by both multiview and tessellation
575 {
576 uint32_t var_id = ir.increase_bound_by(1);
577
578 // Create gl_BaseInstance.
579 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
580 set_decoration(var_id, DecorationBuiltIn, BuiltInBaseInstance);
581 builtin_base_instance_id = var_id;
582 mark_implicit_builtin(StorageClassInput, BuiltInBaseInstance, var_id);
583 }
584
585 if (need_multiview)
586 {
587 // Multiview shaders are not allowed to write to gl_Layer, ostensibly because
588 // it is implicitly written from gl_ViewIndex, but we have to do that explicitly.
589 // Note that we can't just abuse gl_ViewIndex for this purpose: it's an input, but
590 // gl_Layer is an output in vertex-pipeline shaders.
591 uint32_t type_ptr_out_id = ir.increase_bound_by(2);
592 SPIRType uint_type_ptr_out;
593 uint_type_ptr_out = get_uint_type();
594 uint_type_ptr_out.pointer = true;
595 uint_type_ptr_out.pointer_depth++;
596 uint_type_ptr_out.parent_type = get_uint_type_id();
597 uint_type_ptr_out.storage = StorageClassOutput;
598 auto &ptr_out_type = set<SPIRType>(type_ptr_out_id, uint_type_ptr_out);
599 ptr_out_type.self = get_uint_type_id();
600 uint32_t var_id = type_ptr_out_id + 1;
601 set<SPIRVariable>(var_id, type_ptr_out_id, StorageClassOutput);
602 set_decoration(var_id, DecorationBuiltIn, BuiltInLayer);
603 builtin_layer_id = var_id;
604 mark_implicit_builtin(StorageClassOutput, BuiltInLayer, var_id);
605 }
606
607 if (need_multiview && !has_view_idx)
608 {
609 uint32_t var_id = ir.increase_bound_by(1);
610
611 // Create gl_ViewIndex.
612 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
613 set_decoration(var_id, DecorationBuiltIn, BuiltInViewIndex);
614 builtin_view_idx_id = var_id;
615 mark_implicit_builtin(StorageClassInput, BuiltInViewIndex, var_id);
616 }
617 }
618
619 if ((need_tesc_params && (msl_options.multi_patch_workgroup || !has_invocation_id || !has_primitive_id)) ||
620 need_grid_params)
621 {
622 uint32_t type_ptr_id = ir.increase_bound_by(1);
623
624 SPIRType uint_type_ptr;
625 uint_type_ptr = get_uint_type();
626 uint_type_ptr.pointer = true;
627 uint_type_ptr.pointer_depth++;
628 uint_type_ptr.parent_type = get_uint_type_id();
629 uint_type_ptr.storage = StorageClassInput;
630 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
631 ptr_type.self = get_uint_type_id();
632
633 if (msl_options.multi_patch_workgroup || need_grid_params)
634 {
635 uint32_t var_id = ir.increase_bound_by(1);
636
637 // Create gl_GlobalInvocationID.
638 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
639 set_decoration(var_id, DecorationBuiltIn, BuiltInGlobalInvocationId);
640 builtin_invocation_id_id = var_id;
641 mark_implicit_builtin(StorageClassInput, BuiltInGlobalInvocationId, var_id);
642 }
643 else if (need_tesc_params && !has_invocation_id)
644 {
645 uint32_t var_id = ir.increase_bound_by(1);
646
647 // Create gl_InvocationID.
648 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
649 set_decoration(var_id, DecorationBuiltIn, BuiltInInvocationId);
650 builtin_invocation_id_id = var_id;
651 mark_implicit_builtin(StorageClassInput, BuiltInInvocationId, var_id);
652 }
653
654 if (need_tesc_params && !has_primitive_id)
655 {
656 uint32_t var_id = ir.increase_bound_by(1);
657
658 // Create gl_PrimitiveID.
659 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
660 set_decoration(var_id, DecorationBuiltIn, BuiltInPrimitiveId);
661 builtin_primitive_id_id = var_id;
662 mark_implicit_builtin(StorageClassInput, BuiltInPrimitiveId, var_id);
663 }
664
665 if (need_grid_params)
666 {
667 uint32_t var_id = ir.increase_bound_by(1);
668
669 set<SPIRVariable>(var_id, build_extended_vector_type(get_uint_type_id(), 3), StorageClassInput);
670 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize);
671 get_entry_point().interface_variables.push_back(var_id);
672 set_name(var_id, "spvStageInputSize");
673 builtin_stage_input_size_id = var_id;
674 }
675 }
676
677 if (!has_subgroup_invocation_id && (need_subgroup_mask || needs_subgroup_invocation_id))
678 {
679 uint32_t offset = ir.increase_bound_by(2);
680 uint32_t type_ptr_id = offset;
681 uint32_t var_id = offset + 1;
682
683 // Create gl_SubgroupInvocationID.
684 SPIRType uint_type_ptr;
685 uint_type_ptr = get_uint_type();
686 uint_type_ptr.pointer = true;
687 uint_type_ptr.pointer_depth++;
688 uint_type_ptr.parent_type = get_uint_type_id();
689 uint_type_ptr.storage = StorageClassInput;
690 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
691 ptr_type.self = get_uint_type_id();
692
693 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
694 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupLocalInvocationId);
695 builtin_subgroup_invocation_id_id = var_id;
696 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupLocalInvocationId, var_id);
697 }
698
699 if (!has_subgroup_size && (need_subgroup_ge_mask || needs_subgroup_size))
700 {
701 uint32_t offset = ir.increase_bound_by(2);
702 uint32_t type_ptr_id = offset;
703 uint32_t var_id = offset + 1;
704
705 // Create gl_SubgroupSize.
706 SPIRType uint_type_ptr;
707 uint_type_ptr = get_uint_type();
708 uint_type_ptr.pointer = true;
709 uint_type_ptr.pointer_depth++;
710 uint_type_ptr.parent_type = get_uint_type_id();
711 uint_type_ptr.storage = StorageClassInput;
712 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
713 ptr_type.self = get_uint_type_id();
714
715 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
716 set_decoration(var_id, DecorationBuiltIn, BuiltInSubgroupSize);
717 builtin_subgroup_size_id = var_id;
718 mark_implicit_builtin(StorageClassInput, BuiltInSubgroupSize, var_id);
719 }
720
721 if (need_dispatch_base || need_vertex_base_params)
722 {
723 if (workgroup_id_type == 0)
724 workgroup_id_type = build_extended_vector_type(get_uint_type_id(), 3);
725 uint32_t var_id;
726 if (msl_options.supports_msl_version(1, 2))
727 {
728 // If we have MSL 1.2, we can (ab)use the [[grid_origin]] builtin
729 // to convey this information and save a buffer slot.
730 uint32_t offset = ir.increase_bound_by(1);
731 var_id = offset;
732
733 set<SPIRVariable>(var_id, workgroup_id_type, StorageClassInput);
734 set_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase);
735 get_entry_point().interface_variables.push_back(var_id);
736 }
737 else
738 {
739 // Otherwise, we need to fall back to a good ol' fashioned buffer.
740 uint32_t offset = ir.increase_bound_by(2);
741 var_id = offset;
742 uint32_t type_id = offset + 1;
743
744 SPIRType var_type = get<SPIRType>(workgroup_id_type);
745 var_type.storage = StorageClassUniform;
746 set<SPIRType>(type_id, var_type);
747
748 set<SPIRVariable>(var_id, type_id, StorageClassUniform);
749 // This should never match anything.
750 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
751 set_decoration(var_id, DecorationBinding, msl_options.indirect_params_buffer_index);
752 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
753 msl_options.indirect_params_buffer_index);
754 }
755 set_name(var_id, "spvDispatchBase");
756 builtin_dispatch_base_id = var_id;
757 }
758
759 if (has_additional_fixed_sample_mask() && !does_shader_write_sample_mask)
760 {
761 uint32_t offset = ir.increase_bound_by(2);
762 uint32_t var_id = offset + 1;
763
764 // Create gl_SampleMask.
765 SPIRType uint_type_ptr_out;
766 uint_type_ptr_out = get_uint_type();
767 uint_type_ptr_out.pointer = true;
768 uint_type_ptr_out.pointer_depth++;
769 uint_type_ptr_out.parent_type = get_uint_type_id();
770 uint_type_ptr_out.storage = StorageClassOutput;
771
772 auto &ptr_out_type = set<SPIRType>(offset, uint_type_ptr_out);
773 ptr_out_type.self = get_uint_type_id();
774 set<SPIRVariable>(var_id, offset, StorageClassOutput);
775 set_decoration(var_id, DecorationBuiltIn, BuiltInSampleMask);
776 builtin_sample_mask_id = var_id;
777 mark_implicit_builtin(StorageClassOutput, BuiltInSampleMask, var_id);
778 }
779
780 if (need_local_invocation_index && !has_local_invocation_index)
781 {
782 uint32_t offset = ir.increase_bound_by(2);
783 uint32_t type_ptr_id = offset;
784 uint32_t var_id = offset + 1;
785
786 // Create gl_LocalInvocationIndex.
787 SPIRType uint_type_ptr;
788 uint_type_ptr = get_uint_type();
789 uint_type_ptr.pointer = true;
790 uint_type_ptr.pointer_depth++;
791 uint_type_ptr.parent_type = get_uint_type_id();
792 uint_type_ptr.storage = StorageClassInput;
793
794 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
795 ptr_type.self = get_uint_type_id();
796 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
797 set_decoration(var_id, DecorationBuiltIn, BuiltInLocalInvocationIndex);
798 builtin_local_invocation_index_id = var_id;
799 mark_implicit_builtin(StorageClassInput, BuiltInLocalInvocationIndex, var_id);
800 }
801
802 if (need_workgroup_size && !has_workgroup_size)
803 {
804 uint32_t offset = ir.increase_bound_by(2);
805 uint32_t type_ptr_id = offset;
806 uint32_t var_id = offset + 1;
807
808 // Create gl_WorkgroupSize.
809 uint32_t type_id = build_extended_vector_type(get_uint_type_id(), 3);
810 SPIRType uint_type_ptr = get<SPIRType>(type_id);
811 uint_type_ptr.pointer = true;
812 uint_type_ptr.pointer_depth++;
813 uint_type_ptr.parent_type = type_id;
814 uint_type_ptr.storage = StorageClassInput;
815
816 auto &ptr_type = set<SPIRType>(type_ptr_id, uint_type_ptr);
817 ptr_type.self = type_id;
818 set<SPIRVariable>(var_id, type_ptr_id, StorageClassInput);
819 set_decoration(var_id, DecorationBuiltIn, BuiltInWorkgroupSize);
820 builtin_workgroup_size_id = var_id;
821 mark_implicit_builtin(StorageClassInput, BuiltInWorkgroupSize, var_id);
822 }
823 }
824
825 if (needs_swizzle_buffer_def)
826 {
827 uint32_t var_id = build_constant_uint_array_pointer();
828 set_name(var_id, "spvSwizzleConstants");
829 // This should never match anything.
830 set_decoration(var_id, DecorationDescriptorSet, kSwizzleBufferBinding);
831 set_decoration(var_id, DecorationBinding, msl_options.swizzle_buffer_index);
832 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.swizzle_buffer_index);
833 swizzle_buffer_id = var_id;
834 }
835
836 if (!buffers_requiring_array_length.empty())
837 {
838 uint32_t var_id = build_constant_uint_array_pointer();
839 set_name(var_id, "spvBufferSizeConstants");
840 // This should never match anything.
841 set_decoration(var_id, DecorationDescriptorSet, kBufferSizeBufferBinding);
842 set_decoration(var_id, DecorationBinding, msl_options.buffer_size_buffer_index);
843 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.buffer_size_buffer_index);
844 buffer_size_buffer_id = var_id;
845 }
846
847 if (needs_view_mask_buffer())
848 {
849 uint32_t var_id = build_constant_uint_array_pointer();
850 set_name(var_id, "spvViewMask");
851 // This should never match anything.
852 set_decoration(var_id, DecorationDescriptorSet, ~(4u));
853 set_decoration(var_id, DecorationBinding, msl_options.view_mask_buffer_index);
854 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary, msl_options.view_mask_buffer_index);
855 view_mask_buffer_id = var_id;
856 }
857
858 if (!buffers_requiring_dynamic_offset.empty())
859 {
860 uint32_t var_id = build_constant_uint_array_pointer();
861 set_name(var_id, "spvDynamicOffsets");
862 // This should never match anything.
863 set_decoration(var_id, DecorationDescriptorSet, ~(5u));
864 set_decoration(var_id, DecorationBinding, msl_options.dynamic_offsets_buffer_index);
865 set_extended_decoration(var_id, SPIRVCrossDecorationResourceIndexPrimary,
866 msl_options.dynamic_offsets_buffer_index);
867 dynamic_offsets_buffer_id = var_id;
868 }
869
870 // If we're returning a struct from a vertex-like entry point, we must return a position attribute.
871 bool need_position =
872 (get_execution_model() == ExecutionModelVertex ||
873 get_execution_model() == ExecutionModelTessellationEvaluation) &&
874 !capture_output_to_buffer && !get_is_rasterization_disabled() &&
875 !active_output_builtins.get(BuiltInPosition);
876
877 if (need_position)
878 {
879 // If we can get away with returning void from entry point, we don't need to care.
880 // If there is at least one other stage output, we need to return [[position]],
881 // so we need to create one if it doesn't appear in the SPIR-V. Before adding the
882 // implicit variable, check if it actually exists already, but just has not been used
883 // or initialized, and if so, mark it as active, and do not create the implicit variable.
884 bool has_output = false;
885 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
886 if (var.storage == StorageClassOutput && interface_variable_exists_in_entry_point(var.self))
887 {
888 has_output = true;
889
890 // Check if the var is the Position builtin
891 if (has_decoration(var.self, DecorationBuiltIn) && get_decoration(var.self, DecorationBuiltIn) == BuiltInPosition)
892 active_output_builtins.set(BuiltInPosition);
893
894 // If the var is a struct, check if any members is the Position builtin
895 auto &var_type = get_variable_element_type(var);
896 if (var_type.basetype == SPIRType::Struct)
897 {
898 auto mbr_cnt = var_type.member_types.size();
899 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
900 {
901 auto builtin = BuiltInMax;
902 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
903 if (is_builtin && builtin == BuiltInPosition)
904 active_output_builtins.set(BuiltInPosition);
905 }
906 }
907 }
908 });
909 need_position = has_output && !active_output_builtins.get(BuiltInPosition);
910 }
911
912 if (need_position)
913 {
914 uint32_t offset = ir.increase_bound_by(3);
915 uint32_t type_id = offset;
916 uint32_t type_ptr_id = offset + 1;
917 uint32_t var_id = offset + 2;
918
919 // Create gl_Position.
920 SPIRType vec4_type;
921 vec4_type.basetype = SPIRType::Float;
922 vec4_type.width = 32;
923 vec4_type.vecsize = 4;
924 set<SPIRType>(type_id, vec4_type);
925
926 SPIRType vec4_type_ptr;
927 vec4_type_ptr = vec4_type;
928 vec4_type_ptr.pointer = true;
929 vec4_type_ptr.pointer_depth++;
930 vec4_type_ptr.parent_type = type_id;
931 vec4_type_ptr.storage = StorageClassOutput;
932 auto &ptr_type = set<SPIRType>(type_ptr_id, vec4_type_ptr);
933 ptr_type.self = type_id;
934
935 set<SPIRVariable>(var_id, type_ptr_id, StorageClassOutput);
936 set_decoration(var_id, DecorationBuiltIn, BuiltInPosition);
937 mark_implicit_builtin(StorageClassOutput, BuiltInPosition, var_id);
938 }
939}
940
941// Checks if the specified builtin variable (e.g. gl_InstanceIndex) is marked as active.
942// If not, it marks it as active and forces a recompilation.
943// This might be used when the optimization of inactive builtins was too optimistic (e.g. when "spvOut" is emitted).
944void CompilerMSL::ensure_builtin(spv::StorageClass storage, spv::BuiltIn builtin)
945{
946 Bitset *active_builtins = nullptr;
947 switch (storage)
948 {
949 case StorageClassInput:
950 active_builtins = &active_input_builtins;
951 break;
952
953 case StorageClassOutput:
954 active_builtins = &active_output_builtins;
955 break;
956
957 default:
958 break;
959 }
960
961 // At this point, the specified builtin variable must have already been declared in the entry point.
962 // If not, mark as active and force recompile.
963 if (active_builtins != nullptr && !active_builtins->get(builtin))
964 {
965 active_builtins->set(builtin);
966 force_recompile();
967 }
968}
969
970void CompilerMSL::mark_implicit_builtin(StorageClass storage, BuiltIn builtin, uint32_t id)
971{
972 Bitset *active_builtins = nullptr;
973 switch (storage)
974 {
975 case StorageClassInput:
976 active_builtins = &active_input_builtins;
977 break;
978
979 case StorageClassOutput:
980 active_builtins = &active_output_builtins;
981 break;
982
983 default:
984 break;
985 }
986
987 assert(active_builtins != nullptr);
988 active_builtins->set(builtin);
989
990 auto &var = get_entry_point().interface_variables;
991 if (find(begin(var), end(var), VariableID(id)) == end(var))
992 var.push_back(id);
993}
994
995uint32_t CompilerMSL::build_constant_uint_array_pointer()
996{
997 uint32_t offset = ir.increase_bound_by(3);
998 uint32_t type_ptr_id = offset;
999 uint32_t type_ptr_ptr_id = offset + 1;
1000 uint32_t var_id = offset + 2;
1001
1002 // Create a buffer to hold extra data, including the swizzle constants.
1003 SPIRType uint_type_pointer = get_uint_type();
1004 uint_type_pointer.pointer = true;
1005 uint_type_pointer.pointer_depth++;
1006 uint_type_pointer.parent_type = get_uint_type_id();
1007 uint_type_pointer.storage = StorageClassUniform;
1008 set<SPIRType>(type_ptr_id, uint_type_pointer);
1009 set_decoration(type_ptr_id, DecorationArrayStride, 4);
1010
1011 SPIRType uint_type_pointer2 = uint_type_pointer;
1012 uint_type_pointer2.pointer_depth++;
1013 uint_type_pointer2.parent_type = type_ptr_id;
1014 set<SPIRType>(type_ptr_ptr_id, uint_type_pointer2);
1015
1016 set<SPIRVariable>(var_id, type_ptr_ptr_id, StorageClassUniformConstant);
1017 return var_id;
1018}
1019
1020static string create_sampler_address(const char *prefix, MSLSamplerAddress addr)
1021{
1022 switch (addr)
1023 {
1024 case MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE:
1025 return join(prefix, "address::clamp_to_edge");
1026 case MSL_SAMPLER_ADDRESS_CLAMP_TO_ZERO:
1027 return join(prefix, "address::clamp_to_zero");
1028 case MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER:
1029 return join(prefix, "address::clamp_to_border");
1030 case MSL_SAMPLER_ADDRESS_REPEAT:
1031 return join(prefix, "address::repeat");
1032 case MSL_SAMPLER_ADDRESS_MIRRORED_REPEAT:
1033 return join(prefix, "address::mirrored_repeat");
1034 default:
1035 SPIRV_CROSS_THROW("Invalid sampler addressing mode.");
1036 }
1037}
1038
1039SPIRType &CompilerMSL::get_stage_in_struct_type()
1040{
1041 auto &si_var = get<SPIRVariable>(stage_in_var_id);
1042 return get_variable_data_type(si_var);
1043}
1044
1045SPIRType &CompilerMSL::get_stage_out_struct_type()
1046{
1047 auto &so_var = get<SPIRVariable>(stage_out_var_id);
1048 return get_variable_data_type(so_var);
1049}
1050
1051SPIRType &CompilerMSL::get_patch_stage_in_struct_type()
1052{
1053 auto &si_var = get<SPIRVariable>(patch_stage_in_var_id);
1054 return get_variable_data_type(si_var);
1055}
1056
1057SPIRType &CompilerMSL::get_patch_stage_out_struct_type()
1058{
1059 auto &so_var = get<SPIRVariable>(patch_stage_out_var_id);
1060 return get_variable_data_type(so_var);
1061}
1062
1063std::string CompilerMSL::get_tess_factor_struct_name()
1064{
1065 if (get_entry_point().flags.get(ExecutionModeTriangles))
1066 return "MTLTriangleTessellationFactorsHalf";
1067 return "MTLQuadTessellationFactorsHalf";
1068}
1069
1070SPIRType &CompilerMSL::get_uint_type()
1071{
1072 return get<SPIRType>(get_uint_type_id());
1073}
1074
1075uint32_t CompilerMSL::get_uint_type_id()
1076{
1077 if (uint_type_id != 0)
1078 return uint_type_id;
1079
1080 uint_type_id = ir.increase_bound_by(1);
1081
1082 SPIRType type;
1083 type.basetype = SPIRType::UInt;
1084 type.width = 32;
1085 set<SPIRType>(uint_type_id, type);
1086 return uint_type_id;
1087}
1088
1089void CompilerMSL::emit_entry_point_declarations()
1090{
1091 // FIXME: Get test coverage here ...
1092 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
1093 declare_complex_constant_arrays();
1094
1095 // Emit constexpr samplers here.
1096 for (auto &samp : constexpr_samplers_by_id)
1097 {
1098 auto &var = get<SPIRVariable>(samp.first);
1099 auto &type = get<SPIRType>(var.basetype);
1100 if (type.basetype == SPIRType::Sampler)
1101 add_resource_name(samp.first);
1102
1103 SmallVector<string> args;
1104 auto &s = samp.second;
1105
1106 if (s.coord != MSL_SAMPLER_COORD_NORMALIZED)
1107 args.push_back("coord::pixel");
1108
1109 if (s.min_filter == s.mag_filter)
1110 {
1111 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1112 args.push_back("filter::linear");
1113 }
1114 else
1115 {
1116 if (s.min_filter != MSL_SAMPLER_FILTER_NEAREST)
1117 args.push_back("min_filter::linear");
1118 if (s.mag_filter != MSL_SAMPLER_FILTER_NEAREST)
1119 args.push_back("mag_filter::linear");
1120 }
1121
1122 switch (s.mip_filter)
1123 {
1124 case MSL_SAMPLER_MIP_FILTER_NONE:
1125 // Default
1126 break;
1127 case MSL_SAMPLER_MIP_FILTER_NEAREST:
1128 args.push_back("mip_filter::nearest");
1129 break;
1130 case MSL_SAMPLER_MIP_FILTER_LINEAR:
1131 args.push_back("mip_filter::linear");
1132 break;
1133 default:
1134 SPIRV_CROSS_THROW("Invalid mip filter.");
1135 }
1136
1137 if (s.s_address == s.t_address && s.s_address == s.r_address)
1138 {
1139 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1140 args.push_back(create_sampler_address("", s.s_address));
1141 }
1142 else
1143 {
1144 if (s.s_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1145 args.push_back(create_sampler_address("s_", s.s_address));
1146 if (s.t_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1147 args.push_back(create_sampler_address("t_", s.t_address));
1148 if (s.r_address != MSL_SAMPLER_ADDRESS_CLAMP_TO_EDGE)
1149 args.push_back(create_sampler_address("r_", s.r_address));
1150 }
1151
1152 if (s.compare_enable)
1153 {
1154 switch (s.compare_func)
1155 {
1156 case MSL_SAMPLER_COMPARE_FUNC_ALWAYS:
1157 args.push_back("compare_func::always");
1158 break;
1159 case MSL_SAMPLER_COMPARE_FUNC_NEVER:
1160 args.push_back("compare_func::never");
1161 break;
1162 case MSL_SAMPLER_COMPARE_FUNC_EQUAL:
1163 args.push_back("compare_func::equal");
1164 break;
1165 case MSL_SAMPLER_COMPARE_FUNC_NOT_EQUAL:
1166 args.push_back("compare_func::not_equal");
1167 break;
1168 case MSL_SAMPLER_COMPARE_FUNC_LESS:
1169 args.push_back("compare_func::less");
1170 break;
1171 case MSL_SAMPLER_COMPARE_FUNC_LESS_EQUAL:
1172 args.push_back("compare_func::less_equal");
1173 break;
1174 case MSL_SAMPLER_COMPARE_FUNC_GREATER:
1175 args.push_back("compare_func::greater");
1176 break;
1177 case MSL_SAMPLER_COMPARE_FUNC_GREATER_EQUAL:
1178 args.push_back("compare_func::greater_equal");
1179 break;
1180 default:
1181 SPIRV_CROSS_THROW("Invalid sampler compare function.");
1182 }
1183 }
1184
1185 if (s.s_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER || s.t_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER ||
1186 s.r_address == MSL_SAMPLER_ADDRESS_CLAMP_TO_BORDER)
1187 {
1188 switch (s.border_color)
1189 {
1190 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_BLACK:
1191 args.push_back("border_color::opaque_black");
1192 break;
1193 case MSL_SAMPLER_BORDER_COLOR_OPAQUE_WHITE:
1194 args.push_back("border_color::opaque_white");
1195 break;
1196 case MSL_SAMPLER_BORDER_COLOR_TRANSPARENT_BLACK:
1197 args.push_back("border_color::transparent_black");
1198 break;
1199 default:
1200 SPIRV_CROSS_THROW("Invalid sampler border color.");
1201 }
1202 }
1203
1204 if (s.anisotropy_enable)
1205 args.push_back(join("max_anisotropy(", s.max_anisotropy, ")"));
1206 if (s.lod_clamp_enable)
1207 {
1208 args.push_back(join("lod_clamp(", convert_to_string(s.lod_clamp_min, current_locale_radix_character), ", ",
1209 convert_to_string(s.lod_clamp_max, current_locale_radix_character), ")"));
1210 }
1211
1212 // If we would emit no arguments, then omit the parentheses entirely. Otherwise,
1213 // we'll wind up with a "most vexing parse" situation.
1214 if (args.empty())
1215 statement("constexpr sampler ",
1216 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1217 ";");
1218 else
1219 statement("constexpr sampler ",
1220 type.basetype == SPIRType::SampledImage ? to_sampler_expression(samp.first) : to_name(samp.first),
1221 "(", merge(args), ");");
1222 }
1223
1224 // Emit dynamic buffers here.
1225 for (auto &dynamic_buffer : buffers_requiring_dynamic_offset)
1226 {
1227 if (!dynamic_buffer.second.second)
1228 {
1229 // Could happen if no buffer was used at requested binding point.
1230 continue;
1231 }
1232
1233 const auto &var = get<SPIRVariable>(dynamic_buffer.second.second);
1234 uint32_t var_id = var.self;
1235 const auto &type = get_variable_data_type(var);
1236 string name = to_name(var.self);
1237 uint32_t desc_set = get_decoration(var.self, DecorationDescriptorSet);
1238 uint32_t arg_id = argument_buffer_ids[desc_set];
1239 uint32_t base_index = dynamic_buffer.second.first;
1240
1241 if (!type.array.empty())
1242 {
1243 // This is complicated, because we need to support arrays of arrays.
1244 // And it's even worse if the outermost dimension is a runtime array, because now
1245 // all this complicated goop has to go into the shader itself. (FIXME)
1246 if (!type.array[type.array.size() - 1])
1247 SPIRV_CROSS_THROW("Runtime arrays with dynamic offsets are not supported yet.");
1248 else
1249 {
1250 is_using_builtin_array = true;
1251 statement(get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id), name,
1252 type_to_array_glsl(type), " =");
1253
1254 uint32_t dim = uint32_t(type.array.size());
1255 uint32_t j = 0;
1256 for (SmallVector<uint32_t> indices(type.array.size());
1257 indices[type.array.size() - 1] < to_array_size_literal(type); j++)
1258 {
1259 while (dim > 0)
1260 {
1261 begin_scope();
1262 --dim;
1263 }
1264
1265 string arrays;
1266 for (uint32_t i = uint32_t(type.array.size()); i; --i)
1267 arrays += join("[", indices[i - 1], "]");
1268 statement("(", get_argument_address_space(var), " ", type_to_glsl(type), "* ",
1269 to_restrict(var_id, false), ")((", get_argument_address_space(var), " char* ",
1270 to_restrict(var_id, false), ")", to_name(arg_id), ".", ensure_valid_name(name, "m"),
1271 arrays, " + ", to_name(dynamic_offsets_buffer_id), "[", base_index + j, "]),");
1272
1273 while (++indices[dim] >= to_array_size_literal(type, dim) && dim < type.array.size() - 1)
1274 {
1275 end_scope(",");
1276 indices[dim++] = 0;
1277 }
1278 }
1279 end_scope_decl();
1280 statement_no_indent("");
1281 is_using_builtin_array = false;
1282 }
1283 }
1284 else
1285 {
1286 statement(get_argument_address_space(var), " auto& ", to_restrict(var_id), name, " = *(",
1287 get_argument_address_space(var), " ", type_to_glsl(type), "* ", to_restrict(var_id, false), ")((",
1288 get_argument_address_space(var), " char* ", to_restrict(var_id, false), ")", to_name(arg_id), ".",
1289 ensure_valid_name(name, "m"), " + ", to_name(dynamic_offsets_buffer_id), "[", base_index, "]);");
1290 }
1291 }
1292
1293 // Emit buffer arrays here.
1294 for (uint32_t array_id : buffer_arrays)
1295 {
1296 const auto &var = get<SPIRVariable>(array_id);
1297 const auto &type = get_variable_data_type(var);
1298 const auto &buffer_type = get_variable_element_type(var);
1299 string name = to_name(array_id);
1300 statement(get_argument_address_space(var), " ", type_to_glsl(buffer_type), "* ", to_restrict(array_id), name,
1301 "[] =");
1302 begin_scope();
1303 for (uint32_t i = 0; i < to_array_size_literal(type); ++i)
1304 statement(name, "_", i, ",");
1305 end_scope_decl();
1306 statement_no_indent("");
1307 }
1308 // For some reason, without this, we end up emitting the arrays twice.
1309 buffer_arrays.clear();
1310
1311 // Emit disabled fragment outputs.
1312 std::sort(disabled_frag_outputs.begin(), disabled_frag_outputs.end());
1313 for (uint32_t var_id : disabled_frag_outputs)
1314 {
1315 auto &var = get<SPIRVariable>(var_id);
1316 add_local_variable_name(var_id);
1317 statement(variable_decl(var), ";");
1318 var.deferred_declaration = false;
1319 }
1320}
1321
1322string CompilerMSL::compile()
1323{
1324 replace_illegal_entry_point_names();
1325 ir.fixup_reserved_names();
1326
1327 // Do not deal with GLES-isms like precision, older extensions and such.
1328 options.vulkan_semantics = true;
1329 options.es = false;
1330 options.version = 450;
1331 backend.null_pointer_literal = "nullptr";
1332 backend.float_literal_suffix = false;
1333 backend.uint32_t_literal_suffix = true;
1334 backend.int16_t_literal_suffix = "";
1335 backend.uint16_t_literal_suffix = "";
1336 backend.basic_int_type = "int";
1337 backend.basic_uint_type = "uint";
1338 backend.basic_int8_type = "char";
1339 backend.basic_uint8_type = "uchar";
1340 backend.basic_int16_type = "short";
1341 backend.basic_uint16_type = "ushort";
1342 backend.discard_literal = "discard_fragment()";
1343 backend.demote_literal = "discard_fragment()";
1344 backend.boolean_mix_function = "select";
1345 backend.swizzle_is_function = false;
1346 backend.shared_is_implied = false;
1347 backend.use_initializer_list = true;
1348 backend.use_typed_initializer_list = true;
1349 backend.native_row_major_matrix = false;
1350 backend.unsized_array_supported = false;
1351 backend.can_declare_arrays_inline = false;
1352 backend.allow_truncated_access_chain = true;
1353 backend.comparison_image_samples_scalar = true;
1354 backend.native_pointers = true;
1355 backend.nonuniform_qualifier = "";
1356 backend.support_small_type_sampling_result = true;
1357 backend.supports_empty_struct = true;
1358 backend.support_64bit_switch = true;
1359
1360 // Allow Metal to use the array<T> template unless we force it off.
1361 backend.can_return_array = !msl_options.force_native_arrays;
1362 backend.array_is_value_type = !msl_options.force_native_arrays;
1363 // Arrays which are part of buffer objects are never considered to be value types (just plain C-style).
1364 backend.array_is_value_type_in_buffer_blocks = false;
1365 backend.support_pointer_to_pointer = true;
1366
1367 capture_output_to_buffer = msl_options.capture_output_to_buffer;
1368 is_rasterization_disabled = msl_options.disable_rasterization || capture_output_to_buffer;
1369
1370 // Initialize array here rather than constructor, MSVC 2013 workaround.
1371 for (auto &id : next_metal_resource_ids)
1372 id = 0;
1373
1374 fixup_type_alias();
1375 replace_illegal_names();
1376 sync_entry_point_aliases_and_names();
1377
1378 build_function_control_flow_graphs_and_analyze();
1379 update_active_builtins();
1380 analyze_image_and_sampler_usage();
1381 analyze_sampled_image_usage();
1382 analyze_interlocked_resource_usage();
1383 preprocess_op_codes();
1384 build_implicit_builtins();
1385
1386 fixup_image_load_store_access();
1387
1388 set_enabled_interface_variables(get_active_interface_variables());
1389 if (msl_options.force_active_argument_buffer_resources)
1390 activate_argument_buffer_resources();
1391
1392 if (swizzle_buffer_id)
1393 active_interface_variables.insert(swizzle_buffer_id);
1394 if (buffer_size_buffer_id)
1395 active_interface_variables.insert(buffer_size_buffer_id);
1396 if (view_mask_buffer_id)
1397 active_interface_variables.insert(view_mask_buffer_id);
1398 if (dynamic_offsets_buffer_id)
1399 active_interface_variables.insert(dynamic_offsets_buffer_id);
1400 if (builtin_layer_id)
1401 active_interface_variables.insert(builtin_layer_id);
1402 if (builtin_dispatch_base_id && !msl_options.supports_msl_version(1, 2))
1403 active_interface_variables.insert(builtin_dispatch_base_id);
1404 if (builtin_sample_mask_id)
1405 active_interface_variables.insert(builtin_sample_mask_id);
1406
1407 // Create structs to hold input, output and uniform variables.
1408 // Do output first to ensure out. is declared at top of entry function.
1409 qual_pos_var_name = "";
1410 stage_out_var_id = add_interface_block(StorageClassOutput);
1411 patch_stage_out_var_id = add_interface_block(StorageClassOutput, true);
1412 stage_in_var_id = add_interface_block(StorageClassInput);
1413 if (get_execution_model() == ExecutionModelTessellationEvaluation)
1414 patch_stage_in_var_id = add_interface_block(StorageClassInput, true);
1415
1416 if (get_execution_model() == ExecutionModelTessellationControl)
1417 stage_out_ptr_var_id = add_interface_block_pointer(stage_out_var_id, StorageClassOutput);
1418 if (is_tessellation_shader())
1419 stage_in_ptr_var_id = add_interface_block_pointer(stage_in_var_id, StorageClassInput);
1420
1421 // Metal vertex functions that define no output must disable rasterization and return void.
1422 if (!stage_out_var_id)
1423 is_rasterization_disabled = true;
1424
1425 // Convert the use of global variables to recursively-passed function parameters
1426 localize_global_variables();
1427 extract_global_variables_from_functions();
1428
1429 // Mark any non-stage-in structs to be tightly packed.
1430 mark_packable_structs();
1431 reorder_type_alias();
1432
1433 // Add fixup hooks required by shader inputs and outputs. This needs to happen before
1434 // the loop, so the hooks aren't added multiple times.
1435 fix_up_shader_inputs_outputs();
1436
1437 // If we are using argument buffers, we create argument buffer structures for them here.
1438 // These buffers will be used in the entry point, not the individual resources.
1439 if (msl_options.argument_buffers)
1440 {
1441 if (!msl_options.supports_msl_version(2, 0))
1442 SPIRV_CROSS_THROW("Argument buffers can only be used with MSL 2.0 and up.");
1443 analyze_argument_buffers();
1444 }
1445
1446 uint32_t pass_count = 0;
1447 do
1448 {
1449 reset(pass_count);
1450
1451 // Start bindings at zero.
1452 next_metal_resource_index_buffer = 0;
1453 next_metal_resource_index_texture = 0;
1454 next_metal_resource_index_sampler = 0;
1455 for (auto &id : next_metal_resource_ids)
1456 id = 0;
1457
1458 // Move constructor for this type is broken on GCC 4.9 ...
1459 buffer.reset();
1460
1461 emit_header();
1462 emit_custom_templates();
1463 emit_custom_functions();
1464 emit_specialization_constants_and_structs();
1465 emit_resources();
1466 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
1467
1468 pass_count++;
1469 } while (is_forcing_recompilation());
1470
1471 return buffer.str();
1472}
1473
1474// Register the need to output any custom functions.
1475void CompilerMSL::preprocess_op_codes()
1476{
1477 OpCodePreprocessor preproc(*this);
1478 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), preproc);
1479
1480 suppress_missing_prototypes = preproc.suppress_missing_prototypes;
1481
1482 if (preproc.uses_atomics)
1483 {
1484 add_header_line("#include <metal_atomic>");
1485 add_pragma_line("#pragma clang diagnostic ignored \"-Wunused-variable\"");
1486 }
1487
1488 // Before MSL 2.1 (2.2 for textures), Metal vertex functions that write to
1489 // resources must disable rasterization and return void.
1490 if (preproc.uses_resource_write)
1491 is_rasterization_disabled = true;
1492
1493 // Tessellation control shaders are run as compute functions in Metal, and so
1494 // must capture their output to a buffer.
1495 if (get_execution_model() == ExecutionModelTessellationControl ||
1496 (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
1497 {
1498 is_rasterization_disabled = true;
1499 capture_output_to_buffer = true;
1500 }
1501
1502 if (preproc.needs_subgroup_invocation_id)
1503 needs_subgroup_invocation_id = true;
1504 if (preproc.needs_subgroup_size)
1505 needs_subgroup_size = true;
1506 // build_implicit_builtins() hasn't run yet, and in fact, this needs to execute
1507 // before then so that gl_SampleID will get added; so we also need to check if
1508 // that function would add gl_FragCoord.
1509 if (preproc.needs_sample_id || msl_options.force_sample_rate_shading ||
1510 (is_sample_rate() && (active_input_builtins.get(BuiltInFragCoord) ||
1511 (need_subpass_input && !msl_options.use_framebuffer_fetch_subpasses))))
1512 needs_sample_id = true;
1513
1514 if (is_intersection_query())
1515 {
1516 add_header_line("#if __METAL_VERSION__ >= 230");
1517 add_header_line("#include <metal_raytracing>");
1518 add_header_line("using namespace metal::raytracing;");
1519 add_header_line("#endif");
1520 }
1521}
1522
1523// Move the Private and Workgroup global variables to the entry function.
1524// Non-constant variables cannot have global scope in Metal.
1525void CompilerMSL::localize_global_variables()
1526{
1527 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1528 auto iter = global_variables.begin();
1529 while (iter != global_variables.end())
1530 {
1531 uint32_t v_id = *iter;
1532 auto &var = get<SPIRVariable>(v_id);
1533 if (var.storage == StorageClassPrivate || var.storage == StorageClassWorkgroup)
1534 {
1535 if (!variable_is_lut(var))
1536 entry_func.add_local_variable(v_id);
1537 iter = global_variables.erase(iter);
1538 }
1539 else
1540 iter++;
1541 }
1542}
1543
1544// For any global variable accessed directly by a function,
1545// extract that variable and add it as an argument to that function.
1546void CompilerMSL::extract_global_variables_from_functions()
1547{
1548 // Uniforms
1549 unordered_set<uint32_t> global_var_ids;
1550 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1551 if (var.storage == StorageClassInput || var.storage == StorageClassOutput ||
1552 var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
1553 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer)
1554 {
1555 global_var_ids.insert(var.self);
1556 }
1557 });
1558
1559 // Local vars that are declared in the main function and accessed directly by a function
1560 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
1561 for (auto &var : entry_func.local_variables)
1562 if (get<SPIRVariable>(var).storage != StorageClassFunction)
1563 global_var_ids.insert(var);
1564
1565 std::set<uint32_t> added_arg_ids;
1566 unordered_set<uint32_t> processed_func_ids;
1567 extract_global_variables_from_function(ir.default_entry_point, added_arg_ids, global_var_ids, processed_func_ids);
1568}
1569
1570// MSL does not support the use of global variables for shader input content.
1571// For any global variable accessed directly by the specified function, extract that variable,
1572// add it as an argument to that function, and the arg to the added_arg_ids collection.
1573void CompilerMSL::extract_global_variables_from_function(uint32_t func_id, std::set<uint32_t> &added_arg_ids,
1574 unordered_set<uint32_t> &global_var_ids,
1575 unordered_set<uint32_t> &processed_func_ids)
1576{
1577 // Avoid processing a function more than once
1578 if (processed_func_ids.find(func_id) != processed_func_ids.end())
1579 {
1580 // Return function global variables
1581 added_arg_ids = function_global_vars[func_id];
1582 return;
1583 }
1584
1585 processed_func_ids.insert(func_id);
1586
1587 auto &func = get<SPIRFunction>(func_id);
1588
1589 // Recursively establish global args added to functions on which we depend.
1590 for (auto block : func.blocks)
1591 {
1592 auto &b = get<SPIRBlock>(block);
1593 for (auto &i : b.ops)
1594 {
1595 auto ops = stream(i);
1596 auto op = static_cast<Op>(i.op);
1597
1598 switch (op)
1599 {
1600 case OpLoad:
1601 case OpInBoundsAccessChain:
1602 case OpAccessChain:
1603 case OpPtrAccessChain:
1604 case OpArrayLength:
1605 {
1606 uint32_t base_id = ops[2];
1607 if (global_var_ids.find(base_id) != global_var_ids.end())
1608 added_arg_ids.insert(base_id);
1609
1610 // Use Metal's native frame-buffer fetch API for subpass inputs.
1611 auto &type = get<SPIRType>(ops[0]);
1612 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
1613 (!msl_options.use_framebuffer_fetch_subpasses))
1614 {
1615 // Implicitly reads gl_FragCoord.
1616 assert(builtin_frag_coord_id != 0);
1617 added_arg_ids.insert(builtin_frag_coord_id);
1618 if (msl_options.multiview)
1619 {
1620 // Implicitly reads gl_ViewIndex.
1621 assert(builtin_view_idx_id != 0);
1622 added_arg_ids.insert(builtin_view_idx_id);
1623 }
1624 else if (msl_options.arrayed_subpass_input)
1625 {
1626 // Implicitly reads gl_Layer.
1627 assert(builtin_layer_id != 0);
1628 added_arg_ids.insert(builtin_layer_id);
1629 }
1630 }
1631
1632 break;
1633 }
1634
1635 case OpFunctionCall:
1636 {
1637 // First see if any of the function call args are globals
1638 for (uint32_t arg_idx = 3; arg_idx < i.length; arg_idx++)
1639 {
1640 uint32_t arg_id = ops[arg_idx];
1641 if (global_var_ids.find(arg_id) != global_var_ids.end())
1642 added_arg_ids.insert(arg_id);
1643 }
1644
1645 // Then recurse into the function itself to extract globals used internally in the function
1646 uint32_t inner_func_id = ops[2];
1647 std::set<uint32_t> inner_func_args;
1648 extract_global_variables_from_function(inner_func_id, inner_func_args, global_var_ids,
1649 processed_func_ids);
1650 added_arg_ids.insert(inner_func_args.begin(), inner_func_args.end());
1651 break;
1652 }
1653
1654 case OpStore:
1655 {
1656 uint32_t base_id = ops[0];
1657 if (global_var_ids.find(base_id) != global_var_ids.end())
1658 added_arg_ids.insert(base_id);
1659
1660 uint32_t rvalue_id = ops[1];
1661 if (global_var_ids.find(rvalue_id) != global_var_ids.end())
1662 added_arg_ids.insert(rvalue_id);
1663
1664 break;
1665 }
1666
1667 case OpSelect:
1668 {
1669 uint32_t base_id = ops[3];
1670 if (global_var_ids.find(base_id) != global_var_ids.end())
1671 added_arg_ids.insert(base_id);
1672 base_id = ops[4];
1673 if (global_var_ids.find(base_id) != global_var_ids.end())
1674 added_arg_ids.insert(base_id);
1675 break;
1676 }
1677
1678 // Emulate texture2D atomic operations
1679 case OpImageTexelPointer:
1680 {
1681 // When using the pointer, we need to know which variable it is actually loaded from.
1682 uint32_t base_id = ops[2];
1683 auto *var = maybe_get_backing_variable(base_id);
1684 if (var && atomic_image_vars.count(var->self))
1685 {
1686 if (global_var_ids.find(base_id) != global_var_ids.end())
1687 added_arg_ids.insert(base_id);
1688 }
1689 break;
1690 }
1691
1692 case OpExtInst:
1693 {
1694 uint32_t extension_set = ops[2];
1695 if (get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
1696 {
1697 auto op_450 = static_cast<GLSLstd450>(ops[3]);
1698 switch (op_450)
1699 {
1700 case GLSLstd450InterpolateAtCentroid:
1701 case GLSLstd450InterpolateAtSample:
1702 case GLSLstd450InterpolateAtOffset:
1703 {
1704 // For these, we really need the stage-in block. It is theoretically possible to pass the
1705 // interpolant object, but a) doing so would require us to create an entirely new variable
1706 // with Interpolant type, and b) if we have a struct or array, handling all the members and
1707 // elements could get unwieldy fast.
1708 added_arg_ids.insert(stage_in_var_id);
1709 break;
1710 }
1711
1712 case GLSLstd450Modf:
1713 case GLSLstd450Frexp:
1714 {
1715 uint32_t base_id = ops[5];
1716 if (global_var_ids.find(base_id) != global_var_ids.end())
1717 added_arg_ids.insert(base_id);
1718 break;
1719 }
1720
1721 default:
1722 break;
1723 }
1724 }
1725 break;
1726 }
1727
1728 case OpGroupNonUniformInverseBallot:
1729 {
1730 added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1731 break;
1732 }
1733
1734 case OpGroupNonUniformBallotFindLSB:
1735 case OpGroupNonUniformBallotFindMSB:
1736 {
1737 added_arg_ids.insert(builtin_subgroup_size_id);
1738 break;
1739 }
1740
1741 case OpGroupNonUniformBallotBitCount:
1742 {
1743 auto operation = static_cast<GroupOperation>(ops[3]);
1744 switch (operation)
1745 {
1746 case GroupOperationReduce:
1747 added_arg_ids.insert(builtin_subgroup_size_id);
1748 break;
1749 case GroupOperationInclusiveScan:
1750 case GroupOperationExclusiveScan:
1751 added_arg_ids.insert(builtin_subgroup_invocation_id_id);
1752 break;
1753 default:
1754 break;
1755 }
1756 break;
1757 }
1758
1759 default:
1760 break;
1761 }
1762
1763 // TODO: Add all other operations which can affect memory.
1764 // We should consider a more unified system here to reduce boiler-plate.
1765 // This kind of analysis is done in several places ...
1766 }
1767 }
1768
1769 function_global_vars[func_id] = added_arg_ids;
1770
1771 // Add the global variables as arguments to the function
1772 if (func_id != ir.default_entry_point)
1773 {
1774 bool control_point_added_in = false;
1775 bool control_point_added_out = false;
1776 bool patch_added_in = false;
1777 bool patch_added_out = false;
1778
1779 for (uint32_t arg_id : added_arg_ids)
1780 {
1781 auto &var = get<SPIRVariable>(arg_id);
1782 uint32_t type_id = var.basetype;
1783 auto *p_type = &get<SPIRType>(type_id);
1784 BuiltIn bi_type = BuiltIn(get_decoration(arg_id, DecorationBuiltIn));
1785
1786 bool is_patch = has_decoration(arg_id, DecorationPatch) || is_patch_block(*p_type);
1787 bool is_block = has_decoration(p_type->self, DecorationBlock);
1788 bool is_control_point_storage =
1789 !is_patch &&
1790 ((is_tessellation_shader() && var.storage == StorageClassInput) ||
1791 (get_execution_model() == ExecutionModelTessellationControl && var.storage == StorageClassOutput));
1792 bool is_patch_block_storage = is_patch && is_block && var.storage == StorageClassOutput;
1793 bool is_builtin = is_builtin_variable(var);
1794 bool variable_is_stage_io =
1795 !is_builtin || bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
1796 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance ||
1797 p_type->basetype == SPIRType::Struct;
1798 bool is_redirected_to_global_stage_io = (is_control_point_storage || is_patch_block_storage) &&
1799 variable_is_stage_io;
1800
1801 // If output is masked it is not considered part of the global stage IO interface.
1802 if (is_redirected_to_global_stage_io && var.storage == StorageClassOutput)
1803 is_redirected_to_global_stage_io = !is_stage_output_variable_masked(var);
1804
1805 if (is_redirected_to_global_stage_io)
1806 {
1807 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1808 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1809 // We collected them into a structure; we must pass the array of this
1810 // structure to the function.
1811 std::string name;
1812 if (is_patch)
1813 name = var.storage == StorageClassInput ? patch_stage_in_var_name : patch_stage_out_var_name;
1814 else
1815 name = var.storage == StorageClassInput ? "gl_in" : "gl_out";
1816
1817 if (var.storage == StorageClassOutput && has_decoration(p_type->self, DecorationBlock))
1818 {
1819 // If we're redirecting a block, we might still need to access the original block
1820 // variable if we're masking some members.
1821 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(p_type->member_types.size()); mbr_idx++)
1822 {
1823 if (is_stage_output_block_member_masked(var, mbr_idx, true))
1824 {
1825 func.add_parameter(var.basetype, var.self, true);
1826 break;
1827 }
1828 }
1829 }
1830
1831 // Tessellation control shaders see inputs and per-vertex outputs as arrays.
1832 // Similarly, tessellation evaluation shaders see per-vertex inputs as arrays.
1833 // We collected them into a structure; we must pass the array of this
1834 // structure to the function.
1835 if (var.storage == StorageClassInput)
1836 {
1837 auto &added_in = is_patch ? patch_added_in : control_point_added_in;
1838 if (added_in)
1839 continue;
1840 arg_id = is_patch ? patch_stage_in_var_id : stage_in_ptr_var_id;
1841 added_in = true;
1842 }
1843 else if (var.storage == StorageClassOutput)
1844 {
1845 auto &added_out = is_patch ? patch_added_out : control_point_added_out;
1846 if (added_out)
1847 continue;
1848 arg_id = is_patch ? patch_stage_out_var_id : stage_out_ptr_var_id;
1849 added_out = true;
1850 }
1851
1852 type_id = get<SPIRVariable>(arg_id).basetype;
1853 uint32_t next_id = ir.increase_bound_by(1);
1854 func.add_parameter(type_id, next_id, true);
1855 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1856
1857 set_name(next_id, name);
1858 }
1859 else if (is_builtin && has_decoration(p_type->self, DecorationBlock))
1860 {
1861 // Get the pointee type
1862 type_id = get_pointee_type_id(type_id);
1863 p_type = &get<SPIRType>(type_id);
1864
1865 uint32_t mbr_idx = 0;
1866 for (auto &mbr_type_id : p_type->member_types)
1867 {
1868 BuiltIn builtin = BuiltInMax;
1869 is_builtin = is_member_builtin(*p_type, mbr_idx, &builtin);
1870 if (is_builtin && has_active_builtin(builtin, var.storage))
1871 {
1872 // Add a arg variable with the same type and decorations as the member
1873 uint32_t next_ids = ir.increase_bound_by(2);
1874 uint32_t ptr_type_id = next_ids + 0;
1875 uint32_t var_id = next_ids + 1;
1876
1877 // Make sure we have an actual pointer type,
1878 // so that we will get the appropriate address space when declaring these builtins.
1879 auto &ptr = set<SPIRType>(ptr_type_id, get<SPIRType>(mbr_type_id));
1880 ptr.self = mbr_type_id;
1881 ptr.storage = var.storage;
1882 ptr.pointer = true;
1883 ptr.pointer_depth++;
1884 ptr.parent_type = mbr_type_id;
1885
1886 func.add_parameter(mbr_type_id, var_id, true);
1887 set<SPIRVariable>(var_id, ptr_type_id, StorageClassFunction);
1888 ir.meta[var_id].decoration = ir.meta[type_id].members[mbr_idx];
1889 }
1890 mbr_idx++;
1891 }
1892 }
1893 else
1894 {
1895 uint32_t next_id = ir.increase_bound_by(1);
1896 func.add_parameter(type_id, next_id, true);
1897 set<SPIRVariable>(next_id, type_id, StorageClassFunction, 0, arg_id);
1898
1899 // Ensure the existing variable has a valid name and the new variable has all the same meta info
1900 set_name(arg_id, ensure_valid_name(to_name(arg_id), "v"));
1901 ir.meta[next_id] = ir.meta[arg_id];
1902 }
1903 }
1904 }
1905}
1906
1907// For all variables that are some form of non-input-output interface block, mark that all the structs
1908// that are recursively contained within the type referenced by that variable should be packed tightly.
1909void CompilerMSL::mark_packable_structs()
1910{
1911 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1912 if (var.storage != StorageClassFunction && !is_hidden_variable(var))
1913 {
1914 auto &type = this->get<SPIRType>(var.basetype);
1915 if (type.pointer &&
1916 (type.storage == StorageClassUniform || type.storage == StorageClassUniformConstant ||
1917 type.storage == StorageClassPushConstant || type.storage == StorageClassStorageBuffer) &&
1918 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
1919 mark_as_packable(type);
1920 }
1921 });
1922}
1923
1924// If the specified type is a struct, it and any nested structs
1925// are marked as packable with the SPIRVCrossDecorationBufferBlockRepacked decoration,
1926void CompilerMSL::mark_as_packable(SPIRType &type)
1927{
1928 // If this is not the base type (eg. it's a pointer or array), tunnel down
1929 if (type.parent_type)
1930 {
1931 mark_as_packable(get<SPIRType>(type.parent_type));
1932 return;
1933 }
1934
1935 if (type.basetype == SPIRType::Struct)
1936 {
1937 set_extended_decoration(type.self, SPIRVCrossDecorationBufferBlockRepacked);
1938
1939 // Recurse
1940 uint32_t mbr_cnt = uint32_t(type.member_types.size());
1941 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
1942 {
1943 uint32_t mbr_type_id = type.member_types[mbr_idx];
1944 auto &mbr_type = get<SPIRType>(mbr_type_id);
1945 mark_as_packable(mbr_type);
1946 if (mbr_type.type_alias)
1947 {
1948 auto &mbr_type_alias = get<SPIRType>(mbr_type.type_alias);
1949 mark_as_packable(mbr_type_alias);
1950 }
1951 }
1952 }
1953}
1954
1955// If a shader input exists at the location, it is marked as being used by this shader
1956void CompilerMSL::mark_location_as_used_by_shader(uint32_t location, const SPIRType &type,
1957 StorageClass storage, bool fallback)
1958{
1959 if (storage != StorageClassInput)
1960 return;
1961
1962 uint32_t count = type_to_location_count(type);
1963 for (uint32_t i = 0; i < count; i++)
1964 {
1965 location_inputs_in_use.insert(location + i);
1966 if (fallback)
1967 location_inputs_in_use_fallback.insert(location + i);
1968 }
1969}
1970
1971uint32_t CompilerMSL::get_target_components_for_fragment_location(uint32_t location) const
1972{
1973 auto itr = fragment_output_components.find(location);
1974 if (itr == end(fragment_output_components))
1975 return 4;
1976 else
1977 return itr->second;
1978}
1979
1980uint32_t CompilerMSL::build_extended_vector_type(uint32_t type_id, uint32_t components, SPIRType::BaseType basetype)
1981{
1982 uint32_t new_type_id = ir.increase_bound_by(1);
1983 auto &old_type = get<SPIRType>(type_id);
1984 auto *type = &set<SPIRType>(new_type_id, old_type);
1985 type->vecsize = components;
1986 if (basetype != SPIRType::Unknown)
1987 type->basetype = basetype;
1988 type->self = new_type_id;
1989 type->parent_type = type_id;
1990 type->array.clear();
1991 type->array_size_literal.clear();
1992 type->pointer = false;
1993
1994 if (is_array(old_type))
1995 {
1996 uint32_t array_type_id = ir.increase_bound_by(1);
1997 type = &set<SPIRType>(array_type_id, *type);
1998 type->parent_type = new_type_id;
1999 type->array = old_type.array;
2000 type->array_size_literal = old_type.array_size_literal;
2001 new_type_id = array_type_id;
2002 }
2003
2004 if (old_type.pointer)
2005 {
2006 uint32_t ptr_type_id = ir.increase_bound_by(1);
2007 type = &set<SPIRType>(ptr_type_id, *type);
2008 type->self = new_type_id;
2009 type->parent_type = new_type_id;
2010 type->storage = old_type.storage;
2011 type->pointer = true;
2012 type->pointer_depth++;
2013 new_type_id = ptr_type_id;
2014 }
2015
2016 return new_type_id;
2017}
2018
2019uint32_t CompilerMSL::build_msl_interpolant_type(uint32_t type_id, bool is_noperspective)
2020{
2021 uint32_t new_type_id = ir.increase_bound_by(1);
2022 SPIRType &type = set<SPIRType>(new_type_id, get<SPIRType>(type_id));
2023 type.basetype = SPIRType::Interpolant;
2024 type.parent_type = type_id;
2025 // In Metal, the pull-model interpolant type encodes perspective-vs-no-perspective in the type itself.
2026 // Add this decoration so we know which argument to pass to the template.
2027 if (is_noperspective)
2028 set_decoration(new_type_id, DecorationNoPerspective);
2029 return new_type_id;
2030}
2031
2032bool CompilerMSL::add_component_variable_to_interface_block(spv::StorageClass storage, const std::string &ib_var_ref,
2033 SPIRVariable &var,
2034 const SPIRType &type,
2035 InterfaceBlockMeta &meta)
2036{
2037 // Deal with Component decorations.
2038 const InterfaceBlockMeta::LocationMeta *location_meta = nullptr;
2039 uint32_t location = ~0u;
2040 if (has_decoration(var.self, DecorationLocation))
2041 {
2042 location = get_decoration(var.self, DecorationLocation);
2043 auto location_meta_itr = meta.location_meta.find(location);
2044 if (location_meta_itr != end(meta.location_meta))
2045 location_meta = &location_meta_itr->second;
2046 }
2047
2048 // Check if we need to pad fragment output to match a certain number of components.
2049 if (location_meta)
2050 {
2051 bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
2052 msl_options.pad_fragment_output_components &&
2053 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2054
2055 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2056 uint32_t start_component = get_decoration(var.self, DecorationComponent);
2057 uint32_t type_components = type.vecsize;
2058 uint32_t num_components = location_meta->num_components;
2059
2060 if (pad_fragment_output)
2061 {
2062 uint32_t locn = get_decoration(var.self, DecorationLocation);
2063 num_components = std::max(num_components, get_target_components_for_fragment_location(locn));
2064 }
2065
2066 // We have already declared an IO block member as m_location_N.
2067 // Just emit an early-declared variable and fixup as needed.
2068 // Arrays need to be unrolled here since each location might need a different number of components.
2069 entry_func.add_local_variable(var.self);
2070 vars_needing_early_declaration.push_back(var.self);
2071
2072 if (var.storage == StorageClassInput)
2073 {
2074 entry_func.fixup_hooks_in.push_back([=, &type, &var]() {
2075 if (!type.array.empty())
2076 {
2077 uint32_t array_size = to_array_size_literal(type);
2078 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2079 {
2080 statement(to_name(var.self), "[", loc_off, "]", " = ", ib_var_ref,
2081 ".m_location_", location + loc_off,
2082 vector_swizzle(type_components, start_component), ";");
2083 }
2084 }
2085 else
2086 {
2087 statement(to_name(var.self), " = ", ib_var_ref, ".m_location_", location,
2088 vector_swizzle(type_components, start_component), ";");
2089 }
2090 });
2091 }
2092 else
2093 {
2094 entry_func.fixup_hooks_out.push_back([=, &type, &var]() {
2095 if (!type.array.empty())
2096 {
2097 uint32_t array_size = to_array_size_literal(type);
2098 for (uint32_t loc_off = 0; loc_off < array_size; loc_off++)
2099 {
2100 statement(ib_var_ref, ".m_location_", location + loc_off,
2101 vector_swizzle(type_components, start_component), " = ",
2102 to_name(var.self), "[", loc_off, "];");
2103 }
2104 }
2105 else
2106 {
2107 statement(ib_var_ref, ".m_location_", location,
2108 vector_swizzle(type_components, start_component), " = ", to_name(var.self), ";");
2109 }
2110 });
2111 }
2112 return true;
2113 }
2114 else
2115 return false;
2116}
2117
2118void CompilerMSL::add_plain_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2119 SPIRType &ib_type, SPIRVariable &var, InterfaceBlockMeta &meta)
2120{
2121 bool is_builtin = is_builtin_variable(var);
2122 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2123 bool is_flat = has_decoration(var.self, DecorationFlat);
2124 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2125 bool is_centroid = has_decoration(var.self, DecorationCentroid);
2126 bool is_sample = has_decoration(var.self, DecorationSample);
2127
2128 // Add a reference to the variable type to the interface struct.
2129 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2130 uint32_t type_id = ensure_correct_builtin_type(var.basetype, builtin);
2131 var.basetype = type_id;
2132
2133 type_id = get_pointee_type_id(var.basetype);
2134 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2135 type_id = get<SPIRType>(type_id).parent_type;
2136 auto &type = get<SPIRType>(type_id);
2137 uint32_t target_components = 0;
2138 uint32_t type_components = type.vecsize;
2139
2140 bool padded_output = false;
2141 bool padded_input = false;
2142 uint32_t start_component = 0;
2143
2144 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2145
2146 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, type, meta))
2147 return;
2148
2149 bool pad_fragment_output = has_decoration(var.self, DecorationLocation) &&
2150 msl_options.pad_fragment_output_components &&
2151 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput;
2152
2153 if (pad_fragment_output)
2154 {
2155 uint32_t locn = get_decoration(var.self, DecorationLocation);
2156 target_components = get_target_components_for_fragment_location(locn);
2157 if (type_components < target_components)
2158 {
2159 // Make a new type here.
2160 type_id = build_extended_vector_type(type_id, target_components);
2161 padded_output = true;
2162 }
2163 }
2164
2165 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2166 ib_type.member_types.push_back(build_msl_interpolant_type(type_id, is_noperspective));
2167 else
2168 ib_type.member_types.push_back(type_id);
2169
2170 // Give the member a name
2171 string mbr_name = ensure_valid_name(to_expression(var.self), "m");
2172 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2173
2174 // Update the original variable reference to include the structure reference
2175 string qual_var_name = ib_var_ref + "." + mbr_name;
2176 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2177 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2178 {
2179 if (is_centroid)
2180 qual_var_name += ".interpolate_at_centroid()";
2181 else if (is_sample)
2182 qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2183 else
2184 qual_var_name += ".interpolate_at_center()";
2185 }
2186
2187 if (padded_output || padded_input)
2188 {
2189 entry_func.add_local_variable(var.self);
2190 vars_needing_early_declaration.push_back(var.self);
2191
2192 if (padded_output)
2193 {
2194 entry_func.fixup_hooks_out.push_back([=, &var]() {
2195 statement(qual_var_name, vector_swizzle(type_components, start_component), " = ", to_name(var.self),
2196 ";");
2197 });
2198 }
2199 else
2200 {
2201 entry_func.fixup_hooks_in.push_back([=, &var]() {
2202 statement(to_name(var.self), " = ", qual_var_name, vector_swizzle(type_components, start_component),
2203 ";");
2204 });
2205 }
2206 }
2207 else if (!meta.strip_array)
2208 ir.meta[var.self].decoration.qualified_alias = qual_var_name;
2209
2210 if (var.storage == StorageClassOutput && var.initializer != ID(0))
2211 {
2212 if (padded_output || padded_input)
2213 {
2214 entry_func.fixup_hooks_in.push_back(
2215 [=, &var]() { statement(to_name(var.self), " = ", to_expression(var.initializer), ";"); });
2216 }
2217 else
2218 {
2219 if (meta.strip_array)
2220 {
2221 entry_func.fixup_hooks_in.push_back([=, &var]() {
2222 uint32_t index = get_extended_decoration(var.self, SPIRVCrossDecorationInterfaceMemberIndex);
2223 auto invocation = to_tesc_invocation_id();
2224 statement(to_expression(stage_out_ptr_var_id), "[",
2225 invocation, "].",
2226 to_member_name(ib_type, index), " = ", to_expression(var.initializer), "[",
2227 invocation, "];");
2228 });
2229 }
2230 else
2231 {
2232 entry_func.fixup_hooks_in.push_back([=, &var]() {
2233 statement(qual_var_name, " = ", to_expression(var.initializer), ";");
2234 });
2235 }
2236 }
2237 }
2238
2239 // Copy the variable location from the original variable to the member
2240 if (get_decoration_bitset(var.self).get(DecorationLocation))
2241 {
2242 uint32_t locn = get_decoration(var.self, DecorationLocation);
2243 uint32_t comp = get_decoration(var.self, DecorationComponent);
2244 if (storage == StorageClassInput)
2245 {
2246 type_id = ensure_correct_input_type(var.basetype, locn, comp, 0, meta.strip_array);
2247 var.basetype = type_id;
2248
2249 type_id = get_pointee_type_id(type_id);
2250 if (meta.strip_array && is_array(get<SPIRType>(type_id)))
2251 type_id = get<SPIRType>(type_id).parent_type;
2252 if (pull_model_inputs.count(var.self))
2253 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(type_id, is_noperspective);
2254 else
2255 ib_type.member_types[ib_mbr_idx] = type_id;
2256 }
2257 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2258 if (comp)
2259 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2260 mark_location_as_used_by_shader(locn, get<SPIRType>(type_id), storage);
2261 }
2262 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin))
2263 {
2264 uint32_t locn = inputs_by_builtin[builtin].location;
2265 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2266 mark_location_as_used_by_shader(locn, type, storage);
2267 }
2268
2269 if (get_decoration_bitset(var.self).get(DecorationComponent))
2270 {
2271 uint32_t component = get_decoration(var.self, DecorationComponent);
2272 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, component);
2273 }
2274
2275 if (get_decoration_bitset(var.self).get(DecorationIndex))
2276 {
2277 uint32_t index = get_decoration(var.self, DecorationIndex);
2278 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2279 }
2280
2281 // Mark the member as builtin if needed
2282 if (is_builtin)
2283 {
2284 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2285 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2286 qual_pos_var_name = qual_var_name;
2287 }
2288
2289 // Copy interpolation decorations if needed
2290 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2291 {
2292 if (is_flat)
2293 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2294 if (is_noperspective)
2295 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2296 if (is_centroid)
2297 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2298 if (is_sample)
2299 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2300 }
2301
2302 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2303}
2304
2305void CompilerMSL::add_composite_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2306 SPIRType &ib_type, SPIRVariable &var,
2307 InterfaceBlockMeta &meta)
2308{
2309 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2310 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2311 uint32_t elem_cnt = 0;
2312
2313 if (add_component_variable_to_interface_block(storage, ib_var_ref, var, var_type, meta))
2314 return;
2315
2316 if (is_matrix(var_type))
2317 {
2318 if (is_array(var_type))
2319 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2320
2321 elem_cnt = var_type.columns;
2322 }
2323 else if (is_array(var_type))
2324 {
2325 if (var_type.array.size() != 1)
2326 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2327
2328 elem_cnt = to_array_size_literal(var_type);
2329 }
2330
2331 bool is_builtin = is_builtin_variable(var);
2332 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2333 bool is_flat = has_decoration(var.self, DecorationFlat);
2334 bool is_noperspective = has_decoration(var.self, DecorationNoPerspective);
2335 bool is_centroid = has_decoration(var.self, DecorationCentroid);
2336 bool is_sample = has_decoration(var.self, DecorationSample);
2337
2338 auto *usable_type = &var_type;
2339 if (usable_type->pointer)
2340 usable_type = &get<SPIRType>(usable_type->parent_type);
2341 while (is_array(*usable_type) || is_matrix(*usable_type))
2342 usable_type = &get<SPIRType>(usable_type->parent_type);
2343
2344 // If a builtin, force it to have the proper name.
2345 if (is_builtin)
2346 set_name(var.self, builtin_to_glsl(builtin, StorageClassFunction));
2347
2348 bool flatten_from_ib_var = false;
2349 string flatten_from_ib_mbr_name;
2350
2351 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2352 {
2353 // Also declare [[clip_distance]] attribute here.
2354 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2355 ib_type.member_types.push_back(get_variable_data_type_id(var));
2356 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2357
2358 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2359 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2360
2361 // When we flatten, we flatten directly from the "out" struct,
2362 // not from a function variable.
2363 flatten_from_ib_var = true;
2364
2365 if (!msl_options.enable_clip_distance_user_varying)
2366 return;
2367 }
2368 else if (!meta.strip_array)
2369 {
2370 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2371 entry_func.add_local_variable(var.self);
2372 // We need to declare the variable early and at entry-point scope.
2373 vars_needing_early_declaration.push_back(var.self);
2374 }
2375
2376 for (uint32_t i = 0; i < elem_cnt; i++)
2377 {
2378 // Add a reference to the variable type to the interface struct.
2379 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2380
2381 uint32_t target_components = 0;
2382 bool padded_output = false;
2383 uint32_t type_id = usable_type->self;
2384
2385 // Check if we need to pad fragment output to match a certain number of components.
2386 if (get_decoration_bitset(var.self).get(DecorationLocation) && msl_options.pad_fragment_output_components &&
2387 get_entry_point().model == ExecutionModelFragment && storage == StorageClassOutput)
2388 {
2389 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2390 target_components = get_target_components_for_fragment_location(locn);
2391 if (usable_type->vecsize < target_components)
2392 {
2393 // Make a new type here.
2394 type_id = build_extended_vector_type(usable_type->self, target_components);
2395 padded_output = true;
2396 }
2397 }
2398
2399 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2400 ib_type.member_types.push_back(build_msl_interpolant_type(get_pointee_type_id(type_id), is_noperspective));
2401 else
2402 ib_type.member_types.push_back(get_pointee_type_id(type_id));
2403
2404 // Give the member a name
2405 string mbr_name = ensure_valid_name(join(to_expression(var.self), "_", i), "m");
2406 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2407
2408 // There is no qualified alias since we need to flatten the internal array on return.
2409 if (get_decoration_bitset(var.self).get(DecorationLocation))
2410 {
2411 uint32_t locn = get_decoration(var.self, DecorationLocation) + i;
2412 uint32_t comp = get_decoration(var.self, DecorationComponent);
2413 if (storage == StorageClassInput)
2414 {
2415 var.basetype = ensure_correct_input_type(var.basetype, locn, comp, 0, meta.strip_array);
2416 uint32_t mbr_type_id = ensure_correct_input_type(usable_type->self, locn, comp, 0, meta.strip_array);
2417 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2418 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2419 else
2420 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2421 }
2422 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2423 if (comp)
2424 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2425 mark_location_as_used_by_shader(locn, *usable_type, storage);
2426 }
2427 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin))
2428 {
2429 uint32_t locn = inputs_by_builtin[builtin].location + i;
2430 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2431 mark_location_as_used_by_shader(locn, *usable_type, storage);
2432 }
2433 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2434 {
2435 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
2436 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2437 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i);
2438 }
2439
2440 if (get_decoration_bitset(var.self).get(DecorationIndex))
2441 {
2442 uint32_t index = get_decoration(var.self, DecorationIndex);
2443 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, index);
2444 }
2445
2446 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2447 {
2448 // Copy interpolation decorations if needed
2449 if (is_flat)
2450 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2451 if (is_noperspective)
2452 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2453 if (is_centroid)
2454 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2455 if (is_sample)
2456 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2457 }
2458
2459 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2460
2461 // Only flatten/unflatten IO composites for non-tessellation cases where arrays are not stripped.
2462 if (!meta.strip_array)
2463 {
2464 switch (storage)
2465 {
2466 case StorageClassInput:
2467 entry_func.fixup_hooks_in.push_back([=, &var]() {
2468 if (pull_model_inputs.count(var.self))
2469 {
2470 string lerp_call;
2471 if (is_centroid)
2472 lerp_call = ".interpolate_at_centroid()";
2473 else if (is_sample)
2474 lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2475 else
2476 lerp_call = ".interpolate_at_center()";
2477 statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, lerp_call, ";");
2478 }
2479 else
2480 {
2481 statement(to_name(var.self), "[", i, "] = ", ib_var_ref, ".", mbr_name, ";");
2482 }
2483 });
2484 break;
2485
2486 case StorageClassOutput:
2487 entry_func.fixup_hooks_out.push_back([=, &var]() {
2488 if (padded_output)
2489 {
2490 auto &padded_type = this->get<SPIRType>(type_id);
2491 statement(
2492 ib_var_ref, ".", mbr_name, " = ",
2493 remap_swizzle(padded_type, usable_type->vecsize, join(to_name(var.self), "[", i, "]")),
2494 ";");
2495 }
2496 else if (flatten_from_ib_var)
2497 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2498 "];");
2499 else
2500 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), "[", i, "];");
2501 });
2502 break;
2503
2504 default:
2505 break;
2506 }
2507 }
2508 }
2509}
2510
2511void CompilerMSL::add_composite_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2512 SPIRType &ib_type, SPIRVariable &var,
2513 uint32_t mbr_idx, InterfaceBlockMeta &meta)
2514{
2515 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2516 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2517
2518 BuiltIn builtin = BuiltInMax;
2519 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2520 bool is_flat =
2521 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2522 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2523 has_decoration(var.self, DecorationNoPerspective);
2524 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2525 has_decoration(var.self, DecorationCentroid);
2526 bool is_sample =
2527 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2528
2529 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2530 auto &mbr_type = get<SPIRType>(mbr_type_id);
2531 uint32_t elem_cnt = 0;
2532
2533 if (is_matrix(mbr_type))
2534 {
2535 if (is_array(mbr_type))
2536 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-matrices in input and output variables.");
2537
2538 elem_cnt = mbr_type.columns;
2539 }
2540 else if (is_array(mbr_type))
2541 {
2542 if (mbr_type.array.size() != 1)
2543 SPIRV_CROSS_THROW("MSL cannot emit arrays-of-arrays in input and output variables.");
2544
2545 elem_cnt = to_array_size_literal(mbr_type);
2546 }
2547
2548 auto *usable_type = &mbr_type;
2549 if (usable_type->pointer)
2550 usable_type = &get<SPIRType>(usable_type->parent_type);
2551 while (is_array(*usable_type) || is_matrix(*usable_type))
2552 usable_type = &get<SPIRType>(usable_type->parent_type);
2553
2554 bool flatten_from_ib_var = false;
2555 string flatten_from_ib_mbr_name;
2556
2557 if (storage == StorageClassOutput && is_builtin && builtin == BuiltInClipDistance)
2558 {
2559 // Also declare [[clip_distance]] attribute here.
2560 uint32_t clip_array_mbr_idx = uint32_t(ib_type.member_types.size());
2561 ib_type.member_types.push_back(mbr_type_id);
2562 set_member_decoration(ib_type.self, clip_array_mbr_idx, DecorationBuiltIn, BuiltInClipDistance);
2563
2564 flatten_from_ib_mbr_name = builtin_to_glsl(BuiltInClipDistance, StorageClassOutput);
2565 set_member_name(ib_type.self, clip_array_mbr_idx, flatten_from_ib_mbr_name);
2566
2567 // When we flatten, we flatten directly from the "out" struct,
2568 // not from a function variable.
2569 flatten_from_ib_var = true;
2570
2571 if (!msl_options.enable_clip_distance_user_varying)
2572 return;
2573 }
2574
2575 for (uint32_t i = 0; i < elem_cnt; i++)
2576 {
2577 // Add a reference to the variable type to the interface struct.
2578 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2579 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2580 ib_type.member_types.push_back(build_msl_interpolant_type(usable_type->self, is_noperspective));
2581 else
2582 ib_type.member_types.push_back(usable_type->self);
2583
2584 // Give the member a name
2585 string mbr_name = ensure_valid_name(join(to_qualified_member_name(var_type, mbr_idx), "_", i), "m");
2586 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2587
2588 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2589 {
2590 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation) + i;
2591 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2592 mark_location_as_used_by_shader(locn, *usable_type, storage);
2593 }
2594 else if (has_decoration(var.self, DecorationLocation))
2595 {
2596 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array) + i;
2597 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2598 mark_location_as_used_by_shader(locn, *usable_type, storage);
2599 }
2600 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin))
2601 {
2602 uint32_t locn = inputs_by_builtin[builtin].location + i;
2603 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2604 mark_location_as_used_by_shader(locn, *usable_type, storage);
2605 }
2606 else if (is_builtin && (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance))
2607 {
2608 // Declare the Clip/CullDistance as [[user(clip/cullN)]].
2609 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2610 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationIndex, i);
2611 }
2612
2613 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2614 SPIRV_CROSS_THROW("DecorationComponent on matrices and arrays make little sense.");
2615
2616 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2617 {
2618 // Copy interpolation decorations if needed
2619 if (is_flat)
2620 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2621 if (is_noperspective)
2622 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2623 if (is_centroid)
2624 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2625 if (is_sample)
2626 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2627 }
2628
2629 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2630 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2631
2632 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2633 if (!meta.strip_array && meta.allow_local_declaration)
2634 {
2635 switch (storage)
2636 {
2637 case StorageClassInput:
2638 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2639 if (pull_model_inputs.count(var.self))
2640 {
2641 string lerp_call;
2642 if (is_centroid)
2643 lerp_call = ".interpolate_at_centroid()";
2644 else if (is_sample)
2645 lerp_call = join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2646 else
2647 lerp_call = ".interpolate_at_center()";
2648 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2649 ".", mbr_name, lerp_call, ";");
2650 }
2651 else
2652 {
2653 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), "[", i, "] = ", ib_var_ref,
2654 ".", mbr_name, ";");
2655 }
2656 });
2657 break;
2658
2659 case StorageClassOutput:
2660 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2661 if (flatten_from_ib_var)
2662 {
2663 statement(ib_var_ref, ".", mbr_name, " = ", ib_var_ref, ".", flatten_from_ib_mbr_name, "[", i,
2664 "];");
2665 }
2666 else
2667 {
2668 statement(ib_var_ref, ".", mbr_name, " = ", to_name(var.self), ".",
2669 to_member_name(var_type, mbr_idx), "[", i, "];");
2670 }
2671 });
2672 break;
2673
2674 default:
2675 break;
2676 }
2677 }
2678 }
2679}
2680
2681void CompilerMSL::add_plain_member_variable_to_interface_block(StorageClass storage, const string &ib_var_ref,
2682 SPIRType &ib_type, SPIRVariable &var, uint32_t mbr_idx,
2683 InterfaceBlockMeta &meta)
2684{
2685 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
2686 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2687
2688 BuiltIn builtin = BuiltInMax;
2689 bool is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
2690 bool is_flat =
2691 has_member_decoration(var_type.self, mbr_idx, DecorationFlat) || has_decoration(var.self, DecorationFlat);
2692 bool is_noperspective = has_member_decoration(var_type.self, mbr_idx, DecorationNoPerspective) ||
2693 has_decoration(var.self, DecorationNoPerspective);
2694 bool is_centroid = has_member_decoration(var_type.self, mbr_idx, DecorationCentroid) ||
2695 has_decoration(var.self, DecorationCentroid);
2696 bool is_sample =
2697 has_member_decoration(var_type.self, mbr_idx, DecorationSample) || has_decoration(var.self, DecorationSample);
2698
2699 // Add a reference to the member to the interface struct.
2700 uint32_t mbr_type_id = var_type.member_types[mbr_idx];
2701 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2702 mbr_type_id = ensure_correct_builtin_type(mbr_type_id, builtin);
2703 var_type.member_types[mbr_idx] = mbr_type_id;
2704 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2705 ib_type.member_types.push_back(build_msl_interpolant_type(mbr_type_id, is_noperspective));
2706 else
2707 ib_type.member_types.push_back(mbr_type_id);
2708
2709 // Give the member a name
2710 string mbr_name = ensure_valid_name(to_qualified_member_name(var_type, mbr_idx), "m");
2711 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2712
2713 // Update the original variable reference to include the structure reference
2714 string qual_var_name = ib_var_ref + "." + mbr_name;
2715 // If using pull-model interpolation, need to add a call to the correct interpolation method.
2716 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2717 {
2718 if (is_centroid)
2719 qual_var_name += ".interpolate_at_centroid()";
2720 else if (is_sample)
2721 qual_var_name += join(".interpolate_at_sample(", to_expression(builtin_sample_id_id), ")");
2722 else
2723 qual_var_name += ".interpolate_at_center()";
2724 }
2725
2726 bool flatten_stage_out = false;
2727
2728 if (is_builtin && !meta.strip_array)
2729 {
2730 // For the builtin gl_PerVertex, we cannot treat it as a block anyways,
2731 // so redirect to qualified name.
2732 set_member_qualified_name(var_type.self, mbr_idx, qual_var_name);
2733 }
2734 else if (!meta.strip_array && meta.allow_local_declaration)
2735 {
2736 // Unflatten or flatten from [[stage_in]] or [[stage_out]] as appropriate.
2737 switch (storage)
2738 {
2739 case StorageClassInput:
2740 entry_func.fixup_hooks_in.push_back([=, &var, &var_type]() {
2741 statement(to_name(var.self), ".", to_member_name(var_type, mbr_idx), " = ", qual_var_name, ";");
2742 });
2743 break;
2744
2745 case StorageClassOutput:
2746 flatten_stage_out = true;
2747 entry_func.fixup_hooks_out.push_back([=, &var, &var_type]() {
2748 statement(qual_var_name, " = ", to_name(var.self), ".", to_member_name(var_type, mbr_idx), ";");
2749 });
2750 break;
2751
2752 default:
2753 break;
2754 }
2755 }
2756
2757 // Copy the variable location from the original variable to the member
2758 if (has_member_decoration(var_type.self, mbr_idx, DecorationLocation))
2759 {
2760 uint32_t locn = get_member_decoration(var_type.self, mbr_idx, DecorationLocation);
2761 uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2762 if (storage == StorageClassInput)
2763 {
2764 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn, comp, 0, meta.strip_array);
2765 var_type.member_types[mbr_idx] = mbr_type_id;
2766 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2767 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2768 else
2769 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2770 }
2771 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2772 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2773 }
2774 else if (has_decoration(var.self, DecorationLocation))
2775 {
2776 // The block itself might have a location and in this case, all members of the block
2777 // receive incrementing locations.
2778 uint32_t locn = get_accumulated_member_location(var, mbr_idx, meta.strip_array);
2779 if (storage == StorageClassInput)
2780 {
2781 mbr_type_id = ensure_correct_input_type(mbr_type_id, locn, 0, 0, meta.strip_array);
2782 var_type.member_types[mbr_idx] = mbr_type_id;
2783 if (storage == StorageClassInput && pull_model_inputs.count(var.self))
2784 ib_type.member_types[ib_mbr_idx] = build_msl_interpolant_type(mbr_type_id, is_noperspective);
2785 else
2786 ib_type.member_types[ib_mbr_idx] = mbr_type_id;
2787 }
2788 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2789 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2790 }
2791 else if (is_builtin && is_tessellation_shader() && storage == StorageClassInput && inputs_by_builtin.count(builtin))
2792 {
2793 uint32_t locn = 0;
2794 auto builtin_itr = inputs_by_builtin.find(builtin);
2795 if (builtin_itr != end(inputs_by_builtin))
2796 locn = builtin_itr->second.location;
2797 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2798 mark_location_as_used_by_shader(locn, get<SPIRType>(mbr_type_id), storage);
2799 }
2800
2801 // Copy the component location, if present.
2802 if (has_member_decoration(var_type.self, mbr_idx, DecorationComponent))
2803 {
2804 uint32_t comp = get_member_decoration(var_type.self, mbr_idx, DecorationComponent);
2805 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationComponent, comp);
2806 }
2807
2808 // Mark the member as builtin if needed
2809 if (is_builtin)
2810 {
2811 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2812 if (builtin == BuiltInPosition && storage == StorageClassOutput)
2813 qual_pos_var_name = qual_var_name;
2814 }
2815
2816 const SPIRConstant *c = nullptr;
2817 if (!flatten_stage_out && var.storage == StorageClassOutput &&
2818 var.initializer != ID(0) && (c = maybe_get<SPIRConstant>(var.initializer)))
2819 {
2820 if (meta.strip_array)
2821 {
2822 entry_func.fixup_hooks_in.push_back([=, &var]() {
2823 auto &type = this->get<SPIRType>(var.basetype);
2824 uint32_t index = get_extended_member_decoration(var.self, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex);
2825
2826 auto invocation = to_tesc_invocation_id();
2827 auto constant_chain = join(to_expression(var.initializer), "[", invocation, "]");
2828 statement(to_expression(stage_out_ptr_var_id), "[",
2829 invocation, "].",
2830 to_member_name(ib_type, index), " = ",
2831 constant_chain, ".", to_member_name(type, mbr_idx), ";");
2832 });
2833 }
2834 else
2835 {
2836 entry_func.fixup_hooks_in.push_back([=]() {
2837 statement(qual_var_name, " = ", constant_expression(
2838 this->get<SPIRConstant>(c->subconstants[mbr_idx])), ";");
2839 });
2840 }
2841 }
2842
2843 if (storage != StorageClassInput || !pull_model_inputs.count(var.self))
2844 {
2845 // Copy interpolation decorations if needed
2846 if (is_flat)
2847 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
2848 if (is_noperspective)
2849 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
2850 if (is_centroid)
2851 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
2852 if (is_sample)
2853 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
2854 }
2855
2856 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceOrigID, var.self);
2857 set_extended_member_decoration(ib_type.self, ib_mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, mbr_idx);
2858}
2859
2860// In Metal, the tessellation levels are stored as tightly packed half-precision floating point values.
2861// But, stage-in attribute offsets and strides must be multiples of four, so we can't pass the levels
2862// individually. Therefore, we must pass them as vectors. Triangles get a single float4, with the outer
2863// levels in 'xyz' and the inner level in 'w'. Quads get a float4 containing the outer levels and a
2864// float2 containing the inner levels.
2865void CompilerMSL::add_tess_level_input_to_interface_block(const std::string &ib_var_ref, SPIRType &ib_type,
2866 SPIRVariable &var)
2867{
2868 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2869 auto &var_type = get_variable_element_type(var);
2870
2871 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
2872
2873 // Force the variable to have the proper name.
2874 string var_name = builtin_to_glsl(builtin, StorageClassFunction);
2875 set_name(var.self, var_name);
2876
2877 // We need to declare the variable early and at entry-point scope.
2878 entry_func.add_local_variable(var.self);
2879 vars_needing_early_declaration.push_back(var.self);
2880 bool triangles = get_execution_mode_bitset().get(ExecutionModeTriangles);
2881 string mbr_name;
2882
2883 // Add a reference to the variable type to the interface struct.
2884 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
2885
2886 const auto mark_locations = [&](const SPIRType &new_var_type) {
2887 if (get_decoration_bitset(var.self).get(DecorationLocation))
2888 {
2889 uint32_t locn = get_decoration(var.self, DecorationLocation);
2890 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2891 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2892 }
2893 else if (inputs_by_builtin.count(builtin))
2894 {
2895 uint32_t locn = inputs_by_builtin[builtin].location;
2896 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, locn);
2897 mark_location_as_used_by_shader(locn, new_var_type, StorageClassInput);
2898 }
2899 };
2900
2901 if (triangles)
2902 {
2903 // Triangles are tricky, because we want only one member in the struct.
2904 mbr_name = "gl_TessLevel";
2905
2906 // If we already added the other one, we can skip this step.
2907 if (!added_builtin_tess_level)
2908 {
2909 uint32_t type_id = build_extended_vector_type(var_type.self, 4);
2910
2911 ib_type.member_types.push_back(type_id);
2912
2913 // Give the member a name
2914 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2915
2916 // We cannot decorate both, but the important part is that
2917 // it's marked as builtin so we can get automatic attribute assignment if needed.
2918 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2919
2920 mark_locations(var_type);
2921 added_builtin_tess_level = true;
2922 }
2923 }
2924 else
2925 {
2926 mbr_name = var_name;
2927
2928 uint32_t type_id = build_extended_vector_type(var_type.self, builtin == BuiltInTessLevelOuter ? 4 : 2);
2929
2930 uint32_t ptr_type_id = ir.increase_bound_by(1);
2931 auto &new_var_type = set<SPIRType>(ptr_type_id, get<SPIRType>(type_id));
2932 new_var_type.pointer = true;
2933 new_var_type.pointer_depth++;
2934 new_var_type.storage = StorageClassInput;
2935 new_var_type.parent_type = type_id;
2936
2937 ib_type.member_types.push_back(type_id);
2938
2939 // Give the member a name
2940 set_member_name(ib_type.self, ib_mbr_idx, mbr_name);
2941 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationBuiltIn, builtin);
2942
2943 mark_locations(new_var_type);
2944 }
2945
2946 if (builtin == BuiltInTessLevelOuter)
2947 {
2948 entry_func.fixup_hooks_in.push_back([=]() {
2949 statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2950 statement(var_name, "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2951 statement(var_name, "[2] = ", ib_var_ref, ".", mbr_name, ".z;");
2952 if (!triangles)
2953 statement(var_name, "[3] = ", ib_var_ref, ".", mbr_name, ".w;");
2954 });
2955 }
2956 else
2957 {
2958 entry_func.fixup_hooks_in.push_back([=]() {
2959 if (triangles)
2960 {
2961 statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".w;");
2962 }
2963 else
2964 {
2965 statement(var_name, "[0] = ", ib_var_ref, ".", mbr_name, ".x;");
2966 statement(var_name, "[1] = ", ib_var_ref, ".", mbr_name, ".y;");
2967 }
2968 });
2969 }
2970}
2971
2972bool CompilerMSL::variable_storage_requires_stage_io(spv::StorageClass storage) const
2973{
2974 if (storage == StorageClassOutput)
2975 return !capture_output_to_buffer;
2976 else if (storage == StorageClassInput)
2977 return !(get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup);
2978 else
2979 return false;
2980}
2981
2982string CompilerMSL::to_tesc_invocation_id()
2983{
2984 if (msl_options.multi_patch_workgroup)
2985 {
2986 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
2987 // not the TC invocation ID.
2988 return join(to_expression(builtin_invocation_id_id), ".x % ", get_entry_point().output_vertices);
2989 }
2990 else
2991 return builtin_to_glsl(BuiltInInvocationId, StorageClassInput);
2992}
2993
2994void CompilerMSL::emit_local_masked_variable(const SPIRVariable &masked_var, bool strip_array)
2995{
2996 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
2997 bool threadgroup_storage = variable_decl_is_remapped_storage(masked_var, StorageClassWorkgroup);
2998
2999 if (threadgroup_storage && msl_options.multi_patch_workgroup)
3000 {
3001 // We need one threadgroup block per patch, so fake this.
3002 entry_func.fixup_hooks_in.push_back([this, &masked_var]() {
3003 auto &type = get_variable_data_type(masked_var);
3004 add_local_variable_name(masked_var.self);
3005
3006 bool old_is_builtin = is_using_builtin_array;
3007 is_using_builtin_array = true;
3008
3009 const uint32_t max_control_points_per_patch = 32u;
3010 uint32_t max_num_instances =
3011 (max_control_points_per_patch + get_entry_point().output_vertices - 1u) /
3012 get_entry_point().output_vertices;
3013 statement("threadgroup ", type_to_glsl(type), " ",
3014 "spvStorage", to_name(masked_var.self), "[", max_num_instances, "]",
3015 type_to_array_glsl(type), ";");
3016
3017 // Assign a threadgroup slice to each PrimitiveID.
3018 // We assume here that workgroup size is rounded to 32,
3019 // since that's the maximum number of control points per patch.
3020 // We cannot size the array based on fixed dispatch parameters,
3021 // since Metal does not allow that. :(
3022 // FIXME: We will likely need an option to support passing down target workgroup size,
3023 // so we can emit appropriate size here.
3024 statement("threadgroup ", type_to_glsl(type), " ",
3025 "(&", to_name(masked_var.self), ")",
3026 type_to_array_glsl(type), " = spvStorage", to_name(masked_var.self), "[",
3027 "(", to_expression(builtin_invocation_id_id), ".x / ",
3028 get_entry_point().output_vertices, ") % ",
3029 max_num_instances, "];");
3030
3031 is_using_builtin_array = old_is_builtin;
3032 });
3033 }
3034 else
3035 {
3036 entry_func.add_local_variable(masked_var.self);
3037 }
3038
3039 if (!threadgroup_storage)
3040 {
3041 vars_needing_early_declaration.push_back(masked_var.self);
3042 }
3043 else if (masked_var.initializer)
3044 {
3045 // Cannot directly initialize threadgroup variables. Need fixup hooks.
3046 ID initializer = masked_var.initializer;
3047 if (strip_array)
3048 {
3049 entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() {
3050 auto invocation = to_tesc_invocation_id();
3051 statement(to_expression(masked_var.self), "[",
3052 invocation, "] = ",
3053 to_expression(initializer), "[",
3054 invocation, "];");
3055 });
3056 }
3057 else
3058 {
3059 entry_func.fixup_hooks_in.push_back([this, &masked_var, initializer]() {
3060 statement(to_expression(masked_var.self), " = ", to_expression(initializer), ";");
3061 });
3062 }
3063 }
3064}
3065
3066void CompilerMSL::add_variable_to_interface_block(StorageClass storage, const string &ib_var_ref, SPIRType &ib_type,
3067 SPIRVariable &var, InterfaceBlockMeta &meta)
3068{
3069 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3070 // Tessellation control I/O variables and tessellation evaluation per-point inputs are
3071 // usually declared as arrays. In these cases, we want to add the element type to the
3072 // interface block, since in Metal it's the interface block itself which is arrayed.
3073 auto &var_type = meta.strip_array ? get_variable_element_type(var) : get_variable_data_type(var);
3074 bool is_builtin = is_builtin_variable(var);
3075 auto builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
3076 bool is_block = has_decoration(var_type.self, DecorationBlock);
3077
3078 // If stage variables are masked out, emit them as plain variables instead.
3079 // For builtins, we query them one by one later.
3080 // IO blocks are not masked here, we need to mask them per-member instead.
3081 if (storage == StorageClassOutput && is_stage_output_variable_masked(var))
3082 {
3083 // If we ignore an output, we must still emit it, since it might be used by app.
3084 // Instead, just emit it as early declaration.
3085 emit_local_masked_variable(var, meta.strip_array);
3086 return;
3087 }
3088
3089 if (var_type.basetype == SPIRType::Struct)
3090 {
3091 bool block_requires_flattening = variable_storage_requires_stage_io(storage) || is_block;
3092 bool needs_local_declaration = !is_builtin && block_requires_flattening && meta.allow_local_declaration;
3093
3094 if (needs_local_declaration)
3095 {
3096 // For I/O blocks or structs, we will need to pass the block itself around
3097 // to functions if they are used globally in leaf functions.
3098 // Rather than passing down member by member,
3099 // we unflatten I/O blocks while running the shader,
3100 // and pass the actual struct type down to leaf functions.
3101 // We then unflatten inputs, and flatten outputs in the "fixup" stages.
3102 emit_local_masked_variable(var, meta.strip_array);
3103 }
3104
3105 if (!block_requires_flattening)
3106 {
3107 // In Metal tessellation shaders, the interface block itself is arrayed. This makes things
3108 // very complicated, since stage-in structures in MSL don't support nested structures.
3109 // Luckily, for stage-out when capturing output, we can avoid this and just add
3110 // composite members directly, because the stage-out structure is stored to a buffer,
3111 // not returned.
3112 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3113 }
3114 else
3115 {
3116 bool masked_block = false;
3117
3118 // Flatten the struct members into the interface struct
3119 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3120 {
3121 builtin = BuiltInMax;
3122 is_builtin = is_member_builtin(var_type, mbr_idx, &builtin);
3123 auto &mbr_type = get<SPIRType>(var_type.member_types[mbr_idx]);
3124
3125 if (storage == StorageClassOutput && is_stage_output_block_member_masked(var, mbr_idx, meta.strip_array))
3126 {
3127 if (is_block)
3128 masked_block = true;
3129
3130 // Non-builtin block output variables are just ignored, since they will still access
3131 // the block variable as-is. They're just not flattened.
3132 if (is_builtin && !meta.strip_array)
3133 {
3134 // Emit a fake variable instead.
3135 uint32_t ids = ir.increase_bound_by(2);
3136 uint32_t ptr_type_id = ids + 0;
3137 uint32_t var_id = ids + 1;
3138
3139 auto ptr_type = mbr_type;
3140 ptr_type.pointer = true;
3141 ptr_type.pointer_depth++;
3142 ptr_type.parent_type = var_type.member_types[mbr_idx];
3143 ptr_type.storage = StorageClassOutput;
3144
3145 uint32_t initializer = 0;
3146 if (var.initializer)
3147 if (auto *c = maybe_get<SPIRConstant>(var.initializer))
3148 initializer = c->subconstants[mbr_idx];
3149
3150 set<SPIRType>(ptr_type_id, ptr_type);
3151 set<SPIRVariable>(var_id, ptr_type_id, StorageClassOutput, initializer);
3152 entry_func.add_local_variable(var_id);
3153 vars_needing_early_declaration.push_back(var_id);
3154 set_name(var_id, builtin_to_glsl(builtin, StorageClassOutput));
3155 set_decoration(var_id, DecorationBuiltIn, builtin);
3156 }
3157 }
3158 else if (!is_builtin || has_active_builtin(builtin, storage))
3159 {
3160 bool is_composite_type = is_matrix(mbr_type) || is_array(mbr_type);
3161 bool attribute_load_store =
3162 storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3163 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3164
3165 // Clip/CullDistance always need to be declared as user attributes.
3166 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3167 is_builtin = false;
3168
3169 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3170 {
3171 add_composite_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx,
3172 meta);
3173 }
3174 else
3175 {
3176 add_plain_member_variable_to_interface_block(storage, ib_var_ref, ib_type, var, mbr_idx, meta);
3177 }
3178 }
3179 }
3180
3181 // If we're redirecting a block, we might still need to access the original block
3182 // variable if we're masking some members.
3183 if (masked_block && !needs_local_declaration &&
3184 (!is_builtin_variable(var) || get_execution_model() == ExecutionModelTessellationControl))
3185 {
3186 if (is_builtin_variable(var))
3187 {
3188 // Ensure correct names for the block members if we're actually going to
3189 // declare gl_PerVertex.
3190 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(var_type.member_types.size()); mbr_idx++)
3191 {
3192 set_member_name(var_type.self, mbr_idx, builtin_to_glsl(
3193 BuiltIn(get_member_decoration(var_type.self, mbr_idx, DecorationBuiltIn)),
3194 StorageClassOutput));
3195 }
3196
3197 set_name(var_type.self, "gl_PerVertex");
3198 set_name(var.self, "gl_out_masked");
3199 stage_out_masked_builtin_type_id = var_type.self;
3200 }
3201 emit_local_masked_variable(var, meta.strip_array);
3202 }
3203 }
3204 }
3205 else if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput &&
3206 !meta.strip_array && is_builtin && (builtin == BuiltInTessLevelOuter || builtin == BuiltInTessLevelInner))
3207 {
3208 add_tess_level_input_to_interface_block(ib_var_ref, ib_type, var);
3209 }
3210 else if (var_type.basetype == SPIRType::Boolean || var_type.basetype == SPIRType::Char ||
3211 type_is_integral(var_type) || type_is_floating_point(var_type))
3212 {
3213 if (!is_builtin || has_active_builtin(builtin, storage))
3214 {
3215 bool is_composite_type = is_matrix(var_type) || is_array(var_type);
3216 bool storage_is_stage_io = variable_storage_requires_stage_io(storage);
3217 bool attribute_load_store = storage == StorageClassInput && get_execution_model() != ExecutionModelFragment;
3218
3219 // Clip/CullDistance always needs to be declared as user attributes.
3220 if (builtin == BuiltInClipDistance || builtin == BuiltInCullDistance)
3221 is_builtin = false;
3222
3223 // MSL does not allow matrices or arrays in input or output variables, so need to handle it specially.
3224 if ((!is_builtin || attribute_load_store) && storage_is_stage_io && is_composite_type)
3225 {
3226 add_composite_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3227 }
3228 else
3229 {
3230 add_plain_variable_to_interface_block(storage, ib_var_ref, ib_type, var, meta);
3231 }
3232 }
3233 }
3234}
3235
3236// Fix up the mapping of variables to interface member indices, which is used to compile access chains
3237// for per-vertex variables in a tessellation control shader.
3238void CompilerMSL::fix_up_interface_member_indices(StorageClass storage, uint32_t ib_type_id)
3239{
3240 // Only needed for tessellation shaders and pull-model interpolants.
3241 // Need to redirect interface indices back to variables themselves.
3242 // For structs, each member of the struct need a separate instance.
3243 if (get_execution_model() != ExecutionModelTessellationControl &&
3244 !(get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput) &&
3245 !(get_execution_model() == ExecutionModelFragment && storage == StorageClassInput &&
3246 !pull_model_inputs.empty()))
3247 return;
3248
3249 auto mbr_cnt = uint32_t(ir.meta[ib_type_id].members.size());
3250 for (uint32_t i = 0; i < mbr_cnt; i++)
3251 {
3252 uint32_t var_id = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceOrigID);
3253 if (!var_id)
3254 continue;
3255 auto &var = get<SPIRVariable>(var_id);
3256
3257 auto &type = get_variable_element_type(var);
3258
3259 bool flatten_composites = variable_storage_requires_stage_io(var.storage);
3260 bool is_block = has_decoration(type.self, DecorationBlock);
3261
3262 uint32_t mbr_idx = uint32_t(-1);
3263 if (type.basetype == SPIRType::Struct && (flatten_composites || is_block))
3264 mbr_idx = get_extended_member_decoration(ib_type_id, i, SPIRVCrossDecorationInterfaceMemberIndex);
3265
3266 if (mbr_idx != uint32_t(-1))
3267 {
3268 // Only set the lowest InterfaceMemberIndex for each variable member.
3269 // IB struct members will be emitted in-order w.r.t. interface member index.
3270 if (!has_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex))
3271 set_extended_member_decoration(var_id, mbr_idx, SPIRVCrossDecorationInterfaceMemberIndex, i);
3272 }
3273 else
3274 {
3275 // Only set the lowest InterfaceMemberIndex for each variable.
3276 // IB struct members will be emitted in-order w.r.t. interface member index.
3277 if (!has_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex))
3278 set_extended_decoration(var_id, SPIRVCrossDecorationInterfaceMemberIndex, i);
3279 }
3280 }
3281}
3282
3283// Add an interface structure for the type of storage, which is either StorageClassInput or StorageClassOutput.
3284// Returns the ID of the newly added variable, or zero if no variable was added.
3285uint32_t CompilerMSL::add_interface_block(StorageClass storage, bool patch)
3286{
3287 // Accumulate the variables that should appear in the interface struct.
3288 SmallVector<SPIRVariable *> vars;
3289 bool incl_builtins = storage == StorageClassOutput || is_tessellation_shader();
3290 bool has_seen_barycentric = false;
3291
3292 InterfaceBlockMeta meta;
3293
3294 // Varying interfaces between stages which use "user()" attribute can be dealt with
3295 // without explicit packing and unpacking of components. For any variables which link against the runtime
3296 // in some way (vertex attributes, fragment output, etc), we'll need to deal with it somehow.
3297 bool pack_components =
3298 (storage == StorageClassInput && get_execution_model() == ExecutionModelVertex) ||
3299 (storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment) ||
3300 (storage == StorageClassOutput && get_execution_model() == ExecutionModelVertex && capture_output_to_buffer);
3301
3302 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
3303 if (var.storage != storage)
3304 return;
3305
3306 auto &type = this->get<SPIRType>(var.basetype);
3307
3308 bool is_builtin = is_builtin_variable(var);
3309 bool is_block = has_decoration(type.self, DecorationBlock);
3310
3311 auto bi_type = BuiltInMax;
3312 bool builtin_is_gl_in_out = false;
3313 if (is_builtin && !is_block)
3314 {
3315 bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
3316 builtin_is_gl_in_out = bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
3317 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
3318 }
3319
3320 if (is_builtin && is_block)
3321 builtin_is_gl_in_out = true;
3322
3323 uint32_t location = get_decoration(var_id, DecorationLocation);
3324
3325 bool builtin_is_stage_in_out = builtin_is_gl_in_out ||
3326 bi_type == BuiltInLayer || bi_type == BuiltInViewportIndex ||
3327 bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV ||
3328 bi_type == BuiltInFragDepth ||
3329 bi_type == BuiltInFragStencilRefEXT || bi_type == BuiltInSampleMask;
3330
3331 // These builtins are part of the stage in/out structs.
3332 bool is_interface_block_builtin =
3333 builtin_is_stage_in_out ||
3334 (get_execution_model() == ExecutionModelTessellationEvaluation &&
3335 (bi_type == BuiltInTessLevelOuter || bi_type == BuiltInTessLevelInner));
3336
3337 bool is_active = interface_variable_exists_in_entry_point(var.self);
3338 if (is_builtin && is_active)
3339 {
3340 // Only emit the builtin if it's active in this entry point. Interface variable list might lie.
3341 if (is_block)
3342 {
3343 // If any builtin is active, the block is active.
3344 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3345 for (uint32_t i = 0; !is_active && i < mbr_cnt; i++)
3346 is_active = has_active_builtin(BuiltIn(get_member_decoration(type.self, i, DecorationBuiltIn)), storage);
3347 }
3348 else
3349 {
3350 is_active = has_active_builtin(bi_type, storage);
3351 }
3352 }
3353
3354 bool filter_patch_decoration = (has_decoration(var_id, DecorationPatch) || is_patch_block(type)) == patch;
3355
3356 bool hidden = is_hidden_variable(var, incl_builtins);
3357
3358 // ClipDistance is never hidden, we need to emulate it when used as an input.
3359 if (bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance)
3360 hidden = false;
3361
3362 // It's not enough to simply avoid marking fragment outputs if the pipeline won't
3363 // accept them. We can't put them in the struct at all, or otherwise the compiler
3364 // complains that the outputs weren't explicitly marked.
3365 // Frag depth and stencil outputs are incompatible with explicit early fragment tests.
3366 // In GLSL, depth and stencil outputs are just ignored when explicit early fragment tests are required.
3367 // In Metal, it's a compilation error, so we need to exclude them from the output struct.
3368 if (get_execution_model() == ExecutionModelFragment && storage == StorageClassOutput && !patch &&
3369 ((is_builtin && ((bi_type == BuiltInFragDepth && (!msl_options.enable_frag_depth_builtin || uses_explicit_early_fragment_test())) ||
3370 (bi_type == BuiltInFragStencilRefEXT && (!msl_options.enable_frag_stencil_ref_builtin || uses_explicit_early_fragment_test())))) ||
3371 (!is_builtin && !(msl_options.enable_frag_output_mask & (1 << location)))))
3372 {
3373 hidden = true;
3374 disabled_frag_outputs.push_back(var_id);
3375 // If a builtin, force it to have the proper name, and mark it as not part of the output struct.
3376 if (is_builtin)
3377 {
3378 set_name(var_id, builtin_to_glsl(bi_type, StorageClassFunction));
3379 mask_stage_output_by_builtin(bi_type);
3380 }
3381 }
3382
3383 // Barycentric inputs must be emitted in stage-in, because they can have interpolation arguments.
3384 if (is_active && (bi_type == BuiltInBaryCoordNV || bi_type == BuiltInBaryCoordNoPerspNV))
3385 {
3386 if (has_seen_barycentric)
3387 SPIRV_CROSS_THROW("Cannot declare both BaryCoordNV and BaryCoordNoPerspNV in same shader in MSL.");
3388 has_seen_barycentric = true;
3389 hidden = false;
3390 }
3391
3392 if (is_active && !hidden && type.pointer && filter_patch_decoration &&
3393 (!is_builtin || is_interface_block_builtin))
3394 {
3395 vars.push_back(&var);
3396
3397 if (!is_builtin)
3398 {
3399 // Need to deal specially with DecorationComponent.
3400 // Multiple variables can alias the same Location, and try to make sure each location is declared only once.
3401 // We will swizzle data in and out to make this work.
3402 // This is only relevant for vertex inputs and fragment outputs.
3403 // Technically tessellation as well, but it is too complicated to support.
3404 uint32_t component = get_decoration(var_id, DecorationComponent);
3405 if (component != 0)
3406 {
3407 if (is_tessellation_shader())
3408 SPIRV_CROSS_THROW("Component decoration is not supported in tessellation shaders.");
3409 else if (pack_components)
3410 {
3411 uint32_t array_size = 1;
3412 if (!type.array.empty())
3413 array_size = to_array_size_literal(type);
3414
3415 for (uint32_t location_offset = 0; location_offset < array_size; location_offset++)
3416 {
3417 auto &location_meta = meta.location_meta[location + location_offset];
3418 location_meta.num_components = std::max(location_meta.num_components, component + type.vecsize);
3419
3420 // For variables sharing location, decorations and base type must match.
3421 location_meta.base_type_id = type.self;
3422 location_meta.flat = has_decoration(var.self, DecorationFlat);
3423 location_meta.noperspective = has_decoration(var.self, DecorationNoPerspective);
3424 location_meta.centroid = has_decoration(var.self, DecorationCentroid);
3425 location_meta.sample = has_decoration(var.self, DecorationSample);
3426 }
3427 }
3428 }
3429 }
3430 }
3431 });
3432
3433 // If no variables qualify, leave.
3434 // For patch input in a tessellation evaluation shader, the per-vertex stage inputs
3435 // are included in a special patch control point array.
3436 if (vars.empty() && !(storage == StorageClassInput && patch && stage_in_var_id))
3437 return 0;
3438
3439 // Add a new typed variable for this interface structure.
3440 // The initializer expression is allocated here, but populated when the function
3441 // declaraion is emitted, because it is cleared after each compilation pass.
3442 uint32_t next_id = ir.increase_bound_by(3);
3443 uint32_t ib_type_id = next_id++;
3444 auto &ib_type = set<SPIRType>(ib_type_id);
3445 ib_type.basetype = SPIRType::Struct;
3446 ib_type.storage = storage;
3447 set_decoration(ib_type_id, DecorationBlock);
3448
3449 uint32_t ib_var_id = next_id++;
3450 auto &var = set<SPIRVariable>(ib_var_id, ib_type_id, storage, 0);
3451 var.initializer = next_id++;
3452
3453 string ib_var_ref;
3454 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
3455 switch (storage)
3456 {
3457 case StorageClassInput:
3458 ib_var_ref = patch ? patch_stage_in_var_name : stage_in_var_name;
3459 if (get_execution_model() == ExecutionModelTessellationControl)
3460 {
3461 // Add a hook to populate the shared workgroup memory containing the gl_in array.
3462 entry_func.fixup_hooks_in.push_back([=]() {
3463 // Can't use PatchVertices, PrimitiveId, or InvocationId yet; the hooks for those may not have run yet.
3464 if (msl_options.multi_patch_workgroup)
3465 {
3466 // n.b. builtin_invocation_id_id here is the dispatch global invocation ID,
3467 // not the TC invocation ID.
3468 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_in = &",
3469 input_buffer_var_name, "[min(", to_expression(builtin_invocation_id_id), ".x / ",
3470 get_entry_point().output_vertices,
3471 ", spvIndirectParams[1] - 1) * spvIndirectParams[0]];");
3472 }
3473 else
3474 {
3475 // It's safe to use InvocationId here because it's directly mapped to a
3476 // Metal builtin, and therefore doesn't need a hook.
3477 statement("if (", to_expression(builtin_invocation_id_id), " < spvIndirectParams[0])");
3478 statement(" ", input_wg_var_name, "[", to_expression(builtin_invocation_id_id),
3479 "] = ", ib_var_ref, ";");
3480 statement("threadgroup_barrier(mem_flags::mem_threadgroup);");
3481 statement("if (", to_expression(builtin_invocation_id_id),
3482 " >= ", get_entry_point().output_vertices, ")");
3483 statement(" return;");
3484 }
3485 });
3486 }
3487 break;
3488
3489 case StorageClassOutput:
3490 {
3491 ib_var_ref = patch ? patch_stage_out_var_name : stage_out_var_name;
3492
3493 // Add the output interface struct as a local variable to the entry function.
3494 // If the entry point should return the output struct, set the entry function
3495 // to return the output interface struct, otherwise to return nothing.
3496 // Watch out for the rare case where the terminator of the last entry point block is a
3497 // Kill, instead of a Return. Based on SPIR-V's block-domination rules, we assume that
3498 // any block that has a Kill will also have a terminating Return, except the last block.
3499 // Indicate the output var requires early initialization.
3500 bool ep_should_return_output = !get_is_rasterization_disabled();
3501 uint32_t rtn_id = ep_should_return_output ? ib_var_id : 0;
3502 if (!capture_output_to_buffer)
3503 {
3504 entry_func.add_local_variable(ib_var_id);
3505 for (auto &blk_id : entry_func.blocks)
3506 {
3507 auto &blk = get<SPIRBlock>(blk_id);
3508 if (blk.terminator == SPIRBlock::Return || (blk.terminator == SPIRBlock::Kill && blk_id == entry_func.blocks.back()))
3509 blk.return_value = rtn_id;
3510 }
3511 vars_needing_early_declaration.push_back(ib_var_id);
3512 }
3513 else
3514 {
3515 switch (get_execution_model())
3516 {
3517 case ExecutionModelVertex:
3518 case ExecutionModelTessellationEvaluation:
3519 // Instead of declaring a struct variable to hold the output and then
3520 // copying that to the output buffer, we'll declare the output variable
3521 // as a reference to the final output element in the buffer. Then we can
3522 // avoid the extra copy.
3523 entry_func.fixup_hooks_in.push_back([=]() {
3524 if (stage_out_var_id)
3525 {
3526 // The first member of the indirect buffer is always the number of vertices
3527 // to draw.
3528 // We zero-base the InstanceID & VertexID variables for HLSL emulation elsewhere, so don't do it twice
3529 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
3530 {
3531 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3532 " = ", output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3533 ".y * ", to_expression(builtin_stage_input_size_id), ".x + ",
3534 to_expression(builtin_invocation_id_id), ".x];");
3535 }
3536 else if (msl_options.enable_base_index_zero)
3537 {
3538 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3539 " = ", output_buffer_var_name, "[", to_expression(builtin_instance_idx_id),
3540 " * spvIndirectParams[0] + ", to_expression(builtin_vertex_idx_id), "];");
3541 }
3542 else
3543 {
3544 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3545 " = ", output_buffer_var_name, "[(", to_expression(builtin_instance_idx_id),
3546 " - ", to_expression(builtin_base_instance_id), ") * spvIndirectParams[0] + ",
3547 to_expression(builtin_vertex_idx_id), " - ",
3548 to_expression(builtin_base_vertex_id), "];");
3549 }
3550 }
3551 });
3552 break;
3553 case ExecutionModelTessellationControl:
3554 if (msl_options.multi_patch_workgroup)
3555 {
3556 // We cannot use PrimitiveId here, because the hook may not have run yet.
3557 if (patch)
3558 {
3559 entry_func.fixup_hooks_in.push_back([=]() {
3560 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3561 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_invocation_id_id),
3562 ".x / ", get_entry_point().output_vertices, "];");
3563 });
3564 }
3565 else
3566 {
3567 entry_func.fixup_hooks_in.push_back([=]() {
3568 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3569 output_buffer_var_name, "[", to_expression(builtin_invocation_id_id), ".x - ",
3570 to_expression(builtin_invocation_id_id), ".x % ",
3571 get_entry_point().output_vertices, "];");
3572 });
3573 }
3574 }
3575 else
3576 {
3577 if (patch)
3578 {
3579 entry_func.fixup_hooks_in.push_back([=]() {
3580 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "& ", ib_var_ref,
3581 " = ", patch_output_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
3582 "];");
3583 });
3584 }
3585 else
3586 {
3587 entry_func.fixup_hooks_in.push_back([=]() {
3588 statement("device ", to_name(ir.default_entry_point), "_", ib_var_ref, "* gl_out = &",
3589 output_buffer_var_name, "[", to_expression(builtin_primitive_id_id), " * ",
3590 get_entry_point().output_vertices, "];");
3591 });
3592 }
3593 }
3594 break;
3595 default:
3596 break;
3597 }
3598 }
3599 break;
3600 }
3601
3602 default:
3603 break;
3604 }
3605
3606 set_name(ib_type_id, to_name(ir.default_entry_point) + "_" + ib_var_ref);
3607 set_name(ib_var_id, ib_var_ref);
3608
3609 for (auto *p_var : vars)
3610 {
3611 bool strip_array =
3612 (get_execution_model() == ExecutionModelTessellationControl ||
3613 (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput)) &&
3614 !patch;
3615
3616 // Fixing up flattened stores in TESC is impossible since the memory is group shared either via
3617 // device (not masked) or threadgroup (masked) storage classes and it's race condition city.
3618 meta.strip_array = strip_array;
3619 meta.allow_local_declaration = !strip_array && !(get_execution_model() == ExecutionModelTessellationControl &&
3620 storage == StorageClassOutput);
3621 add_variable_to_interface_block(storage, ib_var_ref, ib_type, *p_var, meta);
3622 }
3623
3624 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup &&
3625 storage == StorageClassInput)
3626 {
3627 // For tessellation control inputs, add all outputs from the vertex shader to ensure
3628 // the struct containing them is the correct size and layout.
3629 for (auto &input : inputs_by_location)
3630 {
3631 if (location_inputs_in_use.count(input.first.location) != 0)
3632 continue;
3633
3634 // Create a fake variable to put at the location.
3635 uint32_t offset = ir.increase_bound_by(4);
3636 uint32_t type_id = offset;
3637 uint32_t array_type_id = offset + 1;
3638 uint32_t ptr_type_id = offset + 2;
3639 uint32_t var_id = offset + 3;
3640
3641 SPIRType type;
3642 switch (input.second.format)
3643 {
3644 case MSL_SHADER_INPUT_FORMAT_UINT16:
3645 case MSL_SHADER_INPUT_FORMAT_ANY16:
3646 type.basetype = SPIRType::UShort;
3647 type.width = 16;
3648 break;
3649 case MSL_SHADER_INPUT_FORMAT_ANY32:
3650 default:
3651 type.basetype = SPIRType::UInt;
3652 type.width = 32;
3653 break;
3654 }
3655 type.vecsize = input.second.vecsize;
3656 set<SPIRType>(type_id, type);
3657
3658 type.array.push_back(0);
3659 type.array_size_literal.push_back(true);
3660 type.parent_type = type_id;
3661 set<SPIRType>(array_type_id, type);
3662
3663 type.pointer = true;
3664 type.pointer_depth++;
3665 type.parent_type = array_type_id;
3666 type.storage = storage;
3667 auto &ptr_type = set<SPIRType>(ptr_type_id, type);
3668 ptr_type.self = array_type_id;
3669
3670 auto &fake_var = set<SPIRVariable>(var_id, ptr_type_id, storage);
3671 set_decoration(var_id, DecorationLocation, input.first.location);
3672 if (input.first.component)
3673 set_decoration(var_id, DecorationComponent, input.first.component);
3674
3675 meta.strip_array = true;
3676 meta.allow_local_declaration = false;
3677 add_variable_to_interface_block(storage, ib_var_ref, ib_type, fake_var, meta);
3678 }
3679 }
3680
3681 // When multiple variables need to access same location,
3682 // unroll locations one by one and we will flatten output or input as necessary.
3683 for (auto &loc : meta.location_meta)
3684 {
3685 uint32_t location = loc.first;
3686 auto &location_meta = loc.second;
3687
3688 uint32_t ib_mbr_idx = uint32_t(ib_type.member_types.size());
3689 uint32_t type_id = build_extended_vector_type(location_meta.base_type_id, location_meta.num_components);
3690 ib_type.member_types.push_back(type_id);
3691
3692 set_member_name(ib_type.self, ib_mbr_idx, join("m_location_", location));
3693 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationLocation, location);
3694 mark_location_as_used_by_shader(location, get<SPIRType>(type_id), storage);
3695
3696 if (location_meta.flat)
3697 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationFlat);
3698 if (location_meta.noperspective)
3699 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationNoPerspective);
3700 if (location_meta.centroid)
3701 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationCentroid);
3702 if (location_meta.sample)
3703 set_member_decoration(ib_type.self, ib_mbr_idx, DecorationSample);
3704 }
3705
3706 // Sort the members of the structure by their locations.
3707 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::LocationThenBuiltInType);
3708 member_sorter.sort();
3709
3710 // The member indices were saved to the original variables, but after the members
3711 // were sorted, those indices are now likely incorrect. Fix those up now.
3712 fix_up_interface_member_indices(storage, ib_type_id);
3713
3714 // For patch inputs, add one more member, holding the array of control point data.
3715 if (get_execution_model() == ExecutionModelTessellationEvaluation && storage == StorageClassInput && patch &&
3716 stage_in_var_id)
3717 {
3718 uint32_t pcp_type_id = ir.increase_bound_by(1);
3719 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3720 pcp_type.basetype = SPIRType::ControlPointArray;
3721 pcp_type.parent_type = pcp_type.type_alias = get_stage_in_struct_type().self;
3722 pcp_type.storage = storage;
3723 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3724 uint32_t mbr_idx = uint32_t(ib_type.member_types.size());
3725 ib_type.member_types.push_back(pcp_type_id);
3726 set_member_name(ib_type.self, mbr_idx, "gl_in");
3727 }
3728
3729 return ib_var_id;
3730}
3731
3732uint32_t CompilerMSL::add_interface_block_pointer(uint32_t ib_var_id, StorageClass storage)
3733{
3734 if (!ib_var_id)
3735 return 0;
3736
3737 uint32_t ib_ptr_var_id;
3738 uint32_t next_id = ir.increase_bound_by(3);
3739 auto &ib_type = expression_type(ib_var_id);
3740 if (get_execution_model() == ExecutionModelTessellationControl)
3741 {
3742 // Tessellation control per-vertex I/O is presented as an array, so we must
3743 // do the same with our struct here.
3744 uint32_t ib_ptr_type_id = next_id++;
3745 auto &ib_ptr_type = set<SPIRType>(ib_ptr_type_id, ib_type);
3746 ib_ptr_type.parent_type = ib_ptr_type.type_alias = ib_type.self;
3747 ib_ptr_type.pointer = true;
3748 ib_ptr_type.pointer_depth++;
3749 ib_ptr_type.storage =
3750 storage == StorageClassInput ?
3751 (msl_options.multi_patch_workgroup ? StorageClassStorageBuffer : StorageClassWorkgroup) :
3752 StorageClassStorageBuffer;
3753 ir.meta[ib_ptr_type_id] = ir.meta[ib_type.self];
3754 // To ensure that get_variable_data_type() doesn't strip off the pointer,
3755 // which we need, use another pointer.
3756 uint32_t ib_ptr_ptr_type_id = next_id++;
3757 auto &ib_ptr_ptr_type = set<SPIRType>(ib_ptr_ptr_type_id, ib_ptr_type);
3758 ib_ptr_ptr_type.parent_type = ib_ptr_type_id;
3759 ib_ptr_ptr_type.type_alias = ib_type.self;
3760 ib_ptr_ptr_type.storage = StorageClassFunction;
3761 ir.meta[ib_ptr_ptr_type_id] = ir.meta[ib_type.self];
3762
3763 ib_ptr_var_id = next_id;
3764 set<SPIRVariable>(ib_ptr_var_id, ib_ptr_ptr_type_id, StorageClassFunction, 0);
3765 set_name(ib_ptr_var_id, storage == StorageClassInput ? "gl_in" : "gl_out");
3766 }
3767 else
3768 {
3769 // Tessellation evaluation per-vertex inputs are also presented as arrays.
3770 // But, in Metal, this array uses a very special type, 'patch_control_point<T>',
3771 // which is a container that can be used to access the control point data.
3772 // To represent this, a special 'ControlPointArray' type has been added to the
3773 // SPIRV-Cross type system. It should only be generated by and seen in the MSL
3774 // backend (i.e. this one).
3775 uint32_t pcp_type_id = next_id++;
3776 auto &pcp_type = set<SPIRType>(pcp_type_id, ib_type);
3777 pcp_type.basetype = SPIRType::ControlPointArray;
3778 pcp_type.parent_type = pcp_type.type_alias = ib_type.self;
3779 pcp_type.storage = storage;
3780 ir.meta[pcp_type_id] = ir.meta[ib_type.self];
3781
3782 ib_ptr_var_id = next_id;
3783 set<SPIRVariable>(ib_ptr_var_id, pcp_type_id, storage, 0);
3784 set_name(ib_ptr_var_id, "gl_in");
3785 ir.meta[ib_ptr_var_id].decoration.qualified_alias = join(patch_stage_in_var_name, ".gl_in");
3786 }
3787 return ib_ptr_var_id;
3788}
3789
3790// Ensure that the type is compatible with the builtin.
3791// If it is, simply return the given type ID.
3792// Otherwise, create a new type, and return it's ID.
3793uint32_t CompilerMSL::ensure_correct_builtin_type(uint32_t type_id, BuiltIn builtin)
3794{
3795 auto &type = get<SPIRType>(type_id);
3796
3797 if ((builtin == BuiltInSampleMask && is_array(type)) ||
3798 ((builtin == BuiltInLayer || builtin == BuiltInViewportIndex || builtin == BuiltInFragStencilRefEXT) &&
3799 type.basetype != SPIRType::UInt))
3800 {
3801 uint32_t next_id = ir.increase_bound_by(type.pointer ? 2 : 1);
3802 uint32_t base_type_id = next_id++;
3803 auto &base_type = set<SPIRType>(base_type_id);
3804 base_type.basetype = SPIRType::UInt;
3805 base_type.width = 32;
3806
3807 if (!type.pointer)
3808 return base_type_id;
3809
3810 uint32_t ptr_type_id = next_id++;
3811 auto &ptr_type = set<SPIRType>(ptr_type_id);
3812 ptr_type = base_type;
3813 ptr_type.pointer = true;
3814 ptr_type.pointer_depth++;
3815 ptr_type.storage = type.storage;
3816 ptr_type.parent_type = base_type_id;
3817 return ptr_type_id;
3818 }
3819
3820 return type_id;
3821}
3822
3823// Ensure that the type is compatible with the shader input.
3824// If it is, simply return the given type ID.
3825// Otherwise, create a new type, and return its ID.
3826uint32_t CompilerMSL::ensure_correct_input_type(uint32_t type_id, uint32_t location, uint32_t component, uint32_t num_components, bool strip_array)
3827{
3828 auto &type = get<SPIRType>(type_id);
3829
3830 uint32_t max_array_dimensions = strip_array ? 1 : 0;
3831
3832 // Struct and array types must match exactly.
3833 if (type.basetype == SPIRType::Struct || type.array.size() > max_array_dimensions)
3834 return type_id;
3835
3836 auto p_va = inputs_by_location.find({location, component});
3837 if (p_va == end(inputs_by_location))
3838 {
3839 if (num_components > type.vecsize)
3840 return build_extended_vector_type(type_id, num_components);
3841 else
3842 return type_id;
3843 }
3844
3845 if (num_components == 0)
3846 num_components = p_va->second.vecsize;
3847
3848 switch (p_va->second.format)
3849 {
3850 case MSL_SHADER_INPUT_FORMAT_UINT8:
3851 {
3852 switch (type.basetype)
3853 {
3854 case SPIRType::UByte:
3855 case SPIRType::UShort:
3856 case SPIRType::UInt:
3857 if (num_components > type.vecsize)
3858 return build_extended_vector_type(type_id, num_components);
3859 else
3860 return type_id;
3861
3862 case SPIRType::Short:
3863 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3864 SPIRType::UShort);
3865 case SPIRType::Int:
3866 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3867 SPIRType::UInt);
3868
3869 default:
3870 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3871 }
3872 }
3873
3874 case MSL_SHADER_INPUT_FORMAT_UINT16:
3875 {
3876 switch (type.basetype)
3877 {
3878 case SPIRType::UShort:
3879 case SPIRType::UInt:
3880 if (num_components > type.vecsize)
3881 return build_extended_vector_type(type_id, num_components);
3882 else
3883 return type_id;
3884
3885 case SPIRType::Int:
3886 return build_extended_vector_type(type_id, num_components > type.vecsize ? num_components : type.vecsize,
3887 SPIRType::UInt);
3888
3889 default:
3890 SPIRV_CROSS_THROW("Vertex attribute type mismatch between host and shader");
3891 }
3892 }
3893
3894 default:
3895 if (num_components > type.vecsize)
3896 type_id = build_extended_vector_type(type_id, num_components);
3897 break;
3898 }
3899
3900 return type_id;
3901}
3902
3903void CompilerMSL::mark_struct_members_packed(const SPIRType &type)
3904{
3905 set_extended_decoration(type.self, SPIRVCrossDecorationPhysicalTypePacked);
3906
3907 // Problem case! Struct needs to be placed at an awkward alignment.
3908 // Mark every member of the child struct as packed.
3909 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3910 for (uint32_t i = 0; i < mbr_cnt; i++)
3911 {
3912 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3913 if (mbr_type.basetype == SPIRType::Struct)
3914 {
3915 // Recursively mark structs as packed.
3916 auto *struct_type = &mbr_type;
3917 while (!struct_type->array.empty())
3918 struct_type = &get<SPIRType>(struct_type->parent_type);
3919 mark_struct_members_packed(*struct_type);
3920 }
3921 else if (!is_scalar(mbr_type))
3922 set_extended_member_decoration(type.self, i, SPIRVCrossDecorationPhysicalTypePacked);
3923 }
3924}
3925
3926void CompilerMSL::mark_scalar_layout_structs(const SPIRType &type)
3927{
3928 uint32_t mbr_cnt = uint32_t(type.member_types.size());
3929 for (uint32_t i = 0; i < mbr_cnt; i++)
3930 {
3931 auto &mbr_type = get<SPIRType>(type.member_types[i]);
3932 if (mbr_type.basetype == SPIRType::Struct)
3933 {
3934 auto *struct_type = &mbr_type;
3935 while (!struct_type->array.empty())
3936 struct_type = &get<SPIRType>(struct_type->parent_type);
3937
3938 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPhysicalTypePacked))
3939 continue;
3940
3941 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, i);
3942 uint32_t msl_size = get_declared_struct_member_size_msl(type, i);
3943 uint32_t spirv_offset = type_struct_member_offset(type, i);
3944 uint32_t spirv_offset_next;
3945 if (i + 1 < mbr_cnt)
3946 spirv_offset_next = type_struct_member_offset(type, i + 1);
3947 else
3948 spirv_offset_next = spirv_offset + msl_size;
3949
3950 // Both are complicated cases. In scalar layout, a struct of float3 might just consume 12 bytes,
3951 // and the next member will be placed at offset 12.
3952 bool struct_is_misaligned = (spirv_offset % msl_alignment) != 0;
3953 bool struct_is_too_large = spirv_offset + msl_size > spirv_offset_next;
3954 uint32_t array_stride = 0;
3955 bool struct_needs_explicit_padding = false;
3956
3957 // Verify that if a struct is used as an array that ArrayStride matches the effective size of the struct.
3958 if (!mbr_type.array.empty())
3959 {
3960 array_stride = type_struct_member_array_stride(type, i);
3961 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
3962 for (uint32_t dim = 0; dim < dimensions; dim++)
3963 {
3964 uint32_t array_size = to_array_size_literal(mbr_type, dim);
3965 array_stride /= max(array_size, 1u);
3966 }
3967
3968 // Set expected struct size based on ArrayStride.
3969 struct_needs_explicit_padding = true;
3970
3971 // If struct size is larger than array stride, we might be able to fit, if we tightly pack.
3972 if (get_declared_struct_size_msl(*struct_type) > array_stride)
3973 struct_is_too_large = true;
3974 }
3975
3976 if (struct_is_misaligned || struct_is_too_large)
3977 mark_struct_members_packed(*struct_type);
3978 mark_scalar_layout_structs(*struct_type);
3979
3980 if (struct_needs_explicit_padding)
3981 {
3982 msl_size = get_declared_struct_size_msl(*struct_type, true, true);
3983 if (array_stride < msl_size)
3984 {
3985 SPIRV_CROSS_THROW("Cannot express an array stride smaller than size of struct type.");
3986 }
3987 else
3988 {
3989 if (has_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3990 {
3991 if (array_stride !=
3992 get_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget))
3993 SPIRV_CROSS_THROW(
3994 "A struct is used with different array strides. Cannot express this in MSL.");
3995 }
3996 else
3997 set_extended_decoration(struct_type->self, SPIRVCrossDecorationPaddingTarget, array_stride);
3998 }
3999 }
4000 }
4001 }
4002}
4003
4004// Sort the members of the struct type by offset, and pack and then pad members where needed
4005// to align MSL members with SPIR-V offsets. The struct members are iterated twice. Packing
4006// occurs first, followed by padding, because packing a member reduces both its size and its
4007// natural alignment, possibly requiring a padding member to be added ahead of it.
4008void CompilerMSL::align_struct(SPIRType &ib_type, unordered_set<uint32_t> &aligned_structs)
4009{
4010 // We align structs recursively, so stop any redundant work.
4011 ID &ib_type_id = ib_type.self;
4012 if (aligned_structs.count(ib_type_id))
4013 return;
4014 aligned_structs.insert(ib_type_id);
4015
4016 // Sort the members of the interface structure by their offset.
4017 // They should already be sorted per SPIR-V spec anyway.
4018 MemberSorter member_sorter(ib_type, ir.meta[ib_type_id], MemberSorter::Offset);
4019 member_sorter.sort();
4020
4021 auto mbr_cnt = uint32_t(ib_type.member_types.size());
4022
4023 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4024 {
4025 // Pack any dependent struct types before we pack a parent struct.
4026 auto &mbr_type = get<SPIRType>(ib_type.member_types[mbr_idx]);
4027 if (mbr_type.basetype == SPIRType::Struct)
4028 align_struct(mbr_type, aligned_structs);
4029 }
4030
4031 // Test the alignment of each member, and if a member should be closer to the previous
4032 // member than the default spacing expects, it is likely that the previous member is in
4033 // a packed format. If so, and the previous member is packable, pack it.
4034 // For example ... this applies to any 3-element vector that is followed by a scalar.
4035 uint32_t msl_offset = 0;
4036 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
4037 {
4038 // This checks the member in isolation, if the member needs some kind of type remapping to conform to SPIR-V
4039 // offsets, array strides and matrix strides.
4040 ensure_member_packing_rules_msl(ib_type, mbr_idx);
4041
4042 // Align current offset to the current member's default alignment. If the member was packed, it will observe
4043 // the updated alignment here.
4044 uint32_t msl_align_mask = get_declared_struct_member_alignment_msl(ib_type, mbr_idx) - 1;
4045 uint32_t aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4046
4047 // Fetch the member offset as declared in the SPIRV.
4048 uint32_t spirv_mbr_offset = get_member_decoration(ib_type_id, mbr_idx, DecorationOffset);
4049 if (spirv_mbr_offset > aligned_msl_offset)
4050 {
4051 // Since MSL and SPIR-V have slightly different struct member alignment and
4052 // size rules, we'll pad to standard C-packing rules with a char[] array. If the member is farther
4053 // away than C-packing, expects, add an inert padding member before the the member.
4054 uint32_t padding_bytes = spirv_mbr_offset - aligned_msl_offset;
4055 set_extended_member_decoration(ib_type_id, mbr_idx, SPIRVCrossDecorationPaddingTarget, padding_bytes);
4056
4057 // Re-align as a sanity check that aligning post-padding matches up.
4058 msl_offset += padding_bytes;
4059 aligned_msl_offset = (msl_offset + msl_align_mask) & ~msl_align_mask;
4060 }
4061 else if (spirv_mbr_offset < aligned_msl_offset)
4062 {
4063 // This should not happen, but deal with unexpected scenarios.
4064 // It *might* happen if a sub-struct has a larger alignment requirement in MSL than SPIR-V.
4065 SPIRV_CROSS_THROW("Cannot represent buffer block correctly in MSL.");
4066 }
4067
4068 assert(aligned_msl_offset == spirv_mbr_offset);
4069
4070 // Increment the current offset to be positioned immediately after the current member.
4071 // Don't do this for the last member since it can be unsized, and it is not relevant for padding purposes here.
4072 if (mbr_idx + 1 < mbr_cnt)
4073 msl_offset = aligned_msl_offset + get_declared_struct_member_size_msl(ib_type, mbr_idx);
4074 }
4075}
4076
4077bool CompilerMSL::validate_member_packing_rules_msl(const SPIRType &type, uint32_t index) const
4078{
4079 auto &mbr_type = get<SPIRType>(type.member_types[index]);
4080 uint32_t spirv_offset = get_member_decoration(type.self, index, DecorationOffset);
4081
4082 if (index + 1 < type.member_types.size())
4083 {
4084 // First, we will check offsets. If SPIR-V offset + MSL size > SPIR-V offset of next member,
4085 // we *must* perform some kind of remapping, no way getting around it.
4086 // We can always pad after this member if necessary, so that case is fine.
4087 uint32_t spirv_offset_next = get_member_decoration(type.self, index + 1, DecorationOffset);
4088 assert(spirv_offset_next >= spirv_offset);
4089 uint32_t maximum_size = spirv_offset_next - spirv_offset;
4090 uint32_t msl_mbr_size = get_declared_struct_member_size_msl(type, index);
4091 if (msl_mbr_size > maximum_size)
4092 return false;
4093 }
4094
4095 if (!mbr_type.array.empty())
4096 {
4097 // If we have an array type, array stride must match exactly with SPIR-V.
4098
4099 // An exception to this requirement is if we have one array element.
4100 // This comes from DX scalar layout workaround.
4101 // If app tries to be cheeky and access the member out of bounds, this will not work, but this is the best we can do.
4102 // In OpAccessChain with logical memory models, access chains must be in-bounds in SPIR-V specification.
4103 bool relax_array_stride = mbr_type.array.back() == 1 && mbr_type.array_size_literal.back();
4104
4105 if (!relax_array_stride)
4106 {
4107 uint32_t spirv_array_stride = type_struct_member_array_stride(type, index);
4108 uint32_t msl_array_stride = get_declared_struct_member_array_stride_msl(type, index);
4109 if (spirv_array_stride != msl_array_stride)
4110 return false;
4111 }
4112 }
4113
4114 if (is_matrix(mbr_type))
4115 {
4116 // Need to check MatrixStride as well.
4117 uint32_t spirv_matrix_stride = type_struct_member_matrix_stride(type, index);
4118 uint32_t msl_matrix_stride = get_declared_struct_member_matrix_stride_msl(type, index);
4119 if (spirv_matrix_stride != msl_matrix_stride)
4120 return false;
4121 }
4122
4123 // Now, we check alignment.
4124 uint32_t msl_alignment = get_declared_struct_member_alignment_msl(type, index);
4125 if ((spirv_offset % msl_alignment) != 0)
4126 return false;
4127
4128 // We're in the clear.
4129 return true;
4130}
4131
4132// Here we need to verify that the member type we declare conforms to Offset, ArrayStride or MatrixStride restrictions.
4133// If there is a mismatch, we need to emit remapped types, either normal types, or "packed_X" types.
4134// In odd cases we need to emit packed and remapped types, for e.g. weird matrices or arrays with weird array strides.
4135void CompilerMSL::ensure_member_packing_rules_msl(SPIRType &ib_type, uint32_t index)
4136{
4137 if (validate_member_packing_rules_msl(ib_type, index))
4138 return;
4139
4140 // We failed validation.
4141 // This case will be nightmare-ish to deal with. This could possibly happen if struct alignment does not quite
4142 // match up with what we want. Scalar block layout comes to mind here where we might have to work around the rule
4143 // that struct alignment == max alignment of all members and struct size depends on this alignment.
4144 auto &mbr_type = get<SPIRType>(ib_type.member_types[index]);
4145 if (mbr_type.basetype == SPIRType::Struct)
4146 SPIRV_CROSS_THROW("Cannot perform any repacking for structs when it is used as a member of another struct.");
4147
4148 // Perform remapping here.
4149 // There is nothing to be gained by using packed scalars, so don't attempt it.
4150 if (!is_scalar(ib_type))
4151 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4152
4153 // Try validating again, now with packed.
4154 if (validate_member_packing_rules_msl(ib_type, index))
4155 return;
4156
4157 // We're in deep trouble, and we need to create a new PhysicalType which matches up with what we expect.
4158 // A lot of work goes here ...
4159 // We will need remapping on Load and Store to translate the types between Logical and Physical.
4160
4161 // First, we check if we have small vector std140 array.
4162 // We detect this if we have an array of vectors, and array stride is greater than number of elements.
4163 if (!mbr_type.array.empty() && !is_matrix(mbr_type))
4164 {
4165 uint32_t array_stride = type_struct_member_array_stride(ib_type, index);
4166
4167 // Hack off array-of-arrays until we find the array stride per element we must have to make it work.
4168 uint32_t dimensions = uint32_t(mbr_type.array.size() - 1);
4169 for (uint32_t dim = 0; dim < dimensions; dim++)
4170 array_stride /= max(to_array_size_literal(mbr_type, dim), 1u);
4171
4172 uint32_t elems_per_stride = array_stride / (mbr_type.width / 8);
4173
4174 if (elems_per_stride == 3)
4175 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4176 else if (elems_per_stride > 4)
4177 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4178
4179 auto physical_type = mbr_type;
4180 physical_type.vecsize = elems_per_stride;
4181 physical_type.parent_type = 0;
4182 uint32_t type_id = ir.increase_bound_by(1);
4183 set<SPIRType>(type_id, physical_type);
4184 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
4185 set_decoration(type_id, DecorationArrayStride, array_stride);
4186
4187 // Remove packed_ for vectors of size 1, 2 and 4.
4188 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4189 }
4190 else if (is_matrix(mbr_type))
4191 {
4192 // MatrixStride might be std140-esque.
4193 uint32_t matrix_stride = type_struct_member_matrix_stride(ib_type, index);
4194
4195 uint32_t elems_per_stride = matrix_stride / (mbr_type.width / 8);
4196
4197 if (elems_per_stride == 3)
4198 SPIRV_CROSS_THROW("Cannot use ArrayStride of 3 elements in remapping scenarios.");
4199 else if (elems_per_stride > 4)
4200 SPIRV_CROSS_THROW("Cannot represent vectors with more than 4 elements in MSL.");
4201
4202 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
4203
4204 auto physical_type = mbr_type;
4205 physical_type.parent_type = 0;
4206 if (row_major)
4207 physical_type.columns = elems_per_stride;
4208 else
4209 physical_type.vecsize = elems_per_stride;
4210 uint32_t type_id = ir.increase_bound_by(1);
4211 set<SPIRType>(type_id, physical_type);
4212 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID, type_id);
4213
4214 // Remove packed_ for vectors of size 1, 2 and 4.
4215 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4216 }
4217 else
4218 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4219
4220 // Try validating again, now with physical type remapping.
4221 if (validate_member_packing_rules_msl(ib_type, index))
4222 return;
4223
4224 // We might have a particular odd scalar layout case where the last element of an array
4225 // does not take up as much space as the ArrayStride or MatrixStride. This can happen with DX cbuffers.
4226 // The "proper" workaround for this is extremely painful and essentially impossible in the edge case of float3[],
4227 // so we hack around it by declaring the offending array or matrix with one less array size/col/row,
4228 // and rely on padding to get the correct value. We will technically access arrays out of bounds into the padding region,
4229 // but it should spill over gracefully without too much trouble. We rely on behavior like this for unsized arrays anyways.
4230
4231 // E.g. we might observe a physical layout of:
4232 // { float2 a[2]; float b; } in cbuffer layout where ArrayStride of a is 16, but offset of b is 24, packed right after a[1] ...
4233 uint32_t type_id = get_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
4234 auto &type = get<SPIRType>(type_id);
4235
4236 // Modify the physical type in-place. This is safe since each physical type workaround is a copy.
4237 if (is_array(type))
4238 {
4239 if (type.array.back() > 1)
4240 {
4241 if (!type.array_size_literal.back())
4242 SPIRV_CROSS_THROW("Cannot apply scalar layout workaround with spec constant array size.");
4243 type.array.back() -= 1;
4244 }
4245 else
4246 {
4247 // We have an array of size 1, so we cannot decrement that. Our only option now is to
4248 // force a packed layout instead, and drop the physical type remap since ArrayStride is meaningless now.
4249 unset_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypeID);
4250 set_extended_member_decoration(ib_type.self, index, SPIRVCrossDecorationPhysicalTypePacked);
4251 }
4252 }
4253 else if (is_matrix(type))
4254 {
4255 bool row_major = has_member_decoration(ib_type.self, index, DecorationRowMajor);
4256 if (!row_major)
4257 {
4258 // Slice off one column. If we only have 2 columns, this might turn the matrix into a vector with one array element instead.
4259 if (type.columns > 2)
4260 {
4261 type.columns--;
4262 }
4263 else if (type.columns == 2)
4264 {
4265 type.columns = 1;
4266 assert(type.array.empty());
4267 type.array.push_back(1);
4268 type.array_size_literal.push_back(true);
4269 }
4270 }
4271 else
4272 {
4273 // Slice off one row. If we only have 2 rows, this might turn the matrix into a vector with one array element instead.
4274 if (type.vecsize > 2)
4275 {
4276 type.vecsize--;
4277 }
4278 else if (type.vecsize == 2)
4279 {
4280 type.vecsize = type.columns;
4281 type.columns = 1;
4282 assert(type.array.empty());
4283 type.array.push_back(1);
4284 type.array_size_literal.push_back(true);
4285 }
4286 }
4287 }
4288
4289 // This better validate now, or we must fail gracefully.
4290 if (!validate_member_packing_rules_msl(ib_type, index))
4291 SPIRV_CROSS_THROW("Found a buffer packing case which we cannot represent in MSL.");
4292}
4293
4294void CompilerMSL::emit_store_statement(uint32_t lhs_expression, uint32_t rhs_expression)
4295{
4296 auto &type = expression_type(rhs_expression);
4297
4298 bool lhs_remapped_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID);
4299 bool lhs_packed_type = has_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypePacked);
4300 auto *lhs_e = maybe_get<SPIRExpression>(lhs_expression);
4301 auto *rhs_e = maybe_get<SPIRExpression>(rhs_expression);
4302
4303 bool transpose = lhs_e && lhs_e->need_transpose;
4304
4305 // No physical type remapping, and no packed type, so can just emit a store directly.
4306 if (!lhs_remapped_type && !lhs_packed_type)
4307 {
4308 // We might not be dealing with remapped physical types or packed types,
4309 // but we might be doing a clean store to a row-major matrix.
4310 // In this case, we just flip transpose states, and emit the store, a transpose must be in the RHS expression, if any.
4311 if (is_matrix(type) && lhs_e && lhs_e->need_transpose)
4312 {
4313 lhs_e->need_transpose = false;
4314
4315 if (rhs_e && rhs_e->need_transpose)
4316 {
4317 // Direct copy, but might need to unpack RHS.
4318 // Skip the transpose, as we will transpose when writing to LHS and transpose(transpose(T)) == T.
4319 rhs_e->need_transpose = false;
4320 statement(to_expression(lhs_expression), " = ", to_unpacked_row_major_matrix_expression(rhs_expression),
4321 ";");
4322 rhs_e->need_transpose = true;
4323 }
4324 else
4325 statement(to_expression(lhs_expression), " = transpose(", to_unpacked_expression(rhs_expression), ");");
4326
4327 lhs_e->need_transpose = true;
4328 register_write(lhs_expression);
4329 }
4330 else if (lhs_e && lhs_e->need_transpose)
4331 {
4332 lhs_e->need_transpose = false;
4333
4334 // Storing a column to a row-major matrix. Unroll the write.
4335 for (uint32_t c = 0; c < type.vecsize; c++)
4336 {
4337 auto lhs_expr = to_dereferenced_expression(lhs_expression);
4338 auto column_index = lhs_expr.find_last_of('[');
4339 if (column_index != string::npos)
4340 {
4341 statement(lhs_expr.insert(column_index, join('[', c, ']')), " = ",
4342 to_extract_component_expression(rhs_expression, c), ";");
4343 }
4344 }
4345 lhs_e->need_transpose = true;
4346 register_write(lhs_expression);
4347 }
4348 else
4349 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
4350 }
4351 else if (!lhs_remapped_type && !is_matrix(type) && !transpose)
4352 {
4353 // Even if the target type is packed, we can directly store to it. We cannot store to packed matrices directly,
4354 // since they are declared as array of vectors instead, and we need the fallback path below.
4355 CompilerGLSL::emit_store_statement(lhs_expression, rhs_expression);
4356 }
4357 else
4358 {
4359 // Special handling when storing to a remapped physical type.
4360 // This is mostly to deal with std140 padded matrices or vectors.
4361
4362 TypeID physical_type_id = lhs_remapped_type ?
4363 ID(get_extended_decoration(lhs_expression, SPIRVCrossDecorationPhysicalTypeID)) :
4364 type.self;
4365
4366 auto &physical_type = get<SPIRType>(physical_type_id);
4367
4368 if (is_matrix(type))
4369 {
4370 const char *packed_pfx = lhs_packed_type ? "packed_" : "";
4371
4372 // Packed matrices are stored as arrays of packed vectors, so we need
4373 // to assign the vectors one at a time.
4374 // For row-major matrices, we need to transpose the *right-hand* side,
4375 // not the left-hand side.
4376
4377 // Lots of cases to cover here ...
4378
4379 bool rhs_transpose = rhs_e && rhs_e->need_transpose;
4380 SPIRType write_type = type;
4381 string cast_expr;
4382
4383 // We're dealing with transpose manually.
4384 if (rhs_transpose)
4385 rhs_e->need_transpose = false;
4386
4387 if (transpose)
4388 {
4389 // We're dealing with transpose manually.
4390 lhs_e->need_transpose = false;
4391 write_type.vecsize = type.columns;
4392 write_type.columns = 1;
4393
4394 if (physical_type.columns != type.columns)
4395 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4396
4397 if (rhs_transpose)
4398 {
4399 // If RHS is also transposed, we can just copy row by row.
4400 for (uint32_t i = 0; i < type.vecsize; i++)
4401 {
4402 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4403 to_unpacked_row_major_matrix_expression(rhs_expression), "[", i, "];");
4404 }
4405 }
4406 else
4407 {
4408 auto vector_type = expression_type(rhs_expression);
4409 vector_type.vecsize = vector_type.columns;
4410 vector_type.columns = 1;
4411
4412 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4413 // so pick out individual components instead.
4414 for (uint32_t i = 0; i < type.vecsize; i++)
4415 {
4416 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4417 for (uint32_t j = 0; j < vector_type.vecsize; j++)
4418 {
4419 rhs_row += join(to_enclosed_unpacked_expression(rhs_expression), "[", j, "][", i, "]");
4420 if (j + 1 < vector_type.vecsize)
4421 rhs_row += ", ";
4422 }
4423 rhs_row += ")";
4424
4425 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4426 }
4427 }
4428
4429 // We're dealing with transpose manually.
4430 lhs_e->need_transpose = true;
4431 }
4432 else
4433 {
4434 write_type.columns = 1;
4435
4436 if (physical_type.vecsize != type.vecsize)
4437 cast_expr = join("(device ", packed_pfx, type_to_glsl(write_type), "&)");
4438
4439 if (rhs_transpose)
4440 {
4441 auto vector_type = expression_type(rhs_expression);
4442 vector_type.columns = 1;
4443
4444 // Transpose on the fly. Emitting a lot of full transpose() ops and extracting lanes seems very bad,
4445 // so pick out individual components instead.
4446 for (uint32_t i = 0; i < type.columns; i++)
4447 {
4448 string rhs_row = type_to_glsl_constructor(vector_type) + "(";
4449 for (uint32_t j = 0; j < vector_type.vecsize; j++)
4450 {
4451 // Need to explicitly unpack expression since we've mucked with transpose state.
4452 auto unpacked_expr = to_unpacked_row_major_matrix_expression(rhs_expression);
4453 rhs_row += join(unpacked_expr, "[", j, "][", i, "]");
4454 if (j + 1 < vector_type.vecsize)
4455 rhs_row += ", ";
4456 }
4457 rhs_row += ")";
4458
4459 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ", rhs_row, ";");
4460 }
4461 }
4462 else
4463 {
4464 // Copy column-by-column.
4465 for (uint32_t i = 0; i < type.columns; i++)
4466 {
4467 statement(cast_expr, to_enclosed_expression(lhs_expression), "[", i, "]", " = ",
4468 to_enclosed_unpacked_expression(rhs_expression), "[", i, "];");
4469 }
4470 }
4471 }
4472
4473 // We're dealing with transpose manually.
4474 if (rhs_transpose)
4475 rhs_e->need_transpose = true;
4476 }
4477 else if (transpose)
4478 {
4479 lhs_e->need_transpose = false;
4480
4481 SPIRType write_type = type;
4482 write_type.vecsize = 1;
4483 write_type.columns = 1;
4484
4485 // Storing a column to a row-major matrix. Unroll the write.
4486 for (uint32_t c = 0; c < type.vecsize; c++)
4487 {
4488 auto lhs_expr = to_enclosed_expression(lhs_expression);
4489 auto column_index = lhs_expr.find_last_of('[');
4490 if (column_index != string::npos)
4491 {
4492 statement("((device ", type_to_glsl(write_type), "*)&",
4493 lhs_expr.insert(column_index, join('[', c, ']', ")")), " = ",
4494 to_extract_component_expression(rhs_expression, c), ";");
4495 }
4496 }
4497
4498 lhs_e->need_transpose = true;
4499 }
4500 else if ((is_matrix(physical_type) || is_array(physical_type)) && physical_type.vecsize > type.vecsize)
4501 {
4502 assert(type.vecsize >= 1 && type.vecsize <= 3);
4503
4504 // If we have packed types, we cannot use swizzled stores.
4505 // We could technically unroll the store for each element if needed.
4506 // When remapping to a std140 physical type, we always get float4,
4507 // and the packed decoration should always be removed.
4508 assert(!lhs_packed_type);
4509
4510 string lhs = to_dereferenced_expression(lhs_expression);
4511 string rhs = to_pointer_expression(rhs_expression);
4512
4513 // Unpack the expression so we can store to it with a float or float2.
4514 // It's still an l-value, so it's fine. Most other unpacking of expressions turn them into r-values instead.
4515 lhs = join("(device ", type_to_glsl(type), "&)", enclose_expression(lhs));
4516 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4517 statement(lhs, " = ", rhs, ";");
4518 }
4519 else if (!is_matrix(type))
4520 {
4521 string lhs = to_dereferenced_expression(lhs_expression);
4522 string rhs = to_pointer_expression(rhs_expression);
4523 if (!optimize_read_modify_write(expression_type(rhs_expression), lhs, rhs))
4524 statement(lhs, " = ", rhs, ";");
4525 }
4526
4527 register_write(lhs_expression);
4528 }
4529}
4530
4531static bool expression_ends_with(const string &expr_str, const std::string &ending)
4532{
4533 if (expr_str.length() >= ending.length())
4534 return (expr_str.compare(expr_str.length() - ending.length(), ending.length(), ending) == 0);
4535 else
4536 return false;
4537}
4538
4539// Converts the format of the current expression from packed to unpacked,
4540// by wrapping the expression in a constructor of the appropriate type.
4541// Also, handle special physical ID remapping scenarios, similar to emit_store_statement().
4542string CompilerMSL::unpack_expression_type(string expr_str, const SPIRType &type, uint32_t physical_type_id,
4543 bool packed, bool row_major)
4544{
4545 // Trivial case, nothing to do.
4546 if (physical_type_id == 0 && !packed)
4547 return expr_str;
4548
4549 const SPIRType *physical_type = nullptr;
4550 if (physical_type_id)
4551 physical_type = &get<SPIRType>(physical_type_id);
4552
4553 static const char *swizzle_lut[] = {
4554 ".x",
4555 ".xy",
4556 ".xyz",
4557 };
4558
4559 if (physical_type && is_vector(*physical_type) && is_array(*physical_type) &&
4560 physical_type->vecsize > type.vecsize && !expression_ends_with(expr_str, swizzle_lut[type.vecsize - 1]))
4561 {
4562 // std140 array cases for vectors.
4563 assert(type.vecsize >= 1 && type.vecsize <= 3);
4564 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4565 }
4566 else if (physical_type && is_matrix(*physical_type) && is_vector(type) && physical_type->vecsize > type.vecsize)
4567 {
4568 // Extract column from padded matrix.
4569 assert(type.vecsize >= 1 && type.vecsize <= 3);
4570 return enclose_expression(expr_str) + swizzle_lut[type.vecsize - 1];
4571 }
4572 else if (is_matrix(type))
4573 {
4574 // Packed matrices are stored as arrays of packed vectors. Unfortunately,
4575 // we can't just pass the array straight to the matrix constructor. We have to
4576 // pass each vector individually, so that they can be unpacked to normal vectors.
4577 if (!physical_type)
4578 physical_type = &type;
4579
4580 uint32_t vecsize = type.vecsize;
4581 uint32_t columns = type.columns;
4582 if (row_major)
4583 swap(vecsize, columns);
4584
4585 uint32_t physical_vecsize = row_major ? physical_type->columns : physical_type->vecsize;
4586
4587 const char *base_type = type.width == 16 ? "half" : "float";
4588 string unpack_expr = join(base_type, columns, "x", vecsize, "(");
4589
4590 const char *load_swiz = "";
4591
4592 if (physical_vecsize != vecsize)
4593 load_swiz = swizzle_lut[vecsize - 1];
4594
4595 for (uint32_t i = 0; i < columns; i++)
4596 {
4597 if (i > 0)
4598 unpack_expr += ", ";
4599
4600 if (packed)
4601 unpack_expr += join(base_type, physical_vecsize, "(", expr_str, "[", i, "]", ")", load_swiz);
4602 else
4603 unpack_expr += join(expr_str, "[", i, "]", load_swiz);
4604 }
4605
4606 unpack_expr += ")";
4607 return unpack_expr;
4608 }
4609 else
4610 {
4611 return join(type_to_glsl(type), "(", expr_str, ")");
4612 }
4613}
4614
4615// Emits the file header info
4616void CompilerMSL::emit_header()
4617{
4618 // This particular line can be overridden during compilation, so make it a flag and not a pragma line.
4619 if (suppress_missing_prototypes)
4620 statement("#pragma clang diagnostic ignored \"-Wmissing-prototypes\"");
4621
4622 // Disable warning about missing braces for array<T> template to make arrays a value type
4623 if (spv_function_implementations.count(SPVFuncImplUnsafeArray) != 0)
4624 statement("#pragma clang diagnostic ignored \"-Wmissing-braces\"");
4625
4626 for (auto &pragma : pragma_lines)
4627 statement(pragma);
4628
4629 if (!pragma_lines.empty() || suppress_missing_prototypes)
4630 statement("");
4631
4632 statement("#include <metal_stdlib>");
4633 statement("#include <simd/simd.h>");
4634
4635 for (auto &header : header_lines)
4636 statement(header);
4637
4638 statement("");
4639 statement("using namespace metal;");
4640 statement("");
4641
4642 for (auto &td : typedef_lines)
4643 statement(td);
4644
4645 if (!typedef_lines.empty())
4646 statement("");
4647}
4648
4649void CompilerMSL::add_pragma_line(const string &line)
4650{
4651 auto rslt = pragma_lines.insert(line);
4652 if (rslt.second)
4653 force_recompile();
4654}
4655
4656void CompilerMSL::add_typedef_line(const string &line)
4657{
4658 auto rslt = typedef_lines.insert(line);
4659 if (rslt.second)
4660 force_recompile();
4661}
4662
4663// Template struct like spvUnsafeArray<> need to be declared *before* any resources are declared
4664void CompilerMSL::emit_custom_templates()
4665{
4666 for (const auto &spv_func : spv_function_implementations)
4667 {
4668 switch (spv_func)
4669 {
4670 case SPVFuncImplUnsafeArray:
4671 statement("template<typename T, size_t Num>");
4672 statement("struct spvUnsafeArray");
4673 begin_scope();
4674 statement("T elements[Num ? Num : 1];");
4675 statement("");
4676 statement("thread T& operator [] (size_t pos) thread");
4677 begin_scope();
4678 statement("return elements[pos];");
4679 end_scope();
4680 statement("constexpr const thread T& operator [] (size_t pos) const thread");
4681 begin_scope();
4682 statement("return elements[pos];");
4683 end_scope();
4684 statement("");
4685 statement("device T& operator [] (size_t pos) device");
4686 begin_scope();
4687 statement("return elements[pos];");
4688 end_scope();
4689 statement("constexpr const device T& operator [] (size_t pos) const device");
4690 begin_scope();
4691 statement("return elements[pos];");
4692 end_scope();
4693 statement("");
4694 statement("constexpr const constant T& operator [] (size_t pos) const constant");
4695 begin_scope();
4696 statement("return elements[pos];");
4697 end_scope();
4698 statement("");
4699 statement("threadgroup T& operator [] (size_t pos) threadgroup");
4700 begin_scope();
4701 statement("return elements[pos];");
4702 end_scope();
4703 statement("constexpr const threadgroup T& operator [] (size_t pos) const threadgroup");
4704 begin_scope();
4705 statement("return elements[pos];");
4706 end_scope();
4707 end_scope_decl();
4708 statement("");
4709 break;
4710
4711 default:
4712 break;
4713 }
4714 }
4715}
4716
4717// Emits any needed custom function bodies.
4718// Metal helper functions must be static force-inline, i.e. static inline __attribute__((always_inline))
4719// otherwise they will cause problems when linked together in a single Metallib.
4720void CompilerMSL::emit_custom_functions()
4721{
4722 for (uint32_t i = kArrayCopyMultidimMax; i >= 2; i--)
4723 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i)))
4724 spv_function_implementations.insert(static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + i - 1));
4725
4726 if (spv_function_implementations.count(SPVFuncImplDynamicImageSampler))
4727 {
4728 // Unfortunately, this one needs a lot of the other functions to compile OK.
4729 if (!msl_options.supports_msl_version(2))
4730 SPIRV_CROSS_THROW(
4731 "spvDynamicImageSampler requires default-constructible texture objects, which require MSL 2.0.");
4732 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4733 spv_function_implementations.insert(SPVFuncImplTextureSwizzle);
4734 if (msl_options.swizzle_texture_samples)
4735 spv_function_implementations.insert(SPVFuncImplGatherSwizzle);
4736 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4737 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4738 spv_function_implementations.insert(static_cast<SPVFuncImpl>(i));
4739 spv_function_implementations.insert(SPVFuncImplExpandITUFullRange);
4740 spv_function_implementations.insert(SPVFuncImplExpandITUNarrowRange);
4741 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT709);
4742 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT601);
4743 spv_function_implementations.insert(SPVFuncImplConvertYCbCrBT2020);
4744 }
4745
4746 for (uint32_t i = SPVFuncImplChromaReconstructNearest2Plane;
4747 i <= SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane; i++)
4748 if (spv_function_implementations.count(static_cast<SPVFuncImpl>(i)))
4749 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4750
4751 if (spv_function_implementations.count(SPVFuncImplTextureSwizzle) ||
4752 spv_function_implementations.count(SPVFuncImplGatherSwizzle) ||
4753 spv_function_implementations.count(SPVFuncImplGatherCompareSwizzle))
4754 {
4755 spv_function_implementations.insert(SPVFuncImplForwardArgs);
4756 spv_function_implementations.insert(SPVFuncImplGetSwizzle);
4757 }
4758
4759 for (const auto &spv_func : spv_function_implementations)
4760 {
4761 switch (spv_func)
4762 {
4763 case SPVFuncImplMod:
4764 statement("// Implementation of the GLSL mod() function, which is slightly different than Metal fmod()");
4765 statement("template<typename Tx, typename Ty>");
4766 statement("inline Tx mod(Tx x, Ty y)");
4767 begin_scope();
4768 statement("return x - y * floor(x / y);");
4769 end_scope();
4770 statement("");
4771 break;
4772
4773 case SPVFuncImplRadians:
4774 statement("// Implementation of the GLSL radians() function");
4775 statement("template<typename T>");
4776 statement("inline T radians(T d)");
4777 begin_scope();
4778 statement("return d * T(0.01745329251);");
4779 end_scope();
4780 statement("");
4781 break;
4782
4783 case SPVFuncImplDegrees:
4784 statement("// Implementation of the GLSL degrees() function");
4785 statement("template<typename T>");
4786 statement("inline T degrees(T r)");
4787 begin_scope();
4788 statement("return r * T(57.2957795131);");
4789 end_scope();
4790 statement("");
4791 break;
4792
4793 case SPVFuncImplFindILsb:
4794 statement("// Implementation of the GLSL findLSB() function");
4795 statement("template<typename T>");
4796 statement("inline T spvFindLSB(T x)");
4797 begin_scope();
4798 statement("return select(ctz(x), T(-1), x == T(0));");
4799 end_scope();
4800 statement("");
4801 break;
4802
4803 case SPVFuncImplFindUMsb:
4804 statement("// Implementation of the unsigned GLSL findMSB() function");
4805 statement("template<typename T>");
4806 statement("inline T spvFindUMSB(T x)");
4807 begin_scope();
4808 statement("return select(clz(T(0)) - (clz(x) + T(1)), T(-1), x == T(0));");
4809 end_scope();
4810 statement("");
4811 break;
4812
4813 case SPVFuncImplFindSMsb:
4814 statement("// Implementation of the signed GLSL findMSB() function");
4815 statement("template<typename T>");
4816 statement("inline T spvFindSMSB(T x)");
4817 begin_scope();
4818 statement("T v = select(x, T(-1) - x, x < T(0));");
4819 statement("return select(clz(T(0)) - (clz(v) + T(1)), T(-1), v == T(0));");
4820 end_scope();
4821 statement("");
4822 break;
4823
4824 case SPVFuncImplSSign:
4825 statement("// Implementation of the GLSL sign() function for integer types");
4826 statement("template<typename T, typename E = typename enable_if<is_integral<T>::value>::type>");
4827 statement("inline T sign(T x)");
4828 begin_scope();
4829 statement("return select(select(select(x, T(0), x == T(0)), T(1), x > T(0)), T(-1), x < T(0));");
4830 end_scope();
4831 statement("");
4832 break;
4833
4834 case SPVFuncImplArrayCopy:
4835 case SPVFuncImplArrayOfArrayCopy2Dim:
4836 case SPVFuncImplArrayOfArrayCopy3Dim:
4837 case SPVFuncImplArrayOfArrayCopy4Dim:
4838 case SPVFuncImplArrayOfArrayCopy5Dim:
4839 case SPVFuncImplArrayOfArrayCopy6Dim:
4840 {
4841 // Unfortunately we cannot template on the address space, so combinatorial explosion it is.
4842 static const char *function_name_tags[] = {
4843 "FromConstantToStack", "FromConstantToThreadGroup", "FromStackToStack",
4844 "FromStackToThreadGroup", "FromThreadGroupToStack", "FromThreadGroupToThreadGroup",
4845 "FromDeviceToDevice", "FromConstantToDevice", "FromStackToDevice",
4846 "FromThreadGroupToDevice", "FromDeviceToStack", "FromDeviceToThreadGroup",
4847 };
4848
4849 static const char *src_address_space[] = {
4850 "constant", "constant", "thread const", "thread const",
4851 "threadgroup const", "threadgroup const", "device const", "constant",
4852 "thread const", "threadgroup const", "device const", "device const",
4853 };
4854
4855 static const char *dst_address_space[] = {
4856 "thread", "threadgroup", "thread", "threadgroup", "thread", "threadgroup",
4857 "device", "device", "device", "device", "thread", "threadgroup",
4858 };
4859
4860 for (uint32_t variant = 0; variant < 12; variant++)
4861 {
4862 uint32_t dimensions = spv_func - SPVFuncImplArrayCopyMultidimBase;
4863 string tmp = "template<typename T";
4864 for (uint8_t i = 0; i < dimensions; i++)
4865 {
4866 tmp += ", uint ";
4867 tmp += 'A' + i;
4868 }
4869 tmp += ">";
4870 statement(tmp);
4871
4872 string array_arg;
4873 for (uint8_t i = 0; i < dimensions; i++)
4874 {
4875 array_arg += "[";
4876 array_arg += 'A' + i;
4877 array_arg += "]";
4878 }
4879
4880 statement("inline void spvArrayCopy", function_name_tags[variant], dimensions, "(",
4881 dst_address_space[variant], " T (&dst)", array_arg, ", ", src_address_space[variant],
4882 " T (&src)", array_arg, ")");
4883
4884 begin_scope();
4885 statement("for (uint i = 0; i < A; i++)");
4886 begin_scope();
4887
4888 if (dimensions == 1)
4889 statement("dst[i] = src[i];");
4890 else
4891 statement("spvArrayCopy", function_name_tags[variant], dimensions - 1, "(dst[i], src[i]);");
4892 end_scope();
4893 end_scope();
4894 statement("");
4895 }
4896 break;
4897 }
4898
4899 // Support for Metal 2.1's new texture_buffer type.
4900 case SPVFuncImplTexelBufferCoords:
4901 {
4902 if (msl_options.texel_buffer_texture_width > 0)
4903 {
4904 string tex_width_str = convert_to_string(msl_options.texel_buffer_texture_width);
4905 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4906 statement(force_inline);
4907 statement("uint2 spvTexelBufferCoord(uint tc)");
4908 begin_scope();
4909 statement(join("return uint2(tc % ", tex_width_str, ", tc / ", tex_width_str, ");"));
4910 end_scope();
4911 statement("");
4912 }
4913 else
4914 {
4915 statement("// Returns 2D texture coords corresponding to 1D texel buffer coords");
4916 statement(
4917 "#define spvTexelBufferCoord(tc, tex) uint2((tc) % (tex).get_width(), (tc) / (tex).get_width())");
4918 statement("");
4919 }
4920 break;
4921 }
4922
4923 // Emulate texture2D atomic operations
4924 case SPVFuncImplImage2DAtomicCoords:
4925 {
4926 if (msl_options.supports_msl_version(1, 2))
4927 {
4928 statement("// The required alignment of a linear texture of R32Uint format.");
4929 statement("constant uint spvLinearTextureAlignmentOverride [[function_constant(",
4930 msl_options.r32ui_alignment_constant_id, ")]];");
4931 statement("constant uint spvLinearTextureAlignment = ",
4932 "is_function_constant_defined(spvLinearTextureAlignmentOverride) ? ",
4933 "spvLinearTextureAlignmentOverride : ", msl_options.r32ui_linear_texture_alignment, ";");
4934 }
4935 else
4936 {
4937 statement("// The required alignment of a linear texture of R32Uint format.");
4938 statement("constant uint spvLinearTextureAlignment = ", msl_options.r32ui_linear_texture_alignment,
4939 ";");
4940 }
4941 statement("// Returns buffer coords corresponding to 2D texture coords for emulating 2D texture atomics");
4942 statement("#define spvImage2DAtomicCoord(tc, tex) (((((tex).get_width() + ",
4943 " spvLinearTextureAlignment / 4 - 1) & ~(",
4944 " spvLinearTextureAlignment / 4 - 1)) * (tc).y) + (tc).x)");
4945 statement("");
4946 break;
4947 }
4948
4949 // "fadd" intrinsic support
4950 case SPVFuncImplFAdd:
4951 statement("template<typename T>");
4952 statement("[[clang::optnone]] T spvFAdd(T l, T r)");
4953 begin_scope();
4954 statement("return fma(T(1), l, r);");
4955 end_scope();
4956 statement("");
4957 break;
4958
4959 // "fsub" intrinsic support
4960 case SPVFuncImplFSub:
4961 statement("template<typename T>");
4962 statement("[[clang::optnone]] T spvFSub(T l, T r)");
4963 begin_scope();
4964 statement("return fma(T(-1), r, l);");
4965 end_scope();
4966 statement("");
4967 break;
4968
4969 // "fmul' intrinsic support
4970 case SPVFuncImplFMul:
4971 statement("template<typename T>");
4972 statement("[[clang::optnone]] T spvFMul(T l, T r)");
4973 begin_scope();
4974 statement("return fma(l, r, T(0));");
4975 end_scope();
4976 statement("");
4977
4978 statement("template<typename T, int Cols, int Rows>");
4979 statement("[[clang::optnone]] vec<T, Cols> spvFMulVectorMatrix(vec<T, Rows> v, matrix<T, Cols, Rows> m)");
4980 begin_scope();
4981 statement("vec<T, Cols> res = vec<T, Cols>(0);");
4982 statement("for (uint i = Rows; i > 0; --i)");
4983 begin_scope();
4984 statement("vec<T, Cols> tmp(0);");
4985 statement("for (uint j = 0; j < Cols; ++j)");
4986 begin_scope();
4987 statement("tmp[j] = m[j][i - 1];");
4988 end_scope();
4989 statement("res = fma(tmp, vec<T, Cols>(v[i - 1]), res);");
4990 end_scope();
4991 statement("return res;");
4992 end_scope();
4993 statement("");
4994
4995 statement("template<typename T, int Cols, int Rows>");
4996 statement("[[clang::optnone]] vec<T, Rows> spvFMulMatrixVector(matrix<T, Cols, Rows> m, vec<T, Cols> v)");
4997 begin_scope();
4998 statement("vec<T, Rows> res = vec<T, Rows>(0);");
4999 statement("for (uint i = Cols; i > 0; --i)");
5000 begin_scope();
5001 statement("res = fma(m[i - 1], vec<T, Rows>(v[i - 1]), res);");
5002 end_scope();
5003 statement("return res;");
5004 end_scope();
5005 statement("");
5006
5007 statement("template<typename T, int LCols, int LRows, int RCols, int RRows>");
5008 statement("[[clang::optnone]] matrix<T, RCols, LRows> spvFMulMatrixMatrix(matrix<T, LCols, LRows> l, matrix<T, RCols, RRows> r)");
5009 begin_scope();
5010 statement("matrix<T, RCols, LRows> res;");
5011 statement("for (uint i = 0; i < RCols; i++)");
5012 begin_scope();
5013 statement("vec<T, RCols> tmp(0);");
5014 statement("for (uint j = 0; j < LCols; j++)");
5015 begin_scope();
5016 statement("tmp = fma(vec<T, RCols>(r[i][j]), l[j], tmp);");
5017 end_scope();
5018 statement("res[i] = tmp;");
5019 end_scope();
5020 statement("return res;");
5021 end_scope();
5022 statement("");
5023 break;
5024
5025 case SPVFuncImplQuantizeToF16:
5026 // Ensure fast-math is disabled to match Vulkan results.
5027 // SpvHalfTypeSelector is used to match the half* template type to the float* template type.
5028 // Depending on GPU, MSL does not always flush converted subnormal halfs to zero,
5029 // as required by OpQuantizeToF16, so check for subnormals and flush them to zero.
5030 statement("template <typename F> struct SpvHalfTypeSelector;");
5031 statement("template <> struct SpvHalfTypeSelector<float> { public: using H = half; };");
5032 statement("template<uint N> struct SpvHalfTypeSelector<vec<float, N>> { using H = vec<half, N>; };");
5033 statement("template<typename F, typename H = typename SpvHalfTypeSelector<F>::H>");
5034 statement("[[clang::optnone]] F spvQuantizeToF16(F fval)");
5035 begin_scope();
5036 statement("H hval = H(fval);");
5037 statement("hval = select(copysign(H(0), hval), hval, isnormal(hval) || isinf(hval) || isnan(hval));");
5038 statement("return F(hval);");
5039 end_scope();
5040 statement("");
5041 break;
5042
5043 // Emulate texturecube_array with texture2d_array for iOS where this type is not available
5044 case SPVFuncImplCubemapTo2DArrayFace:
5045 statement(force_inline);
5046 statement("float3 spvCubemapTo2DArrayFace(float3 P)");
5047 begin_scope();
5048 statement("float3 Coords = abs(P.xyz);");
5049 statement("float CubeFace = 0;");
5050 statement("float ProjectionAxis = 0;");
5051 statement("float u = 0;");
5052 statement("float v = 0;");
5053 statement("if (Coords.x >= Coords.y && Coords.x >= Coords.z)");
5054 begin_scope();
5055 statement("CubeFace = P.x >= 0 ? 0 : 1;");
5056 statement("ProjectionAxis = Coords.x;");
5057 statement("u = P.x >= 0 ? -P.z : P.z;");
5058 statement("v = -P.y;");
5059 end_scope();
5060 statement("else if (Coords.y >= Coords.x && Coords.y >= Coords.z)");
5061 begin_scope();
5062 statement("CubeFace = P.y >= 0 ? 2 : 3;");
5063 statement("ProjectionAxis = Coords.y;");
5064 statement("u = P.x;");
5065 statement("v = P.y >= 0 ? P.z : -P.z;");
5066 end_scope();
5067 statement("else");
5068 begin_scope();
5069 statement("CubeFace = P.z >= 0 ? 4 : 5;");
5070 statement("ProjectionAxis = Coords.z;");
5071 statement("u = P.z >= 0 ? P.x : -P.x;");
5072 statement("v = -P.y;");
5073 end_scope();
5074 statement("u = 0.5 * (u/ProjectionAxis + 1);");
5075 statement("v = 0.5 * (v/ProjectionAxis + 1);");
5076 statement("return float3(u, v, CubeFace);");
5077 end_scope();
5078 statement("");
5079 break;
5080
5081 case SPVFuncImplInverse4x4:
5082 statement("// Returns the determinant of a 2x2 matrix.");
5083 statement(force_inline);
5084 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
5085 begin_scope();
5086 statement("return a1 * b2 - b1 * a2;");
5087 end_scope();
5088 statement("");
5089
5090 statement("// Returns the determinant of a 3x3 matrix.");
5091 statement(force_inline);
5092 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
5093 "float c2, float c3)");
5094 begin_scope();
5095 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * spvDet2x2(a2, a3, "
5096 "b2, b3);");
5097 end_scope();
5098 statement("");
5099 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5100 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5101 statement(force_inline);
5102 statement("float4x4 spvInverse4x4(float4x4 m)");
5103 begin_scope();
5104 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
5105 statement_no_indent("");
5106 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5107 statement("adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
5108 "m[3][3]);");
5109 statement("adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
5110 "m[3][3]);");
5111 statement("adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
5112 "m[3][3]);");
5113 statement("adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
5114 "m[2][3]);");
5115 statement_no_indent("");
5116 statement("adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
5117 "m[3][3]);");
5118 statement("adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
5119 "m[3][3]);");
5120 statement("adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
5121 "m[3][3]);");
5122 statement("adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
5123 "m[2][3]);");
5124 statement_no_indent("");
5125 statement("adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
5126 "m[3][3]);");
5127 statement("adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
5128 "m[3][3]);");
5129 statement("adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
5130 "m[3][3]);");
5131 statement("adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
5132 "m[2][3]);");
5133 statement_no_indent("");
5134 statement("adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
5135 "m[3][2]);");
5136 statement("adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
5137 "m[3][2]);");
5138 statement("adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
5139 "m[3][2]);");
5140 statement("adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
5141 "m[2][2]);");
5142 statement_no_indent("");
5143 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5144 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
5145 "* m[3][0]);");
5146 statement_no_indent("");
5147 statement("// Divide the classical adjoint matrix by the determinant.");
5148 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5149 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5150 end_scope();
5151 statement("");
5152 break;
5153
5154 case SPVFuncImplInverse3x3:
5155 if (spv_function_implementations.count(SPVFuncImplInverse4x4) == 0)
5156 {
5157 statement("// Returns the determinant of a 2x2 matrix.");
5158 statement(force_inline);
5159 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
5160 begin_scope();
5161 statement("return a1 * b2 - b1 * a2;");
5162 end_scope();
5163 statement("");
5164 }
5165
5166 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5167 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5168 statement(force_inline);
5169 statement("float3x3 spvInverse3x3(float3x3 m)");
5170 begin_scope();
5171 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
5172 statement_no_indent("");
5173 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5174 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
5175 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
5176 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
5177 statement_no_indent("");
5178 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
5179 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
5180 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
5181 statement_no_indent("");
5182 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
5183 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
5184 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
5185 statement_no_indent("");
5186 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5187 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
5188 statement_no_indent("");
5189 statement("// Divide the classical adjoint matrix by the determinant.");
5190 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5191 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5192 end_scope();
5193 statement("");
5194 break;
5195
5196 case SPVFuncImplInverse2x2:
5197 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
5198 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
5199 statement(force_inline);
5200 statement("float2x2 spvInverse2x2(float2x2 m)");
5201 begin_scope();
5202 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
5203 statement_no_indent("");
5204 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
5205 statement("adj[0][0] = m[1][1];");
5206 statement("adj[0][1] = -m[0][1];");
5207 statement_no_indent("");
5208 statement("adj[1][0] = -m[1][0];");
5209 statement("adj[1][1] = m[0][0];");
5210 statement_no_indent("");
5211 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
5212 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
5213 statement_no_indent("");
5214 statement("// Divide the classical adjoint matrix by the determinant.");
5215 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
5216 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
5217 end_scope();
5218 statement("");
5219 break;
5220
5221 case SPVFuncImplForwardArgs:
5222 statement("template<typename T> struct spvRemoveReference { typedef T type; };");
5223 statement("template<typename T> struct spvRemoveReference<thread T&> { typedef T type; };");
5224 statement("template<typename T> struct spvRemoveReference<thread T&&> { typedef T type; };");
5225 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
5226 "spvRemoveReference<T>::type& x)");
5227 begin_scope();
5228 statement("return static_cast<thread T&&>(x);");
5229 end_scope();
5230 statement("template<typename T> inline constexpr thread T&& spvForward(thread typename "
5231 "spvRemoveReference<T>::type&& x)");
5232 begin_scope();
5233 statement("return static_cast<thread T&&>(x);");
5234 end_scope();
5235 statement("");
5236 break;
5237
5238 case SPVFuncImplGetSwizzle:
5239 statement("enum class spvSwizzle : uint");
5240 begin_scope();
5241 statement("none = 0,");
5242 statement("zero,");
5243 statement("one,");
5244 statement("red,");
5245 statement("green,");
5246 statement("blue,");
5247 statement("alpha");
5248 end_scope_decl();
5249 statement("");
5250 statement("template<typename T>");
5251 statement("inline T spvGetSwizzle(vec<T, 4> x, T c, spvSwizzle s)");
5252 begin_scope();
5253 statement("switch (s)");
5254 begin_scope();
5255 statement("case spvSwizzle::none:");
5256 statement(" return c;");
5257 statement("case spvSwizzle::zero:");
5258 statement(" return 0;");
5259 statement("case spvSwizzle::one:");
5260 statement(" return 1;");
5261 statement("case spvSwizzle::red:");
5262 statement(" return x.r;");
5263 statement("case spvSwizzle::green:");
5264 statement(" return x.g;");
5265 statement("case spvSwizzle::blue:");
5266 statement(" return x.b;");
5267 statement("case spvSwizzle::alpha:");
5268 statement(" return x.a;");
5269 end_scope();
5270 end_scope();
5271 statement("");
5272 break;
5273
5274 case SPVFuncImplTextureSwizzle:
5275 statement("// Wrapper function that swizzles texture samples and fetches.");
5276 statement("template<typename T>");
5277 statement("inline vec<T, 4> spvTextureSwizzle(vec<T, 4> x, uint s)");
5278 begin_scope();
5279 statement("if (!s)");
5280 statement(" return x;");
5281 statement("return vec<T, 4>(spvGetSwizzle(x, x.r, spvSwizzle((s >> 0) & 0xFF)), "
5282 "spvGetSwizzle(x, x.g, spvSwizzle((s >> 8) & 0xFF)), spvGetSwizzle(x, x.b, spvSwizzle((s >> 16) "
5283 "& 0xFF)), "
5284 "spvGetSwizzle(x, x.a, spvSwizzle((s >> 24) & 0xFF)));");
5285 end_scope();
5286 statement("");
5287 statement("template<typename T>");
5288 statement("inline T spvTextureSwizzle(T x, uint s)");
5289 begin_scope();
5290 statement("return spvTextureSwizzle(vec<T, 4>(x, 0, 0, 1), s).x;");
5291 end_scope();
5292 statement("");
5293 break;
5294
5295 case SPVFuncImplGatherSwizzle:
5296 statement("// Wrapper function that swizzles texture gathers.");
5297 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
5298 "typename... Ts>");
5299 statement("inline vec<T, 4> spvGatherSwizzle(const thread Tex<T>& t, sampler s, "
5300 "uint sw, component c, Ts... params) METAL_CONST_ARG(c)");
5301 begin_scope();
5302 statement("if (sw)");
5303 begin_scope();
5304 statement("switch (spvSwizzle((sw >> (uint(c) * 8)) & 0xFF))");
5305 begin_scope();
5306 statement("case spvSwizzle::none:");
5307 statement(" break;");
5308 statement("case spvSwizzle::zero:");
5309 statement(" return vec<T, 4>(0, 0, 0, 0);");
5310 statement("case spvSwizzle::one:");
5311 statement(" return vec<T, 4>(1, 1, 1, 1);");
5312 statement("case spvSwizzle::red:");
5313 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
5314 statement("case spvSwizzle::green:");
5315 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
5316 statement("case spvSwizzle::blue:");
5317 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
5318 statement("case spvSwizzle::alpha:");
5319 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
5320 end_scope();
5321 end_scope();
5322 // texture::gather insists on its component parameter being a constant
5323 // expression, so we need this silly workaround just to compile the shader.
5324 statement("switch (c)");
5325 begin_scope();
5326 statement("case component::x:");
5327 statement(" return t.gather(s, spvForward<Ts>(params)..., component::x);");
5328 statement("case component::y:");
5329 statement(" return t.gather(s, spvForward<Ts>(params)..., component::y);");
5330 statement("case component::z:");
5331 statement(" return t.gather(s, spvForward<Ts>(params)..., component::z);");
5332 statement("case component::w:");
5333 statement(" return t.gather(s, spvForward<Ts>(params)..., component::w);");
5334 end_scope();
5335 end_scope();
5336 statement("");
5337 break;
5338
5339 case SPVFuncImplGatherCompareSwizzle:
5340 statement("// Wrapper function that swizzles depth texture gathers.");
5341 statement("template<typename T, template<typename, access = access::sample, typename = void> class Tex, "
5342 "typename... Ts>");
5343 statement("inline vec<T, 4> spvGatherCompareSwizzle(const thread Tex<T>& t, sampler "
5344 "s, uint sw, Ts... params) ");
5345 begin_scope();
5346 statement("if (sw)");
5347 begin_scope();
5348 statement("switch (spvSwizzle(sw & 0xFF))");
5349 begin_scope();
5350 statement("case spvSwizzle::none:");
5351 statement("case spvSwizzle::red:");
5352 statement(" break;");
5353 statement("case spvSwizzle::zero:");
5354 statement("case spvSwizzle::green:");
5355 statement("case spvSwizzle::blue:");
5356 statement("case spvSwizzle::alpha:");
5357 statement(" return vec<T, 4>(0, 0, 0, 0);");
5358 statement("case spvSwizzle::one:");
5359 statement(" return vec<T, 4>(1, 1, 1, 1);");
5360 end_scope();
5361 end_scope();
5362 statement("return t.gather_compare(s, spvForward<Ts>(params)...);");
5363 end_scope();
5364 statement("");
5365 break;
5366
5367 case SPVFuncImplSubgroupBroadcast:
5368 // Metal doesn't allow broadcasting boolean values directly, but we can work around that by broadcasting
5369 // them as integers.
5370 statement("template<typename T>");
5371 statement("inline T spvSubgroupBroadcast(T value, ushort lane)");
5372 begin_scope();
5373 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5374 statement("return quad_broadcast(value, lane);");
5375 else
5376 statement("return simd_broadcast(value, lane);");
5377 end_scope();
5378 statement("");
5379 statement("template<>");
5380 statement("inline bool spvSubgroupBroadcast(bool value, ushort lane)");
5381 begin_scope();
5382 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5383 statement("return !!quad_broadcast((ushort)value, lane);");
5384 else
5385 statement("return !!simd_broadcast((ushort)value, lane);");
5386 end_scope();
5387 statement("");
5388 statement("template<uint N>");
5389 statement("inline vec<bool, N> spvSubgroupBroadcast(vec<bool, N> value, ushort lane)");
5390 begin_scope();
5391 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5392 statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5393 else
5394 statement("return (vec<bool, N>)simd_broadcast((vec<ushort, N>)value, lane);");
5395 end_scope();
5396 statement("");
5397 break;
5398
5399 case SPVFuncImplSubgroupBroadcastFirst:
5400 statement("template<typename T>");
5401 statement("inline T spvSubgroupBroadcastFirst(T value)");
5402 begin_scope();
5403 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5404 statement("return quad_broadcast_first(value);");
5405 else
5406 statement("return simd_broadcast_first(value);");
5407 end_scope();
5408 statement("");
5409 statement("template<>");
5410 statement("inline bool spvSubgroupBroadcastFirst(bool value)");
5411 begin_scope();
5412 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5413 statement("return !!quad_broadcast_first((ushort)value);");
5414 else
5415 statement("return !!simd_broadcast_first((ushort)value);");
5416 end_scope();
5417 statement("");
5418 statement("template<uint N>");
5419 statement("inline vec<bool, N> spvSubgroupBroadcastFirst(vec<bool, N> value)");
5420 begin_scope();
5421 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5422 statement("return (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value);");
5423 else
5424 statement("return (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value);");
5425 end_scope();
5426 statement("");
5427 break;
5428
5429 case SPVFuncImplSubgroupBallot:
5430 statement("inline uint4 spvSubgroupBallot(bool value)");
5431 begin_scope();
5432 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5433 {
5434 statement("return uint4((quad_vote::vote_t)quad_ballot(value), 0, 0, 0);");
5435 }
5436 else if (msl_options.is_ios())
5437 {
5438 // The current simd_vote on iOS uses a 32-bit integer-like object.
5439 statement("return uint4((simd_vote::vote_t)simd_ballot(value), 0, 0, 0);");
5440 }
5441 else
5442 {
5443 statement("simd_vote vote = simd_ballot(value);");
5444 statement("// simd_ballot() returns a 64-bit integer-like object, but");
5445 statement("// SPIR-V callers expect a uint4. We must convert.");
5446 statement("// FIXME: This won't include higher bits if Apple ever supports");
5447 statement("// 128 lanes in an SIMD-group.");
5448 statement("return uint4(as_type<uint2>((simd_vote::vote_t)vote), 0, 0);");
5449 }
5450 end_scope();
5451 statement("");
5452 break;
5453
5454 case SPVFuncImplSubgroupBallotBitExtract:
5455 statement("inline bool spvSubgroupBallotBitExtract(uint4 ballot, uint bit)");
5456 begin_scope();
5457 statement("return !!extract_bits(ballot[bit / 32], bit % 32, 1);");
5458 end_scope();
5459 statement("");
5460 break;
5461
5462 case SPVFuncImplSubgroupBallotFindLSB:
5463 statement("inline uint spvSubgroupBallotFindLSB(uint4 ballot, uint gl_SubgroupSize)");
5464 begin_scope();
5465 if (msl_options.is_ios())
5466 {
5467 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5468 }
5469 else
5470 {
5471 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5472 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5473 }
5474 statement("ballot &= mask;");
5475 statement("return select(ctz(ballot.x), select(32 + ctz(ballot.y), select(64 + ctz(ballot.z), select(96 + "
5476 "ctz(ballot.w), uint(-1), ballot.w == 0), ballot.z == 0), ballot.y == 0), ballot.x == 0);");
5477 end_scope();
5478 statement("");
5479 break;
5480
5481 case SPVFuncImplSubgroupBallotFindMSB:
5482 statement("inline uint spvSubgroupBallotFindMSB(uint4 ballot, uint gl_SubgroupSize)");
5483 begin_scope();
5484 if (msl_options.is_ios())
5485 {
5486 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5487 }
5488 else
5489 {
5490 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5491 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5492 }
5493 statement("ballot &= mask;");
5494 statement("return select(128 - (clz(ballot.w) + 1), select(96 - (clz(ballot.z) + 1), select(64 - "
5495 "(clz(ballot.y) + 1), select(32 - (clz(ballot.x) + 1), uint(-1), ballot.x == 0), ballot.y == 0), "
5496 "ballot.z == 0), ballot.w == 0);");
5497 end_scope();
5498 statement("");
5499 break;
5500
5501 case SPVFuncImplSubgroupBallotBitCount:
5502 statement("inline uint spvPopCount4(uint4 ballot)");
5503 begin_scope();
5504 statement("return popcount(ballot.x) + popcount(ballot.y) + popcount(ballot.z) + popcount(ballot.w);");
5505 end_scope();
5506 statement("");
5507 statement("inline uint spvSubgroupBallotBitCount(uint4 ballot, uint gl_SubgroupSize)");
5508 begin_scope();
5509 if (msl_options.is_ios())
5510 {
5511 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupSize), uint3(0));");
5512 }
5513 else
5514 {
5515 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupSize, 32u)), "
5516 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupSize - 32, 0)), uint2(0));");
5517 }
5518 statement("return spvPopCount4(ballot & mask);");
5519 end_scope();
5520 statement("");
5521 statement("inline uint spvSubgroupBallotInclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5522 begin_scope();
5523 if (msl_options.is_ios())
5524 {
5525 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID + 1), uint3(0));");
5526 }
5527 else
5528 {
5529 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID + 1, 32u)), "
5530 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID + 1 - 32, 0)), "
5531 "uint2(0));");
5532 }
5533 statement("return spvPopCount4(ballot & mask);");
5534 end_scope();
5535 statement("");
5536 statement("inline uint spvSubgroupBallotExclusiveBitCount(uint4 ballot, uint gl_SubgroupInvocationID)");
5537 begin_scope();
5538 if (msl_options.is_ios())
5539 {
5540 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, gl_SubgroupInvocationID), uint2(0));");
5541 }
5542 else
5543 {
5544 statement("uint4 mask = uint4(extract_bits(0xFFFFFFFF, 0, min(gl_SubgroupInvocationID, 32u)), "
5545 "extract_bits(0xFFFFFFFF, 0, (uint)max((int)gl_SubgroupInvocationID - 32, 0)), uint2(0));");
5546 }
5547 statement("return spvPopCount4(ballot & mask);");
5548 end_scope();
5549 statement("");
5550 break;
5551
5552 case SPVFuncImplSubgroupAllEqual:
5553 // Metal doesn't provide a function to evaluate this directly. But, we can
5554 // implement this by comparing every thread's value to one thread's value
5555 // (in this case, the value of the first active thread). Then, by the transitive
5556 // property of equality, if all comparisons return true, then they are all equal.
5557 statement("template<typename T>");
5558 statement("inline bool spvSubgroupAllEqual(T value)");
5559 begin_scope();
5560 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5561 statement("return quad_all(all(value == quad_broadcast_first(value)));");
5562 else
5563 statement("return simd_all(all(value == simd_broadcast_first(value)));");
5564 end_scope();
5565 statement("");
5566 statement("template<>");
5567 statement("inline bool spvSubgroupAllEqual(bool value)");
5568 begin_scope();
5569 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5570 statement("return quad_all(value) || !quad_any(value);");
5571 else
5572 statement("return simd_all(value) || !simd_any(value);");
5573 end_scope();
5574 statement("");
5575 statement("template<uint N>");
5576 statement("inline bool spvSubgroupAllEqual(vec<bool, N> value)");
5577 begin_scope();
5578 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5579 statement("return quad_all(all(value == (vec<bool, N>)quad_broadcast_first((vec<ushort, N>)value)));");
5580 else
5581 statement("return simd_all(all(value == (vec<bool, N>)simd_broadcast_first((vec<ushort, N>)value)));");
5582 end_scope();
5583 statement("");
5584 break;
5585
5586 case SPVFuncImplSubgroupShuffle:
5587 statement("template<typename T>");
5588 statement("inline T spvSubgroupShuffle(T value, ushort lane)");
5589 begin_scope();
5590 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5591 statement("return quad_shuffle(value, lane);");
5592 else
5593 statement("return simd_shuffle(value, lane);");
5594 end_scope();
5595 statement("");
5596 statement("template<>");
5597 statement("inline bool spvSubgroupShuffle(bool value, ushort lane)");
5598 begin_scope();
5599 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5600 statement("return !!quad_shuffle((ushort)value, lane);");
5601 else
5602 statement("return !!simd_shuffle((ushort)value, lane);");
5603 end_scope();
5604 statement("");
5605 statement("template<uint N>");
5606 statement("inline vec<bool, N> spvSubgroupShuffle(vec<bool, N> value, ushort lane)");
5607 begin_scope();
5608 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5609 statement("return (vec<bool, N>)quad_shuffle((vec<ushort, N>)value, lane);");
5610 else
5611 statement("return (vec<bool, N>)simd_shuffle((vec<ushort, N>)value, lane);");
5612 end_scope();
5613 statement("");
5614 break;
5615
5616 case SPVFuncImplSubgroupShuffleXor:
5617 statement("template<typename T>");
5618 statement("inline T spvSubgroupShuffleXor(T value, ushort mask)");
5619 begin_scope();
5620 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5621 statement("return quad_shuffle_xor(value, mask);");
5622 else
5623 statement("return simd_shuffle_xor(value, mask);");
5624 end_scope();
5625 statement("");
5626 statement("template<>");
5627 statement("inline bool spvSubgroupShuffleXor(bool value, ushort mask)");
5628 begin_scope();
5629 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5630 statement("return !!quad_shuffle_xor((ushort)value, mask);");
5631 else
5632 statement("return !!simd_shuffle_xor((ushort)value, mask);");
5633 end_scope();
5634 statement("");
5635 statement("template<uint N>");
5636 statement("inline vec<bool, N> spvSubgroupShuffleXor(vec<bool, N> value, ushort mask)");
5637 begin_scope();
5638 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5639 statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, mask);");
5640 else
5641 statement("return (vec<bool, N>)simd_shuffle_xor((vec<ushort, N>)value, mask);");
5642 end_scope();
5643 statement("");
5644 break;
5645
5646 case SPVFuncImplSubgroupShuffleUp:
5647 statement("template<typename T>");
5648 statement("inline T spvSubgroupShuffleUp(T value, ushort delta)");
5649 begin_scope();
5650 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5651 statement("return quad_shuffle_up(value, delta);");
5652 else
5653 statement("return simd_shuffle_up(value, delta);");
5654 end_scope();
5655 statement("");
5656 statement("template<>");
5657 statement("inline bool spvSubgroupShuffleUp(bool value, ushort delta)");
5658 begin_scope();
5659 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5660 statement("return !!quad_shuffle_up((ushort)value, delta);");
5661 else
5662 statement("return !!simd_shuffle_up((ushort)value, delta);");
5663 end_scope();
5664 statement("");
5665 statement("template<uint N>");
5666 statement("inline vec<bool, N> spvSubgroupShuffleUp(vec<bool, N> value, ushort delta)");
5667 begin_scope();
5668 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5669 statement("return (vec<bool, N>)quad_shuffle_up((vec<ushort, N>)value, delta);");
5670 else
5671 statement("return (vec<bool, N>)simd_shuffle_up((vec<ushort, N>)value, delta);");
5672 end_scope();
5673 statement("");
5674 break;
5675
5676 case SPVFuncImplSubgroupShuffleDown:
5677 statement("template<typename T>");
5678 statement("inline T spvSubgroupShuffleDown(T value, ushort delta)");
5679 begin_scope();
5680 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5681 statement("return quad_shuffle_down(value, delta);");
5682 else
5683 statement("return simd_shuffle_down(value, delta);");
5684 end_scope();
5685 statement("");
5686 statement("template<>");
5687 statement("inline bool spvSubgroupShuffleDown(bool value, ushort delta)");
5688 begin_scope();
5689 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5690 statement("return !!quad_shuffle_down((ushort)value, delta);");
5691 else
5692 statement("return !!simd_shuffle_down((ushort)value, delta);");
5693 end_scope();
5694 statement("");
5695 statement("template<uint N>");
5696 statement("inline vec<bool, N> spvSubgroupShuffleDown(vec<bool, N> value, ushort delta)");
5697 begin_scope();
5698 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
5699 statement("return (vec<bool, N>)quad_shuffle_down((vec<ushort, N>)value, delta);");
5700 else
5701 statement("return (vec<bool, N>)simd_shuffle_down((vec<ushort, N>)value, delta);");
5702 end_scope();
5703 statement("");
5704 break;
5705
5706 case SPVFuncImplQuadBroadcast:
5707 statement("template<typename T>");
5708 statement("inline T spvQuadBroadcast(T value, uint lane)");
5709 begin_scope();
5710 statement("return quad_broadcast(value, lane);");
5711 end_scope();
5712 statement("");
5713 statement("template<>");
5714 statement("inline bool spvQuadBroadcast(bool value, uint lane)");
5715 begin_scope();
5716 statement("return !!quad_broadcast((ushort)value, lane);");
5717 end_scope();
5718 statement("");
5719 statement("template<uint N>");
5720 statement("inline vec<bool, N> spvQuadBroadcast(vec<bool, N> value, uint lane)");
5721 begin_scope();
5722 statement("return (vec<bool, N>)quad_broadcast((vec<ushort, N>)value, lane);");
5723 end_scope();
5724 statement("");
5725 break;
5726
5727 case SPVFuncImplQuadSwap:
5728 // We can implement this easily based on the following table giving
5729 // the target lane ID from the direction and current lane ID:
5730 // Direction
5731 // | 0 | 1 | 2 |
5732 // ---+---+---+---+
5733 // L 0 | 1 2 3
5734 // a 1 | 0 3 2
5735 // n 2 | 3 0 1
5736 // e 3 | 2 1 0
5737 // Notice that target = source ^ (direction + 1).
5738 statement("template<typename T>");
5739 statement("inline T spvQuadSwap(T value, uint dir)");
5740 begin_scope();
5741 statement("return quad_shuffle_xor(value, dir + 1);");
5742 end_scope();
5743 statement("");
5744 statement("template<>");
5745 statement("inline bool spvQuadSwap(bool value, uint dir)");
5746 begin_scope();
5747 statement("return !!quad_shuffle_xor((ushort)value, dir + 1);");
5748 end_scope();
5749 statement("");
5750 statement("template<uint N>");
5751 statement("inline vec<bool, N> spvQuadSwap(vec<bool, N> value, uint dir)");
5752 begin_scope();
5753 statement("return (vec<bool, N>)quad_shuffle_xor((vec<ushort, N>)value, dir + 1);");
5754 end_scope();
5755 statement("");
5756 break;
5757
5758 case SPVFuncImplReflectScalar:
5759 // Metal does not support scalar versions of these functions.
5760 // Ensure fast-math is disabled to match Vulkan results.
5761 statement("template<typename T>");
5762 statement("[[clang::optnone]] T spvReflect(T i, T n)");
5763 begin_scope();
5764 statement("return i - T(2) * i * n * n;");
5765 end_scope();
5766 statement("");
5767 break;
5768
5769 case SPVFuncImplRefractScalar:
5770 // Metal does not support scalar versions of these functions.
5771 statement("template<typename T>");
5772 statement("inline T spvRefract(T i, T n, T eta)");
5773 begin_scope();
5774 statement("T NoI = n * i;");
5775 statement("T NoI2 = NoI * NoI;");
5776 statement("T k = T(1) - eta * eta * (T(1) - NoI2);");
5777 statement("if (k < T(0))");
5778 begin_scope();
5779 statement("return T(0);");
5780 end_scope();
5781 statement("else");
5782 begin_scope();
5783 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
5784 end_scope();
5785 end_scope();
5786 statement("");
5787 break;
5788
5789 case SPVFuncImplFaceForwardScalar:
5790 // Metal does not support scalar versions of these functions.
5791 statement("template<typename T>");
5792 statement("inline T spvFaceForward(T n, T i, T nref)");
5793 begin_scope();
5794 statement("return i * nref < T(0) ? n : -n;");
5795 end_scope();
5796 statement("");
5797 break;
5798
5799 case SPVFuncImplChromaReconstructNearest2Plane:
5800 statement("template<typename T, typename... LodOptions>");
5801 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, sampler "
5802 "samp, float2 coord, LodOptions... options)");
5803 begin_scope();
5804 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5805 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5806 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5807 statement("return ycbcr;");
5808 end_scope();
5809 statement("");
5810 break;
5811
5812 case SPVFuncImplChromaReconstructNearest3Plane:
5813 statement("template<typename T, typename... LodOptions>");
5814 statement("inline vec<T, 4> spvChromaReconstructNearest(texture2d<T> plane0, texture2d<T> plane1, "
5815 "texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5816 begin_scope();
5817 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5818 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5819 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5820 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5821 statement("return ycbcr;");
5822 end_scope();
5823 statement("");
5824 break;
5825
5826 case SPVFuncImplChromaReconstructLinear422CositedEven2Plane:
5827 statement("template<typename T, typename... LodOptions>");
5828 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5829 "plane1, sampler samp, float2 coord, LodOptions... options)");
5830 begin_scope();
5831 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5832 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5833 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5834 begin_scope();
5835 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5836 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).rg);");
5837 end_scope();
5838 statement("else");
5839 begin_scope();
5840 statement("ycbcr.br = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).rg;");
5841 end_scope();
5842 statement("return ycbcr;");
5843 end_scope();
5844 statement("");
5845 break;
5846
5847 case SPVFuncImplChromaReconstructLinear422CositedEven3Plane:
5848 statement("template<typename T, typename... LodOptions>");
5849 statement("inline vec<T, 4> spvChromaReconstructLinear422CositedEven(texture2d<T> plane0, texture2d<T> "
5850 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5851 begin_scope();
5852 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5853 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5854 statement("if (fract(coord.x * plane1.get_width()) != 0.0)");
5855 begin_scope();
5856 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5857 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5858 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5859 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), 0.5).r);");
5860 end_scope();
5861 statement("else");
5862 begin_scope();
5863 statement("ycbcr.b = plane1.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5864 statement("ycbcr.r = plane2.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5865 end_scope();
5866 statement("return ycbcr;");
5867 end_scope();
5868 statement("");
5869 break;
5870
5871 case SPVFuncImplChromaReconstructLinear422Midpoint2Plane:
5872 statement("template<typename T, typename... LodOptions>");
5873 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5874 "plane1, sampler samp, float2 coord, LodOptions... options)");
5875 begin_scope();
5876 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5877 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5878 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5879 statement("ycbcr.br = vec<T, 2>(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5880 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).rg);");
5881 statement("return ycbcr;");
5882 end_scope();
5883 statement("");
5884 break;
5885
5886 case SPVFuncImplChromaReconstructLinear422Midpoint3Plane:
5887 statement("template<typename T, typename... LodOptions>");
5888 statement("inline vec<T, 4> spvChromaReconstructLinear422Midpoint(texture2d<T> plane0, texture2d<T> "
5889 "plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5890 begin_scope();
5891 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5892 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5893 statement("int2 offs = int2(fract(coord.x * plane1.get_width()) != 0.0 ? 1 : -1, 0);");
5894 statement("ycbcr.b = T(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5895 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5896 statement("ycbcr.r = T(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5897 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., offs), 0.25).r);");
5898 statement("return ycbcr;");
5899 end_scope();
5900 statement("");
5901 break;
5902
5903 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane:
5904 statement("template<typename T, typename... LodOptions>");
5905 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5906 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5907 begin_scope();
5908 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5909 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5910 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5911 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5912 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5913 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5914 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5915 statement("return ycbcr;");
5916 end_scope();
5917 statement("");
5918 break;
5919
5920 case SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane:
5921 statement("template<typename T, typename... LodOptions>");
5922 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYCositedEven(texture2d<T> plane0, "
5923 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5924 begin_scope();
5925 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5926 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5927 statement("float2 ab = fract(round(coord * float2(plane0.get_width(), plane0.get_height())) * 0.5);");
5928 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5929 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5930 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5931 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5932 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5933 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5934 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5935 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5936 statement("return ycbcr;");
5937 end_scope();
5938 statement("");
5939 break;
5940
5941 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane:
5942 statement("template<typename T, typename... LodOptions>");
5943 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5944 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5945 begin_scope();
5946 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5947 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5948 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5949 "0)) * 0.5);");
5950 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5951 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5952 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5953 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5954 statement("return ycbcr;");
5955 end_scope();
5956 statement("");
5957 break;
5958
5959 case SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane:
5960 statement("template<typename T, typename... LodOptions>");
5961 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYCositedEven(texture2d<T> plane0, "
5962 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
5963 begin_scope();
5964 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5965 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5966 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
5967 "0)) * 0.5);");
5968 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5969 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5970 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5971 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5972 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
5973 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5974 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5975 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
5976 statement("return ycbcr;");
5977 end_scope();
5978 statement("");
5979 break;
5980
5981 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane:
5982 statement("template<typename T, typename... LodOptions>");
5983 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
5984 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
5985 begin_scope();
5986 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
5987 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
5988 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
5989 "0.5)) * 0.5);");
5990 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
5991 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
5992 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
5993 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
5994 statement("return ycbcr;");
5995 end_scope();
5996 statement("");
5997 break;
5998
5999 case SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane:
6000 statement("template<typename T, typename... LodOptions>");
6001 statement("inline vec<T, 4> spvChromaReconstructLinear420XCositedEvenYMidpoint(texture2d<T> plane0, "
6002 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6003 begin_scope();
6004 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6005 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6006 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0, "
6007 "0.5)) * 0.5);");
6008 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6009 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6010 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6011 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6012 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6013 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6014 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6015 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6016 statement("return ycbcr;");
6017 end_scope();
6018 statement("");
6019 break;
6020
6021 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane:
6022 statement("template<typename T, typename... LodOptions>");
6023 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
6024 "texture2d<T> plane1, sampler samp, float2 coord, LodOptions... options)");
6025 begin_scope();
6026 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6027 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6028 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6029 "0.5)) * 0.5);");
6030 statement("ycbcr.br = vec<T, 2>(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6031 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6032 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6033 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).rg);");
6034 statement("return ycbcr;");
6035 end_scope();
6036 statement("");
6037 break;
6038
6039 case SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane:
6040 statement("template<typename T, typename... LodOptions>");
6041 statement("inline vec<T, 4> spvChromaReconstructLinear420XMidpointYMidpoint(texture2d<T> plane0, "
6042 "texture2d<T> plane1, texture2d<T> plane2, sampler samp, float2 coord, LodOptions... options)");
6043 begin_scope();
6044 statement("vec<T, 4> ycbcr = vec<T, 4>(0, 0, 0, 1);");
6045 statement("ycbcr.g = plane0.sample(samp, coord, spvForward<LodOptions>(options)...).r;");
6046 statement("float2 ab = fract((round(coord * float2(plane0.get_width(), plane0.get_height())) - float2(0.5, "
6047 "0.5)) * 0.5);");
6048 statement("ycbcr.b = T(mix(mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)...), "
6049 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6050 "mix(plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6051 "plane1.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6052 statement("ycbcr.r = T(mix(mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)...), "
6053 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 0)), ab.x), "
6054 "mix(plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(0, 1)), "
6055 "plane2.sample(samp, coord, spvForward<LodOptions>(options)..., int2(1, 1)), ab.x), ab.y).r);");
6056 statement("return ycbcr;");
6057 end_scope();
6058 statement("");
6059 break;
6060
6061 case SPVFuncImplExpandITUFullRange:
6062 statement("template<typename T>");
6063 statement("inline vec<T, 4> spvExpandITUFullRange(vec<T, 4> ycbcr, int n)");
6064 begin_scope();
6065 statement("ycbcr.br -= exp2(T(n-1))/(exp2(T(n))-1);");
6066 statement("return ycbcr;");
6067 end_scope();
6068 statement("");
6069 break;
6070
6071 case SPVFuncImplExpandITUNarrowRange:
6072 statement("template<typename T>");
6073 statement("inline vec<T, 4> spvExpandITUNarrowRange(vec<T, 4> ycbcr, int n)");
6074 begin_scope();
6075 statement("ycbcr.g = (ycbcr.g * (exp2(T(n)) - 1) - ldexp(T(16), n - 8))/ldexp(T(219), n - 8);");
6076 statement("ycbcr.br = (ycbcr.br * (exp2(T(n)) - 1) - ldexp(T(128), n - 8))/ldexp(T(224), n - 8);");
6077 statement("return ycbcr;");
6078 end_scope();
6079 statement("");
6080 break;
6081
6082 case SPVFuncImplConvertYCbCrBT709:
6083 statement("// cf. Khronos Data Format Specification, section 15.1.1");
6084 statement("constant float3x3 spvBT709Factors = {{1, 1, 1}, {0, -0.13397432/0.7152, 1.8556}, {1.5748, "
6085 "-0.33480248/0.7152, 0}};");
6086 statement("");
6087 statement("template<typename T>");
6088 statement("inline vec<T, 4> spvConvertYCbCrBT709(vec<T, 4> ycbcr)");
6089 begin_scope();
6090 statement("vec<T, 4> rgba;");
6091 statement("rgba.rgb = vec<T, 3>(spvBT709Factors * ycbcr.gbr);");
6092 statement("rgba.a = ycbcr.a;");
6093 statement("return rgba;");
6094 end_scope();
6095 statement("");
6096 break;
6097
6098 case SPVFuncImplConvertYCbCrBT601:
6099 statement("// cf. Khronos Data Format Specification, section 15.1.2");
6100 statement("constant float3x3 spvBT601Factors = {{1, 1, 1}, {0, -0.202008/0.587, 1.772}, {1.402, "
6101 "-0.419198/0.587, 0}};");
6102 statement("");
6103 statement("template<typename T>");
6104 statement("inline vec<T, 4> spvConvertYCbCrBT601(vec<T, 4> ycbcr)");
6105 begin_scope();
6106 statement("vec<T, 4> rgba;");
6107 statement("rgba.rgb = vec<T, 3>(spvBT601Factors * ycbcr.gbr);");
6108 statement("rgba.a = ycbcr.a;");
6109 statement("return rgba;");
6110 end_scope();
6111 statement("");
6112 break;
6113
6114 case SPVFuncImplConvertYCbCrBT2020:
6115 statement("// cf. Khronos Data Format Specification, section 15.1.3");
6116 statement("constant float3x3 spvBT2020Factors = {{1, 1, 1}, {0, -0.11156702/0.6780, 1.8814}, {1.4746, "
6117 "-0.38737742/0.6780, 0}};");
6118 statement("");
6119 statement("template<typename T>");
6120 statement("inline vec<T, 4> spvConvertYCbCrBT2020(vec<T, 4> ycbcr)");
6121 begin_scope();
6122 statement("vec<T, 4> rgba;");
6123 statement("rgba.rgb = vec<T, 3>(spvBT2020Factors * ycbcr.gbr);");
6124 statement("rgba.a = ycbcr.a;");
6125 statement("return rgba;");
6126 end_scope();
6127 statement("");
6128 break;
6129
6130 case SPVFuncImplDynamicImageSampler:
6131 statement("enum class spvFormatResolution");
6132 begin_scope();
6133 statement("_444 = 0,");
6134 statement("_422,");
6135 statement("_420");
6136 end_scope_decl();
6137 statement("");
6138 statement("enum class spvChromaFilter");
6139 begin_scope();
6140 statement("nearest = 0,");
6141 statement("linear");
6142 end_scope_decl();
6143 statement("");
6144 statement("enum class spvXChromaLocation");
6145 begin_scope();
6146 statement("cosited_even = 0,");
6147 statement("midpoint");
6148 end_scope_decl();
6149 statement("");
6150 statement("enum class spvYChromaLocation");
6151 begin_scope();
6152 statement("cosited_even = 0,");
6153 statement("midpoint");
6154 end_scope_decl();
6155 statement("");
6156 statement("enum class spvYCbCrModelConversion");
6157 begin_scope();
6158 statement("rgb_identity = 0,");
6159 statement("ycbcr_identity,");
6160 statement("ycbcr_bt_709,");
6161 statement("ycbcr_bt_601,");
6162 statement("ycbcr_bt_2020");
6163 end_scope_decl();
6164 statement("");
6165 statement("enum class spvYCbCrRange");
6166 begin_scope();
6167 statement("itu_full = 0,");
6168 statement("itu_narrow");
6169 end_scope_decl();
6170 statement("");
6171 statement("struct spvComponentBits");
6172 begin_scope();
6173 statement("constexpr explicit spvComponentBits(int v) thread : value(v) {}");
6174 statement("uchar value : 6;");
6175 end_scope_decl();
6176 statement("// A class corresponding to metal::sampler which holds sampler");
6177 statement("// Y'CbCr conversion info.");
6178 statement("struct spvYCbCrSampler");
6179 begin_scope();
6180 statement("constexpr spvYCbCrSampler() thread : val(build()) {}");
6181 statement("template<typename... Ts>");
6182 statement("constexpr spvYCbCrSampler(Ts... t) thread : val(build(t...)) {}");
6183 statement("constexpr spvYCbCrSampler(const thread spvYCbCrSampler& s) thread = default;");
6184 statement("");
6185 statement("spvFormatResolution get_resolution() const thread");
6186 begin_scope();
6187 statement("return spvFormatResolution((val & resolution_mask) >> resolution_base);");
6188 end_scope();
6189 statement("spvChromaFilter get_chroma_filter() const thread");
6190 begin_scope();
6191 statement("return spvChromaFilter((val & chroma_filter_mask) >> chroma_filter_base);");
6192 end_scope();
6193 statement("spvXChromaLocation get_x_chroma_offset() const thread");
6194 begin_scope();
6195 statement("return spvXChromaLocation((val & x_chroma_off_mask) >> x_chroma_off_base);");
6196 end_scope();
6197 statement("spvYChromaLocation get_y_chroma_offset() const thread");
6198 begin_scope();
6199 statement("return spvYChromaLocation((val & y_chroma_off_mask) >> y_chroma_off_base);");
6200 end_scope();
6201 statement("spvYCbCrModelConversion get_ycbcr_model() const thread");
6202 begin_scope();
6203 statement("return spvYCbCrModelConversion((val & ycbcr_model_mask) >> ycbcr_model_base);");
6204 end_scope();
6205 statement("spvYCbCrRange get_ycbcr_range() const thread");
6206 begin_scope();
6207 statement("return spvYCbCrRange((val & ycbcr_range_mask) >> ycbcr_range_base);");
6208 end_scope();
6209 statement("int get_bpc() const thread { return (val & bpc_mask) >> bpc_base; }");
6210 statement("");
6211 statement("private:");
6212 statement("ushort val;");
6213 statement("");
6214 statement("constexpr static constant ushort resolution_bits = 2;");
6215 statement("constexpr static constant ushort chroma_filter_bits = 2;");
6216 statement("constexpr static constant ushort x_chroma_off_bit = 1;");
6217 statement("constexpr static constant ushort y_chroma_off_bit = 1;");
6218 statement("constexpr static constant ushort ycbcr_model_bits = 3;");
6219 statement("constexpr static constant ushort ycbcr_range_bit = 1;");
6220 statement("constexpr static constant ushort bpc_bits = 6;");
6221 statement("");
6222 statement("constexpr static constant ushort resolution_base = 0;");
6223 statement("constexpr static constant ushort chroma_filter_base = 2;");
6224 statement("constexpr static constant ushort x_chroma_off_base = 4;");
6225 statement("constexpr static constant ushort y_chroma_off_base = 5;");
6226 statement("constexpr static constant ushort ycbcr_model_base = 6;");
6227 statement("constexpr static constant ushort ycbcr_range_base = 9;");
6228 statement("constexpr static constant ushort bpc_base = 10;");
6229 statement("");
6230 statement(
6231 "constexpr static constant ushort resolution_mask = ((1 << resolution_bits) - 1) << resolution_base;");
6232 statement("constexpr static constant ushort chroma_filter_mask = ((1 << chroma_filter_bits) - 1) << "
6233 "chroma_filter_base;");
6234 statement("constexpr static constant ushort x_chroma_off_mask = ((1 << x_chroma_off_bit) - 1) << "
6235 "x_chroma_off_base;");
6236 statement("constexpr static constant ushort y_chroma_off_mask = ((1 << y_chroma_off_bit) - 1) << "
6237 "y_chroma_off_base;");
6238 statement("constexpr static constant ushort ycbcr_model_mask = ((1 << ycbcr_model_bits) - 1) << "
6239 "ycbcr_model_base;");
6240 statement("constexpr static constant ushort ycbcr_range_mask = ((1 << ycbcr_range_bit) - 1) << "
6241 "ycbcr_range_base;");
6242 statement("constexpr static constant ushort bpc_mask = ((1 << bpc_bits) - 1) << bpc_base;");
6243 statement("");
6244 statement("static constexpr ushort build()");
6245 begin_scope();
6246 statement("return 0;");
6247 end_scope();
6248 statement("");
6249 statement("template<typename... Ts>");
6250 statement("static constexpr ushort build(spvFormatResolution res, Ts... t)");
6251 begin_scope();
6252 statement("return (ushort(res) << resolution_base) | (build(t...) & ~resolution_mask);");
6253 end_scope();
6254 statement("");
6255 statement("template<typename... Ts>");
6256 statement("static constexpr ushort build(spvChromaFilter filt, Ts... t)");
6257 begin_scope();
6258 statement("return (ushort(filt) << chroma_filter_base) | (build(t...) & ~chroma_filter_mask);");
6259 end_scope();
6260 statement("");
6261 statement("template<typename... Ts>");
6262 statement("static constexpr ushort build(spvXChromaLocation loc, Ts... t)");
6263 begin_scope();
6264 statement("return (ushort(loc) << x_chroma_off_base) | (build(t...) & ~x_chroma_off_mask);");
6265 end_scope();
6266 statement("");
6267 statement("template<typename... Ts>");
6268 statement("static constexpr ushort build(spvYChromaLocation loc, Ts... t)");
6269 begin_scope();
6270 statement("return (ushort(loc) << y_chroma_off_base) | (build(t...) & ~y_chroma_off_mask);");
6271 end_scope();
6272 statement("");
6273 statement("template<typename... Ts>");
6274 statement("static constexpr ushort build(spvYCbCrModelConversion model, Ts... t)");
6275 begin_scope();
6276 statement("return (ushort(model) << ycbcr_model_base) | (build(t...) & ~ycbcr_model_mask);");
6277 end_scope();
6278 statement("");
6279 statement("template<typename... Ts>");
6280 statement("static constexpr ushort build(spvYCbCrRange range, Ts... t)");
6281 begin_scope();
6282 statement("return (ushort(range) << ycbcr_range_base) | (build(t...) & ~ycbcr_range_mask);");
6283 end_scope();
6284 statement("");
6285 statement("template<typename... Ts>");
6286 statement("static constexpr ushort build(spvComponentBits bpc, Ts... t)");
6287 begin_scope();
6288 statement("return (ushort(bpc.value) << bpc_base) | (build(t...) & ~bpc_mask);");
6289 end_scope();
6290 end_scope_decl();
6291 statement("");
6292 statement("// A class which can hold up to three textures and a sampler, including");
6293 statement("// Y'CbCr conversion info, used to pass combined image-samplers");
6294 statement("// dynamically to functions.");
6295 statement("template<typename T>");
6296 statement("struct spvDynamicImageSampler");
6297 begin_scope();
6298 statement("texture2d<T> plane0;");
6299 statement("texture2d<T> plane1;");
6300 statement("texture2d<T> plane2;");
6301 statement("sampler samp;");
6302 statement("spvYCbCrSampler ycbcr_samp;");
6303 statement("uint swizzle = 0;");
6304 statement("");
6305 if (msl_options.swizzle_texture_samples)
6306 {
6307 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, uint sw) thread :");
6308 statement(" plane0(tex), samp(samp), swizzle(sw) {}");
6309 }
6310 else
6311 {
6312 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp) thread :");
6313 statement(" plane0(tex), samp(samp) {}");
6314 }
6315 statement("constexpr spvDynamicImageSampler(texture2d<T> tex, sampler samp, spvYCbCrSampler ycbcr_samp, "
6316 "uint sw) thread :");
6317 statement(" plane0(tex), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
6318 statement("constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1,");
6319 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
6320 statement(" plane0(plane0), plane1(plane1), samp(samp), ycbcr_samp(ycbcr_samp), swizzle(sw) {}");
6321 statement(
6322 "constexpr spvDynamicImageSampler(texture2d<T> plane0, texture2d<T> plane1, texture2d<T> plane2,");
6323 statement(" sampler samp, spvYCbCrSampler ycbcr_samp, uint sw) thread :");
6324 statement(" plane0(plane0), plane1(plane1), plane2(plane2), samp(samp), ycbcr_samp(ycbcr_samp), "
6325 "swizzle(sw) {}");
6326 statement("");
6327 // XXX This is really hard to follow... I've left comments to make it a bit easier.
6328 statement("template<typename... LodOptions>");
6329 statement("vec<T, 4> do_sample(float2 coord, LodOptions... options) const thread");
6330 begin_scope();
6331 statement("if (!is_null_texture(plane1))");
6332 begin_scope();
6333 statement("if (ycbcr_samp.get_resolution() == spvFormatResolution::_444 ||");
6334 statement(" ycbcr_samp.get_chroma_filter() == spvChromaFilter::nearest)");
6335 begin_scope();
6336 statement("if (!is_null_texture(plane2))");
6337 statement(" return spvChromaReconstructNearest(plane0, plane1, plane2, samp, coord,");
6338 statement(" spvForward<LodOptions>(options)...);");
6339 statement(
6340 "return spvChromaReconstructNearest(plane0, plane1, samp, coord, spvForward<LodOptions>(options)...);");
6341 end_scope(); // if (resolution == 422 || chroma_filter == nearest)
6342 statement("switch (ycbcr_samp.get_resolution())");
6343 begin_scope();
6344 statement("case spvFormatResolution::_444: break;");
6345 statement("case spvFormatResolution::_422:");
6346 begin_scope();
6347 statement("switch (ycbcr_samp.get_x_chroma_offset())");
6348 begin_scope();
6349 statement("case spvXChromaLocation::cosited_even:");
6350 statement(" if (!is_null_texture(plane2))");
6351 statement(" return spvChromaReconstructLinear422CositedEven(");
6352 statement(" plane0, plane1, plane2, samp,");
6353 statement(" coord, spvForward<LodOptions>(options)...);");
6354 statement(" return spvChromaReconstructLinear422CositedEven(");
6355 statement(" plane0, plane1, samp, coord,");
6356 statement(" spvForward<LodOptions>(options)...);");
6357 statement("case spvXChromaLocation::midpoint:");
6358 statement(" if (!is_null_texture(plane2))");
6359 statement(" return spvChromaReconstructLinear422Midpoint(");
6360 statement(" plane0, plane1, plane2, samp,");
6361 statement(" coord, spvForward<LodOptions>(options)...);");
6362 statement(" return spvChromaReconstructLinear422Midpoint(");
6363 statement(" plane0, plane1, samp, coord,");
6364 statement(" spvForward<LodOptions>(options)...);");
6365 end_scope(); // switch (x_chroma_offset)
6366 end_scope(); // case 422:
6367 statement("case spvFormatResolution::_420:");
6368 begin_scope();
6369 statement("switch (ycbcr_samp.get_x_chroma_offset())");
6370 begin_scope();
6371 statement("case spvXChromaLocation::cosited_even:");
6372 begin_scope();
6373 statement("switch (ycbcr_samp.get_y_chroma_offset())");
6374 begin_scope();
6375 statement("case spvYChromaLocation::cosited_even:");
6376 statement(" if (!is_null_texture(plane2))");
6377 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
6378 statement(" plane0, plane1, plane2, samp,");
6379 statement(" coord, spvForward<LodOptions>(options)...);");
6380 statement(" return spvChromaReconstructLinear420XCositedEvenYCositedEven(");
6381 statement(" plane0, plane1, samp, coord,");
6382 statement(" spvForward<LodOptions>(options)...);");
6383 statement("case spvYChromaLocation::midpoint:");
6384 statement(" if (!is_null_texture(plane2))");
6385 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
6386 statement(" plane0, plane1, plane2, samp,");
6387 statement(" coord, spvForward<LodOptions>(options)...);");
6388 statement(" return spvChromaReconstructLinear420XCositedEvenYMidpoint(");
6389 statement(" plane0, plane1, samp, coord,");
6390 statement(" spvForward<LodOptions>(options)...);");
6391 end_scope(); // switch (y_chroma_offset)
6392 end_scope(); // case x::cosited_even:
6393 statement("case spvXChromaLocation::midpoint:");
6394 begin_scope();
6395 statement("switch (ycbcr_samp.get_y_chroma_offset())");
6396 begin_scope();
6397 statement("case spvYChromaLocation::cosited_even:");
6398 statement(" if (!is_null_texture(plane2))");
6399 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
6400 statement(" plane0, plane1, plane2, samp,");
6401 statement(" coord, spvForward<LodOptions>(options)...);");
6402 statement(" return spvChromaReconstructLinear420XMidpointYCositedEven(");
6403 statement(" plane0, plane1, samp, coord,");
6404 statement(" spvForward<LodOptions>(options)...);");
6405 statement("case spvYChromaLocation::midpoint:");
6406 statement(" if (!is_null_texture(plane2))");
6407 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
6408 statement(" plane0, plane1, plane2, samp,");
6409 statement(" coord, spvForward<LodOptions>(options)...);");
6410 statement(" return spvChromaReconstructLinear420XMidpointYMidpoint(");
6411 statement(" plane0, plane1, samp, coord,");
6412 statement(" spvForward<LodOptions>(options)...);");
6413 end_scope(); // switch (y_chroma_offset)
6414 end_scope(); // case x::midpoint
6415 end_scope(); // switch (x_chroma_offset)
6416 end_scope(); // case 420:
6417 end_scope(); // switch (resolution)
6418 end_scope(); // if (multiplanar)
6419 statement("return plane0.sample(samp, coord, spvForward<LodOptions>(options)...);");
6420 end_scope(); // do_sample()
6421 statement("template <typename... LodOptions>");
6422 statement("vec<T, 4> sample(float2 coord, LodOptions... options) const thread");
6423 begin_scope();
6424 statement(
6425 "vec<T, 4> s = spvTextureSwizzle(do_sample(coord, spvForward<LodOptions>(options)...), swizzle);");
6426 statement("if (ycbcr_samp.get_ycbcr_model() == spvYCbCrModelConversion::rgb_identity)");
6427 statement(" return s;");
6428 statement("");
6429 statement("switch (ycbcr_samp.get_ycbcr_range())");
6430 begin_scope();
6431 statement("case spvYCbCrRange::itu_full:");
6432 statement(" s = spvExpandITUFullRange(s, ycbcr_samp.get_bpc());");
6433 statement(" break;");
6434 statement("case spvYCbCrRange::itu_narrow:");
6435 statement(" s = spvExpandITUNarrowRange(s, ycbcr_samp.get_bpc());");
6436 statement(" break;");
6437 end_scope();
6438 statement("");
6439 statement("switch (ycbcr_samp.get_ycbcr_model())");
6440 begin_scope();
6441 statement("case spvYCbCrModelConversion::rgb_identity:"); // Silence Clang warning
6442 statement("case spvYCbCrModelConversion::ycbcr_identity:");
6443 statement(" return s;");
6444 statement("case spvYCbCrModelConversion::ycbcr_bt_709:");
6445 statement(" return spvConvertYCbCrBT709(s);");
6446 statement("case spvYCbCrModelConversion::ycbcr_bt_601:");
6447 statement(" return spvConvertYCbCrBT601(s);");
6448 statement("case spvYCbCrModelConversion::ycbcr_bt_2020:");
6449 statement(" return spvConvertYCbCrBT2020(s);");
6450 end_scope();
6451 end_scope();
6452 statement("");
6453 // Sampler Y'CbCr conversion forbids offsets.
6454 statement("vec<T, 4> sample(float2 coord, int2 offset) const thread");
6455 begin_scope();
6456 if (msl_options.swizzle_texture_samples)
6457 statement("return spvTextureSwizzle(plane0.sample(samp, coord, offset), swizzle);");
6458 else
6459 statement("return plane0.sample(samp, coord, offset);");
6460 end_scope();
6461 statement("template<typename lod_options>");
6462 statement("vec<T, 4> sample(float2 coord, lod_options options, int2 offset) const thread");
6463 begin_scope();
6464 if (msl_options.swizzle_texture_samples)
6465 statement("return spvTextureSwizzle(plane0.sample(samp, coord, options, offset), swizzle);");
6466 else
6467 statement("return plane0.sample(samp, coord, options, offset);");
6468 end_scope();
6469 statement("#if __HAVE_MIN_LOD_CLAMP__");
6470 statement("vec<T, 4> sample(float2 coord, bias b, min_lod_clamp min_lod, int2 offset) const thread");
6471 begin_scope();
6472 statement("return plane0.sample(samp, coord, b, min_lod, offset);");
6473 end_scope();
6474 statement(
6475 "vec<T, 4> sample(float2 coord, gradient2d grad, min_lod_clamp min_lod, int2 offset) const thread");
6476 begin_scope();
6477 statement("return plane0.sample(samp, coord, grad, min_lod, offset);");
6478 end_scope();
6479 statement("#endif");
6480 statement("");
6481 // Y'CbCr conversion forbids all operations but sampling.
6482 statement("vec<T, 4> read(uint2 coord, uint lod = 0) const thread");
6483 begin_scope();
6484 statement("return plane0.read(coord, lod);");
6485 end_scope();
6486 statement("");
6487 statement("vec<T, 4> gather(float2 coord, int2 offset = int2(0), component c = component::x) const thread");
6488 begin_scope();
6489 if (msl_options.swizzle_texture_samples)
6490 statement("return spvGatherSwizzle(plane0, samp, swizzle, c, coord, offset);");
6491 else
6492 statement("return plane0.gather(samp, coord, offset, c);");
6493 end_scope();
6494 end_scope_decl();
6495 statement("");
6496
6497 default:
6498 break;
6499 }
6500 }
6501}
6502
6503static string inject_top_level_storage_qualifier(const string &expr, const string &qualifier)
6504{
6505 // Easier to do this through text munging since the qualifier does not exist in the type system at all,
6506 // and plumbing in all that information is not very helpful.
6507 size_t last_reference = expr.find_last_of('&');
6508 size_t last_pointer = expr.find_last_of('*');
6509 size_t last_significant = string::npos;
6510
6511 if (last_reference == string::npos)
6512 last_significant = last_pointer;
6513 else if (last_pointer == string::npos)
6514 last_significant = last_reference;
6515 else
6516 last_significant = std::max(last_reference, last_pointer);
6517
6518 if (last_significant == string::npos)
6519 return join(qualifier, " ", expr);
6520 else
6521 {
6522 return join(expr.substr(0, last_significant + 1), " ",
6523 qualifier, expr.substr(last_significant + 1, string::npos));
6524 }
6525}
6526
6527// Undefined global memory is not allowed in MSL.
6528// Declare constant and init to zeros. Use {}, as global constructors can break Metal.
6529void CompilerMSL::declare_undefined_values()
6530{
6531 bool emitted = false;
6532 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, SPIRUndef &undef) {
6533 auto &type = this->get<SPIRType>(undef.basetype);
6534 // OpUndef can be void for some reason ...
6535 if (type.basetype == SPIRType::Void)
6536 return;
6537
6538 statement(inject_top_level_storage_qualifier(
6539 variable_decl(type, to_name(undef.self), undef.self),
6540 "constant"),
6541 " = {};");
6542 emitted = true;
6543 });
6544
6545 if (emitted)
6546 statement("");
6547}
6548
6549void CompilerMSL::declare_constant_arrays()
6550{
6551 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6552
6553 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6554 // global constants directly, so we are able to use constants as variable expressions.
6555 bool emitted = false;
6556
6557 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6558 if (c.specialization)
6559 return;
6560
6561 auto &type = this->get<SPIRType>(c.constant_type);
6562 // Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries.
6563 // FIXME: However, hoisting constants to main() means we need to pass down constant arrays to leaf functions if they are used there.
6564 // If there are multiple functions in the module, drop this case to avoid breaking use cases which do not need to
6565 // link into Metal libraries. This is hacky.
6566 if (!type.array.empty() && (!fully_inlined || is_scalar(type) || is_vector(type)))
6567 {
6568 add_resource_name(c.self);
6569 auto name = to_name(c.self);
6570 statement(inject_top_level_storage_qualifier(variable_decl(type, name), "constant"),
6571 " = ", constant_expression(c), ";");
6572 emitted = true;
6573 }
6574 });
6575
6576 if (emitted)
6577 statement("");
6578}
6579
6580// Constant arrays of non-primitive types (i.e. matrices) won't link properly into Metal libraries
6581void CompilerMSL::declare_complex_constant_arrays()
6582{
6583 // If we do not have a fully inlined module, we did not opt in to
6584 // declaring constant arrays of complex types. See CompilerMSL::declare_constant_arrays().
6585 bool fully_inlined = ir.ids_for_type[TypeFunction].size() == 1;
6586 if (!fully_inlined)
6587 return;
6588
6589 // MSL cannot declare arrays inline (except when declaring a variable), so we must move them out to
6590 // global constants directly, so we are able to use constants as variable expressions.
6591 bool emitted = false;
6592
6593 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6594 if (c.specialization)
6595 return;
6596
6597 auto &type = this->get<SPIRType>(c.constant_type);
6598 if (!type.array.empty() && !(is_scalar(type) || is_vector(type)))
6599 {
6600 add_resource_name(c.self);
6601 auto name = to_name(c.self);
6602 statement("", variable_decl(type, name), " = ", constant_expression(c), ";");
6603 emitted = true;
6604 }
6605 });
6606
6607 if (emitted)
6608 statement("");
6609}
6610
6611void CompilerMSL::emit_resources()
6612{
6613 declare_constant_arrays();
6614 declare_undefined_values();
6615
6616 // Emit the special [[stage_in]] and [[stage_out]] interface blocks which we created.
6617 emit_interface_block(stage_out_var_id);
6618 emit_interface_block(patch_stage_out_var_id);
6619 emit_interface_block(stage_in_var_id);
6620 emit_interface_block(patch_stage_in_var_id);
6621}
6622
6623// Emit declarations for the specialization Metal function constants
6624void CompilerMSL::emit_specialization_constants_and_structs()
6625{
6626 SpecializationConstant wg_x, wg_y, wg_z;
6627 ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
6628 bool emitted = false;
6629
6630 unordered_set<uint32_t> declared_structs;
6631 unordered_set<uint32_t> aligned_structs;
6632
6633 // First, we need to deal with scalar block layout.
6634 // It is possible that a struct may have to be placed at an alignment which does not match the innate alignment of the struct itself.
6635 // In that case, if such a case exists for a struct, we must force that all elements of the struct become packed_ types.
6636 // This makes the struct alignment as small as physically possible.
6637 // When we actually align the struct later, we can insert padding as necessary to make the packed members behave like normally aligned types.
6638 ir.for_each_typed_id<SPIRType>([&](uint32_t type_id, const SPIRType &type) {
6639 if (type.basetype == SPIRType::Struct &&
6640 has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6641 mark_scalar_layout_structs(type);
6642 });
6643
6644 bool builtin_block_type_is_required = false;
6645 // Very special case. If gl_PerVertex is initialized as an array (tessellation)
6646 // we have to potentially emit the gl_PerVertex struct type so that we can emit a constant LUT.
6647 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
6648 auto &type = this->get<SPIRType>(c.constant_type);
6649 if (is_array(type) && has_decoration(type.self, DecorationBlock) && is_builtin_type(type))
6650 builtin_block_type_is_required = true;
6651 });
6652
6653 // Very particular use of the soft loop lock.
6654 // align_struct may need to create custom types on the fly, but we don't care about
6655 // these types for purpose of iterating over them in ir.ids_for_type and friends.
6656 auto loop_lock = ir.create_loop_soft_lock();
6657
6658 for (auto &id_ : ir.ids_for_constant_or_type)
6659 {
6660 auto &id = ir.ids[id_];
6661
6662 if (id.get_type() == TypeConstant)
6663 {
6664 auto &c = id.get<SPIRConstant>();
6665
6666 if (c.self == workgroup_size_id)
6667 {
6668 // TODO: This can be expressed as a [[threads_per_threadgroup]] input semantic, but we need to know
6669 // the work group size at compile time in SPIR-V, and [[threads_per_threadgroup]] would need to be passed around as a global.
6670 // The work group size may be a specialization constant.
6671 statement("constant uint3 ", builtin_to_glsl(BuiltInWorkgroupSize, StorageClassWorkgroup),
6672 " [[maybe_unused]] = ", constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
6673 emitted = true;
6674 }
6675 else if (c.specialization)
6676 {
6677 auto &type = get<SPIRType>(c.constant_type);
6678 string sc_type_name = type_to_glsl(type);
6679 add_resource_name(c.self);
6680 string sc_name = to_name(c.self);
6681 string sc_tmp_name = sc_name + "_tmp";
6682
6683 // Function constants are only supported in MSL 1.2 and later.
6684 // If we don't support it just declare the "default" directly.
6685 // This "default" value can be overridden to the true specialization constant by the API user.
6686 // Specialization constants which are used as array length expressions cannot be function constants in MSL,
6687 // so just fall back to macros.
6688 if (msl_options.supports_msl_version(1, 2) && has_decoration(c.self, DecorationSpecId) &&
6689 !c.is_used_as_array_length)
6690 {
6691 uint32_t constant_id = get_decoration(c.self, DecorationSpecId);
6692 // Only scalar, non-composite values can be function constants.
6693 statement("constant ", sc_type_name, " ", sc_tmp_name, " [[function_constant(", constant_id,
6694 ")]];");
6695 statement("constant ", sc_type_name, " ", sc_name, " = is_function_constant_defined(", sc_tmp_name,
6696 ") ? ", sc_tmp_name, " : ", constant_expression(c), ";");
6697 }
6698 else if (has_decoration(c.self, DecorationSpecId))
6699 {
6700 // Fallback to macro overrides.
6701 c.specialization_constant_macro_name =
6702 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
6703
6704 statement("#ifndef ", c.specialization_constant_macro_name);
6705 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
6706 statement("#endif");
6707 statement("constant ", sc_type_name, " ", sc_name, " = ", c.specialization_constant_macro_name,
6708 ";");
6709 }
6710 else
6711 {
6712 // Composite specialization constants must be built from other specialization constants.
6713 statement("constant ", sc_type_name, " ", sc_name, " = ", constant_expression(c), ";");
6714 }
6715 emitted = true;
6716 }
6717 }
6718 else if (id.get_type() == TypeConstantOp)
6719 {
6720 auto &c = id.get<SPIRConstantOp>();
6721 auto &type = get<SPIRType>(c.basetype);
6722 add_resource_name(c.self);
6723 auto name = to_name(c.self);
6724 statement("constant ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
6725 emitted = true;
6726 }
6727 else if (id.get_type() == TypeType)
6728 {
6729 // Output non-builtin interface structs. These include local function structs
6730 // and structs nested within uniform and read-write buffers.
6731 auto &type = id.get<SPIRType>();
6732 TypeID type_id = type.self;
6733
6734 bool is_struct = (type.basetype == SPIRType::Struct) && type.array.empty() && !type.pointer;
6735 bool is_block =
6736 has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
6737
6738 bool is_builtin_block = is_block && is_builtin_type(type);
6739 bool is_declarable_struct = is_struct && (!is_builtin_block || builtin_block_type_is_required);
6740
6741 // We'll declare this later.
6742 if (stage_out_var_id && get_stage_out_struct_type().self == type_id)
6743 is_declarable_struct = false;
6744 if (patch_stage_out_var_id && get_patch_stage_out_struct_type().self == type_id)
6745 is_declarable_struct = false;
6746 if (stage_in_var_id && get_stage_in_struct_type().self == type_id)
6747 is_declarable_struct = false;
6748 if (patch_stage_in_var_id && get_patch_stage_in_struct_type().self == type_id)
6749 is_declarable_struct = false;
6750
6751 // Special case. Declare builtin struct anyways if we need to emit a threadgroup version of it.
6752 if (stage_out_masked_builtin_type_id == type_id)
6753 is_declarable_struct = true;
6754
6755 // Align and emit declarable structs...but avoid declaring each more than once.
6756 if (is_declarable_struct && declared_structs.count(type_id) == 0)
6757 {
6758 if (emitted)
6759 statement("");
6760 emitted = false;
6761
6762 declared_structs.insert(type_id);
6763
6764 if (has_extended_decoration(type_id, SPIRVCrossDecorationBufferBlockRepacked))
6765 align_struct(type, aligned_structs);
6766
6767 // Make sure we declare the underlying struct type, and not the "decorated" type with pointers, etc.
6768 emit_struct(get<SPIRType>(type_id));
6769 }
6770 }
6771 }
6772
6773 if (emitted)
6774 statement("");
6775}
6776
6777void CompilerMSL::emit_binary_unord_op(uint32_t result_type, uint32_t result_id, uint32_t op0, uint32_t op1,
6778 const char *op)
6779{
6780 bool forward = should_forward(op0) && should_forward(op1);
6781 emit_op(result_type, result_id,
6782 join("(isunordered(", to_enclosed_unpacked_expression(op0), ", ", to_enclosed_unpacked_expression(op1),
6783 ") || ", to_enclosed_unpacked_expression(op0), " ", op, " ", to_enclosed_unpacked_expression(op1),
6784 ")"),
6785 forward);
6786
6787 inherit_expression_dependencies(result_id, op0);
6788 inherit_expression_dependencies(result_id, op1);
6789}
6790
6791bool CompilerMSL::emit_tessellation_io_load(uint32_t result_type_id, uint32_t id, uint32_t ptr)
6792{
6793 auto &ptr_type = expression_type(ptr);
6794 auto &result_type = get<SPIRType>(result_type_id);
6795 if (ptr_type.storage != StorageClassInput && ptr_type.storage != StorageClassOutput)
6796 return false;
6797 if (ptr_type.storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationEvaluation)
6798 return false;
6799
6800 if (has_decoration(ptr, DecorationPatch))
6801 return false;
6802 bool ptr_is_io_variable = ir.ids[ptr].get_type() == TypeVariable;
6803
6804 bool flattened_io = variable_storage_requires_stage_io(ptr_type.storage);
6805
6806 bool flat_data_type = flattened_io &&
6807 (is_matrix(result_type) || is_array(result_type) || result_type.basetype == SPIRType::Struct);
6808
6809 // Edge case, even with multi-patch workgroups, we still need to unroll load
6810 // if we're loading control points directly.
6811 if (ptr_is_io_variable && is_array(result_type))
6812 flat_data_type = true;
6813
6814 if (!flat_data_type)
6815 return false;
6816
6817 // Now, we must unflatten a composite type and take care of interleaving array access with gl_in/gl_out.
6818 // Lots of painful code duplication since we *really* should not unroll these kinds of loads in entry point fixup
6819 // unless we're forced to do this when the code is emitting inoptimal OpLoads.
6820 string expr;
6821
6822 uint32_t interface_index = get_extended_decoration(ptr, SPIRVCrossDecorationInterfaceMemberIndex);
6823 auto *var = maybe_get_backing_variable(ptr);
6824 auto &expr_type = get_pointee_type(ptr_type.self);
6825
6826 const auto &iface_type = expression_type(stage_in_ptr_var_id);
6827
6828 if (!flattened_io)
6829 {
6830 // Simplest case for multi-patch workgroups, just unroll array as-is.
6831 if (interface_index == uint32_t(-1))
6832 return false;
6833
6834 expr += type_to_glsl(result_type) + "({ ";
6835 uint32_t num_control_points = to_array_size_literal(result_type, uint32_t(result_type.array.size()) - 1);
6836
6837 for (uint32_t i = 0; i < num_control_points; i++)
6838 {
6839 const uint32_t indices[2] = { i, interface_index };
6840 AccessChainMeta meta;
6841 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6842 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6843 if (i + 1 < num_control_points)
6844 expr += ", ";
6845 }
6846 expr += " })";
6847 }
6848 else if (result_type.array.size() > 2)
6849 {
6850 SPIRV_CROSS_THROW("Cannot load tessellation IO variables with more than 2 dimensions.");
6851 }
6852 else if (result_type.array.size() == 2)
6853 {
6854 if (!ptr_is_io_variable)
6855 SPIRV_CROSS_THROW("Loading an array-of-array must be loaded directly from an IO variable.");
6856 if (interface_index == uint32_t(-1))
6857 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6858 if (result_type.basetype == SPIRType::Struct || is_matrix(result_type))
6859 SPIRV_CROSS_THROW("Cannot load array-of-array of composite type in tessellation IO.");
6860
6861 expr += type_to_glsl(result_type) + "({ ";
6862 uint32_t num_control_points = to_array_size_literal(result_type, 1);
6863 uint32_t base_interface_index = interface_index;
6864
6865 auto &sub_type = get<SPIRType>(result_type.parent_type);
6866
6867 for (uint32_t i = 0; i < num_control_points; i++)
6868 {
6869 expr += type_to_glsl(sub_type) + "({ ";
6870 interface_index = base_interface_index;
6871 uint32_t array_size = to_array_size_literal(result_type, 0);
6872 for (uint32_t j = 0; j < array_size; j++, interface_index++)
6873 {
6874 const uint32_t indices[2] = { i, interface_index };
6875
6876 AccessChainMeta meta;
6877 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6878 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6879 if (!is_matrix(sub_type) && sub_type.basetype != SPIRType::Struct &&
6880 expr_type.vecsize > sub_type.vecsize)
6881 expr += vector_swizzle(sub_type.vecsize, 0);
6882
6883 if (j + 1 < array_size)
6884 expr += ", ";
6885 }
6886 expr += " })";
6887 if (i + 1 < num_control_points)
6888 expr += ", ";
6889 }
6890 expr += " })";
6891 }
6892 else if (result_type.basetype == SPIRType::Struct)
6893 {
6894 bool is_array_of_struct = is_array(result_type);
6895 if (is_array_of_struct && !ptr_is_io_variable)
6896 SPIRV_CROSS_THROW("Loading array of struct from IO variable must come directly from IO variable.");
6897
6898 uint32_t num_control_points = 1;
6899 if (is_array_of_struct)
6900 {
6901 num_control_points = to_array_size_literal(result_type, 0);
6902 expr += type_to_glsl(result_type) + "({ ";
6903 }
6904
6905 auto &struct_type = is_array_of_struct ? get<SPIRType>(result_type.parent_type) : result_type;
6906 assert(struct_type.array.empty());
6907
6908 for (uint32_t i = 0; i < num_control_points; i++)
6909 {
6910 expr += type_to_glsl(struct_type) + "{ ";
6911 for (uint32_t j = 0; j < uint32_t(struct_type.member_types.size()); j++)
6912 {
6913 // The base interface index is stored per variable for structs.
6914 if (var)
6915 {
6916 interface_index =
6917 get_extended_member_decoration(var->self, j, SPIRVCrossDecorationInterfaceMemberIndex);
6918 }
6919
6920 if (interface_index == uint32_t(-1))
6921 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
6922
6923 const auto &mbr_type = get<SPIRType>(struct_type.member_types[j]);
6924 const auto &expr_mbr_type = get<SPIRType>(expr_type.member_types[j]);
6925 if (is_matrix(mbr_type) && ptr_type.storage == StorageClassInput)
6926 {
6927 expr += type_to_glsl(mbr_type) + "(";
6928 for (uint32_t k = 0; k < mbr_type.columns; k++, interface_index++)
6929 {
6930 if (is_array_of_struct)
6931 {
6932 const uint32_t indices[2] = { i, interface_index };
6933 AccessChainMeta meta;
6934 expr += access_chain_internal(
6935 stage_in_ptr_var_id, indices, 2,
6936 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6937 }
6938 else
6939 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6940 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6941 expr += vector_swizzle(mbr_type.vecsize, 0);
6942
6943 if (k + 1 < mbr_type.columns)
6944 expr += ", ";
6945 }
6946 expr += ")";
6947 }
6948 else if (is_array(mbr_type))
6949 {
6950 expr += type_to_glsl(mbr_type) + "({ ";
6951 uint32_t array_size = to_array_size_literal(mbr_type, 0);
6952 for (uint32_t k = 0; k < array_size; k++, interface_index++)
6953 {
6954 if (is_array_of_struct)
6955 {
6956 const uint32_t indices[2] = { i, interface_index };
6957 AccessChainMeta meta;
6958 expr += access_chain_internal(
6959 stage_in_ptr_var_id, indices, 2,
6960 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
6961 }
6962 else
6963 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6964 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6965 expr += vector_swizzle(mbr_type.vecsize, 0);
6966
6967 if (k + 1 < array_size)
6968 expr += ", ";
6969 }
6970 expr += " })";
6971 }
6972 else
6973 {
6974 if (is_array_of_struct)
6975 {
6976 const uint32_t indices[2] = { i, interface_index };
6977 AccessChainMeta meta;
6978 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
6979 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT,
6980 &meta);
6981 }
6982 else
6983 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
6984 if (expr_mbr_type.vecsize > mbr_type.vecsize)
6985 expr += vector_swizzle(mbr_type.vecsize, 0);
6986 }
6987
6988 if (j + 1 < struct_type.member_types.size())
6989 expr += ", ";
6990 }
6991 expr += " }";
6992 if (i + 1 < num_control_points)
6993 expr += ", ";
6994 }
6995 if (is_array_of_struct)
6996 expr += " })";
6997 }
6998 else if (is_matrix(result_type))
6999 {
7000 bool is_array_of_matrix = is_array(result_type);
7001 if (is_array_of_matrix && !ptr_is_io_variable)
7002 SPIRV_CROSS_THROW("Loading array of matrix from IO variable must come directly from IO variable.");
7003 if (interface_index == uint32_t(-1))
7004 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7005
7006 if (is_array_of_matrix)
7007 {
7008 // Loading a matrix from each control point.
7009 uint32_t base_interface_index = interface_index;
7010 uint32_t num_control_points = to_array_size_literal(result_type, 0);
7011 expr += type_to_glsl(result_type) + "({ ";
7012
7013 auto &matrix_type = get_variable_element_type(get<SPIRVariable>(ptr));
7014
7015 for (uint32_t i = 0; i < num_control_points; i++)
7016 {
7017 interface_index = base_interface_index;
7018 expr += type_to_glsl(matrix_type) + "(";
7019 for (uint32_t j = 0; j < result_type.columns; j++, interface_index++)
7020 {
7021 const uint32_t indices[2] = { i, interface_index };
7022
7023 AccessChainMeta meta;
7024 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
7025 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
7026 if (expr_type.vecsize > result_type.vecsize)
7027 expr += vector_swizzle(result_type.vecsize, 0);
7028 if (j + 1 < result_type.columns)
7029 expr += ", ";
7030 }
7031 expr += ")";
7032 if (i + 1 < num_control_points)
7033 expr += ", ";
7034 }
7035
7036 expr += " })";
7037 }
7038 else
7039 {
7040 expr += type_to_glsl(result_type) + "(";
7041 for (uint32_t i = 0; i < result_type.columns; i++, interface_index++)
7042 {
7043 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
7044 if (expr_type.vecsize > result_type.vecsize)
7045 expr += vector_swizzle(result_type.vecsize, 0);
7046 if (i + 1 < result_type.columns)
7047 expr += ", ";
7048 }
7049 expr += ")";
7050 }
7051 }
7052 else if (ptr_is_io_variable)
7053 {
7054 assert(is_array(result_type));
7055 assert(result_type.array.size() == 1);
7056 if (interface_index == uint32_t(-1))
7057 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7058
7059 // We're loading an array directly from a global variable.
7060 // This means we're loading one member from each control point.
7061 expr += type_to_glsl(result_type) + "({ ";
7062 uint32_t num_control_points = to_array_size_literal(result_type, 0);
7063
7064 for (uint32_t i = 0; i < num_control_points; i++)
7065 {
7066 const uint32_t indices[2] = { i, interface_index };
7067
7068 AccessChainMeta meta;
7069 expr += access_chain_internal(stage_in_ptr_var_id, indices, 2,
7070 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_PTR_CHAIN_BIT, &meta);
7071 if (expr_type.vecsize > result_type.vecsize)
7072 expr += vector_swizzle(result_type.vecsize, 0);
7073
7074 if (i + 1 < num_control_points)
7075 expr += ", ";
7076 }
7077 expr += " })";
7078 }
7079 else
7080 {
7081 // We're loading an array from a concrete control point.
7082 assert(is_array(result_type));
7083 assert(result_type.array.size() == 1);
7084 if (interface_index == uint32_t(-1))
7085 SPIRV_CROSS_THROW("Interface index is unknown. Cannot continue.");
7086
7087 expr += type_to_glsl(result_type) + "({ ";
7088 uint32_t array_size = to_array_size_literal(result_type, 0);
7089 for (uint32_t i = 0; i < array_size; i++, interface_index++)
7090 {
7091 expr += to_expression(ptr) + "." + to_member_name(iface_type, interface_index);
7092 if (expr_type.vecsize > result_type.vecsize)
7093 expr += vector_swizzle(result_type.vecsize, 0);
7094 if (i + 1 < array_size)
7095 expr += ", ";
7096 }
7097 expr += " })";
7098 }
7099
7100 emit_op(result_type_id, id, expr, false);
7101 register_read(id, ptr, false);
7102 return true;
7103}
7104
7105bool CompilerMSL::emit_tessellation_access_chain(const uint32_t *ops, uint32_t length)
7106{
7107 // If this is a per-vertex output, remap it to the I/O array buffer.
7108
7109 // Any object which did not go through IO flattening shenanigans will go there instead.
7110 // We will unflatten on-demand instead as needed, but not all possible cases can be supported, especially with arrays.
7111
7112 auto *var = maybe_get_backing_variable(ops[2]);
7113 bool patch = false;
7114 bool flat_data = false;
7115 bool ptr_is_chain = false;
7116 bool flatten_composites = false;
7117
7118 bool is_block = false;
7119
7120 if (var)
7121 is_block = has_decoration(get_variable_data_type(*var).self, DecorationBlock);
7122
7123 if (var)
7124 {
7125 flatten_composites = variable_storage_requires_stage_io(var->storage);
7126 patch = has_decoration(ops[2], DecorationPatch) || is_patch_block(get_variable_data_type(*var));
7127
7128 // Should match strip_array in add_interface_block.
7129 flat_data = var->storage == StorageClassInput ||
7130 (var->storage == StorageClassOutput && get_execution_model() == ExecutionModelTessellationControl);
7131
7132 // Patch inputs are treated as normal block IO variables, so they don't deal with this path at all.
7133 if (patch && (!is_block || var->storage == StorageClassInput))
7134 flat_data = false;
7135
7136 // We might have a chained access chain, where
7137 // we first take the access chain to the control point, and then we chain into a member or something similar.
7138 // In this case, we need to skip gl_in/gl_out remapping.
7139 // Also, skip ptr chain for patches.
7140 ptr_is_chain = var->self != ID(ops[2]);
7141 }
7142
7143 bool builtin_variable = false;
7144 bool variable_is_flat = false;
7145
7146 if (var && flat_data)
7147 {
7148 builtin_variable = is_builtin_variable(*var);
7149
7150 BuiltIn bi_type = BuiltInMax;
7151 if (builtin_variable && !is_block)
7152 bi_type = BuiltIn(get_decoration(var->self, DecorationBuiltIn));
7153
7154 variable_is_flat = !builtin_variable || is_block ||
7155 bi_type == BuiltInPosition || bi_type == BuiltInPointSize ||
7156 bi_type == BuiltInClipDistance || bi_type == BuiltInCullDistance;
7157 }
7158
7159 if (variable_is_flat)
7160 {
7161 // If output is masked, it is emitted as a "normal" variable, just go through normal code paths.
7162 // Only check this for the first level of access chain.
7163 // Dealing with this for partial access chains should be possible, but awkward.
7164 if (var->storage == StorageClassOutput && !ptr_is_chain)
7165 {
7166 bool masked = false;
7167 if (is_block)
7168 {
7169 uint32_t relevant_member_index = patch ? 3 : 4;
7170 // FIXME: This won't work properly if the application first access chains into gl_out element,
7171 // then access chains into the member. Super weird, but theoretically possible ...
7172 if (length > relevant_member_index)
7173 {
7174 uint32_t mbr_idx = get<SPIRConstant>(ops[relevant_member_index]).scalar();
7175 masked = is_stage_output_block_member_masked(*var, mbr_idx, true);
7176 }
7177 }
7178 else if (var)
7179 masked = is_stage_output_variable_masked(*var);
7180
7181 if (masked)
7182 return false;
7183 }
7184
7185 AccessChainMeta meta;
7186 SmallVector<uint32_t> indices;
7187 uint32_t next_id = ir.increase_bound_by(1);
7188
7189 indices.reserve(length - 3 + 1);
7190
7191 uint32_t first_non_array_index = (ptr_is_chain ? 3 : 4) - (patch ? 1 : 0);
7192
7193 VariableID stage_var_id;
7194 if (patch)
7195 stage_var_id = var->storage == StorageClassInput ? patch_stage_in_var_id : patch_stage_out_var_id;
7196 else
7197 stage_var_id = var->storage == StorageClassInput ? stage_in_ptr_var_id : stage_out_ptr_var_id;
7198
7199 VariableID ptr = ptr_is_chain ? VariableID(ops[2]) : stage_var_id;
7200 if (!ptr_is_chain && !patch)
7201 {
7202 // Index into gl_in/gl_out with first array index.
7203 indices.push_back(ops[first_non_array_index - 1]);
7204 }
7205
7206 auto &result_ptr_type = get<SPIRType>(ops[0]);
7207
7208 uint32_t const_mbr_id = next_id++;
7209 uint32_t index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
7210
7211 // If we have a pointer chain expression, and we are no longer pointing to a composite
7212 // object, we are in the clear. There is no longer a need to flatten anything.
7213 bool further_access_chain_is_trivial = false;
7214 if (ptr_is_chain && flatten_composites)
7215 {
7216 auto &ptr_type = expression_type(ptr);
7217 if (!is_array(ptr_type) && !is_matrix(ptr_type) && ptr_type.basetype != SPIRType::Struct)
7218 further_access_chain_is_trivial = true;
7219 }
7220
7221 if (!further_access_chain_is_trivial && (flatten_composites || is_block))
7222 {
7223 uint32_t i = first_non_array_index;
7224 auto *type = &get_variable_element_type(*var);
7225 if (index == uint32_t(-1) && length >= (first_non_array_index + 1))
7226 {
7227 // Maybe this is a struct type in the input class, in which case
7228 // we put it as a decoration on the corresponding member.
7229 uint32_t mbr_idx = get_constant(ops[first_non_array_index]).scalar();
7230 index = get_extended_member_decoration(var->self, mbr_idx,
7231 SPIRVCrossDecorationInterfaceMemberIndex);
7232 assert(index != uint32_t(-1));
7233 i++;
7234 type = &get<SPIRType>(type->member_types[mbr_idx]);
7235 }
7236
7237 // In this case, we're poking into flattened structures and arrays, so now we have to
7238 // combine the following indices. If we encounter a non-constant index,
7239 // we're hosed.
7240 for (; flatten_composites && i < length; ++i)
7241 {
7242 if (!is_array(*type) && !is_matrix(*type) && type->basetype != SPIRType::Struct)
7243 break;
7244
7245 auto *c = maybe_get<SPIRConstant>(ops[i]);
7246 if (!c || c->specialization)
7247 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable in tessellation. "
7248 "This is currently unsupported.");
7249
7250 // We're in flattened space, so just increment the member index into IO block.
7251 // We can only do this once in the current implementation, so either:
7252 // Struct, Matrix or 1-dimensional array for a control point.
7253 if (type->basetype == SPIRType::Struct && var->storage == StorageClassOutput)
7254 {
7255 // Need to consider holes, since individual block members might be masked away.
7256 uint32_t mbr_idx = c->scalar();
7257 for (uint32_t j = 0; j < mbr_idx; j++)
7258 if (!is_stage_output_block_member_masked(*var, j, true))
7259 index++;
7260 }
7261 else
7262 index += c->scalar();
7263
7264 if (type->parent_type)
7265 type = &get<SPIRType>(type->parent_type);
7266 else if (type->basetype == SPIRType::Struct)
7267 type = &get<SPIRType>(type->member_types[c->scalar()]);
7268 }
7269
7270 // We're not going to emit the actual member name, we let any further OpLoad take care of that.
7271 // Tag the access chain with the member index we're referencing.
7272 bool defer_access_chain = flatten_composites && (is_matrix(result_ptr_type) || is_array(result_ptr_type) ||
7273 result_ptr_type.basetype == SPIRType::Struct);
7274
7275 if (!defer_access_chain)
7276 {
7277 // Access the appropriate member of gl_in/gl_out.
7278 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
7279 indices.push_back(const_mbr_id);
7280
7281 // Member index is now irrelevant.
7282 index = uint32_t(-1);
7283
7284 // Append any straggling access chain indices.
7285 if (i < length)
7286 indices.insert(indices.end(), ops + i, ops + length);
7287 }
7288 else
7289 {
7290 // We must have consumed the entire access chain if we're deferring it.
7291 assert(i == length);
7292 }
7293
7294 if (index != uint32_t(-1))
7295 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, index);
7296 else
7297 unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex);
7298 }
7299 else
7300 {
7301 if (index != uint32_t(-1))
7302 {
7303 set<SPIRConstant>(const_mbr_id, get_uint_type_id(), index, false);
7304 indices.push_back(const_mbr_id);
7305 }
7306
7307 // Member index is now irrelevant.
7308 index = uint32_t(-1);
7309 unset_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex);
7310
7311 indices.insert(indices.end(), ops + first_non_array_index, ops + length);
7312 }
7313
7314 // We use the pointer to the base of the input/output array here,
7315 // so this is always a pointer chain.
7316 string e;
7317
7318 if (!ptr_is_chain)
7319 {
7320 // This is the start of an access chain, use ptr_chain to index into control point array.
7321 e = access_chain(ptr, indices.data(), uint32_t(indices.size()), result_ptr_type, &meta, !patch);
7322 }
7323 else
7324 {
7325 // If we're accessing a struct, we need to use member indices which are based on the IO block,
7326 // not actual struct type, so we have to use a split access chain here where
7327 // first path resolves the control point index, i.e. gl_in[index], and second half deals with
7328 // looking up flattened member name.
7329
7330 // However, it is possible that we partially accessed a struct,
7331 // by taking pointer to member inside the control-point array.
7332 // For this case, we fall back to a natural access chain since we have already dealt with remapping struct members.
7333 // One way to check this here is if we have 2 implied read expressions.
7334 // First one is the gl_in/gl_out struct itself, then an index into that array.
7335 // If we have traversed further, we use a normal access chain formulation.
7336 auto *ptr_expr = maybe_get<SPIRExpression>(ptr);
7337 bool split_access_chain_formulation = flatten_composites && ptr_expr &&
7338 ptr_expr->implied_read_expressions.size() == 2 &&
7339 !further_access_chain_is_trivial;
7340
7341 if (split_access_chain_formulation)
7342 {
7343 e = join(to_expression(ptr),
7344 access_chain_internal(stage_var_id, indices.data(), uint32_t(indices.size()),
7345 ACCESS_CHAIN_CHAIN_ONLY_BIT, &meta));
7346 }
7347 else
7348 {
7349 e = access_chain_internal(ptr, indices.data(), uint32_t(indices.size()), 0, &meta);
7350 }
7351 }
7352
7353 // Get the actual type of the object that was accessed. If it's a vector type and we changed it,
7354 // then we'll need to add a swizzle.
7355 // For this, we can't necessarily rely on the type of the base expression, because it might be
7356 // another access chain, and it will therefore already have the "correct" type.
7357 auto *expr_type = &get_variable_data_type(*var);
7358 if (has_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID))
7359 expr_type = &get<SPIRType>(get_extended_decoration(ops[2], SPIRVCrossDecorationTessIOOriginalInputTypeID));
7360 for (uint32_t i = 3; i < length; i++)
7361 {
7362 if (!is_array(*expr_type) && expr_type->basetype == SPIRType::Struct)
7363 expr_type = &get<SPIRType>(expr_type->member_types[get<SPIRConstant>(ops[i]).scalar()]);
7364 else
7365 expr_type = &get<SPIRType>(expr_type->parent_type);
7366 }
7367 if (!is_array(*expr_type) && !is_matrix(*expr_type) && expr_type->basetype != SPIRType::Struct &&
7368 expr_type->vecsize > result_ptr_type.vecsize)
7369 e += vector_swizzle(result_ptr_type.vecsize, 0);
7370
7371 auto &expr = set<SPIRExpression>(ops[1], move(e), ops[0], should_forward(ops[2]));
7372 expr.loaded_from = var->self;
7373 expr.need_transpose = meta.need_transpose;
7374 expr.access_chain = true;
7375
7376 // Mark the result as being packed if necessary.
7377 if (meta.storage_is_packed)
7378 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypePacked);
7379 if (meta.storage_physical_type != 0)
7380 set_extended_decoration(ops[1], SPIRVCrossDecorationPhysicalTypeID, meta.storage_physical_type);
7381 if (meta.storage_is_invariant)
7382 set_decoration(ops[1], DecorationInvariant);
7383 // Save the type we found in case the result is used in another access chain.
7384 set_extended_decoration(ops[1], SPIRVCrossDecorationTessIOOriginalInputTypeID, expr_type->self);
7385
7386 // If we have some expression dependencies in our access chain, this access chain is technically a forwarded
7387 // temporary which could be subject to invalidation.
7388 // Need to assume we're forwarded while calling inherit_expression_depdendencies.
7389 forwarded_temporaries.insert(ops[1]);
7390 // The access chain itself is never forced to a temporary, but its dependencies might.
7391 suppressed_usage_tracking.insert(ops[1]);
7392
7393 for (uint32_t i = 2; i < length; i++)
7394 {
7395 inherit_expression_dependencies(ops[1], ops[i]);
7396 add_implied_read_expression(expr, ops[i]);
7397 }
7398
7399 // If we have no dependencies after all, i.e., all indices in the access chain are immutable temporaries,
7400 // we're not forwarded after all.
7401 if (expr.expression_dependencies.empty())
7402 forwarded_temporaries.erase(ops[1]);
7403
7404 return true;
7405 }
7406
7407 // If this is the inner tessellation level, and we're tessellating triangles,
7408 // drop the last index. It isn't an array in this case, so we can't have an
7409 // array reference here. We need to make this ID a variable instead of an
7410 // expression so we don't try to dereference it as a variable pointer.
7411 // Don't do this if the index is a constant 1, though. We need to drop stores
7412 // to that one.
7413 auto *m = ir.find_meta(var ? var->self : ID(0));
7414 if (get_execution_model() == ExecutionModelTessellationControl && var && m &&
7415 m->decoration.builtin_type == BuiltInTessLevelInner && get_entry_point().flags.get(ExecutionModeTriangles))
7416 {
7417 auto *c = maybe_get<SPIRConstant>(ops[3]);
7418 if (c && c->scalar() == 1)
7419 return false;
7420 auto &dest_var = set<SPIRVariable>(ops[1], *var);
7421 dest_var.basetype = ops[0];
7422 ir.meta[ops[1]] = ir.meta[ops[2]];
7423 inherit_expression_dependencies(ops[1], ops[2]);
7424 return true;
7425 }
7426
7427 return false;
7428}
7429
7430bool CompilerMSL::is_out_of_bounds_tessellation_level(uint32_t id_lhs)
7431{
7432 if (!get_entry_point().flags.get(ExecutionModeTriangles))
7433 return false;
7434
7435 // In SPIR-V, TessLevelInner always has two elements and TessLevelOuter always has
7436 // four. This is true even if we are tessellating triangles. This allows clients
7437 // to use a single tessellation control shader with multiple tessellation evaluation
7438 // shaders.
7439 // In Metal, however, only the first element of TessLevelInner and the first three
7440 // of TessLevelOuter are accessible. This stems from how in Metal, the tessellation
7441 // levels must be stored to a dedicated buffer in a particular format that depends
7442 // on the patch type. Therefore, in Triangles mode, any access to the second
7443 // inner level or the fourth outer level must be dropped.
7444 const auto *e = maybe_get<SPIRExpression>(id_lhs);
7445 if (!e || !e->access_chain)
7446 return false;
7447 BuiltIn builtin = BuiltIn(get_decoration(e->loaded_from, DecorationBuiltIn));
7448 if (builtin != BuiltInTessLevelInner && builtin != BuiltInTessLevelOuter)
7449 return false;
7450 auto *c = maybe_get<SPIRConstant>(e->implied_read_expressions[1]);
7451 if (!c)
7452 return false;
7453 return (builtin == BuiltInTessLevelInner && c->scalar() == 1) ||
7454 (builtin == BuiltInTessLevelOuter && c->scalar() == 3);
7455}
7456
7457void CompilerMSL::prepare_access_chain_for_scalar_access(std::string &expr, const SPIRType &type,
7458 spv::StorageClass storage, bool &is_packed)
7459{
7460 // If there is any risk of writes happening with the access chain in question,
7461 // and there is a risk of concurrent write access to other components,
7462 // we must cast the access chain to a plain pointer to ensure we only access the exact scalars we expect.
7463 // The MSL compiler refuses to allow component-level access for any non-packed vector types.
7464 if (!is_packed && (storage == StorageClassStorageBuffer || storage == StorageClassWorkgroup))
7465 {
7466 const char *addr_space = storage == StorageClassWorkgroup ? "threadgroup" : "device";
7467 expr = join("((", addr_space, " ", type_to_glsl(type), "*)&", enclose_expression(expr), ")");
7468
7469 // Further indexing should happen with packed rules (array index, not swizzle).
7470 is_packed = true;
7471 }
7472}
7473
7474bool CompilerMSL::access_chain_needs_stage_io_builtin_translation(uint32_t base)
7475{
7476 auto *var = maybe_get_backing_variable(base);
7477 if (!var || !is_tessellation_shader())
7478 return true;
7479
7480 // We only need to rewrite builtin access chains when accessing flattened builtins like gl_ClipDistance_N.
7481 // Avoid overriding it back to just gl_ClipDistance.
7482 // This can only happen in scenarios where we cannot flatten/unflatten access chains, so, the only case
7483 // where this triggers is evaluation shader inputs.
7484 bool redirect_builtin = get_execution_model() == ExecutionModelTessellationEvaluation ?
7485 var->storage == StorageClassOutput : false;
7486 return redirect_builtin;
7487}
7488
7489// Sets the interface member index for an access chain to a pull-model interpolant.
7490void CompilerMSL::fix_up_interpolant_access_chain(const uint32_t *ops, uint32_t length)
7491{
7492 auto *var = maybe_get_backing_variable(ops[2]);
7493 if (!var || !pull_model_inputs.count(var->self))
7494 return;
7495 // Get the base index.
7496 uint32_t interface_index;
7497 auto &var_type = get_variable_data_type(*var);
7498 auto &result_type = get<SPIRType>(ops[0]);
7499 auto *type = &var_type;
7500 if (has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex))
7501 {
7502 interface_index = get_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex);
7503 }
7504 else
7505 {
7506 // Assume an access chain into a struct variable.
7507 assert(var_type.basetype == SPIRType::Struct);
7508 auto &c = get<SPIRConstant>(ops[3 + var_type.array.size()]);
7509 interface_index =
7510 get_extended_member_decoration(var->self, c.scalar(), SPIRVCrossDecorationInterfaceMemberIndex);
7511 }
7512 // Accumulate indices. We'll have to skip over the one for the struct, if present, because we already accounted
7513 // for that getting the base index.
7514 for (uint32_t i = 3; i < length; ++i)
7515 {
7516 if (is_vector(*type) && !is_array(*type) && is_scalar(result_type))
7517 {
7518 // We don't want to combine the next index. Actually, we need to save it
7519 // so we know to apply a swizzle to the result of the interpolation.
7520 set_extended_decoration(ops[1], SPIRVCrossDecorationInterpolantComponentExpr, ops[i]);
7521 break;
7522 }
7523
7524 auto *c = maybe_get<SPIRConstant>(ops[i]);
7525 if (!c || c->specialization)
7526 SPIRV_CROSS_THROW("Trying to dynamically index into an array interface variable using pull-model "
7527 "interpolation. This is currently unsupported.");
7528
7529 if (type->parent_type)
7530 type = &get<SPIRType>(type->parent_type);
7531 else if (type->basetype == SPIRType::Struct)
7532 type = &get<SPIRType>(type->member_types[c->scalar()]);
7533
7534 if (!has_extended_decoration(ops[2], SPIRVCrossDecorationInterfaceMemberIndex) &&
7535 i - 3 == var_type.array.size())
7536 continue;
7537
7538 interface_index += c->scalar();
7539 }
7540 // Save this to the access chain itself so we can recover it later when calling an interpolation function.
7541 set_extended_decoration(ops[1], SPIRVCrossDecorationInterfaceMemberIndex, interface_index);
7542}
7543
7544// Override for MSL-specific syntax instructions
7545void CompilerMSL::emit_instruction(const Instruction &instruction)
7546{
7547#define MSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
7548#define MSL_BOP_CAST(op, type) \
7549 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
7550#define MSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
7551#define MSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
7552#define MSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
7553#define MSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
7554#define MSL_BFOP_CAST(op, type) \
7555 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
7556#define MSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
7557#define MSL_UNORD_BOP(op) emit_binary_unord_op(ops[0], ops[1], ops[2], ops[3], #op)
7558
7559 auto ops = stream(instruction);
7560 auto opcode = static_cast<Op>(instruction.op);
7561
7562 // If we need to do implicit bitcasts, make sure we do it with the correct type.
7563 uint32_t integer_width = get_integer_width_for_instruction(instruction);
7564 auto int_type = to_signed_basetype(integer_width);
7565 auto uint_type = to_unsigned_basetype(integer_width);
7566
7567 switch (opcode)
7568 {
7569 case OpLoad:
7570 {
7571 uint32_t id = ops[1];
7572 uint32_t ptr = ops[2];
7573 if (is_tessellation_shader())
7574 {
7575 if (!emit_tessellation_io_load(ops[0], id, ptr))
7576 CompilerGLSL::emit_instruction(instruction);
7577 }
7578 else
7579 {
7580 // Sample mask input for Metal is not an array
7581 if (BuiltIn(get_decoration(ptr, DecorationBuiltIn)) == BuiltInSampleMask)
7582 set_decoration(id, DecorationBuiltIn, BuiltInSampleMask);
7583 CompilerGLSL::emit_instruction(instruction);
7584 }
7585 break;
7586 }
7587
7588 // Comparisons
7589 case OpIEqual:
7590 MSL_BOP_CAST(==, int_type);
7591 break;
7592
7593 case OpLogicalEqual:
7594 case OpFOrdEqual:
7595 MSL_BOP(==);
7596 break;
7597
7598 case OpINotEqual:
7599 MSL_BOP_CAST(!=, int_type);
7600 break;
7601
7602 case OpLogicalNotEqual:
7603 case OpFOrdNotEqual:
7604 MSL_BOP(!=);
7605 break;
7606
7607 case OpUGreaterThan:
7608 MSL_BOP_CAST(>, uint_type);
7609 break;
7610
7611 case OpSGreaterThan:
7612 MSL_BOP_CAST(>, int_type);
7613 break;
7614
7615 case OpFOrdGreaterThan:
7616 MSL_BOP(>);
7617 break;
7618
7619 case OpUGreaterThanEqual:
7620 MSL_BOP_CAST(>=, uint_type);
7621 break;
7622
7623 case OpSGreaterThanEqual:
7624 MSL_BOP_CAST(>=, int_type);
7625 break;
7626
7627 case OpFOrdGreaterThanEqual:
7628 MSL_BOP(>=);
7629 break;
7630
7631 case OpULessThan:
7632 MSL_BOP_CAST(<, uint_type);
7633 break;
7634
7635 case OpSLessThan:
7636 MSL_BOP_CAST(<, int_type);
7637 break;
7638
7639 case OpFOrdLessThan:
7640 MSL_BOP(<);
7641 break;
7642
7643 case OpULessThanEqual:
7644 MSL_BOP_CAST(<=, uint_type);
7645 break;
7646
7647 case OpSLessThanEqual:
7648 MSL_BOP_CAST(<=, int_type);
7649 break;
7650
7651 case OpFOrdLessThanEqual:
7652 MSL_BOP(<=);
7653 break;
7654
7655 case OpFUnordEqual:
7656 MSL_UNORD_BOP(==);
7657 break;
7658
7659 case OpFUnordNotEqual:
7660 MSL_UNORD_BOP(!=);
7661 break;
7662
7663 case OpFUnordGreaterThan:
7664 MSL_UNORD_BOP(>);
7665 break;
7666
7667 case OpFUnordGreaterThanEqual:
7668 MSL_UNORD_BOP(>=);
7669 break;
7670
7671 case OpFUnordLessThan:
7672 MSL_UNORD_BOP(<);
7673 break;
7674
7675 case OpFUnordLessThanEqual:
7676 MSL_UNORD_BOP(<=);
7677 break;
7678
7679 // Derivatives
7680 case OpDPdx:
7681 case OpDPdxFine:
7682 case OpDPdxCoarse:
7683 MSL_UFOP(dfdx);
7684 register_control_dependent_expression(ops[1]);
7685 break;
7686
7687 case OpDPdy:
7688 case OpDPdyFine:
7689 case OpDPdyCoarse:
7690 MSL_UFOP(dfdy);
7691 register_control_dependent_expression(ops[1]);
7692 break;
7693
7694 case OpFwidth:
7695 case OpFwidthCoarse:
7696 case OpFwidthFine:
7697 MSL_UFOP(fwidth);
7698 register_control_dependent_expression(ops[1]);
7699 break;
7700
7701 // Bitfield
7702 case OpBitFieldInsert:
7703 {
7704 emit_bitfield_insert_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], "insert_bits", SPIRType::UInt);
7705 break;
7706 }
7707
7708 case OpBitFieldSExtract:
7709 {
7710 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", int_type, int_type,
7711 SPIRType::UInt, SPIRType::UInt);
7712 break;
7713 }
7714
7715 case OpBitFieldUExtract:
7716 {
7717 emit_trinary_func_op_bitextract(ops[0], ops[1], ops[2], ops[3], ops[4], "extract_bits", uint_type, uint_type,
7718 SPIRType::UInt, SPIRType::UInt);
7719 break;
7720 }
7721
7722 case OpBitReverse:
7723 // BitReverse does not have issues with sign since result type must match input type.
7724 MSL_UFOP(reverse_bits);
7725 break;
7726
7727 case OpBitCount:
7728 {
7729 auto basetype = expression_type(ops[2]).basetype;
7730 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "popcount", basetype, basetype);
7731 break;
7732 }
7733
7734 case OpFRem:
7735 MSL_BFOP(fmod);
7736 break;
7737
7738 case OpFMul:
7739 if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7740 MSL_BFOP(spvFMul);
7741 else
7742 MSL_BOP(*);
7743 break;
7744
7745 case OpFAdd:
7746 if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7747 MSL_BFOP(spvFAdd);
7748 else
7749 MSL_BOP(+);
7750 break;
7751
7752 case OpFSub:
7753 if (msl_options.invariant_float_math || has_decoration(ops[1], DecorationNoContraction))
7754 MSL_BFOP(spvFSub);
7755 else
7756 MSL_BOP(-);
7757 break;
7758
7759 // Atomics
7760 case OpAtomicExchange:
7761 {
7762 uint32_t result_type = ops[0];
7763 uint32_t id = ops[1];
7764 uint32_t ptr = ops[2];
7765 uint32_t mem_sem = ops[4];
7766 uint32_t val = ops[5];
7767 emit_atomic_func_op(result_type, id, "atomic_exchange_explicit", opcode, mem_sem, mem_sem, false, ptr, val);
7768 break;
7769 }
7770
7771 case OpAtomicCompareExchange:
7772 {
7773 uint32_t result_type = ops[0];
7774 uint32_t id = ops[1];
7775 uint32_t ptr = ops[2];
7776 uint32_t mem_sem_pass = ops[4];
7777 uint32_t mem_sem_fail = ops[5];
7778 uint32_t val = ops[6];
7779 uint32_t comp = ops[7];
7780 emit_atomic_func_op(result_type, id, "atomic_compare_exchange_weak_explicit", opcode,
7781 mem_sem_pass, mem_sem_fail, true,
7782 ptr, comp, true, false, val);
7783 break;
7784 }
7785
7786 case OpAtomicCompareExchangeWeak:
7787 SPIRV_CROSS_THROW("OpAtomicCompareExchangeWeak is only supported in kernel profile.");
7788
7789 case OpAtomicLoad:
7790 {
7791 uint32_t result_type = ops[0];
7792 uint32_t id = ops[1];
7793 uint32_t ptr = ops[2];
7794 uint32_t mem_sem = ops[4];
7795 emit_atomic_func_op(result_type, id, "atomic_load_explicit", opcode, mem_sem, mem_sem, false, ptr, 0);
7796 break;
7797 }
7798
7799 case OpAtomicStore:
7800 {
7801 uint32_t result_type = expression_type(ops[0]).self;
7802 uint32_t id = ops[0];
7803 uint32_t ptr = ops[0];
7804 uint32_t mem_sem = ops[2];
7805 uint32_t val = ops[3];
7806 emit_atomic_func_op(result_type, id, "atomic_store_explicit", opcode, mem_sem, mem_sem, false, ptr, val);
7807 break;
7808 }
7809
7810#define MSL_AFMO_IMPL(op, valsrc, valconst) \
7811 do \
7812 { \
7813 uint32_t result_type = ops[0]; \
7814 uint32_t id = ops[1]; \
7815 uint32_t ptr = ops[2]; \
7816 uint32_t mem_sem = ops[4]; \
7817 uint32_t val = valsrc; \
7818 emit_atomic_func_op(result_type, id, "atomic_fetch_" #op "_explicit", opcode, \
7819 mem_sem, mem_sem, false, ptr, val, \
7820 false, valconst); \
7821 } while (false)
7822
7823#define MSL_AFMO(op) MSL_AFMO_IMPL(op, ops[5], false)
7824#define MSL_AFMIO(op) MSL_AFMO_IMPL(op, 1, true)
7825
7826 case OpAtomicIIncrement:
7827 MSL_AFMIO(add);
7828 break;
7829
7830 case OpAtomicIDecrement:
7831 MSL_AFMIO(sub);
7832 break;
7833
7834 case OpAtomicIAdd:
7835 MSL_AFMO(add);
7836 break;
7837
7838 case OpAtomicISub:
7839 MSL_AFMO(sub);
7840 break;
7841
7842 case OpAtomicSMin:
7843 case OpAtomicUMin:
7844 MSL_AFMO(min);
7845 break;
7846
7847 case OpAtomicSMax:
7848 case OpAtomicUMax:
7849 MSL_AFMO(max);
7850 break;
7851
7852 case OpAtomicAnd:
7853 MSL_AFMO(and);
7854 break;
7855
7856 case OpAtomicOr:
7857 MSL_AFMO(or);
7858 break;
7859
7860 case OpAtomicXor:
7861 MSL_AFMO(xor);
7862 break;
7863
7864 // Images
7865
7866 // Reads == Fetches in Metal
7867 case OpImageRead:
7868 {
7869 // Mark that this shader reads from this image
7870 uint32_t img_id = ops[2];
7871 auto &type = expression_type(img_id);
7872 if (type.image.dim != DimSubpassData)
7873 {
7874 auto *p_var = maybe_get_backing_variable(img_id);
7875 if (p_var && has_decoration(p_var->self, DecorationNonReadable))
7876 {
7877 unset_decoration(p_var->self, DecorationNonReadable);
7878 force_recompile();
7879 }
7880 }
7881
7882 emit_texture_op(instruction, false);
7883 break;
7884 }
7885
7886 // Emulate texture2D atomic operations
7887 case OpImageTexelPointer:
7888 {
7889 // When using the pointer, we need to know which variable it is actually loaded from.
7890 auto *var = maybe_get_backing_variable(ops[2]);
7891 if (var && atomic_image_vars.count(var->self))
7892 {
7893 uint32_t result_type = ops[0];
7894 uint32_t id = ops[1];
7895
7896 std::string coord = to_expression(ops[3]);
7897 auto &type = expression_type(ops[2]);
7898 if (type.image.dim == Dim2D)
7899 {
7900 coord = join("spvImage2DAtomicCoord(", coord, ", ", to_expression(ops[2]), ")");
7901 }
7902
7903 auto &e = set<SPIRExpression>(id, join(to_expression(ops[2]), "_atomic[", coord, "]"), result_type, true);
7904 e.loaded_from = var ? var->self : ID(0);
7905 inherit_expression_dependencies(id, ops[3]);
7906 }
7907 else
7908 {
7909 uint32_t result_type = ops[0];
7910 uint32_t id = ops[1];
7911 auto &e =
7912 set<SPIRExpression>(id, join(to_expression(ops[2]), ", ", to_expression(ops[3])), result_type, true);
7913
7914 // When using the pointer, we need to know which variable it is actually loaded from.
7915 e.loaded_from = var ? var->self : ID(0);
7916 inherit_expression_dependencies(id, ops[3]);
7917 }
7918 break;
7919 }
7920
7921 case OpImageWrite:
7922 {
7923 uint32_t img_id = ops[0];
7924 uint32_t coord_id = ops[1];
7925 uint32_t texel_id = ops[2];
7926 const uint32_t *opt = &ops[3];
7927 uint32_t length = instruction.length - 3;
7928
7929 // Bypass pointers because we need the real image struct
7930 auto &type = expression_type(img_id);
7931 auto &img_type = get<SPIRType>(type.self);
7932
7933 // Ensure this image has been marked as being written to and force a
7934 // recommpile so that the image type output will include write access
7935 auto *p_var = maybe_get_backing_variable(img_id);
7936 if (p_var && has_decoration(p_var->self, DecorationNonWritable))
7937 {
7938 unset_decoration(p_var->self, DecorationNonWritable);
7939 force_recompile();
7940 }
7941
7942 bool forward = false;
7943 uint32_t bias = 0;
7944 uint32_t lod = 0;
7945 uint32_t flags = 0;
7946
7947 if (length)
7948 {
7949 flags = *opt++;
7950 length--;
7951 }
7952
7953 auto test = [&](uint32_t &v, uint32_t flag) {
7954 if (length && (flags & flag))
7955 {
7956 v = *opt++;
7957 length--;
7958 }
7959 };
7960
7961 test(bias, ImageOperandsBiasMask);
7962 test(lod, ImageOperandsLodMask);
7963
7964 auto &texel_type = expression_type(texel_id);
7965 auto store_type = texel_type;
7966 store_type.vecsize = 4;
7967
7968 TextureFunctionArguments args = {};
7969 args.base.img = img_id;
7970 args.base.imgtype = &img_type;
7971 args.base.is_fetch = true;
7972 args.coord = coord_id;
7973 args.lod = lod;
7974 statement(join(to_expression(img_id), ".write(",
7975 remap_swizzle(store_type, texel_type.vecsize, to_expression(texel_id)), ", ",
7976 CompilerMSL::to_function_args(args, &forward), ");"));
7977
7978 if (p_var && variable_storage_is_aliased(*p_var))
7979 flush_all_aliased_variables();
7980
7981 break;
7982 }
7983
7984 case OpImageQuerySize:
7985 case OpImageQuerySizeLod:
7986 {
7987 uint32_t rslt_type_id = ops[0];
7988 auto &rslt_type = get<SPIRType>(rslt_type_id);
7989
7990 uint32_t id = ops[1];
7991
7992 uint32_t img_id = ops[2];
7993 string img_exp = to_expression(img_id);
7994 auto &img_type = expression_type(img_id);
7995 Dim img_dim = img_type.image.dim;
7996 bool img_is_array = img_type.image.arrayed;
7997
7998 if (img_type.basetype != SPIRType::Image)
7999 SPIRV_CROSS_THROW("Invalid type for OpImageQuerySize.");
8000
8001 string lod;
8002 if (opcode == OpImageQuerySizeLod)
8003 {
8004 // LOD index defaults to zero, so don't bother outputing level zero index
8005 string decl_lod = to_expression(ops[3]);
8006 if (decl_lod != "0")
8007 lod = decl_lod;
8008 }
8009
8010 string expr = type_to_glsl(rslt_type) + "(";
8011 expr += img_exp + ".get_width(" + lod + ")";
8012
8013 if (img_dim == Dim2D || img_dim == DimCube || img_dim == Dim3D)
8014 expr += ", " + img_exp + ".get_height(" + lod + ")";
8015
8016 if (img_dim == Dim3D)
8017 expr += ", " + img_exp + ".get_depth(" + lod + ")";
8018
8019 if (img_is_array)
8020 {
8021 expr += ", " + img_exp + ".get_array_size()";
8022 if (img_dim == DimCube && msl_options.emulate_cube_array)
8023 expr += " / 6";
8024 }
8025
8026 expr += ")";
8027
8028 emit_op(rslt_type_id, id, expr, should_forward(img_id));
8029
8030 break;
8031 }
8032
8033 case OpImageQueryLod:
8034 {
8035 if (!msl_options.supports_msl_version(2, 2))
8036 SPIRV_CROSS_THROW("ImageQueryLod is only supported on MSL 2.2 and up.");
8037 uint32_t result_type = ops[0];
8038 uint32_t id = ops[1];
8039 uint32_t image_id = ops[2];
8040 uint32_t coord_id = ops[3];
8041 emit_uninitialized_temporary_expression(result_type, id);
8042
8043 auto sampler_expr = to_sampler_expression(image_id);
8044 auto *combined = maybe_get<SPIRCombinedImageSampler>(image_id);
8045 auto image_expr = combined ? to_expression(combined->image) : to_expression(image_id);
8046
8047 // TODO: It is unclear if calculcate_clamped_lod also conditionally rounds
8048 // the reported LOD based on the sampler. NEAREST miplevel should
8049 // round the LOD, but LINEAR miplevel should not round.
8050 // Let's hope this does not become an issue ...
8051 statement(to_expression(id), ".x = ", image_expr, ".calculate_clamped_lod(", sampler_expr, ", ",
8052 to_expression(coord_id), ");");
8053 statement(to_expression(id), ".y = ", image_expr, ".calculate_unclamped_lod(", sampler_expr, ", ",
8054 to_expression(coord_id), ");");
8055 register_control_dependent_expression(id);
8056 break;
8057 }
8058
8059#define MSL_ImgQry(qrytype) \
8060 do \
8061 { \
8062 uint32_t rslt_type_id = ops[0]; \
8063 auto &rslt_type = get<SPIRType>(rslt_type_id); \
8064 uint32_t id = ops[1]; \
8065 uint32_t img_id = ops[2]; \
8066 string img_exp = to_expression(img_id); \
8067 string expr = type_to_glsl(rslt_type) + "(" + img_exp + ".get_num_" #qrytype "())"; \
8068 emit_op(rslt_type_id, id, expr, should_forward(img_id)); \
8069 } while (false)
8070
8071 case OpImageQueryLevels:
8072 MSL_ImgQry(mip_levels);
8073 break;
8074
8075 case OpImageQuerySamples:
8076 MSL_ImgQry(samples);
8077 break;
8078
8079 case OpImage:
8080 {
8081 uint32_t result_type = ops[0];
8082 uint32_t id = ops[1];
8083 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
8084
8085 if (combined)
8086 {
8087 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
8088 auto *var = maybe_get_backing_variable(combined->image);
8089 if (var)
8090 e.loaded_from = var->self;
8091 }
8092 else
8093 {
8094 auto *var = maybe_get_backing_variable(ops[2]);
8095 SPIRExpression *e;
8096 if (var && has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler))
8097 e = &emit_op(result_type, id, join(to_expression(ops[2]), ".plane0"), true, true);
8098 else
8099 e = &emit_op(result_type, id, to_expression(ops[2]), true, true);
8100 if (var)
8101 e->loaded_from = var->self;
8102 }
8103 break;
8104 }
8105
8106 // Casting
8107 case OpQuantizeToF16:
8108 {
8109 uint32_t result_type = ops[0];
8110 uint32_t id = ops[1];
8111 uint32_t arg = ops[2];
8112 string exp = join("spvQuantizeToF16(", to_expression(arg), ")");
8113 emit_op(result_type, id, exp, should_forward(arg));
8114 break;
8115 }
8116
8117 case OpInBoundsAccessChain:
8118 case OpAccessChain:
8119 case OpPtrAccessChain:
8120 if (is_tessellation_shader())
8121 {
8122 if (!emit_tessellation_access_chain(ops, instruction.length))
8123 CompilerGLSL::emit_instruction(instruction);
8124 }
8125 else
8126 CompilerGLSL::emit_instruction(instruction);
8127 fix_up_interpolant_access_chain(ops, instruction.length);
8128 break;
8129
8130 case OpStore:
8131 if (is_out_of_bounds_tessellation_level(ops[0]))
8132 break;
8133
8134 if (maybe_emit_array_assignment(ops[0], ops[1]))
8135 break;
8136
8137 CompilerGLSL::emit_instruction(instruction);
8138 break;
8139
8140 // Compute barriers
8141 case OpMemoryBarrier:
8142 emit_barrier(0, ops[0], ops[1]);
8143 break;
8144
8145 case OpControlBarrier:
8146 // In GLSL a memory barrier is often followed by a control barrier.
8147 // But in MSL, memory barriers are also control barriers, so don't
8148 // emit a simple control barrier if a memory barrier has just been emitted.
8149 if (previous_instruction_opcode != OpMemoryBarrier)
8150 emit_barrier(ops[0], ops[1], ops[2]);
8151 break;
8152
8153 case OpOuterProduct:
8154 {
8155 uint32_t result_type = ops[0];
8156 uint32_t id = ops[1];
8157 uint32_t a = ops[2];
8158 uint32_t b = ops[3];
8159
8160 auto &type = get<SPIRType>(result_type);
8161 string expr = type_to_glsl_constructor(type);
8162 expr += "(";
8163 for (uint32_t col = 0; col < type.columns; col++)
8164 {
8165 expr += to_enclosed_unpacked_expression(a);
8166 expr += " * ";
8167 expr += to_extract_component_expression(b, col);
8168 if (col + 1 < type.columns)
8169 expr += ", ";
8170 }
8171 expr += ")";
8172 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
8173 inherit_expression_dependencies(id, a);
8174 inherit_expression_dependencies(id, b);
8175 break;
8176 }
8177
8178 case OpVectorTimesMatrix:
8179 case OpMatrixTimesVector:
8180 {
8181 if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction))
8182 {
8183 CompilerGLSL::emit_instruction(instruction);
8184 break;
8185 }
8186
8187 // If the matrix needs transpose, just flip the multiply order.
8188 auto *e = maybe_get<SPIRExpression>(ops[opcode == OpMatrixTimesVector ? 2 : 3]);
8189 if (e && e->need_transpose)
8190 {
8191 e->need_transpose = false;
8192 string expr;
8193
8194 if (opcode == OpMatrixTimesVector)
8195 {
8196 expr = join("spvFMulVectorMatrix(", to_enclosed_unpacked_expression(ops[3]), ", ",
8197 to_unpacked_row_major_matrix_expression(ops[2]), ")");
8198 }
8199 else
8200 {
8201 expr = join("spvFMulMatrixVector(", to_unpacked_row_major_matrix_expression(ops[3]), ", ",
8202 to_enclosed_unpacked_expression(ops[2]), ")");
8203 }
8204
8205 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
8206 emit_op(ops[0], ops[1], expr, forward);
8207 e->need_transpose = true;
8208 inherit_expression_dependencies(ops[1], ops[2]);
8209 inherit_expression_dependencies(ops[1], ops[3]);
8210 }
8211 else
8212 {
8213 if (opcode == OpMatrixTimesVector)
8214 MSL_BFOP(spvFMulMatrixVector);
8215 else
8216 MSL_BFOP(spvFMulVectorMatrix);
8217 }
8218 break;
8219 }
8220
8221 case OpMatrixTimesMatrix:
8222 {
8223 if (!msl_options.invariant_float_math && !has_decoration(ops[1], DecorationNoContraction))
8224 {
8225 CompilerGLSL::emit_instruction(instruction);
8226 break;
8227 }
8228
8229 auto *a = maybe_get<SPIRExpression>(ops[2]);
8230 auto *b = maybe_get<SPIRExpression>(ops[3]);
8231
8232 // If both matrices need transpose, we can multiply in flipped order and tag the expression as transposed.
8233 // a^T * b^T = (b * a)^T.
8234 if (a && b && a->need_transpose && b->need_transpose)
8235 {
8236 a->need_transpose = false;
8237 b->need_transpose = false;
8238
8239 auto expr =
8240 join("spvFMulMatrixMatrix(", enclose_expression(to_unpacked_row_major_matrix_expression(ops[3])), ", ",
8241 enclose_expression(to_unpacked_row_major_matrix_expression(ops[2])), ")");
8242
8243 bool forward = should_forward(ops[2]) && should_forward(ops[3]);
8244 auto &e = emit_op(ops[0], ops[1], expr, forward);
8245 e.need_transpose = true;
8246 a->need_transpose = true;
8247 b->need_transpose = true;
8248 inherit_expression_dependencies(ops[1], ops[2]);
8249 inherit_expression_dependencies(ops[1], ops[3]);
8250 }
8251 else
8252 MSL_BFOP(spvFMulMatrixMatrix);
8253
8254 break;
8255 }
8256
8257 case OpIAddCarry:
8258 case OpISubBorrow:
8259 {
8260 uint32_t result_type = ops[0];
8261 uint32_t result_id = ops[1];
8262 uint32_t op0 = ops[2];
8263 uint32_t op1 = ops[3];
8264 auto &type = get<SPIRType>(result_type);
8265 emit_uninitialized_temporary_expression(result_type, result_id);
8266
8267 auto &res_type = get<SPIRType>(type.member_types[1]);
8268 if (opcode == OpIAddCarry)
8269 {
8270 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
8271 to_enclosed_unpacked_expression(op0), " + ", to_enclosed_unpacked_expression(op1), ";");
8272 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
8273 "(1), ", type_to_glsl(res_type), "(0), ", to_unpacked_expression(result_id), ".", to_member_name(type, 0),
8274 " >= max(", to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), "));");
8275 }
8276 else
8277 {
8278 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ", to_enclosed_unpacked_expression(op0), " - ",
8279 to_enclosed_unpacked_expression(op1), ";");
8280 statement(to_expression(result_id), ".", to_member_name(type, 1), " = select(", type_to_glsl(res_type),
8281 "(1), ", type_to_glsl(res_type), "(0), ", to_enclosed_unpacked_expression(op0),
8282 " >= ", to_enclosed_unpacked_expression(op1), ");");
8283 }
8284 break;
8285 }
8286
8287 case OpUMulExtended:
8288 case OpSMulExtended:
8289 {
8290 uint32_t result_type = ops[0];
8291 uint32_t result_id = ops[1];
8292 uint32_t op0 = ops[2];
8293 uint32_t op1 = ops[3];
8294 auto &type = get<SPIRType>(result_type);
8295 emit_uninitialized_temporary_expression(result_type, result_id);
8296
8297 statement(to_expression(result_id), ".", to_member_name(type, 0), " = ",
8298 to_enclosed_unpacked_expression(op0), " * ", to_enclosed_unpacked_expression(op1), ";");
8299 statement(to_expression(result_id), ".", to_member_name(type, 1), " = mulhi(",
8300 to_unpacked_expression(op0), ", ", to_unpacked_expression(op1), ");");
8301 break;
8302 }
8303
8304 case OpArrayLength:
8305 {
8306 auto &type = expression_type(ops[2]);
8307 uint32_t offset = type_struct_member_offset(type, ops[3]);
8308 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
8309
8310 auto expr = join("(", to_buffer_size_expression(ops[2]), " - ", offset, ") / ", stride);
8311 emit_op(ops[0], ops[1], expr, true);
8312 break;
8313 }
8314
8315 // SPV_INTEL_shader_integer_functions2
8316 case OpUCountLeadingZerosINTEL:
8317 MSL_UFOP(clz);
8318 break;
8319
8320 case OpUCountTrailingZerosINTEL:
8321 MSL_UFOP(ctz);
8322 break;
8323
8324 case OpAbsISubINTEL:
8325 case OpAbsUSubINTEL:
8326 MSL_BFOP(absdiff);
8327 break;
8328
8329 case OpIAddSatINTEL:
8330 case OpUAddSatINTEL:
8331 MSL_BFOP(addsat);
8332 break;
8333
8334 case OpIAverageINTEL:
8335 case OpUAverageINTEL:
8336 MSL_BFOP(hadd);
8337 break;
8338
8339 case OpIAverageRoundedINTEL:
8340 case OpUAverageRoundedINTEL:
8341 MSL_BFOP(rhadd);
8342 break;
8343
8344 case OpISubSatINTEL:
8345 case OpUSubSatINTEL:
8346 MSL_BFOP(subsat);
8347 break;
8348
8349 case OpIMul32x16INTEL:
8350 {
8351 uint32_t result_type = ops[0];
8352 uint32_t id = ops[1];
8353 uint32_t a = ops[2], b = ops[3];
8354 bool forward = should_forward(a) && should_forward(b);
8355 emit_op(result_type, id, join("int(short(", to_unpacked_expression(a), ")) * int(short(", to_unpacked_expression(b), "))"), forward);
8356 inherit_expression_dependencies(id, a);
8357 inherit_expression_dependencies(id, b);
8358 break;
8359 }
8360
8361 case OpUMul32x16INTEL:
8362 {
8363 uint32_t result_type = ops[0];
8364 uint32_t id = ops[1];
8365 uint32_t a = ops[2], b = ops[3];
8366 bool forward = should_forward(a) && should_forward(b);
8367 emit_op(result_type, id, join("uint(ushort(", to_unpacked_expression(a), ")) * uint(ushort(", to_unpacked_expression(b), "))"), forward);
8368 inherit_expression_dependencies(id, a);
8369 inherit_expression_dependencies(id, b);
8370 break;
8371 }
8372
8373 // SPV_EXT_demote_to_helper_invocation
8374 case OpDemoteToHelperInvocationEXT:
8375 if (!msl_options.supports_msl_version(2, 3))
8376 SPIRV_CROSS_THROW("discard_fragment() does not formally have demote semantics until MSL 2.3.");
8377 CompilerGLSL::emit_instruction(instruction);
8378 break;
8379
8380 case OpIsHelperInvocationEXT:
8381 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
8382 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.3 on iOS.");
8383 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
8384 SPIRV_CROSS_THROW("simd_is_helper_thread() requires MSL 2.1 on macOS.");
8385 emit_op(ops[0], ops[1], "simd_is_helper_thread()", false);
8386 break;
8387
8388 case OpBeginInvocationInterlockEXT:
8389 case OpEndInvocationInterlockEXT:
8390 if (!msl_options.supports_msl_version(2, 0))
8391 SPIRV_CROSS_THROW("Raster order groups require MSL 2.0.");
8392 break; // Nothing to do in the body
8393
8394 case OpConvertUToAccelerationStructureKHR:
8395 SPIRV_CROSS_THROW("ConvertUToAccelerationStructure is not supported in MSL.");
8396 case OpRayQueryGetIntersectionInstanceShaderBindingTableRecordOffsetKHR:
8397 SPIRV_CROSS_THROW("BindingTableRecordOffset is not supported in MSL.");
8398
8399 case OpRayQueryInitializeKHR:
8400 {
8401 flush_variable_declaration(ops[0]);
8402
8403 statement(to_expression(ops[0]), ".reset(", "ray(", to_expression(ops[4]), ", ", to_expression(ops[6]), ", ",
8404 to_expression(ops[5]), ", ", to_expression(ops[7]), "), ", to_expression(ops[1]),
8405 ", intersection_params());");
8406 break;
8407 }
8408 case OpRayQueryProceedKHR:
8409 {
8410 flush_variable_declaration(ops[0]);
8411 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".next()"), false);
8412 break;
8413 }
8414#define MSL_RAY_QUERY_IS_CANDIDATE get<SPIRConstant>(ops[3]).scalar_i32() == 0
8415
8416#define MSL_RAY_QUERY_GET_OP(op, msl_op) \
8417 case OpRayQueryGet##op##KHR: \
8418 flush_variable_declaration(ops[2]); \
8419 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_" #msl_op "()"), false); \
8420 break
8421
8422#define MSL_RAY_QUERY_OP_INNER2(op, msl_prefix, msl_op) \
8423 case OpRayQueryGet##op##KHR: \
8424 flush_variable_declaration(ops[2]); \
8425 if (MSL_RAY_QUERY_IS_CANDIDATE) \
8426 emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_candidate_" #msl_op "()"), false); \
8427 else \
8428 emit_op(ops[0], ops[1], join(to_expression(ops[2]), #msl_prefix "_committed_" #msl_op "()"), false); \
8429 break
8430
8431#define MSL_RAY_QUERY_GET_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .get, msl_op)
8432#define MSL_RAY_QUERY_IS_OP2(op, msl_op) MSL_RAY_QUERY_OP_INNER2(op, .is, msl_op)
8433
8434 MSL_RAY_QUERY_GET_OP(RayTMin, ray_min_distance);
8435 MSL_RAY_QUERY_GET_OP(WorldRayOrigin, world_space_ray_direction);
8436 MSL_RAY_QUERY_GET_OP(WorldRayDirection, world_space_ray_origin);
8437 MSL_RAY_QUERY_GET_OP2(IntersectionInstanceId, instance_id);
8438 MSL_RAY_QUERY_GET_OP2(IntersectionInstanceCustomIndex, user_instance_id);
8439 MSL_RAY_QUERY_GET_OP2(IntersectionBarycentrics, triangle_barycentric_coord);
8440 MSL_RAY_QUERY_GET_OP2(IntersectionPrimitiveIndex, primitive_id);
8441 MSL_RAY_QUERY_GET_OP2(IntersectionGeometryIndex, geometry_id);
8442 MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayOrigin, ray_origin);
8443 MSL_RAY_QUERY_GET_OP2(IntersectionObjectRayDirection, ray_direction);
8444 MSL_RAY_QUERY_GET_OP2(IntersectionObjectToWorld, object_to_world_transform);
8445 MSL_RAY_QUERY_GET_OP2(IntersectionWorldToObject, world_to_object_transform);
8446 MSL_RAY_QUERY_IS_OP2(IntersectionFrontFace, triangle_front_facing);
8447
8448 case OpRayQueryGetIntersectionTypeKHR:
8449 flush_variable_declaration(ops[2]);
8450 if (MSL_RAY_QUERY_IS_CANDIDATE)
8451 emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_candidate_intersection_type()) - 1"),
8452 false);
8453 else
8454 emit_op(ops[0], ops[1], join("uint(", to_expression(ops[2]), ".get_committed_intersection_type())"), false);
8455 break;
8456 case OpRayQueryGetIntersectionTKHR:
8457 flush_variable_declaration(ops[2]);
8458 if (MSL_RAY_QUERY_IS_CANDIDATE)
8459 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_candidate_triangle_distance()"), false);
8460 else
8461 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".get_committed_distance()"), false);
8462 break;
8463 case OpRayQueryGetIntersectionCandidateAABBOpaqueKHR:
8464 {
8465 flush_variable_declaration(ops[0]);
8466 emit_op(ops[0], ops[1], join(to_expression(ops[2]), ".is_candidate_non_opaque_bounding_box()"), false);
8467 break;
8468 }
8469 case OpRayQueryConfirmIntersectionKHR:
8470 flush_variable_declaration(ops[0]);
8471 statement(to_expression(ops[0]), ".commit_triangle_intersection();");
8472 break;
8473 case OpRayQueryGenerateIntersectionKHR:
8474 flush_variable_declaration(ops[0]);
8475 statement(to_expression(ops[0]), ".commit_bounding_box_intersection(", to_expression(ops[1]), ");");
8476 break;
8477 case OpRayQueryTerminateKHR:
8478 flush_variable_declaration(ops[0]);
8479 statement(to_expression(ops[0]), ".abort();");
8480 break;
8481#undef MSL_RAY_QUERY_GET_OP
8482#undef MSL_RAY_QUERY_IS_CANDIDATE
8483#undef MSL_RAY_QUERY_IS_OP2
8484#undef MSL_RAY_QUERY_GET_OP2
8485#undef MSL_RAY_QUERY_OP_INNER2
8486 default:
8487 CompilerGLSL::emit_instruction(instruction);
8488 break;
8489 }
8490
8491 previous_instruction_opcode = opcode;
8492}
8493
8494void CompilerMSL::emit_texture_op(const Instruction &i, bool sparse)
8495{
8496 if (sparse)
8497 SPIRV_CROSS_THROW("Sparse feedback not yet supported in MSL.");
8498
8499 if (msl_options.use_framebuffer_fetch_subpasses)
8500 {
8501 auto *ops = stream(i);
8502
8503 uint32_t result_type_id = ops[0];
8504 uint32_t id = ops[1];
8505 uint32_t img = ops[2];
8506
8507 auto &type = expression_type(img);
8508 auto &imgtype = get<SPIRType>(type.self);
8509
8510 // Use Metal's native frame-buffer fetch API for subpass inputs.
8511 if (imgtype.image.dim == DimSubpassData)
8512 {
8513 // Subpass inputs cannot be invalidated,
8514 // so just forward the expression directly.
8515 string expr = to_expression(img);
8516 emit_op(result_type_id, id, expr, true);
8517 return;
8518 }
8519 }
8520
8521 // Fallback to default implementation
8522 CompilerGLSL::emit_texture_op(i, sparse);
8523}
8524
8525void CompilerMSL::emit_barrier(uint32_t id_exe_scope, uint32_t id_mem_scope, uint32_t id_mem_sem)
8526{
8527 if (get_execution_model() != ExecutionModelGLCompute && get_execution_model() != ExecutionModelTessellationControl)
8528 return;
8529
8530 uint32_t exe_scope = id_exe_scope ? evaluate_constant_u32(id_exe_scope) : uint32_t(ScopeInvocation);
8531 uint32_t mem_scope = id_mem_scope ? evaluate_constant_u32(id_mem_scope) : uint32_t(ScopeInvocation);
8532 // Use the wider of the two scopes (smaller value)
8533 exe_scope = min(exe_scope, mem_scope);
8534
8535 if (msl_options.emulate_subgroups && exe_scope >= ScopeSubgroup && !id_mem_sem)
8536 // In this case, we assume a "subgroup" size of 1. The barrier, then, is a noop.
8537 return;
8538
8539 string bar_stmt;
8540 if ((msl_options.is_ios() && msl_options.supports_msl_version(1, 2)) || msl_options.supports_msl_version(2))
8541 bar_stmt = exe_scope < ScopeSubgroup ? "threadgroup_barrier" : "simdgroup_barrier";
8542 else
8543 bar_stmt = "threadgroup_barrier";
8544 bar_stmt += "(";
8545
8546 uint32_t mem_sem = id_mem_sem ? evaluate_constant_u32(id_mem_sem) : uint32_t(MemorySemanticsMaskNone);
8547
8548 // Use the | operator to combine flags if we can.
8549 if (msl_options.supports_msl_version(1, 2))
8550 {
8551 string mem_flags = "";
8552 // For tesc shaders, this also affects objects in the Output storage class.
8553 // Since in Metal, these are placed in a device buffer, we have to sync device memory here.
8554 if (get_execution_model() == ExecutionModelTessellationControl ||
8555 (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)))
8556 mem_flags += "mem_flags::mem_device";
8557
8558 // Fix tessellation patch function processing
8559 if (get_execution_model() == ExecutionModelTessellationControl ||
8560 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
8561 {
8562 if (!mem_flags.empty())
8563 mem_flags += " | ";
8564 mem_flags += "mem_flags::mem_threadgroup";
8565 }
8566 if (mem_sem & MemorySemanticsImageMemoryMask)
8567 {
8568 if (!mem_flags.empty())
8569 mem_flags += " | ";
8570 mem_flags += "mem_flags::mem_texture";
8571 }
8572
8573 if (mem_flags.empty())
8574 mem_flags = "mem_flags::mem_none";
8575
8576 bar_stmt += mem_flags;
8577 }
8578 else
8579 {
8580 if ((mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask)) &&
8581 (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask)))
8582 bar_stmt += "mem_flags::mem_device_and_threadgroup";
8583 else if (mem_sem & (MemorySemanticsUniformMemoryMask | MemorySemanticsCrossWorkgroupMemoryMask))
8584 bar_stmt += "mem_flags::mem_device";
8585 else if (mem_sem & (MemorySemanticsSubgroupMemoryMask | MemorySemanticsWorkgroupMemoryMask))
8586 bar_stmt += "mem_flags::mem_threadgroup";
8587 else if (mem_sem & MemorySemanticsImageMemoryMask)
8588 bar_stmt += "mem_flags::mem_texture";
8589 else
8590 bar_stmt += "mem_flags::mem_none";
8591 }
8592
8593 bar_stmt += ");";
8594
8595 statement(bar_stmt);
8596
8597 assert(current_emitting_block);
8598 flush_control_dependent_expressions(current_emitting_block->self);
8599 flush_all_active_variables();
8600}
8601
8602static bool storage_class_array_is_thread(StorageClass storage)
8603{
8604 switch (storage)
8605 {
8606 case StorageClassInput:
8607 case StorageClassOutput:
8608 case StorageClassGeneric:
8609 case StorageClassFunction:
8610 case StorageClassPrivate:
8611 return true;
8612
8613 default:
8614 return false;
8615 }
8616}
8617
8618void CompilerMSL::emit_array_copy(const string &lhs, uint32_t lhs_id, uint32_t rhs_id,
8619 StorageClass lhs_storage, StorageClass rhs_storage)
8620{
8621 // Allow Metal to use the array<T> template to make arrays a value type.
8622 // This, however, cannot be used for threadgroup address specifiers, so consider the custom array copy as fallback.
8623 bool lhs_is_thread_storage = storage_class_array_is_thread(lhs_storage);
8624 bool rhs_is_thread_storage = storage_class_array_is_thread(rhs_storage);
8625
8626 bool lhs_is_array_template = lhs_is_thread_storage;
8627 bool rhs_is_array_template = rhs_is_thread_storage;
8628
8629 // Special considerations for stage IO variables.
8630 // If the variable is actually backed by non-user visible device storage, we use array templates for those.
8631 //
8632 // Another special consideration is given to thread local variables which happen to have Offset decorations
8633 // applied to them. Block-like types do not use array templates, so we need to force POD path if we detect
8634 // these scenarios. This check isn't perfect since it would be technically possible to mix and match these things,
8635 // and for a fully correct solution we might have to track array template state through access chains as well,
8636 // but for all reasonable use cases, this should suffice.
8637 // This special case should also only apply to Function/Private storage classes.
8638 // We should not check backing variable for temporaries.
8639 auto *lhs_var = maybe_get_backing_variable(lhs_id);
8640 if (lhs_var && lhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(lhs_var->storage))
8641 lhs_is_array_template = true;
8642 else if (lhs_var && (lhs_storage == StorageClassFunction || lhs_storage == StorageClassPrivate) &&
8643 type_is_block_like(get<SPIRType>(lhs_var->basetype)))
8644 lhs_is_array_template = false;
8645
8646 auto *rhs_var = maybe_get_backing_variable(rhs_id);
8647 if (rhs_var && rhs_storage == StorageClassStorageBuffer && storage_class_array_is_thread(rhs_var->storage))
8648 rhs_is_array_template = true;
8649 else if (rhs_var && (rhs_storage == StorageClassFunction || rhs_storage == StorageClassPrivate) &&
8650 type_is_block_like(get<SPIRType>(rhs_var->basetype)))
8651 rhs_is_array_template = false;
8652
8653 // If threadgroup storage qualifiers are *not* used:
8654 // Avoid spvCopy* wrapper functions; Otherwise, spvUnsafeArray<> template cannot be used with that storage qualifier.
8655 if (lhs_is_array_template && rhs_is_array_template && !using_builtin_array())
8656 {
8657 statement(lhs, " = ", to_expression(rhs_id), ";");
8658 }
8659 else
8660 {
8661 // Assignment from an array initializer is fine.
8662 auto &type = expression_type(rhs_id);
8663 auto *var = maybe_get_backing_variable(rhs_id);
8664
8665 // Unfortunately, we cannot template on address space in MSL,
8666 // so explicit address space redirection it is ...
8667 bool is_constant = false;
8668 if (ir.ids[rhs_id].get_type() == TypeConstant)
8669 {
8670 is_constant = true;
8671 }
8672 else if (var && var->remapped_variable && var->statically_assigned &&
8673 ir.ids[var->static_expression].get_type() == TypeConstant)
8674 {
8675 is_constant = true;
8676 }
8677 else if (rhs_storage == StorageClassUniform || rhs_storage == StorageClassUniformConstant)
8678 {
8679 is_constant = true;
8680 }
8681
8682 // For the case where we have OpLoad triggering an array copy,
8683 // we cannot easily detect this case ahead of time since it's
8684 // context dependent. We might have to force a recompile here
8685 // if this is the only use of array copies in our shader.
8686 if (type.array.size() > 1)
8687 {
8688 if (type.array.size() > kArrayCopyMultidimMax)
8689 SPIRV_CROSS_THROW("Cannot support this many dimensions for arrays of arrays.");
8690 auto func = static_cast<SPVFuncImpl>(SPVFuncImplArrayCopyMultidimBase + type.array.size());
8691 add_spv_func_and_recompile(func);
8692 }
8693 else
8694 add_spv_func_and_recompile(SPVFuncImplArrayCopy);
8695
8696 const char *tag = nullptr;
8697 if (lhs_is_thread_storage && is_constant)
8698 tag = "FromConstantToStack";
8699 else if (lhs_storage == StorageClassWorkgroup && is_constant)
8700 tag = "FromConstantToThreadGroup";
8701 else if (lhs_is_thread_storage && rhs_is_thread_storage)
8702 tag = "FromStackToStack";
8703 else if (lhs_storage == StorageClassWorkgroup && rhs_is_thread_storage)
8704 tag = "FromStackToThreadGroup";
8705 else if (lhs_is_thread_storage && rhs_storage == StorageClassWorkgroup)
8706 tag = "FromThreadGroupToStack";
8707 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassWorkgroup)
8708 tag = "FromThreadGroupToThreadGroup";
8709 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassStorageBuffer)
8710 tag = "FromDeviceToDevice";
8711 else if (lhs_storage == StorageClassStorageBuffer && is_constant)
8712 tag = "FromConstantToDevice";
8713 else if (lhs_storage == StorageClassStorageBuffer && rhs_storage == StorageClassWorkgroup)
8714 tag = "FromThreadGroupToDevice";
8715 else if (lhs_storage == StorageClassStorageBuffer && rhs_is_thread_storage)
8716 tag = "FromStackToDevice";
8717 else if (lhs_storage == StorageClassWorkgroup && rhs_storage == StorageClassStorageBuffer)
8718 tag = "FromDeviceToThreadGroup";
8719 else if (lhs_is_thread_storage && rhs_storage == StorageClassStorageBuffer)
8720 tag = "FromDeviceToStack";
8721 else
8722 SPIRV_CROSS_THROW("Unknown storage class used for copying arrays.");
8723
8724 // Pass internal array of spvUnsafeArray<> into wrapper functions
8725 if (lhs_is_array_template && rhs_is_array_template && !msl_options.force_native_arrays)
8726 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ".elements);");
8727 if (lhs_is_array_template && !msl_options.force_native_arrays)
8728 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ".elements, ", to_expression(rhs_id), ");");
8729 else if (rhs_is_array_template && !msl_options.force_native_arrays)
8730 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ".elements);");
8731 else
8732 statement("spvArrayCopy", tag, type.array.size(), "(", lhs, ", ", to_expression(rhs_id), ");");
8733 }
8734}
8735
8736uint32_t CompilerMSL::get_physical_tess_level_array_size(spv::BuiltIn builtin) const
8737{
8738 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
8739 return builtin == BuiltInTessLevelInner ? 1 : 3;
8740 else
8741 return builtin == BuiltInTessLevelInner ? 2 : 4;
8742}
8743
8744// Since MSL does not allow arrays to be copied via simple variable assignment,
8745// if the LHS and RHS represent an assignment of an entire array, it must be
8746// implemented by calling an array copy function.
8747// Returns whether the struct assignment was emitted.
8748bool CompilerMSL::maybe_emit_array_assignment(uint32_t id_lhs, uint32_t id_rhs)
8749{
8750 // We only care about assignments of an entire array
8751 auto &type = expression_type(id_rhs);
8752 if (type.array.size() == 0)
8753 return false;
8754
8755 auto *var = maybe_get<SPIRVariable>(id_lhs);
8756
8757 // Is this a remapped, static constant? Don't do anything.
8758 if (var && var->remapped_variable && var->statically_assigned)
8759 return true;
8760
8761 if (ir.ids[id_rhs].get_type() == TypeConstant && var && var->deferred_declaration)
8762 {
8763 // Special case, if we end up declaring a variable when assigning the constant array,
8764 // we can avoid the copy by directly assigning the constant expression.
8765 // This is likely necessary to be able to use a variable as a true look-up table, as it is unlikely
8766 // the compiler will be able to optimize the spvArrayCopy() into a constant LUT.
8767 // After a variable has been declared, we can no longer assign constant arrays in MSL unfortunately.
8768 statement(to_expression(id_lhs), " = ", constant_expression(get<SPIRConstant>(id_rhs)), ";");
8769 return true;
8770 }
8771
8772 if (get_execution_model() == ExecutionModelTessellationControl &&
8773 has_decoration(id_lhs, DecorationBuiltIn))
8774 {
8775 auto builtin = BuiltIn(get_decoration(id_lhs, DecorationBuiltIn));
8776 // Need to manually unroll the array store.
8777 if (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter)
8778 {
8779 uint32_t array_size = get_physical_tess_level_array_size(builtin);
8780 if (array_size == 1)
8781 statement(to_expression(id_lhs), " = half(", to_expression(id_rhs), "[0]);");
8782 else
8783 {
8784 for (uint32_t i = 0; i < array_size; i++)
8785 statement(to_expression(id_lhs), "[", i, "] = half(", to_expression(id_rhs), "[", i, "]);");
8786 }
8787 return true;
8788 }
8789 }
8790
8791 // Ensure the LHS variable has been declared
8792 auto *p_v_lhs = maybe_get_backing_variable(id_lhs);
8793 if (p_v_lhs)
8794 flush_variable_declaration(p_v_lhs->self);
8795
8796 auto lhs_storage = get_expression_effective_storage_class(id_lhs);
8797 auto rhs_storage = get_expression_effective_storage_class(id_rhs);
8798 emit_array_copy(to_expression(id_lhs), id_lhs, id_rhs, lhs_storage, rhs_storage);
8799 register_write(id_lhs);
8800
8801 return true;
8802}
8803
8804// Emits one of the atomic functions. In MSL, the atomic functions operate on pointers
8805void CompilerMSL::emit_atomic_func_op(uint32_t result_type, uint32_t result_id, const char *op, Op opcode,
8806 uint32_t mem_order_1, uint32_t mem_order_2, bool has_mem_order_2, uint32_t obj, uint32_t op1,
8807 bool op1_is_pointer, bool op1_is_literal, uint32_t op2)
8808{
8809 string exp = string(op) + "(";
8810
8811 auto &type = get_pointee_type(expression_type(obj));
8812 auto expected_type = type.basetype;
8813 if (opcode == OpAtomicUMax || opcode == OpAtomicUMin)
8814 expected_type = to_unsigned_basetype(type.width);
8815 else if (opcode == OpAtomicSMax || opcode == OpAtomicSMin)
8816 expected_type = to_signed_basetype(type.width);
8817
8818 auto remapped_type = type;
8819 remapped_type.basetype = expected_type;
8820
8821 exp += "(";
8822 auto *var = maybe_get_backing_variable(obj);
8823 if (!var)
8824 SPIRV_CROSS_THROW("No backing variable for atomic operation.");
8825
8826 // Emulate texture2D atomic operations
8827 const auto &res_type = get<SPIRType>(var->basetype);
8828 if (res_type.storage == StorageClassUniformConstant && res_type.basetype == SPIRType::Image)
8829 {
8830 exp += "device";
8831 }
8832 else
8833 {
8834 exp += get_argument_address_space(*var);
8835 }
8836
8837 exp += " atomic_";
8838 // For signed and unsigned min/max, we can signal this through the pointer type.
8839 // There is no other way, since C++ does not have explicit signage for atomics.
8840 exp += type_to_glsl(remapped_type);
8841 exp += "*)";
8842
8843 exp += "&";
8844 exp += to_enclosed_expression(obj);
8845
8846 bool is_atomic_compare_exchange_strong = op1_is_pointer && op1;
8847
8848 if (is_atomic_compare_exchange_strong)
8849 {
8850 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") == 0);
8851 assert(op2);
8852 assert(has_mem_order_2);
8853 exp += ", &";
8854 exp += to_name(result_id);
8855 exp += ", ";
8856 exp += to_expression(op2);
8857 exp += ", ";
8858 exp += get_memory_order(mem_order_1);
8859 exp += ", ";
8860 exp += get_memory_order(mem_order_2);
8861 exp += ")";
8862
8863 // MSL only supports the weak atomic compare exchange, so emit a CAS loop here.
8864 // The MSL function returns false if the atomic write fails OR the comparison test fails,
8865 // so we must validate that it wasn't the comparison test that failed before continuing
8866 // the CAS loop, otherwise it will loop infinitely, with the comparison test always failing.
8867 // The function updates the comparitor value from the memory value, so the additional
8868 // comparison test evaluates the memory value against the expected value.
8869 emit_uninitialized_temporary_expression(result_type, result_id);
8870 statement("do");
8871 begin_scope();
8872 statement(to_name(result_id), " = ", to_expression(op1), ";");
8873 end_scope_decl(join("while (!", exp, " && ", to_name(result_id), " == ", to_enclosed_expression(op1), ")"));
8874 }
8875 else
8876 {
8877 assert(strcmp(op, "atomic_compare_exchange_weak_explicit") != 0);
8878 if (op1)
8879 {
8880 if (op1_is_literal)
8881 exp += join(", ", op1);
8882 else
8883 exp += ", " + bitcast_expression(expected_type, op1);
8884 }
8885 if (op2)
8886 exp += ", " + to_expression(op2);
8887
8888 exp += string(", ") + get_memory_order(mem_order_1);
8889 if (has_mem_order_2)
8890 exp += string(", ") + get_memory_order(mem_order_2);
8891
8892 exp += ")";
8893
8894 if (expected_type != type.basetype)
8895 exp = bitcast_expression(type, expected_type, exp);
8896
8897 if (strcmp(op, "atomic_store_explicit") != 0)
8898 emit_op(result_type, result_id, exp, false);
8899 else
8900 statement(exp, ";");
8901 }
8902
8903 flush_all_atomic_capable_variables();
8904}
8905
8906// Metal only supports relaxed memory order for now
8907const char *CompilerMSL::get_memory_order(uint32_t)
8908{
8909 return "memory_order_relaxed";
8910}
8911
8912// Override for MSL-specific extension syntax instructions.
8913// In some cases, deliberately select either the fast or precise versions of the MSL functions to match Vulkan math precision results.
8914void CompilerMSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
8915{
8916 auto op = static_cast<GLSLstd450>(eop);
8917
8918 // If we need to do implicit bitcasts, make sure we do it with the correct type.
8919 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
8920 auto int_type = to_signed_basetype(integer_width);
8921 auto uint_type = to_unsigned_basetype(integer_width);
8922
8923 switch (op)
8924 {
8925 case GLSLstd450Sinh:
8926 emit_unary_func_op(result_type, id, args[0], "fast::sinh");
8927 break;
8928 case GLSLstd450Cosh:
8929 emit_unary_func_op(result_type, id, args[0], "fast::cosh");
8930 break;
8931 case GLSLstd450Tanh:
8932 emit_unary_func_op(result_type, id, args[0], "precise::tanh");
8933 break;
8934 case GLSLstd450Atan2:
8935 emit_binary_func_op(result_type, id, args[0], args[1], "precise::atan2");
8936 break;
8937 case GLSLstd450InverseSqrt:
8938 emit_unary_func_op(result_type, id, args[0], "rsqrt");
8939 break;
8940 case GLSLstd450RoundEven:
8941 emit_unary_func_op(result_type, id, args[0], "rint");
8942 break;
8943
8944 case GLSLstd450FindILsb:
8945 {
8946 // In this template version of findLSB, we return T.
8947 auto basetype = expression_type(args[0]).basetype;
8948 emit_unary_func_op_cast(result_type, id, args[0], "spvFindLSB", basetype, basetype);
8949 break;
8950 }
8951
8952 case GLSLstd450FindSMsb:
8953 emit_unary_func_op_cast(result_type, id, args[0], "spvFindSMSB", int_type, int_type);
8954 break;
8955
8956 case GLSLstd450FindUMsb:
8957 emit_unary_func_op_cast(result_type, id, args[0], "spvFindUMSB", uint_type, uint_type);
8958 break;
8959
8960 case GLSLstd450PackSnorm4x8:
8961 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm4x8");
8962 break;
8963 case GLSLstd450PackUnorm4x8:
8964 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm4x8");
8965 break;
8966 case GLSLstd450PackSnorm2x16:
8967 emit_unary_func_op(result_type, id, args[0], "pack_float_to_snorm2x16");
8968 break;
8969 case GLSLstd450PackUnorm2x16:
8970 emit_unary_func_op(result_type, id, args[0], "pack_float_to_unorm2x16");
8971 break;
8972
8973 case GLSLstd450PackHalf2x16:
8974 {
8975 auto expr = join("as_type<uint>(half2(", to_expression(args[0]), "))");
8976 emit_op(result_type, id, expr, should_forward(args[0]));
8977 inherit_expression_dependencies(id, args[0]);
8978 break;
8979 }
8980
8981 case GLSLstd450UnpackSnorm4x8:
8982 emit_unary_func_op(result_type, id, args[0], "unpack_snorm4x8_to_float");
8983 break;
8984 case GLSLstd450UnpackUnorm4x8:
8985 emit_unary_func_op(result_type, id, args[0], "unpack_unorm4x8_to_float");
8986 break;
8987 case GLSLstd450UnpackSnorm2x16:
8988 emit_unary_func_op(result_type, id, args[0], "unpack_snorm2x16_to_float");
8989 break;
8990 case GLSLstd450UnpackUnorm2x16:
8991 emit_unary_func_op(result_type, id, args[0], "unpack_unorm2x16_to_float");
8992 break;
8993
8994 case GLSLstd450UnpackHalf2x16:
8995 {
8996 auto expr = join("float2(as_type<half2>(", to_expression(args[0]), "))");
8997 emit_op(result_type, id, expr, should_forward(args[0]));
8998 inherit_expression_dependencies(id, args[0]);
8999 break;
9000 }
9001
9002 case GLSLstd450PackDouble2x32:
9003 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450PackDouble2x32"); // Currently unsupported
9004 break;
9005 case GLSLstd450UnpackDouble2x32:
9006 emit_unary_func_op(result_type, id, args[0], "unsupported_GLSLstd450UnpackDouble2x32"); // Currently unsupported
9007 break;
9008
9009 case GLSLstd450MatrixInverse:
9010 {
9011 auto &mat_type = get<SPIRType>(result_type);
9012 switch (mat_type.columns)
9013 {
9014 case 2:
9015 emit_unary_func_op(result_type, id, args[0], "spvInverse2x2");
9016 break;
9017 case 3:
9018 emit_unary_func_op(result_type, id, args[0], "spvInverse3x3");
9019 break;
9020 case 4:
9021 emit_unary_func_op(result_type, id, args[0], "spvInverse4x4");
9022 break;
9023 default:
9024 break;
9025 }
9026 break;
9027 }
9028
9029 case GLSLstd450FMin:
9030 // If the result type isn't float, don't bother calling the specific
9031 // precise::/fast:: version. Metal doesn't have those for half and
9032 // double types.
9033 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
9034 emit_binary_func_op(result_type, id, args[0], args[1], "min");
9035 else
9036 emit_binary_func_op(result_type, id, args[0], args[1], "fast::min");
9037 break;
9038
9039 case GLSLstd450FMax:
9040 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
9041 emit_binary_func_op(result_type, id, args[0], args[1], "max");
9042 else
9043 emit_binary_func_op(result_type, id, args[0], args[1], "fast::max");
9044 break;
9045
9046 case GLSLstd450FClamp:
9047 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
9048 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
9049 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
9050 else
9051 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "fast::clamp");
9052 break;
9053
9054 case GLSLstd450NMin:
9055 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
9056 emit_binary_func_op(result_type, id, args[0], args[1], "min");
9057 else
9058 emit_binary_func_op(result_type, id, args[0], args[1], "precise::min");
9059 break;
9060
9061 case GLSLstd450NMax:
9062 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
9063 emit_binary_func_op(result_type, id, args[0], args[1], "max");
9064 else
9065 emit_binary_func_op(result_type, id, args[0], args[1], "precise::max");
9066 break;
9067
9068 case GLSLstd450NClamp:
9069 // TODO: If args[1] is 0 and args[2] is 1, emit a saturate() call.
9070 if (get<SPIRType>(result_type).basetype != SPIRType::Float)
9071 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "clamp");
9072 else
9073 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "precise::clamp");
9074 break;
9075
9076 case GLSLstd450InterpolateAtCentroid:
9077 {
9078 // We can't just emit the expression normally, because the qualified name contains a call to the default
9079 // interpolate method, or refers to a local variable. We saved the interface index we need; use it to construct
9080 // the base for the method call.
9081 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
9082 string component;
9083 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
9084 {
9085 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
9086 auto *c = maybe_get<SPIRConstant>(index_expr);
9087 if (!c || c->specialization)
9088 component = join("[", to_expression(index_expr), "]");
9089 else
9090 component = join(".", index_to_swizzle(c->scalar()));
9091 }
9092 emit_op(result_type, id,
9093 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
9094 ".interpolate_at_centroid()", component),
9095 should_forward(args[0]));
9096 break;
9097 }
9098
9099 case GLSLstd450InterpolateAtSample:
9100 {
9101 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
9102 string component;
9103 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
9104 {
9105 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
9106 auto *c = maybe_get<SPIRConstant>(index_expr);
9107 if (!c || c->specialization)
9108 component = join("[", to_expression(index_expr), "]");
9109 else
9110 component = join(".", index_to_swizzle(c->scalar()));
9111 }
9112 emit_op(result_type, id,
9113 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
9114 ".interpolate_at_sample(", to_expression(args[1]), ")", component),
9115 should_forward(args[0]) && should_forward(args[1]));
9116 break;
9117 }
9118
9119 case GLSLstd450InterpolateAtOffset:
9120 {
9121 uint32_t interface_index = get_extended_decoration(args[0], SPIRVCrossDecorationInterfaceMemberIndex);
9122 string component;
9123 if (has_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr))
9124 {
9125 uint32_t index_expr = get_extended_decoration(args[0], SPIRVCrossDecorationInterpolantComponentExpr);
9126 auto *c = maybe_get<SPIRConstant>(index_expr);
9127 if (!c || c->specialization)
9128 component = join("[", to_expression(index_expr), "]");
9129 else
9130 component = join(".", index_to_swizzle(c->scalar()));
9131 }
9132 // Like Direct3D, Metal puts the (0, 0) at the upper-left corner, not the center as SPIR-V and GLSL do.
9133 // Offset the offset by (1/2 - 1/16), or 0.4375, to compensate for this.
9134 // It has to be (1/2 - 1/16) and not 1/2, or several CTS tests subtly break on Intel.
9135 emit_op(result_type, id,
9136 join(to_name(stage_in_var_id), ".", to_member_name(get_stage_in_struct_type(), interface_index),
9137 ".interpolate_at_offset(", to_expression(args[1]), " + 0.4375)", component),
9138 should_forward(args[0]) && should_forward(args[1]));
9139 break;
9140 }
9141
9142 case GLSLstd450Distance:
9143 // MSL does not support scalar versions here.
9144 if (expression_type(args[0]).vecsize == 1)
9145 {
9146 // Equivalent to length(a - b) -> abs(a - b).
9147 emit_op(result_type, id,
9148 join("abs(", to_enclosed_unpacked_expression(args[0]), " - ",
9149 to_enclosed_unpacked_expression(args[1]), ")"),
9150 should_forward(args[0]) && should_forward(args[1]));
9151 inherit_expression_dependencies(id, args[0]);
9152 inherit_expression_dependencies(id, args[1]);
9153 }
9154 else
9155 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9156 break;
9157
9158 case GLSLstd450Length:
9159 // MSL does not support scalar versions, so use abs().
9160 if (expression_type(args[0]).vecsize == 1)
9161 emit_unary_func_op(result_type, id, args[0], "abs");
9162 else
9163 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9164 break;
9165
9166 case GLSLstd450Normalize:
9167 {
9168 auto &exp_type = expression_type(args[0]);
9169 // MSL does not support scalar versions here.
9170 // MSL has no implementation for normalize in the fast:: namespace for half2 and half3
9171 // Returns -1 or 1 for valid input, sign() does the job.
9172 if (exp_type.vecsize == 1)
9173 emit_unary_func_op(result_type, id, args[0], "sign");
9174 else if (exp_type.vecsize <= 3 && exp_type.basetype == SPIRType::Half)
9175 emit_unary_func_op(result_type, id, args[0], "normalize");
9176 else
9177 emit_unary_func_op(result_type, id, args[0], "fast::normalize");
9178 break;
9179 }
9180 case GLSLstd450Reflect:
9181 if (get<SPIRType>(result_type).vecsize == 1)
9182 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
9183 else
9184 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9185 break;
9186
9187 case GLSLstd450Refract:
9188 if (get<SPIRType>(result_type).vecsize == 1)
9189 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
9190 else
9191 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9192 break;
9193
9194 case GLSLstd450FaceForward:
9195 if (get<SPIRType>(result_type).vecsize == 1)
9196 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
9197 else
9198 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9199 break;
9200
9201 case GLSLstd450Modf:
9202 case GLSLstd450Frexp:
9203 {
9204 // Special case. If the variable is a scalar access chain, we cannot use it directly. We have to emit a temporary.
9205 // Another special case is if the variable is in a storage class which is not thread.
9206 auto *ptr = maybe_get<SPIRExpression>(args[1]);
9207 auto &type = expression_type(args[1]);
9208
9209 bool is_thread_storage = storage_class_array_is_thread(type.storage);
9210 if (type.storage == StorageClassOutput && capture_output_to_buffer)
9211 is_thread_storage = false;
9212
9213 if (!is_thread_storage ||
9214 (ptr && ptr->access_chain && is_scalar(expression_type(args[1]))))
9215 {
9216 register_call_out_argument(args[1]);
9217 forced_temporaries.insert(id);
9218
9219 // Need to create temporaries and copy over to access chain after.
9220 // We cannot directly take the reference of a vector swizzle in MSL, even if it's scalar ...
9221 uint32_t &tmp_id = extra_sub_expressions[id];
9222 if (!tmp_id)
9223 tmp_id = ir.increase_bound_by(1);
9224
9225 uint32_t tmp_type_id = get_pointee_type_id(expression_type_id(args[1]));
9226 emit_uninitialized_temporary_expression(tmp_type_id, tmp_id);
9227 emit_binary_func_op(result_type, id, args[0], tmp_id, eop == GLSLstd450Modf ? "modf" : "frexp");
9228 statement(to_expression(args[1]), " = ", to_expression(tmp_id), ";");
9229 }
9230 else
9231 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9232 break;
9233 }
9234
9235 default:
9236 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
9237 break;
9238 }
9239}
9240
9241void CompilerMSL::emit_spv_amd_shader_trinary_minmax_op(uint32_t result_type, uint32_t id, uint32_t eop,
9242 const uint32_t *args, uint32_t count)
9243{
9244 enum AMDShaderTrinaryMinMax
9245 {
9246 FMin3AMD = 1,
9247 UMin3AMD = 2,
9248 SMin3AMD = 3,
9249 FMax3AMD = 4,
9250 UMax3AMD = 5,
9251 SMax3AMD = 6,
9252 FMid3AMD = 7,
9253 UMid3AMD = 8,
9254 SMid3AMD = 9
9255 };
9256
9257 if (!msl_options.supports_msl_version(2, 1))
9258 SPIRV_CROSS_THROW("Trinary min/max functions require MSL 2.1.");
9259
9260 auto op = static_cast<AMDShaderTrinaryMinMax>(eop);
9261
9262 switch (op)
9263 {
9264 case FMid3AMD:
9265 case UMid3AMD:
9266 case SMid3AMD:
9267 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "median3");
9268 break;
9269 default:
9270 CompilerGLSL::emit_spv_amd_shader_trinary_minmax_op(result_type, id, eop, args, count);
9271 break;
9272 }
9273}
9274
9275// Emit a structure declaration for the specified interface variable.
9276void CompilerMSL::emit_interface_block(uint32_t ib_var_id)
9277{
9278 if (ib_var_id)
9279 {
9280 auto &ib_var = get<SPIRVariable>(ib_var_id);
9281 auto &ib_type = get_variable_data_type(ib_var);
9282 //assert(ib_type.basetype == SPIRType::Struct && !ib_type.member_types.empty());
9283 assert(ib_type.basetype == SPIRType::Struct);
9284 emit_struct(ib_type);
9285 }
9286}
9287
9288// Emits the declaration signature of the specified function.
9289// If this is the entry point function, Metal-specific return value and function arguments are added.
9290void CompilerMSL::emit_function_prototype(SPIRFunction &func, const Bitset &)
9291{
9292 if (func.self != ir.default_entry_point)
9293 add_function_overload(func);
9294
9295 local_variable_names = resource_names;
9296 string decl;
9297
9298 processing_entry_point = func.self == ir.default_entry_point;
9299
9300 // Metal helper functions must be static force-inline otherwise they will cause problems when linked together in a single Metallib.
9301 if (!processing_entry_point)
9302 statement(force_inline);
9303
9304 auto &type = get<SPIRType>(func.return_type);
9305
9306 if (!type.array.empty() && msl_options.force_native_arrays)
9307 {
9308 // We cannot return native arrays in MSL, so "return" through an out variable.
9309 decl += "void";
9310 }
9311 else
9312 {
9313 decl += func_type_decl(type);
9314 }
9315
9316 decl += " ";
9317 decl += to_name(func.self);
9318 decl += "(";
9319
9320 if (!type.array.empty() && msl_options.force_native_arrays)
9321 {
9322 // Fake arrays returns by writing to an out array instead.
9323 decl += "thread ";
9324 decl += type_to_glsl(type);
9325 decl += " (&spvReturnValue)";
9326 decl += type_to_array_glsl(type);
9327 if (!func.arguments.empty())
9328 decl += ", ";
9329 }
9330
9331 if (processing_entry_point)
9332 {
9333 if (msl_options.argument_buffers)
9334 decl += entry_point_args_argument_buffer(!func.arguments.empty());
9335 else
9336 decl += entry_point_args_classic(!func.arguments.empty());
9337
9338 // If entry point function has variables that require early declaration,
9339 // ensure they each have an empty initializer, creating one if needed.
9340 // This is done at this late stage because the initialization expression
9341 // is cleared after each compilation pass.
9342 for (auto var_id : vars_needing_early_declaration)
9343 {
9344 auto &ed_var = get<SPIRVariable>(var_id);
9345 ID &initializer = ed_var.initializer;
9346 if (!initializer)
9347 initializer = ir.increase_bound_by(1);
9348
9349 // Do not override proper initializers.
9350 if (ir.ids[initializer].get_type() == TypeNone || ir.ids[initializer].get_type() == TypeExpression)
9351 set<SPIRExpression>(ed_var.initializer, "{}", ed_var.basetype, true);
9352 }
9353 }
9354
9355 for (auto &arg : func.arguments)
9356 {
9357 uint32_t name_id = arg.id;
9358
9359 auto *var = maybe_get<SPIRVariable>(arg.id);
9360 if (var)
9361 {
9362 // If we need to modify the name of the variable, make sure we modify the original variable.
9363 // Our alias is just a shadow variable.
9364 if (arg.alias_global_variable && var->basevariable)
9365 name_id = var->basevariable;
9366
9367 var->parameter = &arg; // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
9368 }
9369
9370 add_local_variable_name(name_id);
9371
9372 decl += argument_decl(arg);
9373
9374 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
9375
9376 auto &arg_type = get<SPIRType>(arg.type);
9377 if (arg_type.basetype == SPIRType::SampledImage && !is_dynamic_img_sampler)
9378 {
9379 // Manufacture automatic plane args for multiplanar texture
9380 uint32_t planes = 1;
9381 if (auto *constexpr_sampler = find_constexpr_sampler(name_id))
9382 if (constexpr_sampler->ycbcr_conversion_enable)
9383 planes = constexpr_sampler->planes;
9384 for (uint32_t i = 1; i < planes; i++)
9385 decl += join(", ", argument_decl(arg), plane_name_suffix, i);
9386
9387 // Manufacture automatic sampler arg for SampledImage texture
9388 if (arg_type.image.dim != DimBuffer)
9389 {
9390 if (arg_type.array.empty())
9391 {
9392 decl += join(", ", sampler_type(arg_type, arg.id), " ", to_sampler_expression(arg.id));
9393 }
9394 else
9395 {
9396 const char *sampler_address_space =
9397 descriptor_address_space(name_id,
9398 StorageClassUniformConstant,
9399 "thread const");
9400 decl += join(", ", sampler_address_space, " ", sampler_type(arg_type, arg.id), "& ", to_sampler_expression(arg.id));
9401 }
9402 }
9403 }
9404
9405 // Manufacture automatic swizzle arg.
9406 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(arg_type) &&
9407 !is_dynamic_img_sampler)
9408 {
9409 bool arg_is_array = !arg_type.array.empty();
9410 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_swizzle_expression(arg.id));
9411 }
9412
9413 if (buffers_requiring_array_length.count(name_id))
9414 {
9415 bool arg_is_array = !arg_type.array.empty();
9416 decl += join(", constant uint", arg_is_array ? "* " : "& ", to_buffer_size_expression(name_id));
9417 }
9418
9419 if (&arg != &func.arguments.back())
9420 decl += ", ";
9421 }
9422
9423 decl += ")";
9424 statement(decl);
9425}
9426
9427static bool needs_chroma_reconstruction(const MSLConstexprSampler *constexpr_sampler)
9428{
9429 // For now, only multiplanar images need explicit reconstruction. GBGR and BGRG images
9430 // use implicit reconstruction.
9431 return constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && constexpr_sampler->planes > 1;
9432}
9433
9434// Returns the texture sampling function string for the specified image and sampling characteristics.
9435string CompilerMSL::to_function_name(const TextureFunctionNameArguments &args)
9436{
9437 VariableID img = args.base.img;
9438 const MSLConstexprSampler *constexpr_sampler = nullptr;
9439 bool is_dynamic_img_sampler = false;
9440 if (auto *var = maybe_get_backing_variable(img))
9441 {
9442 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9443 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9444 }
9445
9446 // Special-case gather. We have to alter the component being looked up
9447 // in the swizzle case.
9448 if (msl_options.swizzle_texture_samples && args.base.is_gather && !is_dynamic_img_sampler &&
9449 (!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable))
9450 {
9451 bool is_compare = comparison_ids.count(img);
9452 add_spv_func_and_recompile(is_compare ? SPVFuncImplGatherCompareSwizzle : SPVFuncImplGatherSwizzle);
9453 return is_compare ? "spvGatherCompareSwizzle" : "spvGatherSwizzle";
9454 }
9455
9456 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
9457
9458 // Texture reference
9459 string fname;
9460 if (needs_chroma_reconstruction(constexpr_sampler) && !is_dynamic_img_sampler)
9461 {
9462 if (constexpr_sampler->planes != 2 && constexpr_sampler->planes != 3)
9463 SPIRV_CROSS_THROW("Unhandled number of color image planes!");
9464 // 444 images aren't downsampled, so we don't need to do linear filtering.
9465 if (constexpr_sampler->resolution == MSL_FORMAT_RESOLUTION_444 ||
9466 constexpr_sampler->chroma_filter == MSL_SAMPLER_FILTER_NEAREST)
9467 {
9468 if (constexpr_sampler->planes == 2)
9469 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest2Plane);
9470 else
9471 add_spv_func_and_recompile(SPVFuncImplChromaReconstructNearest3Plane);
9472 fname = "spvChromaReconstructNearest";
9473 }
9474 else // Linear with a downsampled format
9475 {
9476 fname = "spvChromaReconstructLinear";
9477 switch (constexpr_sampler->resolution)
9478 {
9479 case MSL_FORMAT_RESOLUTION_444:
9480 assert(false);
9481 break; // not reached
9482 case MSL_FORMAT_RESOLUTION_422:
9483 switch (constexpr_sampler->x_chroma_offset)
9484 {
9485 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9486 if (constexpr_sampler->planes == 2)
9487 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven2Plane);
9488 else
9489 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422CositedEven3Plane);
9490 fname += "422CositedEven";
9491 break;
9492 case MSL_CHROMA_LOCATION_MIDPOINT:
9493 if (constexpr_sampler->planes == 2)
9494 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint2Plane);
9495 else
9496 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear422Midpoint3Plane);
9497 fname += "422Midpoint";
9498 break;
9499 default:
9500 SPIRV_CROSS_THROW("Invalid chroma location.");
9501 }
9502 break;
9503 case MSL_FORMAT_RESOLUTION_420:
9504 fname += "420";
9505 switch (constexpr_sampler->x_chroma_offset)
9506 {
9507 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9508 switch (constexpr_sampler->y_chroma_offset)
9509 {
9510 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9511 if (constexpr_sampler->planes == 2)
9512 add_spv_func_and_recompile(
9513 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven2Plane);
9514 else
9515 add_spv_func_and_recompile(
9516 SPVFuncImplChromaReconstructLinear420XCositedEvenYCositedEven3Plane);
9517 fname += "XCositedEvenYCositedEven";
9518 break;
9519 case MSL_CHROMA_LOCATION_MIDPOINT:
9520 if (constexpr_sampler->planes == 2)
9521 add_spv_func_and_recompile(
9522 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint2Plane);
9523 else
9524 add_spv_func_and_recompile(
9525 SPVFuncImplChromaReconstructLinear420XCositedEvenYMidpoint3Plane);
9526 fname += "XCositedEvenYMidpoint";
9527 break;
9528 default:
9529 SPIRV_CROSS_THROW("Invalid Y chroma location.");
9530 }
9531 break;
9532 case MSL_CHROMA_LOCATION_MIDPOINT:
9533 switch (constexpr_sampler->y_chroma_offset)
9534 {
9535 case MSL_CHROMA_LOCATION_COSITED_EVEN:
9536 if (constexpr_sampler->planes == 2)
9537 add_spv_func_and_recompile(
9538 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven2Plane);
9539 else
9540 add_spv_func_and_recompile(
9541 SPVFuncImplChromaReconstructLinear420XMidpointYCositedEven3Plane);
9542 fname += "XMidpointYCositedEven";
9543 break;
9544 case MSL_CHROMA_LOCATION_MIDPOINT:
9545 if (constexpr_sampler->planes == 2)
9546 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint2Plane);
9547 else
9548 add_spv_func_and_recompile(SPVFuncImplChromaReconstructLinear420XMidpointYMidpoint3Plane);
9549 fname += "XMidpointYMidpoint";
9550 break;
9551 default:
9552 SPIRV_CROSS_THROW("Invalid Y chroma location.");
9553 }
9554 break;
9555 default:
9556 SPIRV_CROSS_THROW("Invalid X chroma location.");
9557 }
9558 break;
9559 default:
9560 SPIRV_CROSS_THROW("Invalid format resolution.");
9561 }
9562 }
9563 }
9564 else
9565 {
9566 fname = to_expression(combined ? combined->image : img) + ".";
9567
9568 // Texture function and sampler
9569 if (args.base.is_fetch)
9570 fname += "read";
9571 else if (args.base.is_gather)
9572 fname += "gather";
9573 else
9574 fname += "sample";
9575
9576 if (args.has_dref)
9577 fname += "_compare";
9578 }
9579
9580 return fname;
9581}
9582
9583string CompilerMSL::convert_to_f32(const string &expr, uint32_t components)
9584{
9585 SPIRType t;
9586 t.basetype = SPIRType::Float;
9587 t.vecsize = components;
9588 t.columns = 1;
9589 return join(type_to_glsl_constructor(t), "(", expr, ")");
9590}
9591
9592static inline bool sampling_type_needs_f32_conversion(const SPIRType &type)
9593{
9594 // Double is not supported to begin with, but doesn't hurt to check for completion.
9595 return type.basetype == SPIRType::Half || type.basetype == SPIRType::Double;
9596}
9597
9598// Returns the function args for a texture sampling function for the specified image and sampling characteristics.
9599string CompilerMSL::to_function_args(const TextureFunctionArguments &args, bool *p_forward)
9600{
9601 VariableID img = args.base.img;
9602 auto &imgtype = *args.base.imgtype;
9603 uint32_t lod = args.lod;
9604 uint32_t grad_x = args.grad_x;
9605 uint32_t grad_y = args.grad_y;
9606 uint32_t bias = args.bias;
9607
9608 const MSLConstexprSampler *constexpr_sampler = nullptr;
9609 bool is_dynamic_img_sampler = false;
9610 if (auto *var = maybe_get_backing_variable(img))
9611 {
9612 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
9613 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
9614 }
9615
9616 string farg_str;
9617 bool forward = true;
9618
9619 if (!is_dynamic_img_sampler)
9620 {
9621 // Texture reference (for some cases)
9622 if (needs_chroma_reconstruction(constexpr_sampler))
9623 {
9624 // Multiplanar images need two or three textures.
9625 farg_str += to_expression(img);
9626 for (uint32_t i = 1; i < constexpr_sampler->planes; i++)
9627 farg_str += join(", ", to_expression(img), plane_name_suffix, i);
9628 }
9629 else if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
9630 msl_options.swizzle_texture_samples && args.base.is_gather)
9631 {
9632 auto *combined = maybe_get<SPIRCombinedImageSampler>(img);
9633 farg_str += to_expression(combined ? combined->image : img);
9634 }
9635
9636 // Sampler reference
9637 if (!args.base.is_fetch)
9638 {
9639 if (!farg_str.empty())
9640 farg_str += ", ";
9641 farg_str += to_sampler_expression(img);
9642 }
9643
9644 if ((!constexpr_sampler || !constexpr_sampler->ycbcr_conversion_enable) &&
9645 msl_options.swizzle_texture_samples && args.base.is_gather)
9646 {
9647 // Add the swizzle constant from the swizzle buffer.
9648 farg_str += ", " + to_swizzle_expression(img);
9649 used_swizzle_buffer = true;
9650 }
9651
9652 // Swizzled gather puts the component before the other args, to allow template
9653 // deduction to work.
9654 if (args.component && msl_options.swizzle_texture_samples)
9655 {
9656 forward = should_forward(args.component);
9657 farg_str += ", " + to_component_argument(args.component);
9658 }
9659 }
9660
9661 // Texture coordinates
9662 forward = forward && should_forward(args.coord);
9663 auto coord_expr = to_enclosed_expression(args.coord);
9664 auto &coord_type = expression_type(args.coord);
9665 bool coord_is_fp = type_is_floating_point(coord_type);
9666 bool is_cube_fetch = false;
9667
9668 string tex_coords = coord_expr;
9669 uint32_t alt_coord_component = 0;
9670
9671 switch (imgtype.image.dim)
9672 {
9673
9674 case Dim1D:
9675 if (coord_type.vecsize > 1)
9676 tex_coords = enclose_expression(tex_coords) + ".x";
9677
9678 if (args.base.is_fetch)
9679 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9680 else if (sampling_type_needs_f32_conversion(coord_type))
9681 tex_coords = convert_to_f32(tex_coords, 1);
9682
9683 if (msl_options.texture_1D_as_2D)
9684 {
9685 if (args.base.is_fetch)
9686 tex_coords = "uint2(" + tex_coords + ", 0)";
9687 else
9688 tex_coords = "float2(" + tex_coords + ", 0.5)";
9689 }
9690
9691 alt_coord_component = 1;
9692 break;
9693
9694 case DimBuffer:
9695 if (coord_type.vecsize > 1)
9696 tex_coords = enclose_expression(tex_coords) + ".x";
9697
9698 if (msl_options.texture_buffer_native)
9699 {
9700 tex_coords = "uint(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9701 }
9702 else
9703 {
9704 // Metal texel buffer textures are 2D, so convert 1D coord to 2D.
9705 // Support for Metal 2.1's new texture_buffer type.
9706 if (args.base.is_fetch)
9707 {
9708 if (msl_options.texel_buffer_texture_width > 0)
9709 {
9710 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9711 }
9712 else
9713 {
9714 tex_coords = "spvTexelBufferCoord(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ", " +
9715 to_expression(img) + ")";
9716 }
9717 }
9718 }
9719
9720 alt_coord_component = 1;
9721 break;
9722
9723 case DimSubpassData:
9724 // If we're using Metal's native frame-buffer fetch API for subpass inputs,
9725 // this path will not be hit.
9726 tex_coords = "uint2(gl_FragCoord.xy)";
9727 alt_coord_component = 2;
9728 break;
9729
9730 case Dim2D:
9731 if (coord_type.vecsize > 2)
9732 tex_coords = enclose_expression(tex_coords) + ".xy";
9733
9734 if (args.base.is_fetch)
9735 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9736 else if (sampling_type_needs_f32_conversion(coord_type))
9737 tex_coords = convert_to_f32(tex_coords, 2);
9738
9739 alt_coord_component = 2;
9740 break;
9741
9742 case Dim3D:
9743 if (coord_type.vecsize > 3)
9744 tex_coords = enclose_expression(tex_coords) + ".xyz";
9745
9746 if (args.base.is_fetch)
9747 tex_coords = "uint3(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9748 else if (sampling_type_needs_f32_conversion(coord_type))
9749 tex_coords = convert_to_f32(tex_coords, 3);
9750
9751 alt_coord_component = 3;
9752 break;
9753
9754 case DimCube:
9755 if (args.base.is_fetch)
9756 {
9757 is_cube_fetch = true;
9758 tex_coords += ".xy";
9759 tex_coords = "uint2(" + round_fp_tex_coords(tex_coords, coord_is_fp) + ")";
9760 }
9761 else
9762 {
9763 if (coord_type.vecsize > 3)
9764 tex_coords = enclose_expression(tex_coords) + ".xyz";
9765 }
9766
9767 if (sampling_type_needs_f32_conversion(coord_type))
9768 tex_coords = convert_to_f32(tex_coords, 3);
9769
9770 alt_coord_component = 3;
9771 break;
9772
9773 default:
9774 break;
9775 }
9776
9777 if (args.base.is_fetch && (args.offset || args.coffset))
9778 {
9779 uint32_t offset_expr = args.offset ? args.offset : args.coffset;
9780 // Fetch offsets must be applied directly to the coordinate.
9781 forward = forward && should_forward(offset_expr);
9782 auto &type = expression_type(offset_expr);
9783 if (imgtype.image.dim == Dim1D && msl_options.texture_1D_as_2D)
9784 {
9785 if (type.basetype != SPIRType::UInt)
9786 tex_coords += join(" + uint2(", bitcast_expression(SPIRType::UInt, offset_expr), ", 0)");
9787 else
9788 tex_coords += join(" + uint2(", to_enclosed_expression(offset_expr), ", 0)");
9789 }
9790 else
9791 {
9792 if (type.basetype != SPIRType::UInt)
9793 tex_coords += " + " + bitcast_expression(SPIRType::UInt, offset_expr);
9794 else
9795 tex_coords += " + " + to_enclosed_expression(offset_expr);
9796 }
9797 }
9798
9799 // If projection, use alt coord as divisor
9800 if (args.base.is_proj)
9801 {
9802 if (sampling_type_needs_f32_conversion(coord_type))
9803 tex_coords += " / " + convert_to_f32(to_extract_component_expression(args.coord, alt_coord_component), 1);
9804 else
9805 tex_coords += " / " + to_extract_component_expression(args.coord, alt_coord_component);
9806 }
9807
9808 if (!farg_str.empty())
9809 farg_str += ", ";
9810
9811 if (imgtype.image.dim == DimCube && imgtype.image.arrayed && msl_options.emulate_cube_array)
9812 {
9813 farg_str += "spvCubemapTo2DArrayFace(" + tex_coords + ").xy";
9814
9815 if (is_cube_fetch)
9816 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ")";
9817 else
9818 farg_str +=
9819 ", uint(spvCubemapTo2DArrayFace(" + tex_coords + ").z) + (uint(" +
9820 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9821 ") * 6u)";
9822
9823 add_spv_func_and_recompile(SPVFuncImplCubemapTo2DArrayFace);
9824 }
9825 else
9826 {
9827 farg_str += tex_coords;
9828
9829 // If fetch from cube, add face explicitly
9830 if (is_cube_fetch)
9831 {
9832 // Special case for cube arrays, face and layer are packed in one dimension.
9833 if (imgtype.image.arrayed)
9834 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") % 6u";
9835 else
9836 farg_str +=
9837 ", uint(" + round_fp_tex_coords(to_extract_component_expression(args.coord, 2), coord_is_fp) + ")";
9838 }
9839
9840 // If array, use alt coord
9841 if (imgtype.image.arrayed)
9842 {
9843 // Special case for cube arrays, face and layer are packed in one dimension.
9844 if (imgtype.image.dim == DimCube && args.base.is_fetch)
9845 {
9846 farg_str += ", uint(" + to_extract_component_expression(args.coord, 2) + ") / 6u";
9847 }
9848 else
9849 {
9850 farg_str +=
9851 ", uint(" +
9852 round_fp_tex_coords(to_extract_component_expression(args.coord, alt_coord_component), coord_is_fp) +
9853 ")";
9854 if (imgtype.image.dim == DimSubpassData)
9855 {
9856 if (msl_options.multiview)
9857 farg_str += " + gl_ViewIndex";
9858 else if (msl_options.arrayed_subpass_input)
9859 farg_str += " + gl_Layer";
9860 }
9861 }
9862 }
9863 else if (imgtype.image.dim == DimSubpassData)
9864 {
9865 if (msl_options.multiview)
9866 farg_str += ", gl_ViewIndex";
9867 else if (msl_options.arrayed_subpass_input)
9868 farg_str += ", gl_Layer";
9869 }
9870 }
9871
9872 // Depth compare reference value
9873 if (args.dref)
9874 {
9875 forward = forward && should_forward(args.dref);
9876 farg_str += ", ";
9877
9878 auto &dref_type = expression_type(args.dref);
9879
9880 string dref_expr;
9881 if (args.base.is_proj)
9882 dref_expr = join(to_enclosed_expression(args.dref), " / ",
9883 to_extract_component_expression(args.coord, alt_coord_component));
9884 else
9885 dref_expr = to_expression(args.dref);
9886
9887 if (sampling_type_needs_f32_conversion(dref_type))
9888 dref_expr = convert_to_f32(dref_expr, 1);
9889
9890 farg_str += dref_expr;
9891
9892 if (msl_options.is_macos() && (grad_x || grad_y))
9893 {
9894 // For sample compare, MSL does not support gradient2d for all targets (only iOS apparently according to docs).
9895 // However, the most common case here is to have a constant gradient of 0, as that is the only way to express
9896 // LOD == 0 in GLSL with sampler2DArrayShadow (cascaded shadow mapping).
9897 // We will detect a compile-time constant 0 value for gradient and promote that to level(0) on MSL.
9898 bool constant_zero_x = !grad_x || expression_is_constant_null(grad_x);
9899 bool constant_zero_y = !grad_y || expression_is_constant_null(grad_y);
9900 if (constant_zero_x && constant_zero_y)
9901 {
9902 lod = 0;
9903 grad_x = 0;
9904 grad_y = 0;
9905 farg_str += ", level(0)";
9906 }
9907 else if (!msl_options.supports_msl_version(2, 3))
9908 {
9909 SPIRV_CROSS_THROW("Using non-constant 0.0 gradient() qualifier for sample_compare. This is not "
9910 "supported on macOS prior to MSL 2.3.");
9911 }
9912 }
9913
9914 if (msl_options.is_macos() && bias)
9915 {
9916 // Bias is not supported either on macOS with sample_compare.
9917 // Verify it is compile-time zero, and drop the argument.
9918 if (expression_is_constant_null(bias))
9919 {
9920 bias = 0;
9921 }
9922 else if (!msl_options.supports_msl_version(2, 3))
9923 {
9924 SPIRV_CROSS_THROW("Using non-constant 0.0 bias() qualifier for sample_compare. This is not supported "
9925 "on macOS prior to MSL 2.3.");
9926 }
9927 }
9928 }
9929
9930 // LOD Options
9931 // Metal does not support LOD for 1D textures.
9932 if (bias && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9933 {
9934 forward = forward && should_forward(bias);
9935 farg_str += ", bias(" + to_expression(bias) + ")";
9936 }
9937
9938 // Metal does not support LOD for 1D textures.
9939 if (lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9940 {
9941 forward = forward && should_forward(lod);
9942 if (args.base.is_fetch)
9943 {
9944 farg_str += ", " + to_expression(lod);
9945 }
9946 else
9947 {
9948 farg_str += ", level(" + to_expression(lod) + ")";
9949 }
9950 }
9951 else if (args.base.is_fetch && !lod && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D) &&
9952 imgtype.image.dim != DimBuffer && !imgtype.image.ms && imgtype.image.sampled != 2)
9953 {
9954 // Lod argument is optional in OpImageFetch, but we require a LOD value, pick 0 as the default.
9955 // Check for sampled type as well, because is_fetch is also used for OpImageRead in MSL.
9956 farg_str += ", 0";
9957 }
9958
9959 // Metal does not support LOD for 1D textures.
9960 if ((grad_x || grad_y) && (imgtype.image.dim != Dim1D || msl_options.texture_1D_as_2D))
9961 {
9962 forward = forward && should_forward(grad_x);
9963 forward = forward && should_forward(grad_y);
9964 string grad_opt;
9965 switch (imgtype.image.dim)
9966 {
9967 case Dim1D:
9968 case Dim2D:
9969 grad_opt = "2d";
9970 break;
9971 case Dim3D:
9972 grad_opt = "3d";
9973 break;
9974 case DimCube:
9975 if (imgtype.image.arrayed && msl_options.emulate_cube_array)
9976 grad_opt = "2d";
9977 else
9978 grad_opt = "cube";
9979 break;
9980 default:
9981 grad_opt = "unsupported_gradient_dimension";
9982 break;
9983 }
9984 farg_str += ", gradient" + grad_opt + "(" + to_expression(grad_x) + ", " + to_expression(grad_y) + ")";
9985 }
9986
9987 if (args.min_lod)
9988 {
9989 if (!msl_options.supports_msl_version(2, 2))
9990 SPIRV_CROSS_THROW("min_lod_clamp() is only supported in MSL 2.2+ and up.");
9991
9992 forward = forward && should_forward(args.min_lod);
9993 farg_str += ", min_lod_clamp(" + to_expression(args.min_lod) + ")";
9994 }
9995
9996 // Add offsets
9997 string offset_expr;
9998 const SPIRType *offset_type = nullptr;
9999 if (args.coffset && !args.base.is_fetch)
10000 {
10001 forward = forward && should_forward(args.coffset);
10002 offset_expr = to_expression(args.coffset);
10003 offset_type = &expression_type(args.coffset);
10004 }
10005 else if (args.offset && !args.base.is_fetch)
10006 {
10007 forward = forward && should_forward(args.offset);
10008 offset_expr = to_expression(args.offset);
10009 offset_type = &expression_type(args.offset);
10010 }
10011
10012 if (!offset_expr.empty())
10013 {
10014 switch (imgtype.image.dim)
10015 {
10016 case Dim1D:
10017 if (!msl_options.texture_1D_as_2D)
10018 break;
10019 if (offset_type->vecsize > 1)
10020 offset_expr = enclose_expression(offset_expr) + ".x";
10021
10022 farg_str += join(", int2(", offset_expr, ", 0)");
10023 break;
10024
10025 case Dim2D:
10026 if (offset_type->vecsize > 2)
10027 offset_expr = enclose_expression(offset_expr) + ".xy";
10028
10029 farg_str += ", " + offset_expr;
10030 break;
10031
10032 case Dim3D:
10033 if (offset_type->vecsize > 3)
10034 offset_expr = enclose_expression(offset_expr) + ".xyz";
10035
10036 farg_str += ", " + offset_expr;
10037 break;
10038
10039 default:
10040 break;
10041 }
10042 }
10043
10044 if (args.component)
10045 {
10046 // If 2D has gather component, ensure it also has an offset arg
10047 if (imgtype.image.dim == Dim2D && offset_expr.empty())
10048 farg_str += ", int2(0)";
10049
10050 if (!msl_options.swizzle_texture_samples || is_dynamic_img_sampler)
10051 {
10052 forward = forward && should_forward(args.component);
10053
10054 uint32_t image_var = 0;
10055 if (const auto *combined = maybe_get<SPIRCombinedImageSampler>(img))
10056 {
10057 if (const auto *img_var = maybe_get_backing_variable(combined->image))
10058 image_var = img_var->self;
10059 }
10060 else if (const auto *var = maybe_get_backing_variable(img))
10061 {
10062 image_var = var->self;
10063 }
10064
10065 if (image_var == 0 || !is_depth_image(expression_type(image_var), image_var))
10066 farg_str += ", " + to_component_argument(args.component);
10067 }
10068 }
10069
10070 if (args.sample)
10071 {
10072 forward = forward && should_forward(args.sample);
10073 farg_str += ", ";
10074 farg_str += to_expression(args.sample);
10075 }
10076
10077 *p_forward = forward;
10078
10079 return farg_str;
10080}
10081
10082// If the texture coordinates are floating point, invokes MSL round() function to round them.
10083string CompilerMSL::round_fp_tex_coords(string tex_coords, bool coord_is_fp)
10084{
10085 return coord_is_fp ? ("round(" + tex_coords + ")") : tex_coords;
10086}
10087
10088// Returns a string to use in an image sampling function argument.
10089// The ID must be a scalar constant.
10090string CompilerMSL::to_component_argument(uint32_t id)
10091{
10092 uint32_t component_index = evaluate_constant_u32(id);
10093 switch (component_index)
10094 {
10095 case 0:
10096 return "component::x";
10097 case 1:
10098 return "component::y";
10099 case 2:
10100 return "component::z";
10101 case 3:
10102 return "component::w";
10103
10104 default:
10105 SPIRV_CROSS_THROW("The value (" + to_string(component_index) + ") of OpConstant ID " + to_string(id) +
10106 " is not a valid Component index, which must be one of 0, 1, 2, or 3.");
10107 }
10108}
10109
10110// Establish sampled image as expression object and assign the sampler to it.
10111void CompilerMSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
10112{
10113 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
10114}
10115
10116string CompilerMSL::to_texture_op(const Instruction &i, bool sparse, bool *forward,
10117 SmallVector<uint32_t> &inherited_expressions)
10118{
10119 auto *ops = stream(i);
10120 uint32_t result_type_id = ops[0];
10121 uint32_t img = ops[2];
10122 auto &result_type = get<SPIRType>(result_type_id);
10123 auto op = static_cast<Op>(i.op);
10124 bool is_gather = (op == OpImageGather || op == OpImageDrefGather);
10125
10126 // Bypass pointers because we need the real image struct
10127 auto &type = expression_type(img);
10128 auto &imgtype = get<SPIRType>(type.self);
10129
10130 const MSLConstexprSampler *constexpr_sampler = nullptr;
10131 bool is_dynamic_img_sampler = false;
10132 if (auto *var = maybe_get_backing_variable(img))
10133 {
10134 constexpr_sampler = find_constexpr_sampler(var->basevariable ? var->basevariable : VariableID(var->self));
10135 is_dynamic_img_sampler = has_extended_decoration(var->self, SPIRVCrossDecorationDynamicImageSampler);
10136 }
10137
10138 string expr;
10139 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
10140 {
10141 // If this needs sampler Y'CbCr conversion, we need to do some additional
10142 // processing.
10143 switch (constexpr_sampler->ycbcr_model)
10144 {
10145 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
10146 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
10147 // Default
10148 break;
10149 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
10150 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT709);
10151 expr += "spvConvertYCbCrBT709(";
10152 break;
10153 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
10154 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT601);
10155 expr += "spvConvertYCbCrBT601(";
10156 break;
10157 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
10158 add_spv_func_and_recompile(SPVFuncImplConvertYCbCrBT2020);
10159 expr += "spvConvertYCbCrBT2020(";
10160 break;
10161 default:
10162 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
10163 }
10164
10165 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
10166 {
10167 switch (constexpr_sampler->ycbcr_range)
10168 {
10169 case MSL_SAMPLER_YCBCR_RANGE_ITU_FULL:
10170 add_spv_func_and_recompile(SPVFuncImplExpandITUFullRange);
10171 expr += "spvExpandITUFullRange(";
10172 break;
10173 case MSL_SAMPLER_YCBCR_RANGE_ITU_NARROW:
10174 add_spv_func_and_recompile(SPVFuncImplExpandITUNarrowRange);
10175 expr += "spvExpandITUNarrowRange(";
10176 break;
10177 default:
10178 SPIRV_CROSS_THROW("Invalid Y'CbCr range.");
10179 }
10180 }
10181 }
10182 else if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
10183 !is_dynamic_img_sampler)
10184 {
10185 add_spv_func_and_recompile(SPVFuncImplTextureSwizzle);
10186 expr += "spvTextureSwizzle(";
10187 }
10188
10189 string inner_expr = CompilerGLSL::to_texture_op(i, sparse, forward, inherited_expressions);
10190
10191 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable && !is_dynamic_img_sampler)
10192 {
10193 if (!constexpr_sampler->swizzle_is_identity())
10194 {
10195 static const char swizzle_names[] = "rgba";
10196 if (!constexpr_sampler->swizzle_has_one_or_zero())
10197 {
10198 // If we can, do it inline.
10199 expr += inner_expr + ".";
10200 for (uint32_t c = 0; c < 4; c++)
10201 {
10202 switch (constexpr_sampler->swizzle[c])
10203 {
10204 case MSL_COMPONENT_SWIZZLE_IDENTITY:
10205 expr += swizzle_names[c];
10206 break;
10207 case MSL_COMPONENT_SWIZZLE_R:
10208 case MSL_COMPONENT_SWIZZLE_G:
10209 case MSL_COMPONENT_SWIZZLE_B:
10210 case MSL_COMPONENT_SWIZZLE_A:
10211 expr += swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
10212 break;
10213 default:
10214 SPIRV_CROSS_THROW("Invalid component swizzle.");
10215 }
10216 }
10217 }
10218 else
10219 {
10220 // Otherwise, we need to emit a temporary and swizzle that.
10221 uint32_t temp_id = ir.increase_bound_by(1);
10222 emit_op(result_type_id, temp_id, inner_expr, false);
10223 for (auto &inherit : inherited_expressions)
10224 inherit_expression_dependencies(temp_id, inherit);
10225 inherited_expressions.clear();
10226 inherited_expressions.push_back(temp_id);
10227
10228 switch (op)
10229 {
10230 case OpImageSampleDrefImplicitLod:
10231 case OpImageSampleImplicitLod:
10232 case OpImageSampleProjImplicitLod:
10233 case OpImageSampleProjDrefImplicitLod:
10234 register_control_dependent_expression(temp_id);
10235 break;
10236
10237 default:
10238 break;
10239 }
10240 expr += type_to_glsl(result_type) + "(";
10241 for (uint32_t c = 0; c < 4; c++)
10242 {
10243 switch (constexpr_sampler->swizzle[c])
10244 {
10245 case MSL_COMPONENT_SWIZZLE_IDENTITY:
10246 expr += to_expression(temp_id) + "." + swizzle_names[c];
10247 break;
10248 case MSL_COMPONENT_SWIZZLE_ZERO:
10249 expr += "0";
10250 break;
10251 case MSL_COMPONENT_SWIZZLE_ONE:
10252 expr += "1";
10253 break;
10254 case MSL_COMPONENT_SWIZZLE_R:
10255 case MSL_COMPONENT_SWIZZLE_G:
10256 case MSL_COMPONENT_SWIZZLE_B:
10257 case MSL_COMPONENT_SWIZZLE_A:
10258 expr += to_expression(temp_id) + "." +
10259 swizzle_names[constexpr_sampler->swizzle[c] - MSL_COMPONENT_SWIZZLE_R];
10260 break;
10261 default:
10262 SPIRV_CROSS_THROW("Invalid component swizzle.");
10263 }
10264 if (c < 3)
10265 expr += ", ";
10266 }
10267 expr += ")";
10268 }
10269 }
10270 else
10271 expr += inner_expr;
10272 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY)
10273 {
10274 expr += join(", ", constexpr_sampler->bpc, ")");
10275 if (constexpr_sampler->ycbcr_model != MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY)
10276 expr += ")";
10277 }
10278 }
10279 else
10280 {
10281 expr += inner_expr;
10282 if (msl_options.swizzle_texture_samples && !is_gather && is_sampled_image_type(imgtype) &&
10283 !is_dynamic_img_sampler)
10284 {
10285 // Add the swizzle constant from the swizzle buffer.
10286 expr += ", " + to_swizzle_expression(img) + ")";
10287 used_swizzle_buffer = true;
10288 }
10289 }
10290
10291 return expr;
10292}
10293
10294static string create_swizzle(MSLComponentSwizzle swizzle)
10295{
10296 switch (swizzle)
10297 {
10298 case MSL_COMPONENT_SWIZZLE_IDENTITY:
10299 return "spvSwizzle::none";
10300 case MSL_COMPONENT_SWIZZLE_ZERO:
10301 return "spvSwizzle::zero";
10302 case MSL_COMPONENT_SWIZZLE_ONE:
10303 return "spvSwizzle::one";
10304 case MSL_COMPONENT_SWIZZLE_R:
10305 return "spvSwizzle::red";
10306 case MSL_COMPONENT_SWIZZLE_G:
10307 return "spvSwizzle::green";
10308 case MSL_COMPONENT_SWIZZLE_B:
10309 return "spvSwizzle::blue";
10310 case MSL_COMPONENT_SWIZZLE_A:
10311 return "spvSwizzle::alpha";
10312 default:
10313 SPIRV_CROSS_THROW("Invalid component swizzle.");
10314 }
10315}
10316
10317// Returns a string representation of the ID, usable as a function arg.
10318// Manufacture automatic sampler arg for SampledImage texture.
10319string CompilerMSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
10320{
10321 string arg_str;
10322
10323 auto &type = expression_type(id);
10324 bool is_dynamic_img_sampler = has_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
10325 // If the argument *itself* is a "dynamic" combined-image sampler, then we can just pass that around.
10326 bool arg_is_dynamic_img_sampler = has_extended_decoration(id, SPIRVCrossDecorationDynamicImageSampler);
10327 if (is_dynamic_img_sampler && !arg_is_dynamic_img_sampler)
10328 arg_str = join("spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">(");
10329
10330 auto *c = maybe_get<SPIRConstant>(id);
10331 if (msl_options.force_native_arrays && c && !get<SPIRType>(c->constant_type).array.empty())
10332 {
10333 // If we are passing a constant array directly to a function for some reason,
10334 // the callee will expect an argument in thread const address space
10335 // (since we can only bind to arrays with references in MSL).
10336 // To resolve this, we must emit a copy in this address space.
10337 // This kind of code gen should be rare enough that performance is not a real concern.
10338 // Inline the SPIR-V to avoid this kind of suboptimal codegen.
10339 //
10340 // We risk calling this inside a continue block (invalid code),
10341 // so just create a thread local copy in the current function.
10342 arg_str = join("_", id, "_array_copy");
10343 auto &constants = current_function->constant_arrays_needed_on_stack;
10344 auto itr = find(begin(constants), end(constants), ID(id));
10345 if (itr == end(constants))
10346 {
10347 force_recompile();
10348 constants.push_back(id);
10349 }
10350 }
10351 else
10352 arg_str += CompilerGLSL::to_func_call_arg(arg, id);
10353
10354 // Need to check the base variable in case we need to apply a qualified alias.
10355 uint32_t var_id = 0;
10356 auto *var = maybe_get<SPIRVariable>(id);
10357 if (var)
10358 var_id = var->basevariable;
10359
10360 if (!arg_is_dynamic_img_sampler)
10361 {
10362 auto *constexpr_sampler = find_constexpr_sampler(var_id ? var_id : id);
10363 if (type.basetype == SPIRType::SampledImage)
10364 {
10365 // Manufacture automatic plane args for multiplanar texture
10366 uint32_t planes = 1;
10367 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10368 {
10369 planes = constexpr_sampler->planes;
10370 // If this parameter isn't aliasing a global, then we need to use
10371 // the special "dynamic image-sampler" class to pass it--and we need
10372 // to use it for *every* non-alias parameter, in case a combined
10373 // image-sampler with a Y'CbCr conversion is passed. Hopefully, this
10374 // pathological case is so rare that it should never be hit in practice.
10375 if (!arg.alias_global_variable)
10376 add_spv_func_and_recompile(SPVFuncImplDynamicImageSampler);
10377 }
10378 for (uint32_t i = 1; i < planes; i++)
10379 arg_str += join(", ", CompilerGLSL::to_func_call_arg(arg, id), plane_name_suffix, i);
10380 // Manufacture automatic sampler arg if the arg is a SampledImage texture.
10381 if (type.image.dim != DimBuffer)
10382 arg_str += ", " + to_sampler_expression(var_id ? var_id : id);
10383
10384 // Add sampler Y'CbCr conversion info if we have it
10385 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10386 {
10387 SmallVector<string> samp_args;
10388
10389 switch (constexpr_sampler->resolution)
10390 {
10391 case MSL_FORMAT_RESOLUTION_444:
10392 // Default
10393 break;
10394 case MSL_FORMAT_RESOLUTION_422:
10395 samp_args.push_back("spvFormatResolution::_422");
10396 break;
10397 case MSL_FORMAT_RESOLUTION_420:
10398 samp_args.push_back("spvFormatResolution::_420");
10399 break;
10400 default:
10401 SPIRV_CROSS_THROW("Invalid format resolution.");
10402 }
10403
10404 if (constexpr_sampler->chroma_filter != MSL_SAMPLER_FILTER_NEAREST)
10405 samp_args.push_back("spvChromaFilter::linear");
10406
10407 if (constexpr_sampler->x_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
10408 samp_args.push_back("spvXChromaLocation::midpoint");
10409 if (constexpr_sampler->y_chroma_offset != MSL_CHROMA_LOCATION_COSITED_EVEN)
10410 samp_args.push_back("spvYChromaLocation::midpoint");
10411 switch (constexpr_sampler->ycbcr_model)
10412 {
10413 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_RGB_IDENTITY:
10414 // Default
10415 break;
10416 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_IDENTITY:
10417 samp_args.push_back("spvYCbCrModelConversion::ycbcr_identity");
10418 break;
10419 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_709:
10420 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_709");
10421 break;
10422 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_601:
10423 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_601");
10424 break;
10425 case MSL_SAMPLER_YCBCR_MODEL_CONVERSION_YCBCR_BT_2020:
10426 samp_args.push_back("spvYCbCrModelConversion::ycbcr_bt_2020");
10427 break;
10428 default:
10429 SPIRV_CROSS_THROW("Invalid Y'CbCr model conversion.");
10430 }
10431 if (constexpr_sampler->ycbcr_range != MSL_SAMPLER_YCBCR_RANGE_ITU_FULL)
10432 samp_args.push_back("spvYCbCrRange::itu_narrow");
10433 samp_args.push_back(join("spvComponentBits(", constexpr_sampler->bpc, ")"));
10434 arg_str += join(", spvYCbCrSampler(", merge(samp_args), ")");
10435 }
10436 }
10437
10438 if (is_dynamic_img_sampler && constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
10439 arg_str += join(", (uint(", create_swizzle(constexpr_sampler->swizzle[3]), ") << 24) | (uint(",
10440 create_swizzle(constexpr_sampler->swizzle[2]), ") << 16) | (uint(",
10441 create_swizzle(constexpr_sampler->swizzle[1]), ") << 8) | uint(",
10442 create_swizzle(constexpr_sampler->swizzle[0]), ")");
10443 else if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
10444 arg_str += ", " + to_swizzle_expression(var_id ? var_id : id);
10445
10446 if (buffers_requiring_array_length.count(var_id))
10447 arg_str += ", " + to_buffer_size_expression(var_id ? var_id : id);
10448
10449 if (is_dynamic_img_sampler)
10450 arg_str += ")";
10451 }
10452
10453 // Emulate texture2D atomic operations
10454 auto *backing_var = maybe_get_backing_variable(var_id);
10455 if (backing_var && atomic_image_vars.count(backing_var->self))
10456 {
10457 arg_str += ", " + to_expression(var_id) + "_atomic";
10458 }
10459
10460 return arg_str;
10461}
10462
10463// If the ID represents a sampled image that has been assigned a sampler already,
10464// generate an expression for the sampler, otherwise generate a fake sampler name
10465// by appending a suffix to the expression constructed from the ID.
10466string CompilerMSL::to_sampler_expression(uint32_t id)
10467{
10468 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
10469 auto expr = to_expression(combined ? combined->image : VariableID(id));
10470 auto index = expr.find_first_of('[');
10471
10472 uint32_t samp_id = 0;
10473 if (combined)
10474 samp_id = combined->sampler;
10475
10476 if (index == string::npos)
10477 return samp_id ? to_expression(samp_id) : expr + sampler_name_suffix;
10478 else
10479 {
10480 auto image_expr = expr.substr(0, index);
10481 auto array_expr = expr.substr(index);
10482 return samp_id ? to_expression(samp_id) : (image_expr + sampler_name_suffix + array_expr);
10483 }
10484}
10485
10486string CompilerMSL::to_swizzle_expression(uint32_t id)
10487{
10488 auto *combined = maybe_get<SPIRCombinedImageSampler>(id);
10489
10490 auto expr = to_expression(combined ? combined->image : VariableID(id));
10491 auto index = expr.find_first_of('[');
10492
10493 // If an image is part of an argument buffer translate this to a legal identifier.
10494 string::size_type period = 0;
10495 while ((period = expr.find_first_of('.', period)) != string::npos && period < index)
10496 expr[period] = '_';
10497
10498 if (index == string::npos)
10499 return expr + swizzle_name_suffix;
10500 else
10501 {
10502 auto image_expr = expr.substr(0, index);
10503 auto array_expr = expr.substr(index);
10504 return image_expr + swizzle_name_suffix + array_expr;
10505 }
10506}
10507
10508string CompilerMSL::to_buffer_size_expression(uint32_t id)
10509{
10510 auto expr = to_expression(id);
10511 auto index = expr.find_first_of('[');
10512
10513 // This is quite crude, but we need to translate the reference name (*spvDescriptorSetN.name) to
10514 // the pointer expression spvDescriptorSetN.name to make a reasonable expression here.
10515 // This only happens if we have argument buffers and we are using OpArrayLength on a lone SSBO in that set.
10516 if (expr.size() >= 3 && expr[0] == '(' && expr[1] == '*')
10517 expr = address_of_expression(expr);
10518
10519 // If a buffer is part of an argument buffer translate this to a legal identifier.
10520 for (auto &c : expr)
10521 if (c == '.')
10522 c = '_';
10523
10524 if (index == string::npos)
10525 return expr + buffer_size_name_suffix;
10526 else
10527 {
10528 auto buffer_expr = expr.substr(0, index);
10529 auto array_expr = expr.substr(index);
10530 return buffer_expr + buffer_size_name_suffix + array_expr;
10531 }
10532}
10533
10534// Checks whether the type is a Block all of whose members have DecorationPatch.
10535bool CompilerMSL::is_patch_block(const SPIRType &type)
10536{
10537 if (!has_decoration(type.self, DecorationBlock))
10538 return false;
10539
10540 for (uint32_t i = 0; i < type.member_types.size(); i++)
10541 {
10542 if (!has_member_decoration(type.self, i, DecorationPatch))
10543 return false;
10544 }
10545
10546 return true;
10547}
10548
10549// Checks whether the ID is a row_major matrix that requires conversion before use
10550bool CompilerMSL::is_non_native_row_major_matrix(uint32_t id)
10551{
10552 auto *e = maybe_get<SPIRExpression>(id);
10553 if (e)
10554 return e->need_transpose;
10555 else
10556 return has_decoration(id, DecorationRowMajor);
10557}
10558
10559// Checks whether the member is a row_major matrix that requires conversion before use
10560bool CompilerMSL::member_is_non_native_row_major_matrix(const SPIRType &type, uint32_t index)
10561{
10562 return has_member_decoration(type.self, index, DecorationRowMajor);
10563}
10564
10565string CompilerMSL::convert_row_major_matrix(string exp_str, const SPIRType &exp_type, uint32_t physical_type_id,
10566 bool is_packed)
10567{
10568 if (!is_matrix(exp_type))
10569 {
10570 return CompilerGLSL::convert_row_major_matrix(move(exp_str), exp_type, physical_type_id, is_packed);
10571 }
10572 else
10573 {
10574 strip_enclosed_expression(exp_str);
10575 if (physical_type_id != 0 || is_packed)
10576 exp_str = unpack_expression_type(exp_str, exp_type, physical_type_id, is_packed, true);
10577 return join("transpose(", exp_str, ")");
10578 }
10579}
10580
10581// Called automatically at the end of the entry point function
10582void CompilerMSL::emit_fixup()
10583{
10584 if (is_vertex_like_shader() && stage_out_var_id && !qual_pos_var_name.empty() && !capture_output_to_buffer)
10585 {
10586 if (options.vertex.fixup_clipspace)
10587 statement(qual_pos_var_name, ".z = (", qual_pos_var_name, ".z + ", qual_pos_var_name,
10588 ".w) * 0.5; // Adjust clip-space for Metal");
10589
10590 if (options.vertex.flip_vert_y)
10591 statement(qual_pos_var_name, ".y = -(", qual_pos_var_name, ".y);", " // Invert Y-axis for Metal");
10592 }
10593}
10594
10595// Return a string defining a structure member, with padding and packing.
10596string CompilerMSL::to_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
10597 const string &qualifier)
10598{
10599 if (member_is_remapped_physical_type(type, index))
10600 member_type_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID);
10601 auto &physical_type = get<SPIRType>(member_type_id);
10602
10603 // If this member is packed, mark it as so.
10604 string pack_pfx;
10605
10606 // Allow Metal to use the array<T> template to make arrays a value type
10607 uint32_t orig_id = 0;
10608 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID))
10609 orig_id = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID);
10610
10611 bool row_major = false;
10612 if (is_matrix(physical_type))
10613 row_major = has_member_decoration(type.self, index, DecorationRowMajor);
10614
10615 SPIRType row_major_physical_type;
10616 const SPIRType *declared_type = &physical_type;
10617
10618 // If a struct is being declared with physical layout,
10619 // do not use array<T> wrappers.
10620 // This avoids a lot of complicated cases with packed vectors and matrices,
10621 // and generally we cannot copy full arrays in and out of buffers into Function
10622 // address space.
10623 // Array of resources should also be declared as builtin arrays.
10624 if (has_member_decoration(type.self, index, DecorationOffset))
10625 is_using_builtin_array = true;
10626 else if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
10627 is_using_builtin_array = true;
10628
10629 if (member_is_packed_physical_type(type, index))
10630 {
10631 // If we're packing a matrix, output an appropriate typedef
10632 if (physical_type.basetype == SPIRType::Struct)
10633 {
10634 SPIRV_CROSS_THROW("Cannot emit a packed struct currently.");
10635 }
10636 else if (is_matrix(physical_type))
10637 {
10638 uint32_t rows = physical_type.vecsize;
10639 uint32_t cols = physical_type.columns;
10640 pack_pfx = "packed_";
10641 if (row_major)
10642 {
10643 // These are stored transposed.
10644 rows = physical_type.columns;
10645 cols = physical_type.vecsize;
10646 pack_pfx = "packed_rm_";
10647 }
10648 string base_type = physical_type.width == 16 ? "half" : "float";
10649 string td_line = "typedef ";
10650 td_line += "packed_" + base_type + to_string(rows);
10651 td_line += " " + pack_pfx;
10652 // Use the actual matrix size here.
10653 td_line += base_type + to_string(physical_type.columns) + "x" + to_string(physical_type.vecsize);
10654 td_line += "[" + to_string(cols) + "]";
10655 td_line += ";";
10656 add_typedef_line(td_line);
10657 }
10658 else if (!is_scalar(physical_type)) // scalar type is already packed.
10659 pack_pfx = "packed_";
10660 }
10661 else if (row_major)
10662 {
10663 // Need to declare type with flipped vecsize/columns.
10664 row_major_physical_type = physical_type;
10665 swap(row_major_physical_type.vecsize, row_major_physical_type.columns);
10666 declared_type = &row_major_physical_type;
10667 }
10668
10669 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
10670 if (msl_options.is_ios() && physical_type.basetype == SPIRType::Image && physical_type.image.sampled == 2)
10671 {
10672 if (!has_decoration(orig_id, DecorationNonWritable))
10673 SPIRV_CROSS_THROW("Writable images are not allowed in argument buffers on iOS.");
10674 }
10675
10676 // Array information is baked into these types.
10677 string array_type;
10678 if (physical_type.basetype != SPIRType::Image && physical_type.basetype != SPIRType::Sampler &&
10679 physical_type.basetype != SPIRType::SampledImage)
10680 {
10681 BuiltIn builtin = BuiltInMax;
10682
10683 // Special handling. In [[stage_out]] or [[stage_in]] blocks,
10684 // we need flat arrays, but if we're somehow declaring gl_PerVertex for constant array reasons, we want
10685 // template array types to be declared.
10686 bool is_ib_in_out =
10687 ((stage_out_var_id && get_stage_out_struct_type().self == type.self &&
10688 variable_storage_requires_stage_io(StorageClassOutput)) ||
10689 (stage_in_var_id && get_stage_in_struct_type().self == type.self &&
10690 variable_storage_requires_stage_io(StorageClassInput)));
10691 if (is_ib_in_out && is_member_builtin(type, index, &builtin))
10692 is_using_builtin_array = true;
10693 array_type = type_to_array_glsl(physical_type);
10694 }
10695
10696 auto result = join(pack_pfx, type_to_glsl(*declared_type, orig_id), " ", qualifier, to_member_name(type, index),
10697 member_attribute_qualifier(type, index), array_type, ";");
10698
10699 is_using_builtin_array = false;
10700 return result;
10701}
10702
10703// Emit a structure member, padding and packing to maintain the correct memeber alignments.
10704void CompilerMSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
10705 const string &qualifier, uint32_t)
10706{
10707 // If this member requires padding to maintain its declared offset, emit a dummy padding member before it.
10708 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget))
10709 {
10710 uint32_t pad_len = get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPaddingTarget);
10711 statement("char _m", index, "_pad", "[", pad_len, "];");
10712 }
10713
10714 // Handle HLSL-style 0-based vertex/instance index.
10715 builtin_declaration = true;
10716 statement(to_struct_member(type, member_type_id, index, qualifier));
10717 builtin_declaration = false;
10718}
10719
10720void CompilerMSL::emit_struct_padding_target(const SPIRType &type)
10721{
10722 uint32_t struct_size = get_declared_struct_size_msl(type, true, true);
10723 uint32_t target_size = get_extended_decoration(type.self, SPIRVCrossDecorationPaddingTarget);
10724 if (target_size < struct_size)
10725 SPIRV_CROSS_THROW("Cannot pad with negative bytes.");
10726 else if (target_size > struct_size)
10727 statement("char _m0_final_padding[", target_size - struct_size, "];");
10728}
10729
10730// Return a MSL qualifier for the specified function attribute member
10731string CompilerMSL::member_attribute_qualifier(const SPIRType &type, uint32_t index)
10732{
10733 auto &execution = get_entry_point();
10734
10735 uint32_t mbr_type_id = type.member_types[index];
10736 auto &mbr_type = get<SPIRType>(mbr_type_id);
10737
10738 BuiltIn builtin = BuiltInMax;
10739 bool is_builtin = is_member_builtin(type, index, &builtin);
10740
10741 if (has_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary))
10742 {
10743 string quals = join(
10744 " [[id(", get_extended_member_decoration(type.self, index, SPIRVCrossDecorationResourceIndexPrimary), ")");
10745 if (interlocked_resources.count(
10746 get_extended_member_decoration(type.self, index, SPIRVCrossDecorationInterfaceOrigID)))
10747 quals += ", raster_order_group(0)";
10748 quals += "]]";
10749 return quals;
10750 }
10751
10752 // Vertex function inputs
10753 if (execution.model == ExecutionModelVertex && type.storage == StorageClassInput)
10754 {
10755 if (is_builtin)
10756 {
10757 switch (builtin)
10758 {
10759 case BuiltInVertexId:
10760 case BuiltInVertexIndex:
10761 case BuiltInBaseVertex:
10762 case BuiltInInstanceId:
10763 case BuiltInInstanceIndex:
10764 case BuiltInBaseInstance:
10765 if (msl_options.vertex_for_tessellation)
10766 return "";
10767 return string(" [[") + builtin_qualifier(builtin) + "]]";
10768
10769 case BuiltInDrawIndex:
10770 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
10771
10772 default:
10773 return "";
10774 }
10775 }
10776
10777 uint32_t locn;
10778 if (is_builtin)
10779 locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10780 else
10781 locn = get_member_location(type.self, index);
10782
10783 if (locn != k_unknown_location)
10784 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10785 }
10786
10787 // Vertex and tessellation evaluation function outputs
10788 if (((execution.model == ExecutionModelVertex && !msl_options.vertex_for_tessellation) ||
10789 execution.model == ExecutionModelTessellationEvaluation) &&
10790 type.storage == StorageClassOutput)
10791 {
10792 if (is_builtin)
10793 {
10794 switch (builtin)
10795 {
10796 case BuiltInPointSize:
10797 // Only mark the PointSize builtin if really rendering points.
10798 // Some shaders may include a PointSize builtin even when used to render
10799 // non-point topologies, and Metal will reject this builtin when compiling
10800 // the shader into a render pipeline that uses a non-point topology.
10801 return msl_options.enable_point_size_builtin ? (string(" [[") + builtin_qualifier(builtin) + "]]") : "";
10802
10803 case BuiltInViewportIndex:
10804 if (!msl_options.supports_msl_version(2, 0))
10805 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
10806 /* fallthrough */
10807 case BuiltInPosition:
10808 case BuiltInLayer:
10809 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10810
10811 case BuiltInClipDistance:
10812 if (has_member_decoration(type.self, index, DecorationIndex))
10813 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10814 else
10815 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10816
10817 case BuiltInCullDistance:
10818 if (has_member_decoration(type.self, index, DecorationIndex))
10819 return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10820 else
10821 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10822
10823 default:
10824 return "";
10825 }
10826 }
10827 string loc_qual = member_location_attribute_qualifier(type, index);
10828 if (!loc_qual.empty())
10829 return join(" [[", loc_qual, "]]");
10830 }
10831
10832 // Tessellation control function inputs
10833 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassInput)
10834 {
10835 if (is_builtin)
10836 {
10837 switch (builtin)
10838 {
10839 case BuiltInInvocationId:
10840 case BuiltInPrimitiveId:
10841 if (msl_options.multi_patch_workgroup)
10842 return "";
10843 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10844 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
10845 case BuiltInSubgroupSize: // FIXME: Should work in any stage
10846 if (msl_options.emulate_subgroups)
10847 return "";
10848 return string(" [[") + builtin_qualifier(builtin) + "]]" + (mbr_type.array.empty() ? "" : " ");
10849 case BuiltInPatchVertices:
10850 return "";
10851 // Others come from stage input.
10852 default:
10853 break;
10854 }
10855 }
10856 if (msl_options.multi_patch_workgroup)
10857 return "";
10858
10859 uint32_t locn;
10860 if (is_builtin)
10861 locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10862 else
10863 locn = get_member_location(type.self, index);
10864
10865 if (locn != k_unknown_location)
10866 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10867 }
10868
10869 // Tessellation control function outputs
10870 if (execution.model == ExecutionModelTessellationControl && type.storage == StorageClassOutput)
10871 {
10872 // For this type of shader, we always arrange for it to capture its
10873 // output to a buffer. For this reason, qualifiers are irrelevant here.
10874 return "";
10875 }
10876
10877 // Tessellation evaluation function inputs
10878 if (execution.model == ExecutionModelTessellationEvaluation && type.storage == StorageClassInput)
10879 {
10880 if (is_builtin)
10881 {
10882 switch (builtin)
10883 {
10884 case BuiltInPrimitiveId:
10885 case BuiltInTessCoord:
10886 return string(" [[") + builtin_qualifier(builtin) + "]]";
10887 case BuiltInPatchVertices:
10888 return "";
10889 // Others come from stage input.
10890 default:
10891 break;
10892 }
10893 }
10894 // The special control point array must not be marked with an attribute.
10895 if (get_type(type.member_types[index]).basetype == SPIRType::ControlPointArray)
10896 return "";
10897
10898 uint32_t locn;
10899 if (is_builtin)
10900 locn = get_or_allocate_builtin_input_member_location(builtin, type.self, index);
10901 else
10902 locn = get_member_location(type.self, index);
10903
10904 if (locn != k_unknown_location)
10905 return string(" [[attribute(") + convert_to_string(locn) + ")]]";
10906 }
10907
10908 // Tessellation evaluation function outputs were handled above.
10909
10910 // Fragment function inputs
10911 if (execution.model == ExecutionModelFragment && type.storage == StorageClassInput)
10912 {
10913 string quals;
10914 if (is_builtin)
10915 {
10916 switch (builtin)
10917 {
10918 case BuiltInViewIndex:
10919 if (!msl_options.multiview || !msl_options.multiview_layered_rendering)
10920 break;
10921 /* fallthrough */
10922 case BuiltInFrontFacing:
10923 case BuiltInPointCoord:
10924 case BuiltInFragCoord:
10925 case BuiltInSampleId:
10926 case BuiltInSampleMask:
10927 case BuiltInLayer:
10928 case BuiltInBaryCoordNV:
10929 case BuiltInBaryCoordNoPerspNV:
10930 quals = builtin_qualifier(builtin);
10931 break;
10932
10933 case BuiltInClipDistance:
10934 return join(" [[user(clip", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10935 case BuiltInCullDistance:
10936 return join(" [[user(cull", get_member_decoration(type.self, index, DecorationIndex), ")]]");
10937
10938 default:
10939 break;
10940 }
10941 }
10942 else
10943 quals = member_location_attribute_qualifier(type, index);
10944
10945 if (builtin == BuiltInBaryCoordNV || builtin == BuiltInBaryCoordNoPerspNV)
10946 {
10947 if (has_member_decoration(type.self, index, DecorationFlat) ||
10948 has_member_decoration(type.self, index, DecorationCentroid) ||
10949 has_member_decoration(type.self, index, DecorationSample) ||
10950 has_member_decoration(type.self, index, DecorationNoPerspective))
10951 {
10952 // NoPerspective is baked into the builtin type.
10953 SPIRV_CROSS_THROW(
10954 "Flat, Centroid, Sample, NoPerspective decorations are not supported for BaryCoord inputs.");
10955 }
10956 }
10957
10958 // Don't bother decorating integers with the 'flat' attribute; it's
10959 // the default (in fact, the only option). Also don't bother with the
10960 // FragCoord builtin; it's always noperspective on Metal.
10961 if (!type_is_integral(mbr_type) && (!is_builtin || builtin != BuiltInFragCoord))
10962 {
10963 if (has_member_decoration(type.self, index, DecorationFlat))
10964 {
10965 if (!quals.empty())
10966 quals += ", ";
10967 quals += "flat";
10968 }
10969 else if (has_member_decoration(type.self, index, DecorationCentroid))
10970 {
10971 if (!quals.empty())
10972 quals += ", ";
10973 if (has_member_decoration(type.self, index, DecorationNoPerspective))
10974 quals += "centroid_no_perspective";
10975 else
10976 quals += "centroid_perspective";
10977 }
10978 else if (has_member_decoration(type.self, index, DecorationSample))
10979 {
10980 if (!quals.empty())
10981 quals += ", ";
10982 if (has_member_decoration(type.self, index, DecorationNoPerspective))
10983 quals += "sample_no_perspective";
10984 else
10985 quals += "sample_perspective";
10986 }
10987 else if (has_member_decoration(type.self, index, DecorationNoPerspective))
10988 {
10989 if (!quals.empty())
10990 quals += ", ";
10991 quals += "center_no_perspective";
10992 }
10993 }
10994
10995 if (!quals.empty())
10996 return " [[" + quals + "]]";
10997 }
10998
10999 // Fragment function outputs
11000 if (execution.model == ExecutionModelFragment && type.storage == StorageClassOutput)
11001 {
11002 if (is_builtin)
11003 {
11004 switch (builtin)
11005 {
11006 case BuiltInFragStencilRefEXT:
11007 // Similar to PointSize, only mark FragStencilRef if there's a stencil buffer.
11008 // Some shaders may include a FragStencilRef builtin even when used to render
11009 // without a stencil attachment, and Metal will reject this builtin
11010 // when compiling the shader into a render pipeline that does not set
11011 // stencilAttachmentPixelFormat.
11012 if (!msl_options.enable_frag_stencil_ref_builtin)
11013 return "";
11014 if (!msl_options.supports_msl_version(2, 1))
11015 SPIRV_CROSS_THROW("Stencil export only supported in MSL 2.1 and up.");
11016 return string(" [[") + builtin_qualifier(builtin) + "]]";
11017
11018 case BuiltInFragDepth:
11019 // Ditto FragDepth.
11020 if (!msl_options.enable_frag_depth_builtin)
11021 return "";
11022 /* fallthrough */
11023 case BuiltInSampleMask:
11024 return string(" [[") + builtin_qualifier(builtin) + "]]";
11025
11026 default:
11027 return "";
11028 }
11029 }
11030 uint32_t locn = get_member_location(type.self, index);
11031 // Metal will likely complain about missing color attachments, too.
11032 if (locn != k_unknown_location && !(msl_options.enable_frag_output_mask & (1 << locn)))
11033 return "";
11034 if (locn != k_unknown_location && has_member_decoration(type.self, index, DecorationIndex))
11035 return join(" [[color(", locn, "), index(", get_member_decoration(type.self, index, DecorationIndex),
11036 ")]]");
11037 else if (locn != k_unknown_location)
11038 return join(" [[color(", locn, ")]]");
11039 else if (has_member_decoration(type.self, index, DecorationIndex))
11040 return join(" [[index(", get_member_decoration(type.self, index, DecorationIndex), ")]]");
11041 else
11042 return "";
11043 }
11044
11045 // Compute function inputs
11046 if (execution.model == ExecutionModelGLCompute && type.storage == StorageClassInput)
11047 {
11048 if (is_builtin)
11049 {
11050 switch (builtin)
11051 {
11052 case BuiltInNumSubgroups:
11053 case BuiltInSubgroupId:
11054 case BuiltInSubgroupLocalInvocationId: // FIXME: Should work in any stage
11055 case BuiltInSubgroupSize: // FIXME: Should work in any stage
11056 if (msl_options.emulate_subgroups)
11057 break;
11058 /* fallthrough */
11059 case BuiltInGlobalInvocationId:
11060 case BuiltInWorkgroupId:
11061 case BuiltInNumWorkgroups:
11062 case BuiltInLocalInvocationId:
11063 case BuiltInLocalInvocationIndex:
11064 return string(" [[") + builtin_qualifier(builtin) + "]]";
11065
11066 default:
11067 return "";
11068 }
11069 }
11070 }
11071
11072 return "";
11073}
11074
11075// A user-defined output variable is considered to match an input variable in the subsequent
11076// stage if the two variables are declared with the same Location and Component decoration and
11077// match in type and decoration, except that interpolation decorations are not required to match.
11078// For the purposes of interface matching, variables declared without a Component decoration are
11079// considered to have a Component decoration of zero.
11080string CompilerMSL::member_location_attribute_qualifier(const SPIRType &type, uint32_t index)
11081{
11082 string quals;
11083 uint32_t comp;
11084 uint32_t locn = get_member_location(type.self, index, &comp);
11085 if (locn != k_unknown_location)
11086 {
11087 quals += "user(locn";
11088 quals += convert_to_string(locn);
11089 if (comp != k_unknown_component && comp != 0)
11090 {
11091 quals += "_";
11092 quals += convert_to_string(comp);
11093 }
11094 quals += ")";
11095 }
11096 return quals;
11097}
11098
11099// Returns the location decoration of the member with the specified index in the specified type.
11100// If the location of the member has been explicitly set, that location is used. If not, this
11101// function assumes the members are ordered in their location order, and simply returns the
11102// index as the location.
11103uint32_t CompilerMSL::get_member_location(uint32_t type_id, uint32_t index, uint32_t *comp) const
11104{
11105 if (comp)
11106 {
11107 if (has_member_decoration(type_id, index, DecorationComponent))
11108 *comp = get_member_decoration(type_id, index, DecorationComponent);
11109 else
11110 *comp = k_unknown_component;
11111 }
11112
11113 if (has_member_decoration(type_id, index, DecorationLocation))
11114 return get_member_decoration(type_id, index, DecorationLocation);
11115 else
11116 return k_unknown_location;
11117}
11118
11119uint32_t CompilerMSL::get_or_allocate_builtin_input_member_location(spv::BuiltIn builtin,
11120 uint32_t type_id, uint32_t index,
11121 uint32_t *comp)
11122{
11123 uint32_t loc = get_member_location(type_id, index, comp);
11124 if (loc != k_unknown_location)
11125 return loc;
11126
11127 if (comp)
11128 *comp = k_unknown_component;
11129
11130 // Late allocation. Find a location which is unused by the application.
11131 // This can happen for built-in inputs in tessellation which are mixed and matched with user inputs.
11132 auto &mbr_type = get<SPIRType>(get<SPIRType>(type_id).member_types[index]);
11133 uint32_t count = type_to_location_count(mbr_type);
11134
11135 loc = 0;
11136
11137 const auto location_range_in_use = [this](uint32_t location, uint32_t location_count) -> bool {
11138 for (uint32_t i = 0; i < location_count; i++)
11139 if (location_inputs_in_use.count(location + i) != 0)
11140 return true;
11141 return false;
11142 };
11143
11144 while (location_range_in_use(loc, count))
11145 loc++;
11146
11147 set_member_decoration(type_id, index, DecorationLocation, loc);
11148
11149 // Triangle tess level inputs are shared in one packed float4,
11150 // mark both builtins as sharing one location.
11151 if (get_execution_mode_bitset().get(ExecutionModeTriangles) &&
11152 (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
11153 {
11154 builtin_to_automatic_input_location[BuiltInTessLevelInner] = loc;
11155 builtin_to_automatic_input_location[BuiltInTessLevelOuter] = loc;
11156 }
11157 else
11158 builtin_to_automatic_input_location[builtin] = loc;
11159
11160 mark_location_as_used_by_shader(loc, mbr_type, StorageClassInput, true);
11161 return loc;
11162}
11163
11164// Returns the type declaration for a function, including the
11165// entry type if the current function is the entry point function
11166string CompilerMSL::func_type_decl(SPIRType &type)
11167{
11168 // The regular function return type. If not processing the entry point function, that's all we need
11169 string return_type = type_to_glsl(type) + type_to_array_glsl(type);
11170 if (!processing_entry_point)
11171 return return_type;
11172
11173 // If an outgoing interface block has been defined, and it should be returned, override the entry point return type
11174 bool ep_should_return_output = !get_is_rasterization_disabled();
11175 if (stage_out_var_id && ep_should_return_output)
11176 return_type = type_to_glsl(get_stage_out_struct_type()) + type_to_array_glsl(type);
11177
11178 // Prepend a entry type, based on the execution model
11179 string entry_type;
11180 auto &execution = get_entry_point();
11181 switch (execution.model)
11182 {
11183 case ExecutionModelVertex:
11184 if (msl_options.vertex_for_tessellation && !msl_options.supports_msl_version(1, 2))
11185 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11186 entry_type = msl_options.vertex_for_tessellation ? "kernel" : "vertex";
11187 break;
11188 case ExecutionModelTessellationEvaluation:
11189 if (!msl_options.supports_msl_version(1, 2))
11190 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11191 if (execution.flags.get(ExecutionModeIsolines))
11192 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
11193 if (msl_options.is_ios())
11194 entry_type =
11195 join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ") ]] vertex");
11196 else
11197 entry_type = join("[[ patch(", execution.flags.get(ExecutionModeTriangles) ? "triangle" : "quad", ", ",
11198 execution.output_vertices, ") ]] vertex");
11199 break;
11200 case ExecutionModelFragment:
11201 entry_type = uses_explicit_early_fragment_test() ? "[[ early_fragment_tests ]] fragment" : "fragment";
11202 break;
11203 case ExecutionModelTessellationControl:
11204 if (!msl_options.supports_msl_version(1, 2))
11205 SPIRV_CROSS_THROW("Tessellation requires Metal 1.2.");
11206 if (execution.flags.get(ExecutionModeIsolines))
11207 SPIRV_CROSS_THROW("Metal does not support isoline tessellation.");
11208 /* fallthrough */
11209 case ExecutionModelGLCompute:
11210 case ExecutionModelKernel:
11211 entry_type = "kernel";
11212 break;
11213 default:
11214 entry_type = "unknown";
11215 break;
11216 }
11217
11218 return entry_type + " " + return_type;
11219}
11220
11221bool CompilerMSL::uses_explicit_early_fragment_test()
11222{
11223 auto &ep_flags = get_entry_point().flags;
11224 return ep_flags.get(ExecutionModeEarlyFragmentTests) || ep_flags.get(ExecutionModePostDepthCoverage);
11225}
11226
11227// In MSL, address space qualifiers are required for all pointer or reference variables
11228string CompilerMSL::get_argument_address_space(const SPIRVariable &argument)
11229{
11230 const auto &type = get<SPIRType>(argument.basetype);
11231 return get_type_address_space(type, argument.self, true);
11232}
11233
11234string CompilerMSL::get_type_address_space(const SPIRType &type, uint32_t id, bool argument)
11235{
11236 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
11237 Bitset flags;
11238 auto *var = maybe_get<SPIRVariable>(id);
11239 if (var && type.basetype == SPIRType::Struct &&
11240 (has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock)))
11241 flags = get_buffer_block_flags(id);
11242 else
11243 flags = get_decoration_bitset(id);
11244
11245 const char *addr_space = nullptr;
11246 switch (type.storage)
11247 {
11248 case StorageClassWorkgroup:
11249 addr_space = "threadgroup";
11250 break;
11251
11252 case StorageClassStorageBuffer:
11253 {
11254 // For arguments from variable pointers, we use the write count deduction, so
11255 // we should not assume any constness here. Only for global SSBOs.
11256 bool readonly = false;
11257 if (!var || has_decoration(type.self, DecorationBlock))
11258 readonly = flags.get(DecorationNonWritable);
11259
11260 addr_space = readonly ? "const device" : "device";
11261 break;
11262 }
11263
11264 case StorageClassUniform:
11265 case StorageClassUniformConstant:
11266 case StorageClassPushConstant:
11267 if (type.basetype == SPIRType::Struct)
11268 {
11269 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11270 if (ssbo)
11271 addr_space = flags.get(DecorationNonWritable) ? "const device" : "device";
11272 else
11273 addr_space = "constant";
11274 }
11275 else if (!argument)
11276 {
11277 addr_space = "constant";
11278 }
11279 else if (type_is_msl_framebuffer_fetch(type))
11280 {
11281 // Subpass inputs are passed around by value.
11282 addr_space = "";
11283 }
11284 break;
11285
11286 case StorageClassFunction:
11287 case StorageClassGeneric:
11288 break;
11289
11290 case StorageClassInput:
11291 if (get_execution_model() == ExecutionModelTessellationControl && var &&
11292 var->basevariable == stage_in_ptr_var_id)
11293 addr_space = msl_options.multi_patch_workgroup ? "constant" : "threadgroup";
11294 if (get_execution_model() == ExecutionModelFragment && var && var->basevariable == stage_in_var_id)
11295 addr_space = "thread";
11296 break;
11297
11298 case StorageClassOutput:
11299 if (capture_output_to_buffer)
11300 {
11301 if (var && type.storage == StorageClassOutput)
11302 {
11303 bool is_masked = is_stage_output_variable_masked(*var);
11304
11305 if (is_masked)
11306 {
11307 if (is_tessellation_shader())
11308 addr_space = "threadgroup";
11309 else
11310 addr_space = "thread";
11311 }
11312 else if (variable_decl_is_remapped_storage(*var, StorageClassWorkgroup))
11313 addr_space = "threadgroup";
11314 }
11315
11316 if (!addr_space)
11317 addr_space = "device";
11318 }
11319 break;
11320
11321 default:
11322 break;
11323 }
11324
11325 if (!addr_space)
11326 {
11327 // No address space for plain values.
11328 addr_space = type.pointer || (argument && type.basetype == SPIRType::ControlPointArray) ? "thread" : "";
11329 }
11330
11331 return join(flags.get(DecorationVolatile) || flags.get(DecorationCoherent) ? "volatile " : "", addr_space);
11332}
11333
11334const char *CompilerMSL::to_restrict(uint32_t id, bool space)
11335{
11336 // This can be called for variable pointer contexts as well, so be very careful about which method we choose.
11337 Bitset flags;
11338 if (ir.ids[id].get_type() == TypeVariable)
11339 {
11340 uint32_t type_id = expression_type_id(id);
11341 auto &type = expression_type(id);
11342 if (type.basetype == SPIRType::Struct &&
11343 (has_decoration(type_id, DecorationBlock) || has_decoration(type_id, DecorationBufferBlock)))
11344 flags = get_buffer_block_flags(id);
11345 else
11346 flags = get_decoration_bitset(id);
11347 }
11348 else
11349 flags = get_decoration_bitset(id);
11350
11351 return flags.get(DecorationRestrict) ? (space ? "restrict " : "restrict") : "";
11352}
11353
11354string CompilerMSL::entry_point_arg_stage_in()
11355{
11356 string decl;
11357
11358 if (get_execution_model() == ExecutionModelTessellationControl && msl_options.multi_patch_workgroup)
11359 return decl;
11360
11361 // Stage-in structure
11362 uint32_t stage_in_id;
11363 if (get_execution_model() == ExecutionModelTessellationEvaluation)
11364 stage_in_id = patch_stage_in_var_id;
11365 else
11366 stage_in_id = stage_in_var_id;
11367
11368 if (stage_in_id)
11369 {
11370 auto &var = get<SPIRVariable>(stage_in_id);
11371 auto &type = get_variable_data_type(var);
11372
11373 add_resource_name(var.self);
11374 decl = join(type_to_glsl(type), " ", to_name(var.self), " [[stage_in]]");
11375 }
11376
11377 return decl;
11378}
11379
11380// Returns true if this input builtin should be a direct parameter on a shader function parameter list,
11381// and false for builtins that should be passed or calculated some other way.
11382bool CompilerMSL::is_direct_input_builtin(BuiltIn bi_type)
11383{
11384 switch (bi_type)
11385 {
11386 // Vertex function in
11387 case BuiltInVertexId:
11388 case BuiltInVertexIndex:
11389 case BuiltInBaseVertex:
11390 case BuiltInInstanceId:
11391 case BuiltInInstanceIndex:
11392 case BuiltInBaseInstance:
11393 return get_execution_model() != ExecutionModelVertex || !msl_options.vertex_for_tessellation;
11394 // Tess. control function in
11395 case BuiltInPosition:
11396 case BuiltInPointSize:
11397 case BuiltInClipDistance:
11398 case BuiltInCullDistance:
11399 case BuiltInPatchVertices:
11400 return false;
11401 case BuiltInInvocationId:
11402 case BuiltInPrimitiveId:
11403 return get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup;
11404 // Tess. evaluation function in
11405 case BuiltInTessLevelInner:
11406 case BuiltInTessLevelOuter:
11407 return false;
11408 // Fragment function in
11409 case BuiltInSamplePosition:
11410 case BuiltInHelperInvocation:
11411 case BuiltInBaryCoordNV:
11412 case BuiltInBaryCoordNoPerspNV:
11413 return false;
11414 case BuiltInViewIndex:
11415 return get_execution_model() == ExecutionModelFragment && msl_options.multiview &&
11416 msl_options.multiview_layered_rendering;
11417 // Compute function in
11418 case BuiltInSubgroupId:
11419 case BuiltInNumSubgroups:
11420 return !msl_options.emulate_subgroups;
11421 // Any stage function in
11422 case BuiltInDeviceIndex:
11423 case BuiltInSubgroupEqMask:
11424 case BuiltInSubgroupGeMask:
11425 case BuiltInSubgroupGtMask:
11426 case BuiltInSubgroupLeMask:
11427 case BuiltInSubgroupLtMask:
11428 return false;
11429 case BuiltInSubgroupSize:
11430 if (msl_options.fixed_subgroup_size != 0)
11431 return false;
11432 /* fallthrough */
11433 case BuiltInSubgroupLocalInvocationId:
11434 return !msl_options.emulate_subgroups;
11435 default:
11436 return true;
11437 }
11438}
11439
11440// Returns true if this is a fragment shader that runs per sample, and false otherwise.
11441bool CompilerMSL::is_sample_rate() const
11442{
11443 auto &caps = get_declared_capabilities();
11444 return get_execution_model() == ExecutionModelFragment &&
11445 (msl_options.force_sample_rate_shading ||
11446 std::find(caps.begin(), caps.end(), CapabilitySampleRateShading) != caps.end() ||
11447 (msl_options.use_framebuffer_fetch_subpasses && need_subpass_input));
11448}
11449
11450bool CompilerMSL::is_intersection_query() const
11451{
11452 auto &caps = get_declared_capabilities();
11453 return std::find(caps.begin(), caps.end(), CapabilityRayQueryKHR) != caps.end();
11454}
11455
11456void CompilerMSL::entry_point_args_builtin(string &ep_args)
11457{
11458 // Builtin variables
11459 SmallVector<pair<SPIRVariable *, BuiltIn>, 8> active_builtins;
11460 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
11461 if (var.storage != StorageClassInput)
11462 return;
11463
11464 auto bi_type = BuiltIn(get_decoration(var_id, DecorationBuiltIn));
11465
11466 // Don't emit SamplePosition as a separate parameter. In the entry
11467 // point, we get that by calling get_sample_position() on the sample ID.
11468 if (is_builtin_variable(var) &&
11469 get_variable_data_type(var).basetype != SPIRType::Struct &&
11470 get_variable_data_type(var).basetype != SPIRType::ControlPointArray)
11471 {
11472 // If the builtin is not part of the active input builtin set, don't emit it.
11473 // Relevant for multiple entry-point modules which might declare unused builtins.
11474 if (!active_input_builtins.get(bi_type) || !interface_variable_exists_in_entry_point(var_id))
11475 return;
11476
11477 // Remember this variable. We may need to correct its type.
11478 active_builtins.push_back(make_pair(&var, bi_type));
11479
11480 if (is_direct_input_builtin(bi_type))
11481 {
11482 if (!ep_args.empty())
11483 ep_args += ", ";
11484
11485 // Handle HLSL-style 0-based vertex/instance index.
11486 builtin_declaration = true;
11487
11488 // Handle different MSL gl_TessCoord types. (float2, float3)
11489 if (bi_type == BuiltInTessCoord && get_entry_point().flags.get(ExecutionModeQuads))
11490 ep_args += "float2 " + to_expression(var_id) + "In";
11491 else
11492 ep_args += builtin_type_decl(bi_type, var_id) + " " + to_expression(var_id);
11493
11494 ep_args += " [[" + builtin_qualifier(bi_type);
11495 if (bi_type == BuiltInSampleMask && get_entry_point().flags.get(ExecutionModePostDepthCoverage))
11496 {
11497 if (!msl_options.supports_msl_version(2))
11498 SPIRV_CROSS_THROW("Post-depth coverage requires MSL 2.0.");
11499 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11500 SPIRV_CROSS_THROW("Post-depth coverage on Mac requires MSL 2.3.");
11501 ep_args += ", post_depth_coverage";
11502 }
11503 ep_args += "]]";
11504 builtin_declaration = false;
11505 }
11506 }
11507
11508 if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInDispatchBase))
11509 {
11510 // This is a special implicit builtin, not corresponding to any SPIR-V builtin,
11511 // which holds the base that was passed to vkCmdDispatchBase() or vkCmdDrawIndexed(). If it's present,
11512 // assume we emitted it for a good reason.
11513 assert(msl_options.supports_msl_version(1, 2));
11514 if (!ep_args.empty())
11515 ep_args += ", ";
11516
11517 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_origin]]";
11518 }
11519
11520 if (has_extended_decoration(var_id, SPIRVCrossDecorationBuiltInStageInputSize))
11521 {
11522 // This is another special implicit builtin, not corresponding to any SPIR-V builtin,
11523 // which holds the number of vertices and instances to draw. If it's present,
11524 // assume we emitted it for a good reason.
11525 assert(msl_options.supports_msl_version(1, 2));
11526 if (!ep_args.empty())
11527 ep_args += ", ";
11528
11529 ep_args += type_to_glsl(get_variable_data_type(var)) + " " + to_expression(var_id) + " [[grid_size]]";
11530 }
11531 });
11532
11533 // Correct the types of all encountered active builtins. We couldn't do this before
11534 // because ensure_correct_builtin_type() may increase the bound, which isn't allowed
11535 // while iterating over IDs.
11536 for (auto &var : active_builtins)
11537 var.first->basetype = ensure_correct_builtin_type(var.first->basetype, var.second);
11538
11539 // Handle HLSL-style 0-based vertex/instance index.
11540 if (needs_base_vertex_arg == TriState::Yes)
11541 ep_args += built_in_func_arg(BuiltInBaseVertex, !ep_args.empty());
11542
11543 if (needs_base_instance_arg == TriState::Yes)
11544 ep_args += built_in_func_arg(BuiltInBaseInstance, !ep_args.empty());
11545
11546 if (capture_output_to_buffer)
11547 {
11548 // Add parameters to hold the indirect draw parameters and the shader output. This has to be handled
11549 // specially because it needs to be a pointer, not a reference.
11550 if (stage_out_var_id)
11551 {
11552 if (!ep_args.empty())
11553 ep_args += ", ";
11554 ep_args += join("device ", type_to_glsl(get_stage_out_struct_type()), "* ", output_buffer_var_name,
11555 " [[buffer(", msl_options.shader_output_buffer_index, ")]]");
11556 }
11557
11558 if (get_execution_model() == ExecutionModelTessellationControl)
11559 {
11560 if (!ep_args.empty())
11561 ep_args += ", ";
11562 ep_args +=
11563 join("constant uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
11564 }
11565 else if (stage_out_var_id &&
11566 !(get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation))
11567 {
11568 if (!ep_args.empty())
11569 ep_args += ", ";
11570 ep_args +=
11571 join("device uint* spvIndirectParams [[buffer(", msl_options.indirect_params_buffer_index, ")]]");
11572 }
11573
11574 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation &&
11575 (active_input_builtins.get(BuiltInVertexIndex) || active_input_builtins.get(BuiltInVertexId)) &&
11576 msl_options.vertex_index_type != Options::IndexType::None)
11577 {
11578 // Add the index buffer so we can set gl_VertexIndex correctly.
11579 if (!ep_args.empty())
11580 ep_args += ", ";
11581 switch (msl_options.vertex_index_type)
11582 {
11583 case Options::IndexType::None:
11584 break;
11585 case Options::IndexType::UInt16:
11586 ep_args += join("const device ushort* ", index_buffer_var_name, " [[buffer(",
11587 msl_options.shader_index_buffer_index, ")]]");
11588 break;
11589 case Options::IndexType::UInt32:
11590 ep_args += join("const device uint* ", index_buffer_var_name, " [[buffer(",
11591 msl_options.shader_index_buffer_index, ")]]");
11592 break;
11593 }
11594 }
11595
11596 // Tessellation control shaders get three additional parameters:
11597 // a buffer to hold the per-patch data, a buffer to hold the per-patch
11598 // tessellation levels, and a block of workgroup memory to hold the
11599 // input control point data.
11600 if (get_execution_model() == ExecutionModelTessellationControl)
11601 {
11602 if (patch_stage_out_var_id)
11603 {
11604 if (!ep_args.empty())
11605 ep_args += ", ";
11606 ep_args +=
11607 join("device ", type_to_glsl(get_patch_stage_out_struct_type()), "* ", patch_output_buffer_var_name,
11608 " [[buffer(", convert_to_string(msl_options.shader_patch_output_buffer_index), ")]]");
11609 }
11610 if (!ep_args.empty())
11611 ep_args += ", ";
11612 ep_args += join("device ", get_tess_factor_struct_name(), "* ", tess_factor_buffer_var_name, " [[buffer(",
11613 convert_to_string(msl_options.shader_tess_factor_buffer_index), ")]]");
11614
11615 // Initializer for tess factors must be handled specially since it's never declared as a normal variable.
11616 uint32_t outer_factor_initializer_id = 0;
11617 uint32_t inner_factor_initializer_id = 0;
11618 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11619 if (!has_decoration(var.self, DecorationBuiltIn) || var.storage != StorageClassOutput || !var.initializer)
11620 return;
11621
11622 BuiltIn builtin = BuiltIn(get_decoration(var.self, DecorationBuiltIn));
11623 if (builtin == BuiltInTessLevelInner)
11624 inner_factor_initializer_id = var.initializer;
11625 else if (builtin == BuiltInTessLevelOuter)
11626 outer_factor_initializer_id = var.initializer;
11627 });
11628
11629 const SPIRConstant *c = nullptr;
11630
11631 if (outer_factor_initializer_id && (c = maybe_get<SPIRConstant>(outer_factor_initializer_id)))
11632 {
11633 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
11634 entry_func.fixup_hooks_in.push_back([=]() {
11635 uint32_t components = get_execution_mode_bitset().get(ExecutionModeTriangles) ? 3 : 4;
11636 for (uint32_t i = 0; i < components; i++)
11637 {
11638 statement(builtin_to_glsl(BuiltInTessLevelOuter, StorageClassOutput), "[", i, "] = ",
11639 "half(", to_expression(c->subconstants[i]), ");");
11640 }
11641 });
11642 }
11643
11644 if (inner_factor_initializer_id && (c = maybe_get<SPIRConstant>(inner_factor_initializer_id)))
11645 {
11646 auto &entry_func = get<SPIRFunction>(ir.default_entry_point);
11647 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
11648 {
11649 entry_func.fixup_hooks_in.push_back([=]() {
11650 statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), " = ", "half(",
11651 to_expression(c->subconstants[0]), ");");
11652 });
11653 }
11654 else
11655 {
11656 entry_func.fixup_hooks_in.push_back([=]() {
11657 for (uint32_t i = 0; i < 2; i++)
11658 {
11659 statement(builtin_to_glsl(BuiltInTessLevelInner, StorageClassOutput), "[", i, "] = ",
11660 "half(", to_expression(c->subconstants[i]), ");");
11661 }
11662 });
11663 }
11664 }
11665
11666 if (stage_in_var_id)
11667 {
11668 if (!ep_args.empty())
11669 ep_args += ", ";
11670 if (msl_options.multi_patch_workgroup)
11671 {
11672 ep_args += join("device ", type_to_glsl(get_stage_in_struct_type()), "* ", input_buffer_var_name,
11673 " [[buffer(", convert_to_string(msl_options.shader_input_buffer_index), ")]]");
11674 }
11675 else
11676 {
11677 ep_args += join("threadgroup ", type_to_glsl(get_stage_in_struct_type()), "* ", input_wg_var_name,
11678 " [[threadgroup(", convert_to_string(msl_options.shader_input_wg_index), ")]]");
11679 }
11680 }
11681 }
11682 }
11683}
11684
11685string CompilerMSL::entry_point_args_argument_buffer(bool append_comma)
11686{
11687 string ep_args = entry_point_arg_stage_in();
11688 Bitset claimed_bindings;
11689
11690 for (uint32_t i = 0; i < kMaxArgumentBuffers; i++)
11691 {
11692 uint32_t id = argument_buffer_ids[i];
11693 if (id == 0)
11694 continue;
11695
11696 add_resource_name(id);
11697 auto &var = get<SPIRVariable>(id);
11698 auto &type = get_variable_data_type(var);
11699
11700 if (!ep_args.empty())
11701 ep_args += ", ";
11702
11703 // Check if the argument buffer binding itself has been remapped.
11704 uint32_t buffer_binding;
11705 auto itr = resource_bindings.find({ get_entry_point().model, i, kArgumentBufferBinding });
11706 if (itr != end(resource_bindings))
11707 {
11708 buffer_binding = itr->second.first.msl_buffer;
11709 itr->second.second = true;
11710 }
11711 else
11712 {
11713 // As a fallback, directly map desc set <-> binding.
11714 // If that was taken, take the next buffer binding.
11715 if (claimed_bindings.get(i))
11716 buffer_binding = next_metal_resource_index_buffer;
11717 else
11718 buffer_binding = i;
11719 }
11720
11721 claimed_bindings.set(buffer_binding);
11722
11723 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(id) + to_name(id);
11724 ep_args += " [[buffer(" + convert_to_string(buffer_binding) + ")]]";
11725
11726 next_metal_resource_index_buffer = max(next_metal_resource_index_buffer, buffer_binding + 1);
11727 }
11728
11729 entry_point_args_discrete_descriptors(ep_args);
11730 entry_point_args_builtin(ep_args);
11731
11732 if (!ep_args.empty() && append_comma)
11733 ep_args += ", ";
11734
11735 return ep_args;
11736}
11737
11738const MSLConstexprSampler *CompilerMSL::find_constexpr_sampler(uint32_t id) const
11739{
11740 // Try by ID.
11741 {
11742 auto itr = constexpr_samplers_by_id.find(id);
11743 if (itr != end(constexpr_samplers_by_id))
11744 return &itr->second;
11745 }
11746
11747 // Try by binding.
11748 {
11749 uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
11750 uint32_t binding = get_decoration(id, DecorationBinding);
11751
11752 auto itr = constexpr_samplers_by_binding.find({ desc_set, binding });
11753 if (itr != end(constexpr_samplers_by_binding))
11754 return &itr->second;
11755 }
11756
11757 return nullptr;
11758}
11759
11760void CompilerMSL::entry_point_args_discrete_descriptors(string &ep_args)
11761{
11762 // Output resources, sorted by resource index & type
11763 // We need to sort to work around a bug on macOS 10.13 with NVidia drivers where switching between shaders
11764 // with different order of buffers can result in issues with buffer assignments inside the driver.
11765 struct Resource
11766 {
11767 SPIRVariable *var;
11768 string name;
11769 SPIRType::BaseType basetype;
11770 uint32_t index;
11771 uint32_t plane;
11772 uint32_t secondary_index;
11773 };
11774
11775 SmallVector<Resource> resources;
11776
11777 ir.for_each_typed_id<SPIRVariable>([&](uint32_t var_id, SPIRVariable &var) {
11778 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
11779 var.storage == StorageClassPushConstant || var.storage == StorageClassStorageBuffer) &&
11780 !is_hidden_variable(var))
11781 {
11782 auto &type = get_variable_data_type(var);
11783
11784 if (is_supported_argument_buffer_type(type) && var.storage != StorageClassPushConstant)
11785 {
11786 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
11787 if (descriptor_set_is_argument_buffer(desc_set))
11788 return;
11789 }
11790
11791 const MSLConstexprSampler *constexpr_sampler = nullptr;
11792 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
11793 {
11794 constexpr_sampler = find_constexpr_sampler(var_id);
11795 if (constexpr_sampler)
11796 {
11797 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
11798 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
11799 }
11800 }
11801
11802 // Emulate texture2D atomic operations
11803 uint32_t secondary_index = 0;
11804 if (atomic_image_vars.count(var.self))
11805 {
11806 secondary_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
11807 }
11808
11809 if (type.basetype == SPIRType::SampledImage)
11810 {
11811 add_resource_name(var_id);
11812
11813 uint32_t plane_count = 1;
11814 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
11815 plane_count = constexpr_sampler->planes;
11816
11817 for (uint32_t i = 0; i < plane_count; i++)
11818 resources.push_back({ &var, to_name(var_id), SPIRType::Image,
11819 get_metal_resource_index(var, SPIRType::Image, i), i, secondary_index });
11820
11821 if (type.image.dim != DimBuffer && !constexpr_sampler)
11822 {
11823 resources.push_back({ &var, to_sampler_expression(var_id), SPIRType::Sampler,
11824 get_metal_resource_index(var, SPIRType::Sampler), 0, 0 });
11825 }
11826 }
11827 else if (!constexpr_sampler)
11828 {
11829 // constexpr samplers are not declared as resources.
11830 add_resource_name(var_id);
11831 resources.push_back({ &var, to_name(var_id), type.basetype,
11832 get_metal_resource_index(var, type.basetype), 0, secondary_index });
11833 }
11834 }
11835 });
11836
11837 sort(resources.begin(), resources.end(), [](const Resource &lhs, const Resource &rhs) {
11838 return tie(lhs.basetype, lhs.index) < tie(rhs.basetype, rhs.index);
11839 });
11840
11841 for (auto &r : resources)
11842 {
11843 auto &var = *r.var;
11844 auto &type = get_variable_data_type(var);
11845
11846 uint32_t var_id = var.self;
11847
11848 switch (r.basetype)
11849 {
11850 case SPIRType::Struct:
11851 {
11852 auto &m = ir.meta[type.self];
11853 if (m.members.size() == 0)
11854 break;
11855 if (!type.array.empty())
11856 {
11857 if (type.array.size() > 1)
11858 SPIRV_CROSS_THROW("Arrays of arrays of buffers are not supported.");
11859
11860 // Metal doesn't directly support this, so we must expand the
11861 // array. We'll declare a local array to hold these elements
11862 // later.
11863 uint32_t array_size = to_array_size_literal(type);
11864
11865 if (array_size == 0)
11866 SPIRV_CROSS_THROW("Unsized arrays of buffers are not supported in MSL.");
11867
11868 // Allow Metal to use the array<T> template to make arrays a value type
11869 is_using_builtin_array = true;
11870 buffer_arrays.push_back(var_id);
11871 for (uint32_t i = 0; i < array_size; ++i)
11872 {
11873 if (!ep_args.empty())
11874 ep_args += ", ";
11875 ep_args += get_argument_address_space(var) + " " + type_to_glsl(type) + "* " + to_restrict(var_id) +
11876 r.name + "_" + convert_to_string(i);
11877 ep_args += " [[buffer(" + convert_to_string(r.index + i) + ")";
11878 if (interlocked_resources.count(var_id))
11879 ep_args += ", raster_order_group(0)";
11880 ep_args += "]]";
11881 }
11882 is_using_builtin_array = false;
11883 }
11884 else
11885 {
11886 if (!ep_args.empty())
11887 ep_args += ", ";
11888 ep_args +=
11889 get_argument_address_space(var) + " " + type_to_glsl(type) + "& " + to_restrict(var_id) + r.name;
11890 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11891 if (interlocked_resources.count(var_id))
11892 ep_args += ", raster_order_group(0)";
11893 ep_args += "]]";
11894 }
11895 break;
11896 }
11897 case SPIRType::Sampler:
11898 if (!ep_args.empty())
11899 ep_args += ", ";
11900 ep_args += sampler_type(type, var_id) + " " + r.name;
11901 ep_args += " [[sampler(" + convert_to_string(r.index) + ")]]";
11902 break;
11903 case SPIRType::Image:
11904 {
11905 if (!ep_args.empty())
11906 ep_args += ", ";
11907
11908 // Use Metal's native frame-buffer fetch API for subpass inputs.
11909 const auto &basetype = get<SPIRType>(var.basetype);
11910 if (!type_is_msl_framebuffer_fetch(basetype))
11911 {
11912 ep_args += image_type_glsl(type, var_id) + " " + r.name;
11913 if (r.plane > 0)
11914 ep_args += join(plane_name_suffix, r.plane);
11915 ep_args += " [[texture(" + convert_to_string(r.index) + ")";
11916 if (interlocked_resources.count(var_id))
11917 ep_args += ", raster_order_group(0)";
11918 ep_args += "]]";
11919 }
11920 else
11921 {
11922 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 3))
11923 SPIRV_CROSS_THROW("Framebuffer fetch on Mac is not supported before MSL 2.3.");
11924 ep_args += image_type_glsl(type, var_id) + " " + r.name;
11925 ep_args += " [[color(" + convert_to_string(r.index) + ")]]";
11926 }
11927
11928 // Emulate texture2D atomic operations
11929 if (atomic_image_vars.count(var.self))
11930 {
11931 ep_args += ", device atomic_" + type_to_glsl(get<SPIRType>(basetype.image.type), 0);
11932 ep_args += "* " + r.name + "_atomic";
11933 ep_args += " [[buffer(" + convert_to_string(r.secondary_index) + ")";
11934 if (interlocked_resources.count(var_id))
11935 ep_args += ", raster_order_group(0)";
11936 ep_args += "]]";
11937 }
11938 break;
11939 }
11940 case SPIRType::AccelerationStructure:
11941 ep_args += ", " + type_to_glsl(type, var_id) + " " + r.name;
11942 ep_args += " [[buffer(" + convert_to_string(r.index) + ")]]";
11943 break;
11944 default:
11945 if (!ep_args.empty())
11946 ep_args += ", ";
11947 if (!type.pointer)
11948 ep_args += get_type_address_space(get<SPIRType>(var.basetype), var_id) + " " +
11949 type_to_glsl(type, var_id) + "& " + r.name;
11950 else
11951 ep_args += type_to_glsl(type, var_id) + " " + r.name;
11952 ep_args += " [[buffer(" + convert_to_string(r.index) + ")";
11953 if (interlocked_resources.count(var_id))
11954 ep_args += ", raster_order_group(0)";
11955 ep_args += "]]";
11956 break;
11957 }
11958 }
11959}
11960
11961// Returns a string containing a comma-delimited list of args for the entry point function
11962// This is the "classic" method of MSL 1 when we don't have argument buffer support.
11963string CompilerMSL::entry_point_args_classic(bool append_comma)
11964{
11965 string ep_args = entry_point_arg_stage_in();
11966 entry_point_args_discrete_descriptors(ep_args);
11967 entry_point_args_builtin(ep_args);
11968
11969 if (!ep_args.empty() && append_comma)
11970 ep_args += ", ";
11971
11972 return ep_args;
11973}
11974
11975void CompilerMSL::fix_up_shader_inputs_outputs()
11976{
11977 auto &entry_func = this->get<SPIRFunction>(ir.default_entry_point);
11978
11979 // Emit a guard to ensure we don't execute beyond the last vertex.
11980 // Vertex shaders shouldn't have the problems with barriers in non-uniform control flow that
11981 // tessellation control shaders do, so early returns should be OK. We may need to revisit this
11982 // if it ever becomes possible to use barriers from a vertex shader.
11983 if (get_execution_model() == ExecutionModelVertex && msl_options.vertex_for_tessellation)
11984 {
11985 entry_func.fixup_hooks_in.push_back([this]() {
11986 statement("if (any(", to_expression(builtin_invocation_id_id),
11987 " >= ", to_expression(builtin_stage_input_size_id), "))");
11988 statement(" return;");
11989 });
11990 }
11991
11992 // Look for sampled images and buffer. Add hooks to set up the swizzle constants or array lengths.
11993 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
11994 auto &type = get_variable_data_type(var);
11995 uint32_t var_id = var.self;
11996 bool ssbo = has_decoration(type.self, DecorationBufferBlock);
11997
11998 if (var.storage == StorageClassUniformConstant && !is_hidden_variable(var))
11999 {
12000 if (msl_options.swizzle_texture_samples && has_sampled_images && is_sampled_image_type(type))
12001 {
12002 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
12003 bool is_array_type = !type.array.empty();
12004
12005 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
12006 if (descriptor_set_is_argument_buffer(desc_set))
12007 {
12008 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
12009 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
12010 ".spvSwizzleConstants", "[",
12011 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
12012 }
12013 else
12014 {
12015 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
12016 statement("constant uint", is_array_type ? "* " : "& ", to_swizzle_expression(var_id),
12017 is_array_type ? " = &" : " = ", to_name(swizzle_buffer_id), "[",
12018 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
12019 }
12020 });
12021 }
12022 }
12023 else if ((var.storage == StorageClassStorageBuffer || (var.storage == StorageClassUniform && ssbo)) &&
12024 !is_hidden_variable(var))
12025 {
12026 if (buffers_requiring_array_length.count(var.self))
12027 {
12028 entry_func.fixup_hooks_in.push_back([this, &type, &var, var_id]() {
12029 bool is_array_type = !type.array.empty();
12030
12031 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
12032 if (descriptor_set_is_argument_buffer(desc_set))
12033 {
12034 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
12035 is_array_type ? " = &" : " = ", to_name(argument_buffer_ids[desc_set]),
12036 ".spvBufferSizeConstants", "[",
12037 convert_to_string(get_metal_resource_index(var, SPIRType::Image)), "];");
12038 }
12039 else
12040 {
12041 // If we have an array of images, we need to be able to index into it, so take a pointer instead.
12042 statement("constant uint", is_array_type ? "* " : "& ", to_buffer_size_expression(var_id),
12043 is_array_type ? " = &" : " = ", to_name(buffer_size_buffer_id), "[",
12044 convert_to_string(get_metal_resource_index(var, type.basetype)), "];");
12045 }
12046 });
12047 }
12048 }
12049 });
12050
12051 // Builtin variables
12052 ir.for_each_typed_id<SPIRVariable>([this, &entry_func](uint32_t, SPIRVariable &var) {
12053 uint32_t var_id = var.self;
12054 BuiltIn bi_type = ir.meta[var_id].decoration.builtin_type;
12055
12056 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
12057 return;
12058 if (!interface_variable_exists_in_entry_point(var.self))
12059 return;
12060
12061 if (var.storage == StorageClassInput && is_builtin_variable(var) && active_input_builtins.get(bi_type))
12062 {
12063 switch (bi_type)
12064 {
12065 case BuiltInSamplePosition:
12066 entry_func.fixup_hooks_in.push_back([=]() {
12067 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = get_sample_position(",
12068 to_expression(builtin_sample_id_id), ");");
12069 });
12070 break;
12071 case BuiltInFragCoord:
12072 if (is_sample_rate())
12073 {
12074 entry_func.fixup_hooks_in.push_back([=]() {
12075 statement(to_expression(var_id), ".xy += get_sample_position(",
12076 to_expression(builtin_sample_id_id), ") - 0.5;");
12077 });
12078 }
12079 break;
12080 case BuiltInHelperInvocation:
12081 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
12082 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.3 on iOS.");
12083 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
12084 SPIRV_CROSS_THROW("simd_is_helper_thread() requires version 2.1 on macOS.");
12085
12086 entry_func.fixup_hooks_in.push_back([=]() {
12087 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = simd_is_helper_thread();");
12088 });
12089 break;
12090 case BuiltInInvocationId:
12091 // This is direct-mapped without multi-patch workgroups.
12092 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
12093 break;
12094
12095 entry_func.fixup_hooks_in.push_back([=]() {
12096 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12097 to_expression(builtin_invocation_id_id), ".x % ", this->get_entry_point().output_vertices,
12098 ";");
12099 });
12100 break;
12101 case BuiltInPrimitiveId:
12102 // This is natively supported by fragment and tessellation evaluation shaders.
12103 // In tessellation control shaders, this is direct-mapped without multi-patch workgroups.
12104 if (get_execution_model() != ExecutionModelTessellationControl || !msl_options.multi_patch_workgroup)
12105 break;
12106
12107 entry_func.fixup_hooks_in.push_back([=]() {
12108 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = min(",
12109 to_expression(builtin_invocation_id_id), ".x / ", this->get_entry_point().output_vertices,
12110 ", spvIndirectParams[1] - 1);");
12111 });
12112 break;
12113 case BuiltInPatchVertices:
12114 if (get_execution_model() == ExecutionModelTessellationEvaluation)
12115 entry_func.fixup_hooks_in.push_back([=]() {
12116 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12117 to_expression(patch_stage_in_var_id), ".gl_in.size();");
12118 });
12119 else
12120 entry_func.fixup_hooks_in.push_back([=]() {
12121 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = spvIndirectParams[0];");
12122 });
12123 break;
12124 case BuiltInTessCoord:
12125 if (get_entry_point().flags.get(ExecutionModeQuads))
12126 {
12127 // The entry point will only have a float2 TessCoord variable.
12128 // Pad to float3.
12129 entry_func.fixup_hooks_in.push_back([=]() {
12130 auto name = builtin_to_glsl(BuiltInTessCoord, StorageClassInput);
12131 statement("float3 " + name + " = float3(" + name + "In.x, " + name + "In.y, 0.0);");
12132 });
12133 }
12134
12135 // Emit a fixup to account for the shifted domain. Don't do this for triangles;
12136 // MoltenVK will just reverse the winding order instead.
12137 if (msl_options.tess_domain_origin_lower_left && !get_entry_point().flags.get(ExecutionModeTriangles))
12138 {
12139 string tc = to_expression(var_id);
12140 entry_func.fixup_hooks_in.push_back([=]() { statement(tc, ".y = 1.0 - ", tc, ".y;"); });
12141 }
12142 break;
12143 case BuiltInSubgroupId:
12144 if (!msl_options.emulate_subgroups)
12145 break;
12146 // For subgroup emulation, this is the same as the local invocation index.
12147 entry_func.fixup_hooks_in.push_back([=]() {
12148 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12149 to_expression(builtin_local_invocation_index_id), ";");
12150 });
12151 break;
12152 case BuiltInNumSubgroups:
12153 if (!msl_options.emulate_subgroups)
12154 break;
12155 // For subgroup emulation, this is the same as the workgroup size.
12156 entry_func.fixup_hooks_in.push_back([=]() {
12157 auto &type = expression_type(builtin_workgroup_size_id);
12158 string size_expr = to_expression(builtin_workgroup_size_id);
12159 if (type.vecsize >= 3)
12160 size_expr = join(size_expr, ".x * ", size_expr, ".y * ", size_expr, ".z");
12161 else if (type.vecsize == 2)
12162 size_expr = join(size_expr, ".x * ", size_expr, ".y");
12163 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", size_expr, ";");
12164 });
12165 break;
12166 case BuiltInSubgroupLocalInvocationId:
12167 if (!msl_options.emulate_subgroups)
12168 break;
12169 // For subgroup emulation, assume subgroups of size 1.
12170 entry_func.fixup_hooks_in.push_back(
12171 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;"); });
12172 break;
12173 case BuiltInSubgroupSize:
12174 if (msl_options.emulate_subgroups)
12175 {
12176 // For subgroup emulation, assume subgroups of size 1.
12177 entry_func.fixup_hooks_in.push_back(
12178 [=]() { statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = 1;"); });
12179 }
12180 else if (msl_options.fixed_subgroup_size != 0)
12181 {
12182 entry_func.fixup_hooks_in.push_back([=]() {
12183 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12184 msl_options.fixed_subgroup_size, ";");
12185 });
12186 }
12187 break;
12188 case BuiltInSubgroupEqMask:
12189 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12190 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12191 if (!msl_options.supports_msl_version(2, 1))
12192 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12193 entry_func.fixup_hooks_in.push_back([=]() {
12194 if (msl_options.is_ios())
12195 {
12196 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", "uint4(1 << ",
12197 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
12198 }
12199 else
12200 {
12201 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12202 to_expression(builtin_subgroup_invocation_id_id), " >= 32 ? uint4(0, (1 << (",
12203 to_expression(builtin_subgroup_invocation_id_id), " - 32)), uint2(0)) : uint4(1 << ",
12204 to_expression(builtin_subgroup_invocation_id_id), ", uint3(0));");
12205 }
12206 });
12207 break;
12208 case BuiltInSubgroupGeMask:
12209 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12210 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12211 if (!msl_options.supports_msl_version(2, 1))
12212 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12213 if (msl_options.fixed_subgroup_size != 0)
12214 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12215 entry_func.fixup_hooks_in.push_back([=]() {
12216 // Case where index < 32, size < 32:
12217 // mask0 = bfi(0, 0xFFFFFFFF, index, size - index);
12218 // mask1 = bfi(0, 0xFFFFFFFF, 0, 0); // Gives 0
12219 // Case where index < 32 but size >= 32:
12220 // mask0 = bfi(0, 0xFFFFFFFF, index, 32 - index);
12221 // mask1 = bfi(0, 0xFFFFFFFF, 0, size - 32);
12222 // Case where index >= 32:
12223 // mask0 = bfi(0, 0xFFFFFFFF, 32, 0); // Gives 0
12224 // mask1 = bfi(0, 0xFFFFFFFF, index - 32, size - index);
12225 // This is expressed without branches to avoid divergent
12226 // control flow--hence the complicated min/max expressions.
12227 // This is further complicated by the fact that if you attempt
12228 // to bfi/bfe out-of-bounds on Metal, undefined behavior is the
12229 // result.
12230 if (msl_options.fixed_subgroup_size > 32)
12231 {
12232 // Don't use the subgroup size variable with fixed subgroup sizes,
12233 // since the variables could be defined in the wrong order.
12234 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12235 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12236 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(32 - (int)",
12237 to_expression(builtin_subgroup_invocation_id_id),
12238 ", 0)), insert_bits(0u, 0xFFFFFFFF,"
12239 " (uint)max((int)",
12240 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), ",
12241 msl_options.fixed_subgroup_size, " - max(",
12242 to_expression(builtin_subgroup_invocation_id_id),
12243 ", 32u)), uint2(0));");
12244 }
12245 else if (msl_options.fixed_subgroup_size != 0)
12246 {
12247 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12248 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12249 to_expression(builtin_subgroup_invocation_id_id), ", ",
12250 msl_options.fixed_subgroup_size, " - ",
12251 to_expression(builtin_subgroup_invocation_id_id),
12252 "), uint3(0));");
12253 }
12254 else if (msl_options.is_ios())
12255 {
12256 // On iOS, the SIMD-group size will currently never exceed 32.
12257 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12258 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12259 to_expression(builtin_subgroup_invocation_id_id), ", ",
12260 to_expression(builtin_subgroup_size_id), " - ",
12261 to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
12262 }
12263 else
12264 {
12265 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12266 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12267 to_expression(builtin_subgroup_invocation_id_id), ", 32u), (uint)max(min((int)",
12268 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
12269 to_expression(builtin_subgroup_invocation_id_id),
12270 ", 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12271 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0), (uint)max((int)",
12272 to_expression(builtin_subgroup_size_id), " - (int)max(",
12273 to_expression(builtin_subgroup_invocation_id_id), ", 32u), 0)), uint2(0));");
12274 }
12275 });
12276 break;
12277 case BuiltInSubgroupGtMask:
12278 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12279 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12280 if (!msl_options.supports_msl_version(2, 1))
12281 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12282 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12283 entry_func.fixup_hooks_in.push_back([=]() {
12284 // The same logic applies here, except now the index is one
12285 // more than the subgroup invocation ID.
12286 if (msl_options.fixed_subgroup_size > 32)
12287 {
12288 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12289 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12290 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(32 - (int)",
12291 to_expression(builtin_subgroup_invocation_id_id),
12292 " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12293 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), ",
12294 msl_options.fixed_subgroup_size, " - max(",
12295 to_expression(builtin_subgroup_invocation_id_id),
12296 " + 1, 32u)), uint2(0));");
12297 }
12298 else if (msl_options.fixed_subgroup_size != 0)
12299 {
12300 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12301 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12302 to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
12303 msl_options.fixed_subgroup_size, " - ",
12304 to_expression(builtin_subgroup_invocation_id_id),
12305 " - 1), uint3(0));");
12306 }
12307 else if (msl_options.is_ios())
12308 {
12309 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12310 " = uint4(insert_bits(0u, 0xFFFFFFFF, ",
12311 to_expression(builtin_subgroup_invocation_id_id), " + 1, ",
12312 to_expression(builtin_subgroup_size_id), " - ",
12313 to_expression(builtin_subgroup_invocation_id_id), " - 1), uint3(0));");
12314 }
12315 else
12316 {
12317 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12318 " = uint4(insert_bits(0u, 0xFFFFFFFF, min(",
12319 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), (uint)max(min((int)",
12320 to_expression(builtin_subgroup_size_id), ", 32) - (int)",
12321 to_expression(builtin_subgroup_invocation_id_id),
12322 " - 1, 0)), insert_bits(0u, 0xFFFFFFFF, (uint)max((int)",
12323 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0), (uint)max((int)",
12324 to_expression(builtin_subgroup_size_id), " - (int)max(",
12325 to_expression(builtin_subgroup_invocation_id_id), " + 1, 32u), 0)), uint2(0));");
12326 }
12327 });
12328 break;
12329 case BuiltInSubgroupLeMask:
12330 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12331 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12332 if (!msl_options.supports_msl_version(2, 1))
12333 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12334 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12335 entry_func.fixup_hooks_in.push_back([=]() {
12336 if (msl_options.is_ios())
12337 {
12338 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12339 " = uint4(extract_bits(0xFFFFFFFF, 0, ",
12340 to_expression(builtin_subgroup_invocation_id_id), " + 1), uint3(0));");
12341 }
12342 else
12343 {
12344 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12345 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
12346 to_expression(builtin_subgroup_invocation_id_id),
12347 " + 1, 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
12348 to_expression(builtin_subgroup_invocation_id_id), " + 1 - 32, 0)), uint2(0));");
12349 }
12350 });
12351 break;
12352 case BuiltInSubgroupLtMask:
12353 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 2))
12354 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.2 on iOS.");
12355 if (!msl_options.supports_msl_version(2, 1))
12356 SPIRV_CROSS_THROW("Subgroup ballot functionality requires Metal 2.1.");
12357 add_spv_func_and_recompile(SPVFuncImplSubgroupBallot);
12358 entry_func.fixup_hooks_in.push_back([=]() {
12359 if (msl_options.is_ios())
12360 {
12361 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12362 " = uint4(extract_bits(0xFFFFFFFF, 0, ",
12363 to_expression(builtin_subgroup_invocation_id_id), "), uint3(0));");
12364 }
12365 else
12366 {
12367 statement(builtin_type_decl(bi_type), " ", to_expression(var_id),
12368 " = uint4(extract_bits(0xFFFFFFFF, 0, min(",
12369 to_expression(builtin_subgroup_invocation_id_id),
12370 ", 32u)), extract_bits(0xFFFFFFFF, 0, (uint)max((int)",
12371 to_expression(builtin_subgroup_invocation_id_id), " - 32, 0)), uint2(0));");
12372 }
12373 });
12374 break;
12375 case BuiltInViewIndex:
12376 if (!msl_options.multiview)
12377 {
12378 // According to the Vulkan spec, when not running under a multiview
12379 // render pass, ViewIndex is 0.
12380 entry_func.fixup_hooks_in.push_back([=]() {
12381 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = 0;");
12382 });
12383 }
12384 else if (msl_options.view_index_from_device_index)
12385 {
12386 // In this case, we take the view index from that of the device we're running on.
12387 entry_func.fixup_hooks_in.push_back([=]() {
12388 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12389 msl_options.device_index, ";");
12390 });
12391 // We actually don't want to set the render_target_array_index here.
12392 // Since every physical device is rendering a different view,
12393 // there's no need for layered rendering here.
12394 }
12395 else if (!msl_options.multiview_layered_rendering)
12396 {
12397 // In this case, the views are rendered one at a time. The view index, then,
12398 // is just the first part of the "view mask".
12399 entry_func.fixup_hooks_in.push_back([=]() {
12400 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12401 to_expression(view_mask_buffer_id), "[0];");
12402 });
12403 }
12404 else if (get_execution_model() == ExecutionModelFragment)
12405 {
12406 // Because we adjusted the view index in the vertex shader, we have to
12407 // adjust it back here.
12408 entry_func.fixup_hooks_in.push_back([=]() {
12409 statement(to_expression(var_id), " += ", to_expression(view_mask_buffer_id), "[0];");
12410 });
12411 }
12412 else if (get_execution_model() == ExecutionModelVertex)
12413 {
12414 // Metal provides no special support for multiview, so we smuggle
12415 // the view index in the instance index.
12416 entry_func.fixup_hooks_in.push_back([=]() {
12417 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12418 to_expression(view_mask_buffer_id), "[0] + (", to_expression(builtin_instance_idx_id),
12419 " - ", to_expression(builtin_base_instance_id), ") % ",
12420 to_expression(view_mask_buffer_id), "[1];");
12421 statement(to_expression(builtin_instance_idx_id), " = (",
12422 to_expression(builtin_instance_idx_id), " - ",
12423 to_expression(builtin_base_instance_id), ") / ", to_expression(view_mask_buffer_id),
12424 "[1] + ", to_expression(builtin_base_instance_id), ";");
12425 });
12426 // In addition to setting the variable itself, we also need to
12427 // set the render_target_array_index with it on output. We have to
12428 // offset this by the base view index, because Metal isn't in on
12429 // our little game here.
12430 entry_func.fixup_hooks_out.push_back([=]() {
12431 statement(to_expression(builtin_layer_id), " = ", to_expression(var_id), " - ",
12432 to_expression(view_mask_buffer_id), "[0];");
12433 });
12434 }
12435 break;
12436 case BuiltInDeviceIndex:
12437 // Metal pipelines belong to the devices which create them, so we'll
12438 // need to create a MTLPipelineState for every MTLDevice in a grouped
12439 // VkDevice. We can assume, then, that the device index is constant.
12440 entry_func.fixup_hooks_in.push_back([=]() {
12441 statement("const ", builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12442 msl_options.device_index, ";");
12443 });
12444 break;
12445 case BuiltInWorkgroupId:
12446 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInWorkgroupId))
12447 break;
12448
12449 // The vkCmdDispatchBase() command lets the client set the base value
12450 // of WorkgroupId. Metal has no direct equivalent; we must make this
12451 // adjustment ourselves.
12452 entry_func.fixup_hooks_in.push_back([=]() {
12453 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id), ";");
12454 });
12455 break;
12456 case BuiltInGlobalInvocationId:
12457 if (!msl_options.dispatch_base || !active_input_builtins.get(BuiltInGlobalInvocationId))
12458 break;
12459
12460 // GlobalInvocationId is defined as LocalInvocationId + WorkgroupId * WorkgroupSize.
12461 // This needs to be adjusted too.
12462 entry_func.fixup_hooks_in.push_back([=]() {
12463 auto &execution = this->get_entry_point();
12464 uint32_t workgroup_size_id = execution.workgroup_size.constant;
12465 if (workgroup_size_id)
12466 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
12467 " * ", to_expression(workgroup_size_id), ";");
12468 else
12469 statement(to_expression(var_id), " += ", to_dereferenced_expression(builtin_dispatch_base_id),
12470 " * uint3(", execution.workgroup_size.x, ", ", execution.workgroup_size.y, ", ",
12471 execution.workgroup_size.z, ");");
12472 });
12473 break;
12474 case BuiltInVertexId:
12475 case BuiltInVertexIndex:
12476 // This is direct-mapped normally.
12477 if (!msl_options.vertex_for_tessellation)
12478 break;
12479
12480 entry_func.fixup_hooks_in.push_back([=]() {
12481 builtin_declaration = true;
12482 switch (msl_options.vertex_index_type)
12483 {
12484 case Options::IndexType::None:
12485 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12486 to_expression(builtin_invocation_id_id), ".x + ",
12487 to_expression(builtin_dispatch_base_id), ".x;");
12488 break;
12489 case Options::IndexType::UInt16:
12490 case Options::IndexType::UInt32:
12491 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ", index_buffer_var_name,
12492 "[", to_expression(builtin_invocation_id_id), ".x] + ",
12493 to_expression(builtin_dispatch_base_id), ".x;");
12494 break;
12495 }
12496 builtin_declaration = false;
12497 });
12498 break;
12499 case BuiltInBaseVertex:
12500 // This is direct-mapped normally.
12501 if (!msl_options.vertex_for_tessellation)
12502 break;
12503
12504 entry_func.fixup_hooks_in.push_back([=]() {
12505 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12506 to_expression(builtin_dispatch_base_id), ".x;");
12507 });
12508 break;
12509 case BuiltInInstanceId:
12510 case BuiltInInstanceIndex:
12511 // This is direct-mapped normally.
12512 if (!msl_options.vertex_for_tessellation)
12513 break;
12514
12515 entry_func.fixup_hooks_in.push_back([=]() {
12516 builtin_declaration = true;
12517 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12518 to_expression(builtin_invocation_id_id), ".y + ", to_expression(builtin_dispatch_base_id),
12519 ".y;");
12520 builtin_declaration = false;
12521 });
12522 break;
12523 case BuiltInBaseInstance:
12524 // This is direct-mapped normally.
12525 if (!msl_options.vertex_for_tessellation)
12526 break;
12527
12528 entry_func.fixup_hooks_in.push_back([=]() {
12529 statement(builtin_type_decl(bi_type), " ", to_expression(var_id), " = ",
12530 to_expression(builtin_dispatch_base_id), ".y;");
12531 });
12532 break;
12533 default:
12534 break;
12535 }
12536 }
12537 else if (var.storage == StorageClassOutput && get_execution_model() == ExecutionModelFragment &&
12538 is_builtin_variable(var) && active_output_builtins.get(bi_type) &&
12539 bi_type == BuiltInSampleMask && has_additional_fixed_sample_mask())
12540 {
12541 // If the additional fixed sample mask was set, we need to adjust the sample_mask
12542 // output to reflect that. If the shader outputs the sample_mask itself too, we need
12543 // to AND the two masks to get the final one.
12544 string op_str = does_shader_write_sample_mask ? " &= " : " = ";
12545 entry_func.fixup_hooks_out.push_back([=]() {
12546 statement(to_expression(builtin_sample_mask_id), op_str, additional_fixed_sample_mask_str(), ";");
12547 });
12548 }
12549 });
12550}
12551
12552// Returns the Metal index of the resource of the specified type as used by the specified variable.
12553uint32_t CompilerMSL::get_metal_resource_index(SPIRVariable &var, SPIRType::BaseType basetype, uint32_t plane)
12554{
12555 auto &execution = get_entry_point();
12556 auto &var_dec = ir.meta[var.self].decoration;
12557 auto &var_type = get<SPIRType>(var.basetype);
12558 uint32_t var_desc_set = (var.storage == StorageClassPushConstant) ? kPushConstDescSet : var_dec.set;
12559 uint32_t var_binding = (var.storage == StorageClassPushConstant) ? kPushConstBinding : var_dec.binding;
12560
12561 // If a matching binding has been specified, find and use it.
12562 auto itr = resource_bindings.find({ execution.model, var_desc_set, var_binding });
12563
12564 // Atomic helper buffers for image atomics need to use secondary bindings as well.
12565 bool use_secondary_binding = (var_type.basetype == SPIRType::SampledImage && basetype == SPIRType::Sampler) ||
12566 basetype == SPIRType::AtomicCounter;
12567
12568 auto resource_decoration =
12569 use_secondary_binding ? SPIRVCrossDecorationResourceIndexSecondary : SPIRVCrossDecorationResourceIndexPrimary;
12570
12571 if (plane == 1)
12572 resource_decoration = SPIRVCrossDecorationResourceIndexTertiary;
12573 if (plane == 2)
12574 resource_decoration = SPIRVCrossDecorationResourceIndexQuaternary;
12575
12576 if (itr != end(resource_bindings))
12577 {
12578 auto &remap = itr->second;
12579 remap.second = true;
12580 switch (basetype)
12581 {
12582 case SPIRType::Image:
12583 set_extended_decoration(var.self, resource_decoration, remap.first.msl_texture + plane);
12584 return remap.first.msl_texture + plane;
12585 case SPIRType::Sampler:
12586 set_extended_decoration(var.self, resource_decoration, remap.first.msl_sampler);
12587 return remap.first.msl_sampler;
12588 default:
12589 set_extended_decoration(var.self, resource_decoration, remap.first.msl_buffer);
12590 return remap.first.msl_buffer;
12591 }
12592 }
12593
12594 // If we have already allocated an index, keep using it.
12595 if (has_extended_decoration(var.self, resource_decoration))
12596 return get_extended_decoration(var.self, resource_decoration);
12597
12598 auto &type = get<SPIRType>(var.basetype);
12599
12600 if (type_is_msl_framebuffer_fetch(type))
12601 {
12602 // Frame-buffer fetch gets its fallback resource index from the input attachment index,
12603 // which is then treated as color index.
12604 return get_decoration(var.self, DecorationInputAttachmentIndex);
12605 }
12606 else if (msl_options.enable_decoration_binding)
12607 {
12608 // Allow user to enable decoration binding.
12609 // If there is no explicit mapping of bindings to MSL, use the declared binding as a fallback.
12610 if (has_decoration(var.self, DecorationBinding))
12611 {
12612 var_binding = get_decoration(var.self, DecorationBinding);
12613 // Avoid emitting sentinel bindings.
12614 if (var_binding < 0x80000000u)
12615 return var_binding;
12616 }
12617 }
12618
12619 // If we did not explicitly remap, allocate bindings on demand.
12620 // We cannot reliably use Binding decorations since SPIR-V and MSL's binding models are very different.
12621
12622 bool allocate_argument_buffer_ids = false;
12623
12624 if (var.storage != StorageClassPushConstant)
12625 allocate_argument_buffer_ids = descriptor_set_is_argument_buffer(var_desc_set);
12626
12627 uint32_t binding_stride = 1;
12628 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
12629 binding_stride *= to_array_size_literal(type, i);
12630
12631 assert(binding_stride != 0);
12632
12633 // If a binding has not been specified, revert to incrementing resource indices.
12634 uint32_t resource_index;
12635
12636 if (allocate_argument_buffer_ids)
12637 {
12638 // Allocate from a flat ID binding space.
12639 resource_index = next_metal_resource_ids[var_desc_set];
12640 next_metal_resource_ids[var_desc_set] += binding_stride;
12641 }
12642 else
12643 {
12644 // Allocate from plain bindings which are allocated per resource type.
12645 switch (basetype)
12646 {
12647 case SPIRType::Image:
12648 resource_index = next_metal_resource_index_texture;
12649 next_metal_resource_index_texture += binding_stride;
12650 break;
12651 case SPIRType::Sampler:
12652 resource_index = next_metal_resource_index_sampler;
12653 next_metal_resource_index_sampler += binding_stride;
12654 break;
12655 default:
12656 resource_index = next_metal_resource_index_buffer;
12657 next_metal_resource_index_buffer += binding_stride;
12658 break;
12659 }
12660 }
12661
12662 set_extended_decoration(var.self, resource_decoration, resource_index);
12663 return resource_index;
12664}
12665
12666bool CompilerMSL::type_is_msl_framebuffer_fetch(const SPIRType &type) const
12667{
12668 return type.basetype == SPIRType::Image && type.image.dim == DimSubpassData &&
12669 msl_options.use_framebuffer_fetch_subpasses;
12670}
12671
12672bool CompilerMSL::type_is_pointer(const SPIRType &type) const
12673{
12674 if (!type.pointer)
12675 return false;
12676 auto &parent_type = get<SPIRType>(type.parent_type);
12677 // Safeguards when we forget to set pointer_depth (there is an assert for it in type_to_glsl),
12678 // but the extra check shouldn't hurt.
12679 return (type.pointer_depth > parent_type.pointer_depth) || !parent_type.pointer;
12680}
12681
12682bool CompilerMSL::type_is_pointer_to_pointer(const SPIRType &type) const
12683{
12684 if (!type.pointer)
12685 return false;
12686 auto &parent_type = get<SPIRType>(type.parent_type);
12687 return type.pointer_depth > parent_type.pointer_depth && type_is_pointer(parent_type);
12688}
12689
12690const char *CompilerMSL::descriptor_address_space(uint32_t id, StorageClass storage, const char *plain_address_space) const
12691{
12692 if (msl_options.argument_buffers)
12693 {
12694 bool storage_class_is_descriptor = storage == StorageClassUniform ||
12695 storage == StorageClassStorageBuffer ||
12696 storage == StorageClassUniformConstant;
12697
12698 uint32_t desc_set = get_decoration(id, DecorationDescriptorSet);
12699 if (storage_class_is_descriptor && descriptor_set_is_argument_buffer(desc_set))
12700 {
12701 // An awkward case where we need to emit *more* address space declarations (yay!).
12702 // An example is where we pass down an array of buffer pointers to leaf functions.
12703 // It's a constant array containing pointers to constants.
12704 // The pointer array is always constant however. E.g.
12705 // device SSBO * constant (&array)[N].
12706 // const device SSBO * constant (&array)[N].
12707 // constant SSBO * constant (&array)[N].
12708 // However, this only matters for argument buffers, since for MSL 1.0 style codegen,
12709 // we emit the buffer array on stack instead, and that seems to work just fine apparently.
12710
12711 // If the argument was marked as being in device address space, any pointer to member would
12712 // be const device, not constant.
12713 if (argument_buffer_device_storage_mask & (1u << desc_set))
12714 return "const device";
12715 else
12716 return "constant";
12717 }
12718 }
12719
12720 return plain_address_space;
12721}
12722
12723string CompilerMSL::argument_decl(const SPIRFunction::Parameter &arg)
12724{
12725 auto &var = get<SPIRVariable>(arg.id);
12726 auto &type = get_variable_data_type(var);
12727 auto &var_type = get<SPIRType>(arg.type);
12728 StorageClass type_storage = var_type.storage;
12729 bool is_pointer = var_type.pointer;
12730
12731 // If we need to modify the name of the variable, make sure we use the original variable.
12732 // Our alias is just a shadow variable.
12733 uint32_t name_id = var.self;
12734 if (arg.alias_global_variable && var.basevariable)
12735 name_id = var.basevariable;
12736
12737 bool constref = !arg.alias_global_variable && is_pointer && arg.write_count == 0;
12738 // Framebuffer fetch is plain value, const looks out of place, but it is not wrong.
12739 if (type_is_msl_framebuffer_fetch(type))
12740 constref = false;
12741 else if (type_storage == StorageClassUniformConstant)
12742 constref = true;
12743
12744 bool type_is_image = type.basetype == SPIRType::Image || type.basetype == SPIRType::SampledImage ||
12745 type.basetype == SPIRType::Sampler;
12746
12747 // For opaque types we handle const later due to descriptor address spaces.
12748 const char *cv_qualifier = (constref && !type_is_image) ? "const " : "";
12749 string decl;
12750
12751 // If this is a combined image-sampler for a 2D image with floating-point type,
12752 // we emitted the 'spvDynamicImageSampler' type, and this is *not* an alias parameter
12753 // for a global, then we need to emit a "dynamic" combined image-sampler.
12754 // Unfortunately, this is necessary to properly support passing around
12755 // combined image-samplers with Y'CbCr conversions on them.
12756 bool is_dynamic_img_sampler = !arg.alias_global_variable && type.basetype == SPIRType::SampledImage &&
12757 type.image.dim == Dim2D && type_is_floating_point(get<SPIRType>(type.image.type)) &&
12758 spv_function_implementations.count(SPVFuncImplDynamicImageSampler);
12759
12760 // Allow Metal to use the array<T> template to make arrays a value type
12761 string address_space = get_argument_address_space(var);
12762 bool builtin = has_decoration(var.self, DecorationBuiltIn);
12763 auto builtin_type = BuiltIn(get_decoration(arg.id, DecorationBuiltIn));
12764
12765 if (address_space == "threadgroup")
12766 is_using_builtin_array = true;
12767
12768 if (var.basevariable && (var.basevariable == stage_in_ptr_var_id || var.basevariable == stage_out_ptr_var_id))
12769 decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12770 else if (builtin)
12771 {
12772 // Only use templated array for Clip/Cull distance when feasible.
12773 // In other scenarios, we need need to override array length for tess levels (if used as outputs),
12774 // or we need to emit the expected type for builtins (uint vs int).
12775 auto storage = get<SPIRType>(var.basetype).storage;
12776
12777 if (storage == StorageClassInput &&
12778 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
12779 {
12780 is_using_builtin_array = false;
12781 }
12782 else if (builtin_type != BuiltInClipDistance && builtin_type != BuiltInCullDistance)
12783 {
12784 is_using_builtin_array = true;
12785 }
12786
12787 if (storage == StorageClassOutput && variable_storage_requires_stage_io(storage) &&
12788 !is_stage_output_builtin_masked(builtin_type))
12789 is_using_builtin_array = true;
12790
12791 if (is_using_builtin_array)
12792 decl = join(cv_qualifier, builtin_type_decl(builtin_type, arg.id));
12793 else
12794 decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12795 }
12796 else if ((type_storage == StorageClassUniform || type_storage == StorageClassStorageBuffer) && is_array(type))
12797 {
12798 is_using_builtin_array = true;
12799 decl += join(cv_qualifier, type_to_glsl(type, arg.id), "*");
12800 }
12801 else if (is_dynamic_img_sampler)
12802 {
12803 decl = join(cv_qualifier, "spvDynamicImageSampler<", type_to_glsl(get<SPIRType>(type.image.type)), ">");
12804 // Mark the variable so that we can handle passing it to another function.
12805 set_extended_decoration(arg.id, SPIRVCrossDecorationDynamicImageSampler);
12806 }
12807 else
12808 {
12809 // The type is a pointer type we need to emit cv_qualifier late.
12810 if (type_is_pointer(type))
12811 {
12812 decl = type_to_glsl(type, arg.id);
12813 if (*cv_qualifier != '\0')
12814 decl += join(" ", cv_qualifier);
12815 }
12816 else
12817 decl = join(cv_qualifier, type_to_glsl(type, arg.id));
12818 }
12819
12820 if (!builtin && !is_pointer &&
12821 (type_storage == StorageClassFunction || type_storage == StorageClassGeneric))
12822 {
12823 // If the argument is a pure value and not an opaque type, we will pass by value.
12824 if (msl_options.force_native_arrays && is_array(type))
12825 {
12826 // We are receiving an array by value. This is problematic.
12827 // We cannot be sure of the target address space since we are supposed to receive a copy,
12828 // but this is not possible with MSL without some extra work.
12829 // We will have to assume we're getting a reference in thread address space.
12830 // If we happen to get a reference in constant address space, the caller must emit a copy and pass that.
12831 // Thread const therefore becomes the only logical choice, since we cannot "create" a constant array from
12832 // non-constant arrays, but we can create thread const from constant.
12833 decl = string("thread const ") + decl;
12834 decl += " (&";
12835 const char *restrict_kw = to_restrict(name_id);
12836 if (*restrict_kw)
12837 {
12838 decl += " ";
12839 decl += restrict_kw;
12840 }
12841 decl += to_expression(name_id);
12842 decl += ")";
12843 decl += type_to_array_glsl(type);
12844 }
12845 else
12846 {
12847 if (!address_space.empty())
12848 decl = join(address_space, " ", decl);
12849 decl += " ";
12850 decl += to_expression(name_id);
12851 }
12852 }
12853 else if (is_array(type) && !type_is_image)
12854 {
12855 // Arrays of opaque types are special cased.
12856 if (!address_space.empty())
12857 decl = join(address_space, " ", decl);
12858
12859 const char *argument_buffer_space = descriptor_address_space(name_id, type_storage, nullptr);
12860 if (argument_buffer_space)
12861 {
12862 decl += " ";
12863 decl += argument_buffer_space;
12864 }
12865
12866 // Special case, need to override the array size here if we're using tess level as an argument.
12867 if (get_execution_model() == ExecutionModelTessellationControl && builtin &&
12868 (builtin_type == BuiltInTessLevelInner || builtin_type == BuiltInTessLevelOuter))
12869 {
12870 uint32_t array_size = get_physical_tess_level_array_size(builtin_type);
12871 if (array_size == 1)
12872 {
12873 decl += " &";
12874 decl += to_expression(name_id);
12875 }
12876 else
12877 {
12878 decl += " (&";
12879 decl += to_expression(name_id);
12880 decl += ")";
12881 decl += join("[", array_size, "]");
12882 }
12883 }
12884 else
12885 {
12886 auto array_size_decl = type_to_array_glsl(type);
12887 if (array_size_decl.empty())
12888 decl += "& ";
12889 else
12890 decl += " (&";
12891
12892 const char *restrict_kw = to_restrict(name_id);
12893 if (*restrict_kw)
12894 {
12895 decl += " ";
12896 decl += restrict_kw;
12897 }
12898 decl += to_expression(name_id);
12899
12900 if (!array_size_decl.empty())
12901 {
12902 decl += ")";
12903 decl += array_size_decl;
12904 }
12905 }
12906 }
12907 else if (!type_is_image && (!pull_model_inputs.count(var.basevariable) || type.basetype == SPIRType::Struct))
12908 {
12909 // If this is going to be a reference to a variable pointer, the address space
12910 // for the reference has to go before the '&', but after the '*'.
12911 if (!address_space.empty())
12912 {
12913 if (type_is_pointer(type))
12914 {
12915 if (*cv_qualifier == '\0')
12916 decl += ' ';
12917 decl += join(address_space, " ");
12918 }
12919 else
12920 decl = join(address_space, " ", decl);
12921 }
12922 decl += "&";
12923 decl += " ";
12924 decl += to_restrict(name_id);
12925 decl += to_expression(name_id);
12926 }
12927 else if (type_is_image)
12928 {
12929 if (type.array.empty())
12930 {
12931 // For non-arrayed types we can just pass opaque descriptors by value.
12932 // This fixes problems if descriptors are passed by value from argument buffers and plain descriptors
12933 // in same shader.
12934 // There is no address space we can actually use, but value will work.
12935 // This will break if applications attempt to pass down descriptor arrays as arguments, but
12936 // fortunately that is extremely unlikely ...
12937 decl += " ";
12938 decl += to_expression(name_id);
12939 }
12940 else
12941 {
12942 const char *img_address_space = descriptor_address_space(name_id, type_storage, "thread const");
12943 decl = join(img_address_space, " ", decl);
12944 decl += "& ";
12945 decl += to_expression(name_id);
12946 }
12947 }
12948 else
12949 {
12950 if (!address_space.empty())
12951 decl = join(address_space, " ", decl);
12952 decl += " ";
12953 decl += to_expression(name_id);
12954 }
12955
12956 // Emulate texture2D atomic operations
12957 auto *backing_var = maybe_get_backing_variable(name_id);
12958 if (backing_var && atomic_image_vars.count(backing_var->self))
12959 {
12960 decl += ", device atomic_" + type_to_glsl(get<SPIRType>(var_type.image.type), 0);
12961 decl += "* " + to_expression(name_id) + "_atomic";
12962 }
12963
12964 is_using_builtin_array = false;
12965
12966 return decl;
12967}
12968
12969// If we're currently in the entry point function, and the object
12970// has a qualified name, use it, otherwise use the standard name.
12971string CompilerMSL::to_name(uint32_t id, bool allow_alias) const
12972{
12973 if (current_function && (current_function->self == ir.default_entry_point))
12974 {
12975 auto *m = ir.find_meta(id);
12976 if (m && !m->decoration.qualified_alias.empty())
12977 return m->decoration.qualified_alias;
12978 }
12979 return Compiler::to_name(id, allow_alias);
12980}
12981
12982// Returns a name that combines the name of the struct with the name of the member, except for Builtins
12983string CompilerMSL::to_qualified_member_name(const SPIRType &type, uint32_t index)
12984{
12985 // Don't qualify Builtin names because they are unique and are treated as such when building expressions
12986 BuiltIn builtin = BuiltInMax;
12987 if (is_member_builtin(type, index, &builtin))
12988 return builtin_to_glsl(builtin, type.storage);
12989
12990 // Strip any underscore prefix from member name
12991 string mbr_name = to_member_name(type, index);
12992 size_t startPos = mbr_name.find_first_not_of("_");
12993 mbr_name = (startPos != string::npos) ? mbr_name.substr(startPos) : "";
12994 return join(to_name(type.self), "_", mbr_name);
12995}
12996
12997// Ensures that the specified name is permanently usable by prepending a prefix
12998// if the first chars are _ and a digit, which indicate a transient name.
12999string CompilerMSL::ensure_valid_name(string name, string pfx)
13000{
13001 return (name.size() >= 2 && name[0] == '_' && isdigit(name[1])) ? (pfx + name) : name;
13002}
13003
13004const std::unordered_set<std::string> &CompilerMSL::get_reserved_keyword_set()
13005{
13006 static const unordered_set<string> keywords = {
13007 "kernel",
13008 "vertex",
13009 "fragment",
13010 "compute",
13011 "bias",
13012 "level",
13013 "gradient2d",
13014 "gradientcube",
13015 "gradient3d",
13016 "min_lod_clamp",
13017 "assert",
13018 "VARIABLE_TRACEPOINT",
13019 "STATIC_DATA_TRACEPOINT",
13020 "STATIC_DATA_TRACEPOINT_V",
13021 "METAL_ALIGN",
13022 "METAL_ASM",
13023 "METAL_CONST",
13024 "METAL_DEPRECATED",
13025 "METAL_ENABLE_IF",
13026 "METAL_FUNC",
13027 "METAL_INTERNAL",
13028 "METAL_NON_NULL_RETURN",
13029 "METAL_NORETURN",
13030 "METAL_NOTHROW",
13031 "METAL_PURE",
13032 "METAL_UNAVAILABLE",
13033 "METAL_IMPLICIT",
13034 "METAL_EXPLICIT",
13035 "METAL_CONST_ARG",
13036 "METAL_ARG_UNIFORM",
13037 "METAL_ZERO_ARG",
13038 "METAL_VALID_LOD_ARG",
13039 "METAL_VALID_LEVEL_ARG",
13040 "METAL_VALID_STORE_ORDER",
13041 "METAL_VALID_LOAD_ORDER",
13042 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
13043 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
13044 "METAL_VALID_RENDER_TARGET",
13045 "is_function_constant_defined",
13046 "CHAR_BIT",
13047 "SCHAR_MAX",
13048 "SCHAR_MIN",
13049 "UCHAR_MAX",
13050 "CHAR_MAX",
13051 "CHAR_MIN",
13052 "USHRT_MAX",
13053 "SHRT_MAX",
13054 "SHRT_MIN",
13055 "UINT_MAX",
13056 "INT_MAX",
13057 "INT_MIN",
13058 "FLT_DIG",
13059 "FLT_MANT_DIG",
13060 "FLT_MAX_10_EXP",
13061 "FLT_MAX_EXP",
13062 "FLT_MIN_10_EXP",
13063 "FLT_MIN_EXP",
13064 "FLT_RADIX",
13065 "FLT_MAX",
13066 "FLT_MIN",
13067 "FLT_EPSILON",
13068 "FP_ILOGB0",
13069 "FP_ILOGBNAN",
13070 "MAXFLOAT",
13071 "HUGE_VALF",
13072 "INFINITY",
13073 "NAN",
13074 "M_E_F",
13075 "M_LOG2E_F",
13076 "M_LOG10E_F",
13077 "M_LN2_F",
13078 "M_LN10_F",
13079 "M_PI_F",
13080 "M_PI_2_F",
13081 "M_PI_4_F",
13082 "M_1_PI_F",
13083 "M_2_PI_F",
13084 "M_2_SQRTPI_F",
13085 "M_SQRT2_F",
13086 "M_SQRT1_2_F",
13087 "HALF_DIG",
13088 "HALF_MANT_DIG",
13089 "HALF_MAX_10_EXP",
13090 "HALF_MAX_EXP",
13091 "HALF_MIN_10_EXP",
13092 "HALF_MIN_EXP",
13093 "HALF_RADIX",
13094 "HALF_MAX",
13095 "HALF_MIN",
13096 "HALF_EPSILON",
13097 "MAXHALF",
13098 "HUGE_VALH",
13099 "M_E_H",
13100 "M_LOG2E_H",
13101 "M_LOG10E_H",
13102 "M_LN2_H",
13103 "M_LN10_H",
13104 "M_PI_H",
13105 "M_PI_2_H",
13106 "M_PI_4_H",
13107 "M_1_PI_H",
13108 "M_2_PI_H",
13109 "M_2_SQRTPI_H",
13110 "M_SQRT2_H",
13111 "M_SQRT1_2_H",
13112 "DBL_DIG",
13113 "DBL_MANT_DIG",
13114 "DBL_MAX_10_EXP",
13115 "DBL_MAX_EXP",
13116 "DBL_MIN_10_EXP",
13117 "DBL_MIN_EXP",
13118 "DBL_RADIX",
13119 "DBL_MAX",
13120 "DBL_MIN",
13121 "DBL_EPSILON",
13122 "HUGE_VAL",
13123 "M_E",
13124 "M_LOG2E",
13125 "M_LOG10E",
13126 "M_LN2",
13127 "M_LN10",
13128 "M_PI",
13129 "M_PI_2",
13130 "M_PI_4",
13131 "M_1_PI",
13132 "M_2_PI",
13133 "M_2_SQRTPI",
13134 "M_SQRT2",
13135 "M_SQRT1_2",
13136 "quad_broadcast",
13137 };
13138
13139 return keywords;
13140}
13141
13142const std::unordered_set<std::string> &CompilerMSL::get_illegal_func_names()
13143{
13144 static const unordered_set<string> illegal_func_names = {
13145 "main",
13146 "saturate",
13147 "assert",
13148 "fmin3",
13149 "fmax3",
13150 "VARIABLE_TRACEPOINT",
13151 "STATIC_DATA_TRACEPOINT",
13152 "STATIC_DATA_TRACEPOINT_V",
13153 "METAL_ALIGN",
13154 "METAL_ASM",
13155 "METAL_CONST",
13156 "METAL_DEPRECATED",
13157 "METAL_ENABLE_IF",
13158 "METAL_FUNC",
13159 "METAL_INTERNAL",
13160 "METAL_NON_NULL_RETURN",
13161 "METAL_NORETURN",
13162 "METAL_NOTHROW",
13163 "METAL_PURE",
13164 "METAL_UNAVAILABLE",
13165 "METAL_IMPLICIT",
13166 "METAL_EXPLICIT",
13167 "METAL_CONST_ARG",
13168 "METAL_ARG_UNIFORM",
13169 "METAL_ZERO_ARG",
13170 "METAL_VALID_LOD_ARG",
13171 "METAL_VALID_LEVEL_ARG",
13172 "METAL_VALID_STORE_ORDER",
13173 "METAL_VALID_LOAD_ORDER",
13174 "METAL_VALID_COMPARE_EXCHANGE_FAILURE_ORDER",
13175 "METAL_COMPATIBLE_COMPARE_EXCHANGE_ORDERS",
13176 "METAL_VALID_RENDER_TARGET",
13177 "is_function_constant_defined",
13178 "CHAR_BIT",
13179 "SCHAR_MAX",
13180 "SCHAR_MIN",
13181 "UCHAR_MAX",
13182 "CHAR_MAX",
13183 "CHAR_MIN",
13184 "USHRT_MAX",
13185 "SHRT_MAX",
13186 "SHRT_MIN",
13187 "UINT_MAX",
13188 "INT_MAX",
13189 "INT_MIN",
13190 "FLT_DIG",
13191 "FLT_MANT_DIG",
13192 "FLT_MAX_10_EXP",
13193 "FLT_MAX_EXP",
13194 "FLT_MIN_10_EXP",
13195 "FLT_MIN_EXP",
13196 "FLT_RADIX",
13197 "FLT_MAX",
13198 "FLT_MIN",
13199 "FLT_EPSILON",
13200 "FP_ILOGB0",
13201 "FP_ILOGBNAN",
13202 "MAXFLOAT",
13203 "HUGE_VALF",
13204 "INFINITY",
13205 "NAN",
13206 "M_E_F",
13207 "M_LOG2E_F",
13208 "M_LOG10E_F",
13209 "M_LN2_F",
13210 "M_LN10_F",
13211 "M_PI_F",
13212 "M_PI_2_F",
13213 "M_PI_4_F",
13214 "M_1_PI_F",
13215 "M_2_PI_F",
13216 "M_2_SQRTPI_F",
13217 "M_SQRT2_F",
13218 "M_SQRT1_2_F",
13219 "HALF_DIG",
13220 "HALF_MANT_DIG",
13221 "HALF_MAX_10_EXP",
13222 "HALF_MAX_EXP",
13223 "HALF_MIN_10_EXP",
13224 "HALF_MIN_EXP",
13225 "HALF_RADIX",
13226 "HALF_MAX",
13227 "HALF_MIN",
13228 "HALF_EPSILON",
13229 "MAXHALF",
13230 "HUGE_VALH",
13231 "M_E_H",
13232 "M_LOG2E_H",
13233 "M_LOG10E_H",
13234 "M_LN2_H",
13235 "M_LN10_H",
13236 "M_PI_H",
13237 "M_PI_2_H",
13238 "M_PI_4_H",
13239 "M_1_PI_H",
13240 "M_2_PI_H",
13241 "M_2_SQRTPI_H",
13242 "M_SQRT2_H",
13243 "M_SQRT1_2_H",
13244 "DBL_DIG",
13245 "DBL_MANT_DIG",
13246 "DBL_MAX_10_EXP",
13247 "DBL_MAX_EXP",
13248 "DBL_MIN_10_EXP",
13249 "DBL_MIN_EXP",
13250 "DBL_RADIX",
13251 "DBL_MAX",
13252 "DBL_MIN",
13253 "DBL_EPSILON",
13254 "HUGE_VAL",
13255 "M_E",
13256 "M_LOG2E",
13257 "M_LOG10E",
13258 "M_LN2",
13259 "M_LN10",
13260 "M_PI",
13261 "M_PI_2",
13262 "M_PI_4",
13263 "M_1_PI",
13264 "M_2_PI",
13265 "M_2_SQRTPI",
13266 "M_SQRT2",
13267 "M_SQRT1_2",
13268 };
13269
13270 return illegal_func_names;
13271}
13272
13273// Replace all names that match MSL keywords or Metal Standard Library functions.
13274void CompilerMSL::replace_illegal_names()
13275{
13276 // FIXME: MSL and GLSL are doing two different things here.
13277 // Agree on convention and remove this override.
13278 auto &keywords = get_reserved_keyword_set();
13279 auto &illegal_func_names = get_illegal_func_names();
13280
13281 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &) {
13282 auto *meta = ir.find_meta(self);
13283 if (!meta)
13284 return;
13285
13286 auto &dec = meta->decoration;
13287 if (keywords.find(dec.alias) != end(keywords))
13288 dec.alias += "0";
13289 });
13290
13291 ir.for_each_typed_id<SPIRFunction>([&](uint32_t self, SPIRFunction &) {
13292 auto *meta = ir.find_meta(self);
13293 if (!meta)
13294 return;
13295
13296 auto &dec = meta->decoration;
13297 if (illegal_func_names.find(dec.alias) != end(illegal_func_names))
13298 dec.alias += "0";
13299 });
13300
13301 ir.for_each_typed_id<SPIRType>([&](uint32_t self, SPIRType &) {
13302 auto *meta = ir.find_meta(self);
13303 if (!meta)
13304 return;
13305
13306 for (auto &mbr_dec : meta->members)
13307 if (keywords.find(mbr_dec.alias) != end(keywords))
13308 mbr_dec.alias += "0";
13309 });
13310
13311 CompilerGLSL::replace_illegal_names();
13312}
13313
13314void CompilerMSL::replace_illegal_entry_point_names()
13315{
13316 auto &illegal_func_names = get_illegal_func_names();
13317
13318 // It is important to this before we fixup identifiers,
13319 // since if ep_name is reserved, we will need to fix that up,
13320 // and then copy alias back into entry.name after the fixup.
13321 for (auto &entry : ir.entry_points)
13322 {
13323 // Change both the entry point name and the alias, to keep them synced.
13324 string &ep_name = entry.second.name;
13325 if (illegal_func_names.find(ep_name) != end(illegal_func_names))
13326 ep_name += "0";
13327
13328 ir.meta[entry.first].decoration.alias = ep_name;
13329 }
13330}
13331
13332void CompilerMSL::sync_entry_point_aliases_and_names()
13333{
13334 for (auto &entry : ir.entry_points)
13335 entry.second.name = ir.meta[entry.first].decoration.alias;
13336}
13337
13338string CompilerMSL::to_member_reference(uint32_t base, const SPIRType &type, uint32_t index, bool ptr_chain)
13339{
13340 auto *var = maybe_get<SPIRVariable>(base);
13341 // If this is a buffer array, we have to dereference the buffer pointers.
13342 // Otherwise, if this is a pointer expression, dereference it.
13343
13344 bool declared_as_pointer = false;
13345
13346 if (var)
13347 {
13348 // Only allow -> dereference for block types. This is so we get expressions like
13349 // buffer[i]->first_member.second_member, rather than buffer[i]->first->second.
13350 bool is_block = has_decoration(type.self, DecorationBlock) || has_decoration(type.self, DecorationBufferBlock);
13351
13352 bool is_buffer_variable =
13353 is_block && (var->storage == StorageClassUniform || var->storage == StorageClassStorageBuffer);
13354 declared_as_pointer = is_buffer_variable && is_array(get<SPIRType>(var->basetype));
13355 }
13356
13357 if (declared_as_pointer || (!ptr_chain && should_dereference(base)))
13358 return join("->", to_member_name(type, index));
13359 else
13360 return join(".", to_member_name(type, index));
13361}
13362
13363string CompilerMSL::to_qualifiers_glsl(uint32_t id)
13364{
13365 string quals;
13366
13367 auto *var = maybe_get<SPIRVariable>(id);
13368 auto &type = expression_type(id);
13369
13370 if (type.storage == StorageClassWorkgroup || (var && variable_decl_is_remapped_storage(*var, StorageClassWorkgroup)))
13371 quals += "threadgroup ";
13372
13373 return quals;
13374}
13375
13376// The optional id parameter indicates the object whose type we are trying
13377// to find the description for. It is optional. Most type descriptions do not
13378// depend on a specific object's use of that type.
13379string CompilerMSL::type_to_glsl(const SPIRType &type, uint32_t id)
13380{
13381 string type_name;
13382
13383 // Pointer?
13384 if (type.pointer)
13385 {
13386 assert(type.pointer_depth > 0);
13387
13388 const char *restrict_kw;
13389
13390 auto type_address_space = get_type_address_space(type, id);
13391 auto type_decl = type_to_glsl(get<SPIRType>(type.parent_type), id);
13392
13393 // Work around C pointer qualifier rules. If glsl_type is a pointer type as well
13394 // we'll need to emit the address space to the right.
13395 // We could always go this route, but it makes the code unnatural.
13396 // Prefer emitting thread T *foo over T thread* foo since it's more readable,
13397 // but we'll have to emit thread T * thread * T constant bar; for example.
13398 if (type_is_pointer_to_pointer(type))
13399 type_name = join(type_decl, " ", type_address_space, " ");
13400 else
13401 type_name = join(type_address_space, " ", type_decl);
13402
13403 switch (type.basetype)
13404 {
13405 case SPIRType::Image:
13406 case SPIRType::SampledImage:
13407 case SPIRType::Sampler:
13408 // These are handles.
13409 break;
13410 default:
13411 // Anything else can be a raw pointer.
13412 type_name += "*";
13413 restrict_kw = to_restrict(id);
13414 if (*restrict_kw)
13415 {
13416 type_name += " ";
13417 type_name += restrict_kw;
13418 }
13419 break;
13420 }
13421 return type_name;
13422 }
13423
13424 switch (type.basetype)
13425 {
13426 case SPIRType::Struct:
13427 // Need OpName lookup here to get a "sensible" name for a struct.
13428 // Allow Metal to use the array<T> template to make arrays a value type
13429 type_name = to_name(type.self);
13430 break;
13431
13432 case SPIRType::Image:
13433 case SPIRType::SampledImage:
13434 return image_type_glsl(type, id);
13435
13436 case SPIRType::Sampler:
13437 return sampler_type(type, id);
13438
13439 case SPIRType::Void:
13440 return "void";
13441
13442 case SPIRType::AtomicCounter:
13443 return "atomic_uint";
13444
13445 case SPIRType::ControlPointArray:
13446 return join("patch_control_point<", type_to_glsl(get<SPIRType>(type.parent_type), id), ">");
13447
13448 case SPIRType::Interpolant:
13449 return join("interpolant<", type_to_glsl(get<SPIRType>(type.parent_type), id), ", interpolation::",
13450 has_decoration(type.self, DecorationNoPerspective) ? "no_perspective" : "perspective", ">");
13451
13452 // Scalars
13453 case SPIRType::Boolean:
13454 {
13455 auto *var = maybe_get_backing_variable(id);
13456 if (var && var->basevariable)
13457 var = &get<SPIRVariable>(var->basevariable);
13458
13459 // Need to special-case threadgroup booleans. They are supposed to be logical
13460 // storage, but MSL compilers will sometimes crash if you use threadgroup bool.
13461 // Workaround this by using 16-bit types instead and fixup on load-store to this data.
13462 // FIXME: We have no sane way of working around this problem if a struct member is boolean
13463 // and that struct is used as a threadgroup variable, but ... sigh.
13464 if ((var && var->storage == StorageClassWorkgroup) || type.storage == StorageClassWorkgroup)
13465 type_name = "short";
13466 else
13467 type_name = "bool";
13468 break;
13469 }
13470
13471 case SPIRType::Char:
13472 case SPIRType::SByte:
13473 type_name = "char";
13474 break;
13475 case SPIRType::UByte:
13476 type_name = "uchar";
13477 break;
13478 case SPIRType::Short:
13479 type_name = "short";
13480 break;
13481 case SPIRType::UShort:
13482 type_name = "ushort";
13483 break;
13484 case SPIRType::Int:
13485 type_name = "int";
13486 break;
13487 case SPIRType::UInt:
13488 type_name = "uint";
13489 break;
13490 case SPIRType::Int64:
13491 if (!msl_options.supports_msl_version(2, 2))
13492 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
13493 type_name = "long";
13494 break;
13495 case SPIRType::UInt64:
13496 if (!msl_options.supports_msl_version(2, 2))
13497 SPIRV_CROSS_THROW("64-bit integers are only supported in MSL 2.2 and above.");
13498 type_name = "ulong";
13499 break;
13500 case SPIRType::Half:
13501 type_name = "half";
13502 break;
13503 case SPIRType::Float:
13504 type_name = "float";
13505 break;
13506 case SPIRType::Double:
13507 type_name = "double"; // Currently unsupported
13508 break;
13509 case SPIRType::AccelerationStructure:
13510 if (msl_options.supports_msl_version(2, 4))
13511 type_name = "acceleration_structure<instancing>";
13512 else if (msl_options.supports_msl_version(2, 3))
13513 type_name = "instance_acceleration_structure";
13514 else
13515 SPIRV_CROSS_THROW("Acceleration Structure Type is supported in MSL 2.3 and above.");
13516 break;
13517 case SPIRType::RayQuery:
13518 return "intersection_query<instancing, triangle_data>";
13519
13520 default:
13521 return "unknown_type";
13522 }
13523
13524 // Matrix?
13525 if (type.columns > 1)
13526 type_name += to_string(type.columns) + "x";
13527
13528 // Vector or Matrix?
13529 if (type.vecsize > 1)
13530 type_name += to_string(type.vecsize);
13531
13532 if (type.array.empty() || using_builtin_array())
13533 {
13534 return type_name;
13535 }
13536 else
13537 {
13538 // Allow Metal to use the array<T> template to make arrays a value type
13539 add_spv_func_and_recompile(SPVFuncImplUnsafeArray);
13540 string res;
13541 string sizes;
13542
13543 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
13544 {
13545 res += "spvUnsafeArray<";
13546 sizes += ", ";
13547 sizes += to_array_size(type, i);
13548 sizes += ">";
13549 }
13550
13551 res += type_name + sizes;
13552 return res;
13553 }
13554}
13555
13556string CompilerMSL::type_to_array_glsl(const SPIRType &type)
13557{
13558 // Allow Metal to use the array<T> template to make arrays a value type
13559 switch (type.basetype)
13560 {
13561 case SPIRType::AtomicCounter:
13562 case SPIRType::ControlPointArray:
13563 case SPIRType::RayQuery:
13564 {
13565 return CompilerGLSL::type_to_array_glsl(type);
13566 }
13567 default:
13568 {
13569 if (using_builtin_array())
13570 return CompilerGLSL::type_to_array_glsl(type);
13571 else
13572 return "";
13573 }
13574 }
13575}
13576
13577string CompilerMSL::constant_op_expression(const SPIRConstantOp &cop)
13578{
13579 switch (cop.opcode)
13580 {
13581 case OpQuantizeToF16:
13582 add_spv_func_and_recompile(SPVFuncImplQuantizeToF16);
13583 return join("spvQuantizeToF16(", to_expression(cop.arguments[0]), ")");
13584 default:
13585 return CompilerGLSL::constant_op_expression(cop);
13586 }
13587}
13588
13589bool CompilerMSL::variable_decl_is_remapped_storage(const SPIRVariable &variable, spv::StorageClass storage) const
13590{
13591 if (variable.storage == storage)
13592 return true;
13593
13594 if (storage == StorageClassWorkgroup)
13595 {
13596 auto model = get_execution_model();
13597
13598 // Specially masked IO block variable.
13599 // Normally, we will never access IO blocks directly here.
13600 // The only scenario which that should occur is with a masked IO block.
13601 if (model == ExecutionModelTessellationControl && variable.storage == StorageClassOutput &&
13602 has_decoration(get<SPIRType>(variable.basetype).self, DecorationBlock))
13603 {
13604 return true;
13605 }
13606
13607 return variable.storage == StorageClassOutput &&
13608 model == ExecutionModelTessellationControl &&
13609 is_stage_output_variable_masked(variable);
13610 }
13611 else if (storage == StorageClassStorageBuffer)
13612 {
13613 // We won't be able to catch writes to control point outputs here since variable
13614 // refers to a function local pointer.
13615 // This is fine, as there cannot be concurrent writers to that memory anyways,
13616 // so we just ignore that case.
13617
13618 return (variable.storage == StorageClassOutput || variable.storage == StorageClassInput) &&
13619 !variable_storage_requires_stage_io(variable.storage) &&
13620 (variable.storage != StorageClassOutput || !is_stage_output_variable_masked(variable));
13621 }
13622 else
13623 {
13624 return false;
13625 }
13626}
13627
13628std::string CompilerMSL::variable_decl(const SPIRVariable &variable)
13629{
13630 bool old_is_using_builtin_array = is_using_builtin_array;
13631
13632 // Threadgroup arrays can't have a wrapper type.
13633 if (variable_decl_is_remapped_storage(variable, StorageClassWorkgroup))
13634 is_using_builtin_array = true;
13635
13636 auto expr = CompilerGLSL::variable_decl(variable);
13637 is_using_builtin_array = old_is_using_builtin_array;
13638 return expr;
13639}
13640
13641// GCC workaround of lambdas calling protected funcs
13642std::string CompilerMSL::variable_decl(const SPIRType &type, const std::string &name, uint32_t id)
13643{
13644 return CompilerGLSL::variable_decl(type, name, id);
13645}
13646
13647std::string CompilerMSL::sampler_type(const SPIRType &type, uint32_t id)
13648{
13649 auto *var = maybe_get<SPIRVariable>(id);
13650 if (var && var->basevariable)
13651 {
13652 // Check against the base variable, and not a fake ID which might have been generated for this variable.
13653 id = var->basevariable;
13654 }
13655
13656 if (!type.array.empty())
13657 {
13658 if (!msl_options.supports_msl_version(2))
13659 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of samplers.");
13660
13661 if (type.array.size() > 1)
13662 SPIRV_CROSS_THROW("Arrays of arrays of samplers are not supported in MSL.");
13663
13664 // Arrays of samplers in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
13665 // If we have a runtime array, it could be a variable-count descriptor set binding.
13666 uint32_t array_size = to_array_size_literal(type);
13667 if (array_size == 0)
13668 array_size = get_resource_array_size(id);
13669
13670 if (array_size == 0)
13671 SPIRV_CROSS_THROW("Unsized array of samplers is not supported in MSL.");
13672
13673 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
13674 return join("array<", sampler_type(parent, id), ", ", array_size, ">");
13675 }
13676 else
13677 return "sampler";
13678}
13679
13680// Returns an MSL string describing the SPIR-V image type
13681string CompilerMSL::image_type_glsl(const SPIRType &type, uint32_t id)
13682{
13683 auto *var = maybe_get<SPIRVariable>(id);
13684 if (var && var->basevariable)
13685 {
13686 // For comparison images, check against the base variable,
13687 // and not the fake ID which might have been generated for this variable.
13688 id = var->basevariable;
13689 }
13690
13691 if (!type.array.empty())
13692 {
13693 uint32_t major = 2, minor = 0;
13694 if (msl_options.is_ios())
13695 {
13696 major = 1;
13697 minor = 2;
13698 }
13699 if (!msl_options.supports_msl_version(major, minor))
13700 {
13701 if (msl_options.is_ios())
13702 SPIRV_CROSS_THROW("MSL 1.2 or greater is required for arrays of textures.");
13703 else
13704 SPIRV_CROSS_THROW("MSL 2.0 or greater is required for arrays of textures.");
13705 }
13706
13707 if (type.array.size() > 1)
13708 SPIRV_CROSS_THROW("Arrays of arrays of textures are not supported in MSL.");
13709
13710 // Arrays of images in MSL must be declared with a special array<T, N> syntax ala C++11 std::array.
13711 // If we have a runtime array, it could be a variable-count descriptor set binding.
13712 uint32_t array_size = to_array_size_literal(type);
13713 if (array_size == 0)
13714 array_size = get_resource_array_size(id);
13715
13716 if (array_size == 0)
13717 SPIRV_CROSS_THROW("Unsized array of images is not supported in MSL.");
13718
13719 auto &parent = get<SPIRType>(get_pointee_type(type).parent_type);
13720 return join("array<", image_type_glsl(parent, id), ", ", array_size, ">");
13721 }
13722
13723 string img_type_name;
13724
13725 // Bypass pointers because we need the real image struct
13726 auto &img_type = get<SPIRType>(type.self).image;
13727 if (is_depth_image(type, id))
13728 {
13729 switch (img_type.dim)
13730 {
13731 case Dim1D:
13732 case Dim2D:
13733 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
13734 {
13735 // Use a native Metal 1D texture
13736 img_type_name += "depth1d_unsupported_by_metal";
13737 break;
13738 }
13739
13740 if (img_type.ms && img_type.arrayed)
13741 {
13742 if (!msl_options.supports_msl_version(2, 1))
13743 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
13744 img_type_name += "depth2d_ms_array";
13745 }
13746 else if (img_type.ms)
13747 img_type_name += "depth2d_ms";
13748 else if (img_type.arrayed)
13749 img_type_name += "depth2d_array";
13750 else
13751 img_type_name += "depth2d";
13752 break;
13753 case Dim3D:
13754 img_type_name += "depth3d_unsupported_by_metal";
13755 break;
13756 case DimCube:
13757 if (!msl_options.emulate_cube_array)
13758 img_type_name += (img_type.arrayed ? "depthcube_array" : "depthcube");
13759 else
13760 img_type_name += (img_type.arrayed ? "depth2d_array" : "depthcube");
13761 break;
13762 default:
13763 img_type_name += "unknown_depth_texture_type";
13764 break;
13765 }
13766 }
13767 else
13768 {
13769 switch (img_type.dim)
13770 {
13771 case DimBuffer:
13772 if (img_type.ms || img_type.arrayed)
13773 SPIRV_CROSS_THROW("Cannot use texel buffers with multisampling or array layers.");
13774
13775 if (msl_options.texture_buffer_native)
13776 {
13777 if (!msl_options.supports_msl_version(2, 1))
13778 SPIRV_CROSS_THROW("Native texture_buffer type is only supported in MSL 2.1.");
13779 img_type_name = "texture_buffer";
13780 }
13781 else
13782 img_type_name += "texture2d";
13783 break;
13784 case Dim1D:
13785 case Dim2D:
13786 case DimSubpassData:
13787 {
13788 bool subpass_array =
13789 img_type.dim == DimSubpassData && (msl_options.multiview || msl_options.arrayed_subpass_input);
13790 if (img_type.dim == Dim1D && !msl_options.texture_1D_as_2D)
13791 {
13792 // Use a native Metal 1D texture
13793 img_type_name += (img_type.arrayed ? "texture1d_array" : "texture1d");
13794 break;
13795 }
13796
13797 // Use Metal's native frame-buffer fetch API for subpass inputs.
13798 if (type_is_msl_framebuffer_fetch(type))
13799 {
13800 auto img_type_4 = get<SPIRType>(img_type.type);
13801 img_type_4.vecsize = 4;
13802 return type_to_glsl(img_type_4);
13803 }
13804 if (img_type.ms && (img_type.arrayed || subpass_array))
13805 {
13806 if (!msl_options.supports_msl_version(2, 1))
13807 SPIRV_CROSS_THROW("Multisampled array textures are supported from 2.1.");
13808 img_type_name += "texture2d_ms_array";
13809 }
13810 else if (img_type.ms)
13811 img_type_name += "texture2d_ms";
13812 else if (img_type.arrayed || subpass_array)
13813 img_type_name += "texture2d_array";
13814 else
13815 img_type_name += "texture2d";
13816 break;
13817 }
13818 case Dim3D:
13819 img_type_name += "texture3d";
13820 break;
13821 case DimCube:
13822 if (!msl_options.emulate_cube_array)
13823 img_type_name += (img_type.arrayed ? "texturecube_array" : "texturecube");
13824 else
13825 img_type_name += (img_type.arrayed ? "texture2d_array" : "texturecube");
13826 break;
13827 default:
13828 img_type_name += "unknown_texture_type";
13829 break;
13830 }
13831 }
13832
13833 // Append the pixel type
13834 img_type_name += "<";
13835 img_type_name += type_to_glsl(get<SPIRType>(img_type.type));
13836
13837 // For unsampled images, append the sample/read/write access qualifier.
13838 // For kernel images, the access qualifier my be supplied directly by SPIR-V.
13839 // Otherwise it may be set based on whether the image is read from or written to within the shader.
13840 if (type.basetype == SPIRType::Image && type.image.sampled == 2 && type.image.dim != DimSubpassData)
13841 {
13842 switch (img_type.access)
13843 {
13844 case AccessQualifierReadOnly:
13845 img_type_name += ", access::read";
13846 break;
13847
13848 case AccessQualifierWriteOnly:
13849 img_type_name += ", access::write";
13850 break;
13851
13852 case AccessQualifierReadWrite:
13853 img_type_name += ", access::read_write";
13854 break;
13855
13856 default:
13857 {
13858 auto *p_var = maybe_get_backing_variable(id);
13859 if (p_var && p_var->basevariable)
13860 p_var = maybe_get<SPIRVariable>(p_var->basevariable);
13861 if (p_var && !has_decoration(p_var->self, DecorationNonWritable))
13862 {
13863 img_type_name += ", access::";
13864
13865 if (!has_decoration(p_var->self, DecorationNonReadable))
13866 img_type_name += "read_";
13867
13868 img_type_name += "write";
13869 }
13870 break;
13871 }
13872 }
13873 }
13874
13875 img_type_name += ">";
13876
13877 return img_type_name;
13878}
13879
13880void CompilerMSL::emit_subgroup_op(const Instruction &i)
13881{
13882 const uint32_t *ops = stream(i);
13883 auto op = static_cast<Op>(i.op);
13884
13885 if (msl_options.emulate_subgroups)
13886 {
13887 // In this mode, only the GroupNonUniform cap is supported. The only op
13888 // we need to handle, then, is OpGroupNonUniformElect.
13889 if (op != OpGroupNonUniformElect)
13890 SPIRV_CROSS_THROW("Subgroup emulation does not support operations other than Elect.");
13891 // In this mode, the subgroup size is assumed to be one, so every invocation
13892 // is elected.
13893 emit_op(ops[0], ops[1], "true", true);
13894 return;
13895 }
13896
13897 // Metal 2.0 is required. iOS only supports quad ops on 11.0 (2.0), with
13898 // full support in 13.0 (2.2). macOS only supports broadcast and shuffle on
13899 // 10.13 (2.0), with full support in 10.14 (2.1).
13900 // Note that Apple GPUs before A13 make no distinction between a quad-group
13901 // and a SIMD-group; all SIMD-groups are quad-groups on those.
13902 if (!msl_options.supports_msl_version(2))
13903 SPIRV_CROSS_THROW("Subgroups are only supported in Metal 2.0 and up.");
13904
13905 // If we need to do implicit bitcasts, make sure we do it with the correct type.
13906 uint32_t integer_width = get_integer_width_for_instruction(i);
13907 auto int_type = to_signed_basetype(integer_width);
13908 auto uint_type = to_unsigned_basetype(integer_width);
13909
13910 if (msl_options.is_ios() && (!msl_options.supports_msl_version(2, 3) || !msl_options.ios_use_simdgroup_functions))
13911 {
13912 switch (op)
13913 {
13914 default:
13915 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast, ballot, and shuffle on iOS require Metal 2.3 and up.");
13916 case OpGroupNonUniformBroadcastFirst:
13917 if (!msl_options.supports_msl_version(2, 2))
13918 SPIRV_CROSS_THROW("BroadcastFirst on iOS requires Metal 2.2 and up.");
13919 break;
13920 case OpGroupNonUniformElect:
13921 if (!msl_options.supports_msl_version(2, 2))
13922 SPIRV_CROSS_THROW("Elect on iOS requires Metal 2.2 and up.");
13923 break;
13924 case OpGroupNonUniformAny:
13925 case OpGroupNonUniformAll:
13926 case OpGroupNonUniformAllEqual:
13927 case OpGroupNonUniformBallot:
13928 case OpGroupNonUniformInverseBallot:
13929 case OpGroupNonUniformBallotBitExtract:
13930 case OpGroupNonUniformBallotFindLSB:
13931 case OpGroupNonUniformBallotFindMSB:
13932 case OpGroupNonUniformBallotBitCount:
13933 if (!msl_options.supports_msl_version(2, 2))
13934 SPIRV_CROSS_THROW("Ballot ops on iOS requires Metal 2.2 and up.");
13935 break;
13936 case OpGroupNonUniformBroadcast:
13937 case OpGroupNonUniformShuffle:
13938 case OpGroupNonUniformShuffleXor:
13939 case OpGroupNonUniformShuffleUp:
13940 case OpGroupNonUniformShuffleDown:
13941 case OpGroupNonUniformQuadSwap:
13942 case OpGroupNonUniformQuadBroadcast:
13943 break;
13944 }
13945 }
13946
13947 if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 1))
13948 {
13949 switch (op)
13950 {
13951 default:
13952 SPIRV_CROSS_THROW("Subgroup ops beyond broadcast and shuffle on macOS require Metal 2.1 and up.");
13953 case OpGroupNonUniformBroadcast:
13954 case OpGroupNonUniformShuffle:
13955 case OpGroupNonUniformShuffleXor:
13956 case OpGroupNonUniformShuffleUp:
13957 case OpGroupNonUniformShuffleDown:
13958 break;
13959 }
13960 }
13961
13962 uint32_t result_type = ops[0];
13963 uint32_t id = ops[1];
13964
13965 auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
13966 if (scope != ScopeSubgroup)
13967 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
13968
13969 switch (op)
13970 {
13971 case OpGroupNonUniformElect:
13972 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
13973 emit_op(result_type, id, "quad_is_first()", false);
13974 else
13975 emit_op(result_type, id, "simd_is_first()", false);
13976 break;
13977
13978 case OpGroupNonUniformBroadcast:
13979 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBroadcast");
13980 break;
13981
13982 case OpGroupNonUniformBroadcastFirst:
13983 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBroadcastFirst");
13984 break;
13985
13986 case OpGroupNonUniformBallot:
13987 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupBallot");
13988 break;
13989
13990 case OpGroupNonUniformInverseBallot:
13991 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_invocation_id_id, "spvSubgroupBallotBitExtract");
13992 break;
13993
13994 case OpGroupNonUniformBallotBitExtract:
13995 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupBallotBitExtract");
13996 break;
13997
13998 case OpGroupNonUniformBallotFindLSB:
13999 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindLSB");
14000 break;
14001
14002 case OpGroupNonUniformBallotFindMSB:
14003 emit_binary_func_op(result_type, id, ops[3], builtin_subgroup_size_id, "spvSubgroupBallotFindMSB");
14004 break;
14005
14006 case OpGroupNonUniformBallotBitCount:
14007 {
14008 auto operation = static_cast<GroupOperation>(ops[3]);
14009 switch (operation)
14010 {
14011 case GroupOperationReduce:
14012 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_size_id, "spvSubgroupBallotBitCount");
14013 break;
14014 case GroupOperationInclusiveScan:
14015 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
14016 "spvSubgroupBallotInclusiveBitCount");
14017 break;
14018 case GroupOperationExclusiveScan:
14019 emit_binary_func_op(result_type, id, ops[4], builtin_subgroup_invocation_id_id,
14020 "spvSubgroupBallotExclusiveBitCount");
14021 break;
14022 default:
14023 SPIRV_CROSS_THROW("Invalid BitCount operation.");
14024 }
14025 break;
14026 }
14027
14028 case OpGroupNonUniformShuffle:
14029 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffle");
14030 break;
14031
14032 case OpGroupNonUniformShuffleXor:
14033 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleXor");
14034 break;
14035
14036 case OpGroupNonUniformShuffleUp:
14037 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleUp");
14038 break;
14039
14040 case OpGroupNonUniformShuffleDown:
14041 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvSubgroupShuffleDown");
14042 break;
14043
14044 case OpGroupNonUniformAll:
14045 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
14046 emit_unary_func_op(result_type, id, ops[3], "quad_all");
14047 else
14048 emit_unary_func_op(result_type, id, ops[3], "simd_all");
14049 break;
14050
14051 case OpGroupNonUniformAny:
14052 if (msl_options.is_ios() && !msl_options.ios_use_simdgroup_functions)
14053 emit_unary_func_op(result_type, id, ops[3], "quad_any");
14054 else
14055 emit_unary_func_op(result_type, id, ops[3], "simd_any");
14056 break;
14057
14058 case OpGroupNonUniformAllEqual:
14059 emit_unary_func_op(result_type, id, ops[3], "spvSubgroupAllEqual");
14060 break;
14061
14062 // clang-format off
14063#define MSL_GROUP_OP(op, msl_op) \
14064case OpGroupNonUniform##op: \
14065 { \
14066 auto operation = static_cast<GroupOperation>(ops[3]); \
14067 if (operation == GroupOperationReduce) \
14068 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
14069 else if (operation == GroupOperationInclusiveScan) \
14070 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_inclusive_" #msl_op); \
14071 else if (operation == GroupOperationExclusiveScan) \
14072 emit_unary_func_op(result_type, id, ops[4], "simd_prefix_exclusive_" #msl_op); \
14073 else if (operation == GroupOperationClusteredReduce) \
14074 { \
14075 /* Only cluster sizes of 4 are supported. */ \
14076 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
14077 if (cluster_size != 4) \
14078 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
14079 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
14080 } \
14081 else \
14082 SPIRV_CROSS_THROW("Invalid group operation."); \
14083 break; \
14084 }
14085 MSL_GROUP_OP(FAdd, sum)
14086 MSL_GROUP_OP(FMul, product)
14087 MSL_GROUP_OP(IAdd, sum)
14088 MSL_GROUP_OP(IMul, product)
14089#undef MSL_GROUP_OP
14090 // The others, unfortunately, don't support InclusiveScan or ExclusiveScan.
14091
14092#define MSL_GROUP_OP(op, msl_op) \
14093case OpGroupNonUniform##op: \
14094 { \
14095 auto operation = static_cast<GroupOperation>(ops[3]); \
14096 if (operation == GroupOperationReduce) \
14097 emit_unary_func_op(result_type, id, ops[4], "simd_" #msl_op); \
14098 else if (operation == GroupOperationInclusiveScan) \
14099 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
14100 else if (operation == GroupOperationExclusiveScan) \
14101 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
14102 else if (operation == GroupOperationClusteredReduce) \
14103 { \
14104 /* Only cluster sizes of 4 are supported. */ \
14105 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
14106 if (cluster_size != 4) \
14107 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
14108 emit_unary_func_op(result_type, id, ops[4], "quad_" #msl_op); \
14109 } \
14110 else \
14111 SPIRV_CROSS_THROW("Invalid group operation."); \
14112 break; \
14113 }
14114
14115#define MSL_GROUP_OP_CAST(op, msl_op, type) \
14116case OpGroupNonUniform##op: \
14117 { \
14118 auto operation = static_cast<GroupOperation>(ops[3]); \
14119 if (operation == GroupOperationReduce) \
14120 emit_unary_func_op_cast(result_type, id, ops[4], "simd_" #msl_op, type, type); \
14121 else if (operation == GroupOperationInclusiveScan) \
14122 SPIRV_CROSS_THROW("Metal doesn't support InclusiveScan for OpGroupNonUniform" #op "."); \
14123 else if (operation == GroupOperationExclusiveScan) \
14124 SPIRV_CROSS_THROW("Metal doesn't support ExclusiveScan for OpGroupNonUniform" #op "."); \
14125 else if (operation == GroupOperationClusteredReduce) \
14126 { \
14127 /* Only cluster sizes of 4 are supported. */ \
14128 uint32_t cluster_size = evaluate_constant_u32(ops[5]); \
14129 if (cluster_size != 4) \
14130 SPIRV_CROSS_THROW("Metal only supports quad ClusteredReduce."); \
14131 emit_unary_func_op_cast(result_type, id, ops[4], "quad_" #msl_op, type, type); \
14132 } \
14133 else \
14134 SPIRV_CROSS_THROW("Invalid group operation."); \
14135 break; \
14136 }
14137
14138 MSL_GROUP_OP(FMin, min)
14139 MSL_GROUP_OP(FMax, max)
14140 MSL_GROUP_OP_CAST(SMin, min, int_type)
14141 MSL_GROUP_OP_CAST(SMax, max, int_type)
14142 MSL_GROUP_OP_CAST(UMin, min, uint_type)
14143 MSL_GROUP_OP_CAST(UMax, max, uint_type)
14144 MSL_GROUP_OP(BitwiseAnd, and)
14145 MSL_GROUP_OP(BitwiseOr, or)
14146 MSL_GROUP_OP(BitwiseXor, xor)
14147 MSL_GROUP_OP(LogicalAnd, and)
14148 MSL_GROUP_OP(LogicalOr, or)
14149 MSL_GROUP_OP(LogicalXor, xor)
14150 // clang-format on
14151#undef MSL_GROUP_OP
14152#undef MSL_GROUP_OP_CAST
14153
14154 case OpGroupNonUniformQuadSwap:
14155 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadSwap");
14156 break;
14157
14158 case OpGroupNonUniformQuadBroadcast:
14159 emit_binary_func_op(result_type, id, ops[3], ops[4], "spvQuadBroadcast");
14160 break;
14161
14162 default:
14163 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
14164 }
14165
14166 register_control_dependent_expression(id);
14167}
14168
14169string CompilerMSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
14170{
14171 if (out_type.basetype == in_type.basetype)
14172 return "";
14173
14174 assert(out_type.basetype != SPIRType::Boolean);
14175 assert(in_type.basetype != SPIRType::Boolean);
14176
14177 bool integral_cast = type_is_integral(out_type) && type_is_integral(in_type) && (out_type.vecsize == in_type.vecsize);
14178 bool same_size_cast = (out_type.width * out_type.vecsize) == (in_type.width * in_type.vecsize);
14179
14180 // Bitcasting can only be used between types of the same overall size.
14181 // And always formally cast between integers, because it's trivial, and also
14182 // because Metal can internally cast the results of some integer ops to a larger
14183 // size (eg. short shift right becomes int), which means chaining integer ops
14184 // together may introduce size variations that SPIR-V doesn't know about.
14185 if (same_size_cast && !integral_cast)
14186 {
14187 return "as_type<" + type_to_glsl(out_type) + ">";
14188 }
14189 else
14190 {
14191 return type_to_glsl(out_type);
14192 }
14193}
14194
14195bool CompilerMSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
14196{
14197 return false;
14198}
14199
14200// Returns an MSL string identifying the name of a SPIR-V builtin.
14201// Output builtins are qualified with the name of the stage out structure.
14202string CompilerMSL::builtin_to_glsl(BuiltIn builtin, StorageClass storage)
14203{
14204 switch (builtin)
14205 {
14206 // Handle HLSL-style 0-based vertex/instance index.
14207 // Override GLSL compiler strictness
14208 case BuiltInVertexId:
14209 ensure_builtin(StorageClassInput, BuiltInVertexId);
14210 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
14211 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14212 {
14213 if (builtin_declaration)
14214 {
14215 if (needs_base_vertex_arg != TriState::No)
14216 needs_base_vertex_arg = TriState::Yes;
14217 return "gl_VertexID";
14218 }
14219 else
14220 {
14221 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
14222 return "(gl_VertexID - gl_BaseVertex)";
14223 }
14224 }
14225 else
14226 {
14227 return "gl_VertexID";
14228 }
14229 case BuiltInInstanceId:
14230 ensure_builtin(StorageClassInput, BuiltInInstanceId);
14231 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
14232 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14233 {
14234 if (builtin_declaration)
14235 {
14236 if (needs_base_instance_arg != TriState::No)
14237 needs_base_instance_arg = TriState::Yes;
14238 return "gl_InstanceID";
14239 }
14240 else
14241 {
14242 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
14243 return "(gl_InstanceID - gl_BaseInstance)";
14244 }
14245 }
14246 else
14247 {
14248 return "gl_InstanceID";
14249 }
14250 case BuiltInVertexIndex:
14251 ensure_builtin(StorageClassInput, BuiltInVertexIndex);
14252 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
14253 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14254 {
14255 if (builtin_declaration)
14256 {
14257 if (needs_base_vertex_arg != TriState::No)
14258 needs_base_vertex_arg = TriState::Yes;
14259 return "gl_VertexIndex";
14260 }
14261 else
14262 {
14263 ensure_builtin(StorageClassInput, BuiltInBaseVertex);
14264 return "(gl_VertexIndex - gl_BaseVertex)";
14265 }
14266 }
14267 else
14268 {
14269 return "gl_VertexIndex";
14270 }
14271 case BuiltInInstanceIndex:
14272 ensure_builtin(StorageClassInput, BuiltInInstanceIndex);
14273 if (msl_options.enable_base_index_zero && msl_options.supports_msl_version(1, 1) &&
14274 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14275 {
14276 if (builtin_declaration)
14277 {
14278 if (needs_base_instance_arg != TriState::No)
14279 needs_base_instance_arg = TriState::Yes;
14280 return "gl_InstanceIndex";
14281 }
14282 else
14283 {
14284 ensure_builtin(StorageClassInput, BuiltInBaseInstance);
14285 return "(gl_InstanceIndex - gl_BaseInstance)";
14286 }
14287 }
14288 else
14289 {
14290 return "gl_InstanceIndex";
14291 }
14292 case BuiltInBaseVertex:
14293 if (msl_options.supports_msl_version(1, 1) &&
14294 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14295 {
14296 needs_base_vertex_arg = TriState::No;
14297 return "gl_BaseVertex";
14298 }
14299 else
14300 {
14301 SPIRV_CROSS_THROW("BaseVertex requires Metal 1.1 and Mac or Apple A9+ hardware.");
14302 }
14303 case BuiltInBaseInstance:
14304 if (msl_options.supports_msl_version(1, 1) &&
14305 (msl_options.ios_support_base_vertex_instance || msl_options.is_macos()))
14306 {
14307 needs_base_instance_arg = TriState::No;
14308 return "gl_BaseInstance";
14309 }
14310 else
14311 {
14312 SPIRV_CROSS_THROW("BaseInstance requires Metal 1.1 and Mac or Apple A9+ hardware.");
14313 }
14314 case BuiltInDrawIndex:
14315 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14316
14317 // When used in the entry function, output builtins are qualified with output struct name.
14318 // Test storage class as NOT Input, as output builtins might be part of generic type.
14319 // Also don't do this for tessellation control shaders.
14320 case BuiltInViewportIndex:
14321 if (!msl_options.supports_msl_version(2, 0))
14322 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14323 /* fallthrough */
14324 case BuiltInFragDepth:
14325 case BuiltInFragStencilRefEXT:
14326 if ((builtin == BuiltInFragDepth && !msl_options.enable_frag_depth_builtin) ||
14327 (builtin == BuiltInFragStencilRefEXT && !msl_options.enable_frag_stencil_ref_builtin))
14328 break;
14329 /* fallthrough */
14330 case BuiltInPosition:
14331 case BuiltInPointSize:
14332 case BuiltInClipDistance:
14333 case BuiltInCullDistance:
14334 case BuiltInLayer:
14335 if (get_execution_model() == ExecutionModelTessellationControl)
14336 break;
14337 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14338 !is_stage_output_builtin_masked(builtin))
14339 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14340 break;
14341
14342 case BuiltInSampleMask:
14343 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14344 (has_additional_fixed_sample_mask() || needs_sample_id))
14345 {
14346 string samp_mask_in;
14347 samp_mask_in += "(" + CompilerGLSL::builtin_to_glsl(builtin, storage);
14348 if (has_additional_fixed_sample_mask())
14349 samp_mask_in += " & " + additional_fixed_sample_mask_str();
14350 if (needs_sample_id)
14351 samp_mask_in += " & (1 << gl_SampleID)";
14352 samp_mask_in += ")";
14353 return samp_mask_in;
14354 }
14355 if (storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point) &&
14356 !is_stage_output_builtin_masked(builtin))
14357 return stage_out_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14358 break;
14359
14360 case BuiltInBaryCoordNV:
14361 case BuiltInBaryCoordNoPerspNV:
14362 if (storage == StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14363 return stage_in_var_name + "." + CompilerGLSL::builtin_to_glsl(builtin, storage);
14364 break;
14365
14366 case BuiltInTessLevelOuter:
14367 if (get_execution_model() == ExecutionModelTessellationControl &&
14368 storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14369 {
14370 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
14371 "].edgeTessellationFactor");
14372 }
14373 break;
14374
14375 case BuiltInTessLevelInner:
14376 if (get_execution_model() == ExecutionModelTessellationControl &&
14377 storage != StorageClassInput && current_function && (current_function->self == ir.default_entry_point))
14378 {
14379 return join(tess_factor_buffer_var_name, "[", to_expression(builtin_primitive_id_id),
14380 "].insideTessellationFactor");
14381 }
14382 break;
14383
14384 default:
14385 break;
14386 }
14387
14388 return CompilerGLSL::builtin_to_glsl(builtin, storage);
14389}
14390
14391// Returns an MSL string attribute qualifer for a SPIR-V builtin
14392string CompilerMSL::builtin_qualifier(BuiltIn builtin)
14393{
14394 auto &execution = get_entry_point();
14395
14396 switch (builtin)
14397 {
14398 // Vertex function in
14399 case BuiltInVertexId:
14400 return "vertex_id";
14401 case BuiltInVertexIndex:
14402 return "vertex_id";
14403 case BuiltInBaseVertex:
14404 return "base_vertex";
14405 case BuiltInInstanceId:
14406 return "instance_id";
14407 case BuiltInInstanceIndex:
14408 return "instance_id";
14409 case BuiltInBaseInstance:
14410 return "base_instance";
14411 case BuiltInDrawIndex:
14412 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14413
14414 // Vertex function out
14415 case BuiltInClipDistance:
14416 return "clip_distance";
14417 case BuiltInPointSize:
14418 return "point_size";
14419 case BuiltInPosition:
14420 if (position_invariant)
14421 {
14422 if (!msl_options.supports_msl_version(2, 1))
14423 SPIRV_CROSS_THROW("Invariant position is only supported on MSL 2.1 and up.");
14424 return "position, invariant";
14425 }
14426 else
14427 return "position";
14428 case BuiltInLayer:
14429 return "render_target_array_index";
14430 case BuiltInViewportIndex:
14431 if (!msl_options.supports_msl_version(2, 0))
14432 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14433 return "viewport_array_index";
14434
14435 // Tess. control function in
14436 case BuiltInInvocationId:
14437 if (msl_options.multi_patch_workgroup)
14438 {
14439 // Shouldn't be reached.
14440 SPIRV_CROSS_THROW("InvocationId is computed manually with multi-patch workgroups in MSL.");
14441 }
14442 return "thread_index_in_threadgroup";
14443 case BuiltInPatchVertices:
14444 // Shouldn't be reached.
14445 SPIRV_CROSS_THROW("PatchVertices is derived from the auxiliary buffer in MSL.");
14446 case BuiltInPrimitiveId:
14447 switch (execution.model)
14448 {
14449 case ExecutionModelTessellationControl:
14450 if (msl_options.multi_patch_workgroup)
14451 {
14452 // Shouldn't be reached.
14453 SPIRV_CROSS_THROW("PrimitiveId is computed manually with multi-patch workgroups in MSL.");
14454 }
14455 return "threadgroup_position_in_grid";
14456 case ExecutionModelTessellationEvaluation:
14457 return "patch_id";
14458 case ExecutionModelFragment:
14459 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14460 SPIRV_CROSS_THROW("PrimitiveId on iOS requires MSL 2.3.");
14461 else if (msl_options.is_macos() && !msl_options.supports_msl_version(2, 2))
14462 SPIRV_CROSS_THROW("PrimitiveId on macOS requires MSL 2.2.");
14463 return "primitive_id";
14464 default:
14465 SPIRV_CROSS_THROW("PrimitiveId is not supported in this execution model.");
14466 }
14467
14468 // Tess. control function out
14469 case BuiltInTessLevelOuter:
14470 case BuiltInTessLevelInner:
14471 // Shouldn't be reached.
14472 SPIRV_CROSS_THROW("Tessellation levels are handled specially in MSL.");
14473
14474 // Tess. evaluation function in
14475 case BuiltInTessCoord:
14476 return "position_in_patch";
14477
14478 // Fragment function in
14479 case BuiltInFrontFacing:
14480 return "front_facing";
14481 case BuiltInPointCoord:
14482 return "point_coord";
14483 case BuiltInFragCoord:
14484 return "position";
14485 case BuiltInSampleId:
14486 return "sample_id";
14487 case BuiltInSampleMask:
14488 return "sample_mask";
14489 case BuiltInSamplePosition:
14490 // Shouldn't be reached.
14491 SPIRV_CROSS_THROW("Sample position is retrieved by a function in MSL.");
14492 case BuiltInViewIndex:
14493 if (execution.model != ExecutionModelFragment)
14494 SPIRV_CROSS_THROW("ViewIndex is handled specially outside fragment shaders.");
14495 // The ViewIndex was implicitly used in the prior stages to set the render_target_array_index,
14496 // so we can get it from there.
14497 return "render_target_array_index";
14498
14499 // Fragment function out
14500 case BuiltInFragDepth:
14501 if (execution.flags.get(ExecutionModeDepthGreater))
14502 return "depth(greater)";
14503 else if (execution.flags.get(ExecutionModeDepthLess))
14504 return "depth(less)";
14505 else
14506 return "depth(any)";
14507
14508 case BuiltInFragStencilRefEXT:
14509 return "stencil";
14510
14511 // Compute function in
14512 case BuiltInGlobalInvocationId:
14513 return "thread_position_in_grid";
14514
14515 case BuiltInWorkgroupId:
14516 return "threadgroup_position_in_grid";
14517
14518 case BuiltInNumWorkgroups:
14519 return "threadgroups_per_grid";
14520
14521 case BuiltInLocalInvocationId:
14522 return "thread_position_in_threadgroup";
14523
14524 case BuiltInLocalInvocationIndex:
14525 return "thread_index_in_threadgroup";
14526
14527 case BuiltInSubgroupSize:
14528 if (msl_options.emulate_subgroups || msl_options.fixed_subgroup_size != 0)
14529 // Shouldn't be reached.
14530 SPIRV_CROSS_THROW("Emitting threads_per_simdgroup attribute with fixed subgroup size??");
14531 if (execution.model == ExecutionModelFragment)
14532 {
14533 if (!msl_options.supports_msl_version(2, 2))
14534 SPIRV_CROSS_THROW("threads_per_simdgroup requires Metal 2.2 in fragment shaders.");
14535 return "threads_per_simdgroup";
14536 }
14537 else
14538 {
14539 // thread_execution_width is an alias for threads_per_simdgroup, and it's only available since 1.0,
14540 // but not in fragment.
14541 return "thread_execution_width";
14542 }
14543
14544 case BuiltInNumSubgroups:
14545 if (msl_options.emulate_subgroups)
14546 // Shouldn't be reached.
14547 SPIRV_CROSS_THROW("NumSubgroups is handled specially with emulation.");
14548 if (!msl_options.supports_msl_version(2))
14549 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
14550 return msl_options.is_ios() ? "quadgroups_per_threadgroup" : "simdgroups_per_threadgroup";
14551
14552 case BuiltInSubgroupId:
14553 if (msl_options.emulate_subgroups)
14554 // Shouldn't be reached.
14555 SPIRV_CROSS_THROW("SubgroupId is handled specially with emulation.");
14556 if (!msl_options.supports_msl_version(2))
14557 SPIRV_CROSS_THROW("Subgroup builtins require Metal 2.0.");
14558 return msl_options.is_ios() ? "quadgroup_index_in_threadgroup" : "simdgroup_index_in_threadgroup";
14559
14560 case BuiltInSubgroupLocalInvocationId:
14561 if (msl_options.emulate_subgroups)
14562 // Shouldn't be reached.
14563 SPIRV_CROSS_THROW("SubgroupLocalInvocationId is handled specially with emulation.");
14564 if (execution.model == ExecutionModelFragment)
14565 {
14566 if (!msl_options.supports_msl_version(2, 2))
14567 SPIRV_CROSS_THROW("thread_index_in_simdgroup requires Metal 2.2 in fragment shaders.");
14568 return "thread_index_in_simdgroup";
14569 }
14570 else if (execution.model == ExecutionModelKernel || execution.model == ExecutionModelGLCompute ||
14571 execution.model == ExecutionModelTessellationControl ||
14572 (execution.model == ExecutionModelVertex && msl_options.vertex_for_tessellation))
14573 {
14574 // We are generating a Metal kernel function.
14575 if (!msl_options.supports_msl_version(2))
14576 SPIRV_CROSS_THROW("Subgroup builtins in kernel functions require Metal 2.0.");
14577 return msl_options.is_ios() ? "thread_index_in_quadgroup" : "thread_index_in_simdgroup";
14578 }
14579 else
14580 SPIRV_CROSS_THROW("Subgroup builtins are not available in this type of function.");
14581
14582 case BuiltInSubgroupEqMask:
14583 case BuiltInSubgroupGeMask:
14584 case BuiltInSubgroupGtMask:
14585 case BuiltInSubgroupLeMask:
14586 case BuiltInSubgroupLtMask:
14587 // Shouldn't be reached.
14588 SPIRV_CROSS_THROW("Subgroup ballot masks are handled specially in MSL.");
14589
14590 case BuiltInBaryCoordNV:
14591 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
14592 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14593 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
14594 else if (!msl_options.supports_msl_version(2, 2))
14595 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
14596 return "barycentric_coord, center_perspective";
14597
14598 case BuiltInBaryCoordNoPerspNV:
14599 // TODO: AMD barycentrics as well? Seem to have different swizzle and 2 components rather than 3.
14600 if (msl_options.is_ios() && !msl_options.supports_msl_version(2, 3))
14601 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.3 and above on iOS.");
14602 else if (!msl_options.supports_msl_version(2, 2))
14603 SPIRV_CROSS_THROW("Barycentrics are only supported in MSL 2.2 and above on macOS.");
14604 return "barycentric_coord, center_no_perspective";
14605
14606 default:
14607 return "unsupported-built-in";
14608 }
14609}
14610
14611// Returns an MSL string type declaration for a SPIR-V builtin
14612string CompilerMSL::builtin_type_decl(BuiltIn builtin, uint32_t id)
14613{
14614 const SPIREntryPoint &execution = get_entry_point();
14615 switch (builtin)
14616 {
14617 // Vertex function in
14618 case BuiltInVertexId:
14619 return "uint";
14620 case BuiltInVertexIndex:
14621 return "uint";
14622 case BuiltInBaseVertex:
14623 return "uint";
14624 case BuiltInInstanceId:
14625 return "uint";
14626 case BuiltInInstanceIndex:
14627 return "uint";
14628 case BuiltInBaseInstance:
14629 return "uint";
14630 case BuiltInDrawIndex:
14631 SPIRV_CROSS_THROW("DrawIndex is not supported in MSL.");
14632
14633 // Vertex function out
14634 case BuiltInClipDistance:
14635 case BuiltInCullDistance:
14636 return "float";
14637 case BuiltInPointSize:
14638 return "float";
14639 case BuiltInPosition:
14640 return "float4";
14641 case BuiltInLayer:
14642 return "uint";
14643 case BuiltInViewportIndex:
14644 if (!msl_options.supports_msl_version(2, 0))
14645 SPIRV_CROSS_THROW("ViewportIndex requires Metal 2.0.");
14646 return "uint";
14647
14648 // Tess. control function in
14649 case BuiltInInvocationId:
14650 return "uint";
14651 case BuiltInPatchVertices:
14652 return "uint";
14653 case BuiltInPrimitiveId:
14654 return "uint";
14655
14656 // Tess. control function out
14657 case BuiltInTessLevelInner:
14658 if (execution.model == ExecutionModelTessellationEvaluation)
14659 return !execution.flags.get(ExecutionModeTriangles) ? "float2" : "float";
14660 return "half";
14661 case BuiltInTessLevelOuter:
14662 if (execution.model == ExecutionModelTessellationEvaluation)
14663 return !execution.flags.get(ExecutionModeTriangles) ? "float4" : "float";
14664 return "half";
14665
14666 // Tess. evaluation function in
14667 case BuiltInTessCoord:
14668 return "float3";
14669
14670 // Fragment function in
14671 case BuiltInFrontFacing:
14672 return "bool";
14673 case BuiltInPointCoord:
14674 return "float2";
14675 case BuiltInFragCoord:
14676 return "float4";
14677 case BuiltInSampleId:
14678 return "uint";
14679 case BuiltInSampleMask:
14680 return "uint";
14681 case BuiltInSamplePosition:
14682 return "float2";
14683 case BuiltInViewIndex:
14684 return "uint";
14685
14686 case BuiltInHelperInvocation:
14687 return "bool";
14688
14689 case BuiltInBaryCoordNV:
14690 case BuiltInBaryCoordNoPerspNV:
14691 // Use the type as declared, can be 1, 2 or 3 components.
14692 return type_to_glsl(get_variable_data_type(get<SPIRVariable>(id)));
14693
14694 // Fragment function out
14695 case BuiltInFragDepth:
14696 return "float";
14697
14698 case BuiltInFragStencilRefEXT:
14699 return "uint";
14700
14701 // Compute function in
14702 case BuiltInGlobalInvocationId:
14703 case BuiltInLocalInvocationId:
14704 case BuiltInNumWorkgroups:
14705 case BuiltInWorkgroupId:
14706 return "uint3";
14707 case BuiltInLocalInvocationIndex:
14708 case BuiltInNumSubgroups:
14709 case BuiltInSubgroupId:
14710 case BuiltInSubgroupSize:
14711 case BuiltInSubgroupLocalInvocationId:
14712 return "uint";
14713 case BuiltInSubgroupEqMask:
14714 case BuiltInSubgroupGeMask:
14715 case BuiltInSubgroupGtMask:
14716 case BuiltInSubgroupLeMask:
14717 case BuiltInSubgroupLtMask:
14718 return "uint4";
14719
14720 case BuiltInDeviceIndex:
14721 return "int";
14722
14723 default:
14724 return "unsupported-built-in-type";
14725 }
14726}
14727
14728// Returns the declaration of a built-in argument to a function
14729string CompilerMSL::built_in_func_arg(BuiltIn builtin, bool prefix_comma)
14730{
14731 string bi_arg;
14732 if (prefix_comma)
14733 bi_arg += ", ";
14734
14735 // Handle HLSL-style 0-based vertex/instance index.
14736 builtin_declaration = true;
14737 bi_arg += builtin_type_decl(builtin);
14738 bi_arg += " " + builtin_to_glsl(builtin, StorageClassInput);
14739 bi_arg += " [[" + builtin_qualifier(builtin) + "]]";
14740 builtin_declaration = false;
14741
14742 return bi_arg;
14743}
14744
14745const SPIRType &CompilerMSL::get_physical_member_type(const SPIRType &type, uint32_t index) const
14746{
14747 if (member_is_remapped_physical_type(type, index))
14748 return get<SPIRType>(get_extended_member_decoration(type.self, index, SPIRVCrossDecorationPhysicalTypeID));
14749 else
14750 return get<SPIRType>(type.member_types[index]);
14751}
14752
14753SPIRType CompilerMSL::get_presumed_input_type(const SPIRType &ib_type, uint32_t index) const
14754{
14755 SPIRType type = get_physical_member_type(ib_type, index);
14756 uint32_t loc = get_member_decoration(ib_type.self, index, DecorationLocation);
14757 uint32_t cmp = get_member_decoration(ib_type.self, index, DecorationComponent);
14758 auto p_va = inputs_by_location.find({loc, cmp});
14759 if (p_va != end(inputs_by_location) && p_va->second.vecsize > type.vecsize)
14760 type.vecsize = p_va->second.vecsize;
14761
14762 return type;
14763}
14764
14765uint32_t CompilerMSL::get_declared_type_array_stride_msl(const SPIRType &type, bool is_packed, bool row_major) const
14766{
14767 // Array stride in MSL is always size * array_size. sizeof(float3) == 16,
14768 // unlike GLSL and HLSL where array stride would be 16 and size 12.
14769
14770 // We could use parent type here and recurse, but that makes creating physical type remappings
14771 // far more complicated. We'd rather just create the final type, and ignore having to create the entire type
14772 // hierarchy in order to compute this value, so make a temporary type on the stack.
14773
14774 auto basic_type = type;
14775 basic_type.array.clear();
14776 basic_type.array_size_literal.clear();
14777 uint32_t value_size = get_declared_type_size_msl(basic_type, is_packed, row_major);
14778
14779 uint32_t dimensions = uint32_t(type.array.size());
14780 assert(dimensions > 0);
14781 dimensions--;
14782
14783 // Multiply together every dimension, except the last one.
14784 for (uint32_t dim = 0; dim < dimensions; dim++)
14785 {
14786 uint32_t array_size = to_array_size_literal(type, dim);
14787 value_size *= max(array_size, 1u);
14788 }
14789
14790 return value_size;
14791}
14792
14793uint32_t CompilerMSL::get_declared_struct_member_array_stride_msl(const SPIRType &type, uint32_t index) const
14794{
14795 return get_declared_type_array_stride_msl(get_physical_member_type(type, index),
14796 member_is_packed_physical_type(type, index),
14797 has_member_decoration(type.self, index, DecorationRowMajor));
14798}
14799
14800uint32_t CompilerMSL::get_declared_input_array_stride_msl(const SPIRType &type, uint32_t index) const
14801{
14802 return get_declared_type_array_stride_msl(get_presumed_input_type(type, index), false,
14803 has_member_decoration(type.self, index, DecorationRowMajor));
14804}
14805
14806uint32_t CompilerMSL::get_declared_type_matrix_stride_msl(const SPIRType &type, bool packed, bool row_major) const
14807{
14808 // For packed matrices, we just use the size of the vector type.
14809 // Otherwise, MatrixStride == alignment, which is the size of the underlying vector type.
14810 if (packed)
14811 return (type.width / 8) * ((row_major && type.columns > 1) ? type.columns : type.vecsize);
14812 else
14813 return get_declared_type_alignment_msl(type, false, row_major);
14814}
14815
14816uint32_t CompilerMSL::get_declared_struct_member_matrix_stride_msl(const SPIRType &type, uint32_t index) const
14817{
14818 return get_declared_type_matrix_stride_msl(get_physical_member_type(type, index),
14819 member_is_packed_physical_type(type, index),
14820 has_member_decoration(type.self, index, DecorationRowMajor));
14821}
14822
14823uint32_t CompilerMSL::get_declared_input_matrix_stride_msl(const SPIRType &type, uint32_t index) const
14824{
14825 return get_declared_type_matrix_stride_msl(get_presumed_input_type(type, index), false,
14826 has_member_decoration(type.self, index, DecorationRowMajor));
14827}
14828
14829uint32_t CompilerMSL::get_declared_struct_size_msl(const SPIRType &struct_type, bool ignore_alignment,
14830 bool ignore_padding) const
14831{
14832 // If we have a target size, that is the declared size as well.
14833 if (!ignore_padding && has_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget))
14834 return get_extended_decoration(struct_type.self, SPIRVCrossDecorationPaddingTarget);
14835
14836 if (struct_type.member_types.empty())
14837 return 0;
14838
14839 uint32_t mbr_cnt = uint32_t(struct_type.member_types.size());
14840
14841 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
14842 uint32_t alignment = 1;
14843
14844 if (!ignore_alignment)
14845 {
14846 for (uint32_t i = 0; i < mbr_cnt; i++)
14847 {
14848 uint32_t mbr_alignment = get_declared_struct_member_alignment_msl(struct_type, i);
14849 alignment = max(alignment, mbr_alignment);
14850 }
14851 }
14852
14853 // Last member will always be matched to the final Offset decoration, but size of struct in MSL now depends
14854 // on physical size in MSL, and the size of the struct itself is then aligned to struct alignment.
14855 uint32_t spirv_offset = type_struct_member_offset(struct_type, mbr_cnt - 1);
14856 uint32_t msl_size = spirv_offset + get_declared_struct_member_size_msl(struct_type, mbr_cnt - 1);
14857 msl_size = (msl_size + alignment - 1) & ~(alignment - 1);
14858 return msl_size;
14859}
14860
14861// Returns the byte size of a struct member.
14862uint32_t CompilerMSL::get_declared_type_size_msl(const SPIRType &type, bool is_packed, bool row_major) const
14863{
14864 switch (type.basetype)
14865 {
14866 case SPIRType::Unknown:
14867 case SPIRType::Void:
14868 case SPIRType::AtomicCounter:
14869 case SPIRType::Image:
14870 case SPIRType::SampledImage:
14871 case SPIRType::Sampler:
14872 SPIRV_CROSS_THROW("Querying size of opaque object.");
14873
14874 default:
14875 {
14876 if (!type.array.empty())
14877 {
14878 uint32_t array_size = to_array_size_literal(type);
14879 return get_declared_type_array_stride_msl(type, is_packed, row_major) * max(array_size, 1u);
14880 }
14881
14882 if (type.basetype == SPIRType::Struct)
14883 return get_declared_struct_size_msl(type);
14884
14885 if (is_packed)
14886 {
14887 return type.vecsize * type.columns * (type.width / 8);
14888 }
14889 else
14890 {
14891 // An unpacked 3-element vector or matrix column is the same memory size as a 4-element.
14892 uint32_t vecsize = type.vecsize;
14893 uint32_t columns = type.columns;
14894
14895 if (row_major && columns > 1)
14896 swap(vecsize, columns);
14897
14898 if (vecsize == 3)
14899 vecsize = 4;
14900
14901 return vecsize * columns * (type.width / 8);
14902 }
14903 }
14904 }
14905}
14906
14907uint32_t CompilerMSL::get_declared_struct_member_size_msl(const SPIRType &type, uint32_t index) const
14908{
14909 return get_declared_type_size_msl(get_physical_member_type(type, index),
14910 member_is_packed_physical_type(type, index),
14911 has_member_decoration(type.self, index, DecorationRowMajor));
14912}
14913
14914uint32_t CompilerMSL::get_declared_input_size_msl(const SPIRType &type, uint32_t index) const
14915{
14916 return get_declared_type_size_msl(get_presumed_input_type(type, index), false,
14917 has_member_decoration(type.self, index, DecorationRowMajor));
14918}
14919
14920// Returns the byte alignment of a type.
14921uint32_t CompilerMSL::get_declared_type_alignment_msl(const SPIRType &type, bool is_packed, bool row_major) const
14922{
14923 switch (type.basetype)
14924 {
14925 case SPIRType::Unknown:
14926 case SPIRType::Void:
14927 case SPIRType::AtomicCounter:
14928 case SPIRType::Image:
14929 case SPIRType::SampledImage:
14930 case SPIRType::Sampler:
14931 SPIRV_CROSS_THROW("Querying alignment of opaque object.");
14932
14933 case SPIRType::Double:
14934 SPIRV_CROSS_THROW("double types are not supported in buffers in MSL.");
14935
14936 case SPIRType::Struct:
14937 {
14938 // In MSL, a struct's alignment is equal to the maximum alignment of any of its members.
14939 uint32_t alignment = 1;
14940 for (uint32_t i = 0; i < type.member_types.size(); i++)
14941 alignment = max(alignment, uint32_t(get_declared_struct_member_alignment_msl(type, i)));
14942 return alignment;
14943 }
14944
14945 default:
14946 {
14947 if (type.basetype == SPIRType::Int64 && !msl_options.supports_msl_version(2, 3))
14948 SPIRV_CROSS_THROW("long types in buffers are only supported in MSL 2.3 and above.");
14949 if (type.basetype == SPIRType::UInt64 && !msl_options.supports_msl_version(2, 3))
14950 SPIRV_CROSS_THROW("ulong types in buffers are only supported in MSL 2.3 and above.");
14951 // Alignment of packed type is the same as the underlying component or column size.
14952 // Alignment of unpacked type is the same as the vector size.
14953 // Alignment of 3-elements vector is the same as 4-elements (including packed using column).
14954 if (is_packed)
14955 {
14956 // If we have packed_T and friends, the alignment is always scalar.
14957 return type.width / 8;
14958 }
14959 else
14960 {
14961 // This is the general rule for MSL. Size == alignment.
14962 uint32_t vecsize = (row_major && type.columns > 1) ? type.columns : type.vecsize;
14963 return (type.width / 8) * (vecsize == 3 ? 4 : vecsize);
14964 }
14965 }
14966 }
14967}
14968
14969uint32_t CompilerMSL::get_declared_struct_member_alignment_msl(const SPIRType &type, uint32_t index) const
14970{
14971 return get_declared_type_alignment_msl(get_physical_member_type(type, index),
14972 member_is_packed_physical_type(type, index),
14973 has_member_decoration(type.self, index, DecorationRowMajor));
14974}
14975
14976uint32_t CompilerMSL::get_declared_input_alignment_msl(const SPIRType &type, uint32_t index) const
14977{
14978 return get_declared_type_alignment_msl(get_presumed_input_type(type, index), false,
14979 has_member_decoration(type.self, index, DecorationRowMajor));
14980}
14981
14982bool CompilerMSL::skip_argument(uint32_t) const
14983{
14984 return false;
14985}
14986
14987void CompilerMSL::analyze_sampled_image_usage()
14988{
14989 if (msl_options.swizzle_texture_samples)
14990 {
14991 SampledImageScanner scanner(*this);
14992 traverse_all_reachable_opcodes(get<SPIRFunction>(ir.default_entry_point), scanner);
14993 }
14994}
14995
14996bool CompilerMSL::SampledImageScanner::handle(spv::Op opcode, const uint32_t *args, uint32_t length)
14997{
14998 switch (opcode)
14999 {
15000 case OpLoad:
15001 case OpImage:
15002 case OpSampledImage:
15003 {
15004 if (length < 3)
15005 return false;
15006
15007 uint32_t result_type = args[0];
15008 auto &type = compiler.get<SPIRType>(result_type);
15009 if ((type.basetype != SPIRType::Image && type.basetype != SPIRType::SampledImage) || type.image.sampled != 1)
15010 return true;
15011
15012 uint32_t id = args[1];
15013 compiler.set<SPIRExpression>(id, "", result_type, true);
15014 break;
15015 }
15016 case OpImageSampleExplicitLod:
15017 case OpImageSampleProjExplicitLod:
15018 case OpImageSampleDrefExplicitLod:
15019 case OpImageSampleProjDrefExplicitLod:
15020 case OpImageSampleImplicitLod:
15021 case OpImageSampleProjImplicitLod:
15022 case OpImageSampleDrefImplicitLod:
15023 case OpImageSampleProjDrefImplicitLod:
15024 case OpImageFetch:
15025 case OpImageGather:
15026 case OpImageDrefGather:
15027 compiler.has_sampled_images =
15028 compiler.has_sampled_images || compiler.is_sampled_image_type(compiler.expression_type(args[2]));
15029 compiler.needs_swizzle_buffer_def = compiler.needs_swizzle_buffer_def || compiler.has_sampled_images;
15030 break;
15031 default:
15032 break;
15033 }
15034 return true;
15035}
15036
15037// If a needed custom function wasn't added before, add it and force a recompile.
15038void CompilerMSL::add_spv_func_and_recompile(SPVFuncImpl spv_func)
15039{
15040 if (spv_function_implementations.count(spv_func) == 0)
15041 {
15042 spv_function_implementations.insert(spv_func);
15043 suppress_missing_prototypes = true;
15044 force_recompile();
15045 }
15046}
15047
15048bool CompilerMSL::OpCodePreprocessor::handle(Op opcode, const uint32_t *args, uint32_t length)
15049{
15050 // Since MSL exists in a single execution scope, function prototype declarations are not
15051 // needed, and clutter the output. If secondary functions are output (either as a SPIR-V
15052 // function implementation or as indicated by the presence of OpFunctionCall), then set
15053 // suppress_missing_prototypes to suppress compiler warnings of missing function prototypes.
15054
15055 // Mark if the input requires the implementation of an SPIR-V function that does not exist in Metal.
15056 SPVFuncImpl spv_func = get_spv_func_impl(opcode, args);
15057 if (spv_func != SPVFuncImplNone)
15058 {
15059 compiler.spv_function_implementations.insert(spv_func);
15060 suppress_missing_prototypes = true;
15061 }
15062
15063 switch (opcode)
15064 {
15065
15066 case OpFunctionCall:
15067 suppress_missing_prototypes = true;
15068 break;
15069
15070 // Emulate texture2D atomic operations
15071 case OpImageTexelPointer:
15072 {
15073 auto *var = compiler.maybe_get_backing_variable(args[2]);
15074 image_pointers[args[1]] = var ? var->self : ID(0);
15075 break;
15076 }
15077
15078 case OpImageWrite:
15079 if (!compiler.msl_options.supports_msl_version(2, 2))
15080 uses_resource_write = true;
15081 break;
15082
15083 case OpStore:
15084 check_resource_write(args[0]);
15085 break;
15086
15087 // Emulate texture2D atomic operations
15088 case OpAtomicExchange:
15089 case OpAtomicCompareExchange:
15090 case OpAtomicCompareExchangeWeak:
15091 case OpAtomicIIncrement:
15092 case OpAtomicIDecrement:
15093 case OpAtomicIAdd:
15094 case OpAtomicISub:
15095 case OpAtomicSMin:
15096 case OpAtomicUMin:
15097 case OpAtomicSMax:
15098 case OpAtomicUMax:
15099 case OpAtomicAnd:
15100 case OpAtomicOr:
15101 case OpAtomicXor:
15102 {
15103 uses_atomics = true;
15104 auto it = image_pointers.find(args[2]);
15105 if (it != image_pointers.end())
15106 {
15107 compiler.atomic_image_vars.insert(it->second);
15108 }
15109 check_resource_write(args[2]);
15110 break;
15111 }
15112
15113 case OpAtomicStore:
15114 {
15115 uses_atomics = true;
15116 auto it = image_pointers.find(args[0]);
15117 if (it != image_pointers.end())
15118 {
15119 compiler.atomic_image_vars.insert(it->second);
15120 }
15121 check_resource_write(args[0]);
15122 break;
15123 }
15124
15125 case OpAtomicLoad:
15126 {
15127 uses_atomics = true;
15128 auto it = image_pointers.find(args[2]);
15129 if (it != image_pointers.end())
15130 {
15131 compiler.atomic_image_vars.insert(it->second);
15132 }
15133 break;
15134 }
15135
15136 case OpGroupNonUniformInverseBallot:
15137 needs_subgroup_invocation_id = true;
15138 break;
15139
15140 case OpGroupNonUniformBallotFindLSB:
15141 case OpGroupNonUniformBallotFindMSB:
15142 needs_subgroup_size = true;
15143 break;
15144
15145 case OpGroupNonUniformBallotBitCount:
15146 if (args[3] == GroupOperationReduce)
15147 needs_subgroup_size = true;
15148 else
15149 needs_subgroup_invocation_id = true;
15150 break;
15151
15152 case OpArrayLength:
15153 {
15154 auto *var = compiler.maybe_get_backing_variable(args[2]);
15155 if (var)
15156 compiler.buffers_requiring_array_length.insert(var->self);
15157 break;
15158 }
15159
15160 case OpInBoundsAccessChain:
15161 case OpAccessChain:
15162 case OpPtrAccessChain:
15163 {
15164 // OpArrayLength might want to know if taking ArrayLength of an array of SSBOs.
15165 uint32_t result_type = args[0];
15166 uint32_t id = args[1];
15167 uint32_t ptr = args[2];
15168
15169 compiler.set<SPIRExpression>(id, "", result_type, true);
15170 compiler.register_read(id, ptr, true);
15171 compiler.ir.ids[id].set_allow_type_rewrite();
15172 break;
15173 }
15174
15175 case OpExtInst:
15176 {
15177 uint32_t extension_set = args[2];
15178 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
15179 {
15180 auto op_450 = static_cast<GLSLstd450>(args[3]);
15181 switch (op_450)
15182 {
15183 case GLSLstd450InterpolateAtCentroid:
15184 case GLSLstd450InterpolateAtSample:
15185 case GLSLstd450InterpolateAtOffset:
15186 {
15187 if (!compiler.msl_options.supports_msl_version(2, 3))
15188 SPIRV_CROSS_THROW("Pull-model interpolation requires MSL 2.3.");
15189 // Fragment varyings used with pull-model interpolation need special handling,
15190 // due to the way pull-model interpolation works in Metal.
15191 auto *var = compiler.maybe_get_backing_variable(args[4]);
15192 if (var)
15193 {
15194 compiler.pull_model_inputs.insert(var->self);
15195 auto &var_type = compiler.get_variable_element_type(*var);
15196 // In addition, if this variable has a 'Sample' decoration, we need the sample ID
15197 // in order to do default interpolation.
15198 if (compiler.has_decoration(var->self, DecorationSample))
15199 {
15200 needs_sample_id = true;
15201 }
15202 else if (var_type.basetype == SPIRType::Struct)
15203 {
15204 // Now we need to check each member and see if it has this decoration.
15205 for (uint32_t i = 0; i < var_type.member_types.size(); ++i)
15206 {
15207 if (compiler.has_member_decoration(var_type.self, i, DecorationSample))
15208 {
15209 needs_sample_id = true;
15210 break;
15211 }
15212 }
15213 }
15214 }
15215 break;
15216 }
15217 default:
15218 break;
15219 }
15220 }
15221 break;
15222 }
15223
15224 default:
15225 break;
15226 }
15227
15228 // If it has one, keep track of the instruction's result type, mapped by ID
15229 uint32_t result_type, result_id;
15230 if (compiler.instruction_to_result_type(result_type, result_id, opcode, args, length))
15231 result_types[result_id] = result_type;
15232
15233 return true;
15234}
15235
15236// If the variable is a Uniform or StorageBuffer, mark that a resource has been written to.
15237void CompilerMSL::OpCodePreprocessor::check_resource_write(uint32_t var_id)
15238{
15239 auto *p_var = compiler.maybe_get_backing_variable(var_id);
15240 StorageClass sc = p_var ? p_var->storage : StorageClassMax;
15241 if (!compiler.msl_options.supports_msl_version(2, 1) &&
15242 (sc == StorageClassUniform || sc == StorageClassStorageBuffer))
15243 uses_resource_write = true;
15244}
15245
15246// Returns an enumeration of a SPIR-V function that needs to be output for certain Op codes.
15247CompilerMSL::SPVFuncImpl CompilerMSL::OpCodePreprocessor::get_spv_func_impl(Op opcode, const uint32_t *args)
15248{
15249 switch (opcode)
15250 {
15251 case OpFMod:
15252 return SPVFuncImplMod;
15253
15254 case OpFAdd:
15255 case OpFSub:
15256 if (compiler.msl_options.invariant_float_math ||
15257 compiler.has_decoration(args[1], DecorationNoContraction))
15258 {
15259 return opcode == OpFAdd ? SPVFuncImplFAdd : SPVFuncImplFSub;
15260 }
15261 break;
15262
15263 case OpFMul:
15264 case OpOuterProduct:
15265 case OpMatrixTimesVector:
15266 case OpVectorTimesMatrix:
15267 case OpMatrixTimesMatrix:
15268 if (compiler.msl_options.invariant_float_math ||
15269 compiler.has_decoration(args[1], DecorationNoContraction))
15270 {
15271 return SPVFuncImplFMul;
15272 }
15273 break;
15274
15275 case OpQuantizeToF16:
15276 return SPVFuncImplQuantizeToF16;
15277
15278 case OpTypeArray:
15279 {
15280 // Allow Metal to use the array<T> template to make arrays a value type
15281 return SPVFuncImplUnsafeArray;
15282 }
15283
15284 // Emulate texture2D atomic operations
15285 case OpAtomicExchange:
15286 case OpAtomicCompareExchange:
15287 case OpAtomicCompareExchangeWeak:
15288 case OpAtomicIIncrement:
15289 case OpAtomicIDecrement:
15290 case OpAtomicIAdd:
15291 case OpAtomicISub:
15292 case OpAtomicSMin:
15293 case OpAtomicUMin:
15294 case OpAtomicSMax:
15295 case OpAtomicUMax:
15296 case OpAtomicAnd:
15297 case OpAtomicOr:
15298 case OpAtomicXor:
15299 case OpAtomicLoad:
15300 case OpAtomicStore:
15301 {
15302 auto it = image_pointers.find(args[opcode == OpAtomicStore ? 0 : 2]);
15303 if (it != image_pointers.end())
15304 {
15305 uint32_t tid = compiler.get<SPIRVariable>(it->second).basetype;
15306 if (tid && compiler.get<SPIRType>(tid).image.dim == Dim2D)
15307 return SPVFuncImplImage2DAtomicCoords;
15308 }
15309 break;
15310 }
15311
15312 case OpImageFetch:
15313 case OpImageRead:
15314 case OpImageWrite:
15315 {
15316 // Retrieve the image type, and if it's a Buffer, emit a texel coordinate function
15317 uint32_t tid = result_types[args[opcode == OpImageWrite ? 0 : 2]];
15318 if (tid && compiler.get<SPIRType>(tid).image.dim == DimBuffer && !compiler.msl_options.texture_buffer_native)
15319 return SPVFuncImplTexelBufferCoords;
15320 break;
15321 }
15322
15323 case OpExtInst:
15324 {
15325 uint32_t extension_set = args[2];
15326 if (compiler.get<SPIRExtension>(extension_set).ext == SPIRExtension::GLSL)
15327 {
15328 auto op_450 = static_cast<GLSLstd450>(args[3]);
15329 switch (op_450)
15330 {
15331 case GLSLstd450Radians:
15332 return SPVFuncImplRadians;
15333 case GLSLstd450Degrees:
15334 return SPVFuncImplDegrees;
15335 case GLSLstd450FindILsb:
15336 return SPVFuncImplFindILsb;
15337 case GLSLstd450FindSMsb:
15338 return SPVFuncImplFindSMsb;
15339 case GLSLstd450FindUMsb:
15340 return SPVFuncImplFindUMsb;
15341 case GLSLstd450SSign:
15342 return SPVFuncImplSSign;
15343 case GLSLstd450Reflect:
15344 {
15345 auto &type = compiler.get<SPIRType>(args[0]);
15346 if (type.vecsize == 1)
15347 return SPVFuncImplReflectScalar;
15348 break;
15349 }
15350 case GLSLstd450Refract:
15351 {
15352 auto &type = compiler.get<SPIRType>(args[0]);
15353 if (type.vecsize == 1)
15354 return SPVFuncImplRefractScalar;
15355 break;
15356 }
15357 case GLSLstd450FaceForward:
15358 {
15359 auto &type = compiler.get<SPIRType>(args[0]);
15360 if (type.vecsize == 1)
15361 return SPVFuncImplFaceForwardScalar;
15362 break;
15363 }
15364 case GLSLstd450MatrixInverse:
15365 {
15366 auto &mat_type = compiler.get<SPIRType>(args[0]);
15367 switch (mat_type.columns)
15368 {
15369 case 2:
15370 return SPVFuncImplInverse2x2;
15371 case 3:
15372 return SPVFuncImplInverse3x3;
15373 case 4:
15374 return SPVFuncImplInverse4x4;
15375 default:
15376 break;
15377 }
15378 break;
15379 }
15380 default:
15381 break;
15382 }
15383 }
15384 break;
15385 }
15386
15387 case OpGroupNonUniformBroadcast:
15388 return SPVFuncImplSubgroupBroadcast;
15389
15390 case OpGroupNonUniformBroadcastFirst:
15391 return SPVFuncImplSubgroupBroadcastFirst;
15392
15393 case OpGroupNonUniformBallot:
15394 return SPVFuncImplSubgroupBallot;
15395
15396 case OpGroupNonUniformInverseBallot:
15397 case OpGroupNonUniformBallotBitExtract:
15398 return SPVFuncImplSubgroupBallotBitExtract;
15399
15400 case OpGroupNonUniformBallotFindLSB:
15401 return SPVFuncImplSubgroupBallotFindLSB;
15402
15403 case OpGroupNonUniformBallotFindMSB:
15404 return SPVFuncImplSubgroupBallotFindMSB;
15405
15406 case OpGroupNonUniformBallotBitCount:
15407 return SPVFuncImplSubgroupBallotBitCount;
15408
15409 case OpGroupNonUniformAllEqual:
15410 return SPVFuncImplSubgroupAllEqual;
15411
15412 case OpGroupNonUniformShuffle:
15413 return SPVFuncImplSubgroupShuffle;
15414
15415 case OpGroupNonUniformShuffleXor:
15416 return SPVFuncImplSubgroupShuffleXor;
15417
15418 case OpGroupNonUniformShuffleUp:
15419 return SPVFuncImplSubgroupShuffleUp;
15420
15421 case OpGroupNonUniformShuffleDown:
15422 return SPVFuncImplSubgroupShuffleDown;
15423
15424 case OpGroupNonUniformQuadBroadcast:
15425 return SPVFuncImplQuadBroadcast;
15426
15427 case OpGroupNonUniformQuadSwap:
15428 return SPVFuncImplQuadSwap;
15429
15430 default:
15431 break;
15432 }
15433 return SPVFuncImplNone;
15434}
15435
15436// Sort both type and meta member content based on builtin status (put builtins at end),
15437// then by the required sorting aspect.
15438void CompilerMSL::MemberSorter::sort()
15439{
15440 // Create a temporary array of consecutive member indices and sort it based on how
15441 // the members should be reordered, based on builtin and sorting aspect meta info.
15442 size_t mbr_cnt = type.member_types.size();
15443 SmallVector<uint32_t> mbr_idxs(mbr_cnt);
15444 std::iota(mbr_idxs.begin(), mbr_idxs.end(), 0); // Fill with consecutive indices
15445 std::stable_sort(mbr_idxs.begin(), mbr_idxs.end(), *this); // Sort member indices based on sorting aspect
15446
15447 bool sort_is_identity = true;
15448 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
15449 {
15450 if (mbr_idx != mbr_idxs[mbr_idx])
15451 {
15452 sort_is_identity = false;
15453 break;
15454 }
15455 }
15456
15457 if (sort_is_identity)
15458 return;
15459
15460 if (meta.members.size() < type.member_types.size())
15461 {
15462 // This should never trigger in normal circumstances, but to be safe.
15463 meta.members.resize(type.member_types.size());
15464 }
15465
15466 // Move type and meta member info to the order defined by the sorted member indices.
15467 // This is done by creating temporary copies of both member types and meta, and then
15468 // copying back to the original content at the sorted indices.
15469 auto mbr_types_cpy = type.member_types;
15470 auto mbr_meta_cpy = meta.members;
15471 for (uint32_t mbr_idx = 0; mbr_idx < mbr_cnt; mbr_idx++)
15472 {
15473 type.member_types[mbr_idx] = mbr_types_cpy[mbr_idxs[mbr_idx]];
15474 meta.members[mbr_idx] = mbr_meta_cpy[mbr_idxs[mbr_idx]];
15475 }
15476
15477 // If we're sorting by Offset, this might affect user code which accesses a buffer block.
15478 // We will need to redirect member indices from defined index to sorted index using reverse lookup.
15479 if (sort_aspect == SortAspect::Offset)
15480 {
15481 type.member_type_index_redirection.resize(mbr_cnt);
15482 for (uint32_t map_idx = 0; map_idx < mbr_cnt; map_idx++)
15483 type.member_type_index_redirection[mbr_idxs[map_idx]] = map_idx;
15484 }
15485}
15486
15487bool CompilerMSL::MemberSorter::operator()(uint32_t mbr_idx1, uint32_t mbr_idx2)
15488{
15489 auto &mbr_meta1 = meta.members[mbr_idx1];
15490 auto &mbr_meta2 = meta.members[mbr_idx2];
15491
15492 if (sort_aspect == LocationThenBuiltInType)
15493 {
15494 // Sort first by builtin status (put builtins at end), then by the sorting aspect.
15495 if (mbr_meta1.builtin != mbr_meta2.builtin)
15496 return mbr_meta2.builtin;
15497 else if (mbr_meta1.builtin)
15498 return mbr_meta1.builtin_type < mbr_meta2.builtin_type;
15499 else if (mbr_meta1.location == mbr_meta2.location)
15500 return mbr_meta1.component < mbr_meta2.component;
15501 else
15502 return mbr_meta1.location < mbr_meta2.location;
15503 }
15504 else
15505 return mbr_meta1.offset < mbr_meta2.offset;
15506}
15507
15508CompilerMSL::MemberSorter::MemberSorter(SPIRType &t, Meta &m, SortAspect sa)
15509 : type(t)
15510 , meta(m)
15511 , sort_aspect(sa)
15512{
15513 // Ensure enough meta info is available
15514 meta.members.resize(max(type.member_types.size(), meta.members.size()));
15515}
15516
15517void CompilerMSL::remap_constexpr_sampler(VariableID id, const MSLConstexprSampler &sampler)
15518{
15519 auto &type = get<SPIRType>(get<SPIRVariable>(id).basetype);
15520 if (type.basetype != SPIRType::SampledImage && type.basetype != SPIRType::Sampler)
15521 SPIRV_CROSS_THROW("Can only remap SampledImage and Sampler type.");
15522 if (!type.array.empty())
15523 SPIRV_CROSS_THROW("Can not remap array of samplers.");
15524 constexpr_samplers_by_id[id] = sampler;
15525}
15526
15527void CompilerMSL::remap_constexpr_sampler_by_binding(uint32_t desc_set, uint32_t binding,
15528 const MSLConstexprSampler &sampler)
15529{
15530 constexpr_samplers_by_binding[{ desc_set, binding }] = sampler;
15531}
15532
15533void CompilerMSL::cast_from_variable_load(uint32_t source_id, std::string &expr, const SPIRType &expr_type)
15534{
15535 auto *var = maybe_get_backing_variable(source_id);
15536 if (var)
15537 source_id = var->self;
15538
15539 // Type fixups for workgroup variables if they are booleans.
15540 if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
15541 expr = join(type_to_glsl(expr_type), "(", expr, ")");
15542
15543 // Only interested in standalone builtin variables.
15544 if (!has_decoration(source_id, DecorationBuiltIn))
15545 return;
15546
15547 auto builtin = static_cast<BuiltIn>(get_decoration(source_id, DecorationBuiltIn));
15548 auto expected_type = expr_type.basetype;
15549 auto expected_width = expr_type.width;
15550 switch (builtin)
15551 {
15552 case BuiltInGlobalInvocationId:
15553 case BuiltInLocalInvocationId:
15554 case BuiltInWorkgroupId:
15555 case BuiltInLocalInvocationIndex:
15556 case BuiltInWorkgroupSize:
15557 case BuiltInNumWorkgroups:
15558 case BuiltInLayer:
15559 case BuiltInViewportIndex:
15560 case BuiltInFragStencilRefEXT:
15561 case BuiltInPrimitiveId:
15562 case BuiltInSubgroupSize:
15563 case BuiltInSubgroupLocalInvocationId:
15564 case BuiltInViewIndex:
15565 case BuiltInVertexIndex:
15566 case BuiltInInstanceIndex:
15567 case BuiltInBaseInstance:
15568 case BuiltInBaseVertex:
15569 expected_type = SPIRType::UInt;
15570 expected_width = 32;
15571 break;
15572
15573 case BuiltInTessLevelInner:
15574 case BuiltInTessLevelOuter:
15575 if (get_execution_model() == ExecutionModelTessellationControl)
15576 {
15577 expected_type = SPIRType::Half;
15578 expected_width = 16;
15579 }
15580 break;
15581
15582 default:
15583 break;
15584 }
15585
15586 if (expected_type != expr_type.basetype)
15587 {
15588 if (!expr_type.array.empty() && (builtin == BuiltInTessLevelInner || builtin == BuiltInTessLevelOuter))
15589 {
15590 // Triggers when loading TessLevel directly as an array.
15591 // Need explicit padding + cast.
15592 auto wrap_expr = join(type_to_glsl(expr_type), "({ ");
15593
15594 uint32_t array_size = get_physical_tess_level_array_size(builtin);
15595 for (uint32_t i = 0; i < array_size; i++)
15596 {
15597 if (array_size > 1)
15598 wrap_expr += join("float(", expr, "[", i, "])");
15599 else
15600 wrap_expr += join("float(", expr, ")");
15601 if (i + 1 < array_size)
15602 wrap_expr += ", ";
15603 }
15604
15605 if (get_execution_mode_bitset().get(ExecutionModeTriangles))
15606 wrap_expr += ", 0.0";
15607
15608 wrap_expr += " })";
15609 expr = std::move(wrap_expr);
15610 }
15611 else
15612 {
15613 // These are of different widths, so we cannot do a straight bitcast.
15614 if (expected_width != expr_type.width)
15615 expr = join(type_to_glsl(expr_type), "(", expr, ")");
15616 else
15617 expr = bitcast_expression(expr_type, expected_type, expr);
15618 }
15619 }
15620}
15621
15622void CompilerMSL::cast_to_variable_store(uint32_t target_id, std::string &expr, const SPIRType &expr_type)
15623{
15624 auto *var = maybe_get_backing_variable(target_id);
15625 if (var)
15626 target_id = var->self;
15627
15628 // Type fixups for workgroup variables if they are booleans.
15629 if (var && var->storage == StorageClassWorkgroup && expr_type.basetype == SPIRType::Boolean)
15630 {
15631 auto short_type = expr_type;
15632 short_type.basetype = SPIRType::Short;
15633 expr = join(type_to_glsl(short_type), "(", expr, ")");
15634 }
15635
15636 // Only interested in standalone builtin variables.
15637 if (!has_decoration(target_id, DecorationBuiltIn))
15638 return;
15639
15640 auto builtin = static_cast<BuiltIn>(get_decoration(target_id, DecorationBuiltIn));
15641 auto expected_type = expr_type.basetype;
15642 auto expected_width = expr_type.width;
15643 switch (builtin)
15644 {
15645 case BuiltInLayer:
15646 case BuiltInViewportIndex:
15647 case BuiltInFragStencilRefEXT:
15648 case BuiltInPrimitiveId:
15649 case BuiltInViewIndex:
15650 expected_type = SPIRType::UInt;
15651 expected_width = 32;
15652 break;
15653
15654 case BuiltInTessLevelInner:
15655 case BuiltInTessLevelOuter:
15656 expected_type = SPIRType::Half;
15657 expected_width = 16;
15658 break;
15659
15660 default:
15661 break;
15662 }
15663
15664 if (expected_type != expr_type.basetype)
15665 {
15666 if (expected_width != expr_type.width)
15667 {
15668 // These are of different widths, so we cannot do a straight bitcast.
15669 auto type = expr_type;
15670 type.basetype = expected_type;
15671 type.width = expected_width;
15672 expr = join(type_to_glsl(type), "(", expr, ")");
15673 }
15674 else
15675 {
15676 auto type = expr_type;
15677 type.basetype = expected_type;
15678 expr = bitcast_expression(type, expr_type.basetype, expr);
15679 }
15680 }
15681}
15682
15683string CompilerMSL::to_initializer_expression(const SPIRVariable &var)
15684{
15685 // We risk getting an array initializer here with MSL. If we have an array.
15686 // FIXME: We cannot handle non-constant arrays being initialized.
15687 // We will need to inject spvArrayCopy here somehow ...
15688 auto &type = get<SPIRType>(var.basetype);
15689 string expr;
15690 if (ir.ids[var.initializer].get_type() == TypeConstant &&
15691 (!type.array.empty() || type.basetype == SPIRType::Struct))
15692 expr = constant_expression(get<SPIRConstant>(var.initializer));
15693 else
15694 expr = CompilerGLSL::to_initializer_expression(var);
15695 // If the initializer has more vector components than the variable, add a swizzle.
15696 // FIXME: This can't handle arrays or structs.
15697 auto &init_type = expression_type(var.initializer);
15698 if (type.array.empty() && type.basetype != SPIRType::Struct && init_type.vecsize > type.vecsize)
15699 expr = enclose_expression(expr + vector_swizzle(type.vecsize, 0));
15700 return expr;
15701}
15702
15703string CompilerMSL::to_zero_initialized_expression(uint32_t)
15704{
15705 return "{}";
15706}
15707
15708bool CompilerMSL::descriptor_set_is_argument_buffer(uint32_t desc_set) const
15709{
15710 if (!msl_options.argument_buffers)
15711 return false;
15712 if (desc_set >= kMaxArgumentBuffers)
15713 return false;
15714
15715 return (argument_buffer_discrete_mask & (1u << desc_set)) == 0;
15716}
15717
15718bool CompilerMSL::is_supported_argument_buffer_type(const SPIRType &type) const
15719{
15720 // Very specifically, image load-store in argument buffers are disallowed on MSL on iOS.
15721 // But we won't know when the argument buffer is encoded whether this image will have
15722 // a NonWritable decoration. So just use discrete arguments for all storage images
15723 // on iOS.
15724 bool is_storage_image = type.basetype == SPIRType::Image && type.image.sampled == 2;
15725 bool is_supported_type = !msl_options.is_ios() || !is_storage_image;
15726 return !type_is_msl_framebuffer_fetch(type) && is_supported_type;
15727}
15728
15729void CompilerMSL::analyze_argument_buffers()
15730{
15731 // Gather all used resources and sort them out into argument buffers.
15732 // Each argument buffer corresponds to a descriptor set in SPIR-V.
15733 // The [[id(N)]] values used correspond to the resource mapping we have for MSL.
15734 // Otherwise, the binding number is used, but this is generally not safe some types like
15735 // combined image samplers and arrays of resources. Metal needs different indices here,
15736 // while SPIR-V can have one descriptor set binding. To use argument buffers in practice,
15737 // you will need to use the remapping from the API.
15738 for (auto &id : argument_buffer_ids)
15739 id = 0;
15740
15741 // Output resources, sorted by resource index & type.
15742 struct Resource
15743 {
15744 SPIRVariable *var;
15745 string name;
15746 SPIRType::BaseType basetype;
15747 uint32_t index;
15748 uint32_t plane;
15749 };
15750 SmallVector<Resource> resources_in_set[kMaxArgumentBuffers];
15751 SmallVector<uint32_t> inline_block_vars;
15752
15753 bool set_needs_swizzle_buffer[kMaxArgumentBuffers] = {};
15754 bool set_needs_buffer_sizes[kMaxArgumentBuffers] = {};
15755 bool needs_buffer_sizes = false;
15756
15757 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, SPIRVariable &var) {
15758 if ((var.storage == StorageClassUniform || var.storage == StorageClassUniformConstant ||
15759 var.storage == StorageClassStorageBuffer) &&
15760 !is_hidden_variable(var))
15761 {
15762 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
15763 // Ignore if it's part of a push descriptor set.
15764 if (!descriptor_set_is_argument_buffer(desc_set))
15765 return;
15766
15767 uint32_t var_id = var.self;
15768 auto &type = get_variable_data_type(var);
15769
15770 if (desc_set >= kMaxArgumentBuffers)
15771 SPIRV_CROSS_THROW("Descriptor set index is out of range.");
15772
15773 const MSLConstexprSampler *constexpr_sampler = nullptr;
15774 if (type.basetype == SPIRType::SampledImage || type.basetype == SPIRType::Sampler)
15775 {
15776 constexpr_sampler = find_constexpr_sampler(var_id);
15777 if (constexpr_sampler)
15778 {
15779 // Mark this ID as a constexpr sampler for later in case it came from set/bindings.
15780 constexpr_samplers_by_id[var_id] = *constexpr_sampler;
15781 }
15782 }
15783
15784 uint32_t binding = get_decoration(var_id, DecorationBinding);
15785 if (type.basetype == SPIRType::SampledImage)
15786 {
15787 add_resource_name(var_id);
15788
15789 uint32_t plane_count = 1;
15790 if (constexpr_sampler && constexpr_sampler->ycbcr_conversion_enable)
15791 plane_count = constexpr_sampler->planes;
15792
15793 for (uint32_t i = 0; i < plane_count; i++)
15794 {
15795 uint32_t image_resource_index = get_metal_resource_index(var, SPIRType::Image, i);
15796 resources_in_set[desc_set].push_back(
15797 { &var, to_name(var_id), SPIRType::Image, image_resource_index, i });
15798 }
15799
15800 if (type.image.dim != DimBuffer && !constexpr_sampler)
15801 {
15802 uint32_t sampler_resource_index = get_metal_resource_index(var, SPIRType::Sampler);
15803 resources_in_set[desc_set].push_back(
15804 { &var, to_sampler_expression(var_id), SPIRType::Sampler, sampler_resource_index, 0 });
15805 }
15806 }
15807 else if (inline_uniform_blocks.count(SetBindingPair{ desc_set, binding }))
15808 {
15809 inline_block_vars.push_back(var_id);
15810 }
15811 else if (!constexpr_sampler && is_supported_argument_buffer_type(type))
15812 {
15813 // constexpr samplers are not declared as resources.
15814 // Inline uniform blocks are always emitted at the end.
15815 add_resource_name(var_id);
15816 resources_in_set[desc_set].push_back(
15817 { &var, to_name(var_id), type.basetype, get_metal_resource_index(var, type.basetype), 0 });
15818
15819 // Emulate texture2D atomic operations
15820 if (atomic_image_vars.count(var.self))
15821 {
15822 uint32_t buffer_resource_index = get_metal_resource_index(var, SPIRType::AtomicCounter, 0);
15823 resources_in_set[desc_set].push_back(
15824 { &var, to_name(var_id) + "_atomic", SPIRType::Struct, buffer_resource_index, 0 });
15825 }
15826 }
15827
15828 // Check if this descriptor set needs a swizzle buffer.
15829 if (needs_swizzle_buffer_def && is_sampled_image_type(type))
15830 set_needs_swizzle_buffer[desc_set] = true;
15831 else if (buffers_requiring_array_length.count(var_id) != 0)
15832 {
15833 set_needs_buffer_sizes[desc_set] = true;
15834 needs_buffer_sizes = true;
15835 }
15836 }
15837 });
15838
15839 if (needs_swizzle_buffer_def || needs_buffer_sizes)
15840 {
15841 uint32_t uint_ptr_type_id = 0;
15842
15843 // We might have to add a swizzle buffer resource to the set.
15844 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
15845 {
15846 if (!set_needs_swizzle_buffer[desc_set] && !set_needs_buffer_sizes[desc_set])
15847 continue;
15848
15849 if (uint_ptr_type_id == 0)
15850 {
15851 uint_ptr_type_id = ir.increase_bound_by(1);
15852
15853 // Create a buffer to hold extra data, including the swizzle constants.
15854 SPIRType uint_type_pointer = get_uint_type();
15855 uint_type_pointer.pointer = true;
15856 uint_type_pointer.pointer_depth++;
15857 uint_type_pointer.parent_type = get_uint_type_id();
15858 uint_type_pointer.storage = StorageClassUniform;
15859 set<SPIRType>(uint_ptr_type_id, uint_type_pointer);
15860 set_decoration(uint_ptr_type_id, DecorationArrayStride, 4);
15861 }
15862
15863 if (set_needs_swizzle_buffer[desc_set])
15864 {
15865 uint32_t var_id = ir.increase_bound_by(1);
15866 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
15867 set_name(var_id, "spvSwizzleConstants");
15868 set_decoration(var_id, DecorationDescriptorSet, desc_set);
15869 set_decoration(var_id, DecorationBinding, kSwizzleBufferBinding);
15870 resources_in_set[desc_set].push_back(
15871 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
15872 }
15873
15874 if (set_needs_buffer_sizes[desc_set])
15875 {
15876 uint32_t var_id = ir.increase_bound_by(1);
15877 auto &var = set<SPIRVariable>(var_id, uint_ptr_type_id, StorageClassUniformConstant);
15878 set_name(var_id, "spvBufferSizeConstants");
15879 set_decoration(var_id, DecorationDescriptorSet, desc_set);
15880 set_decoration(var_id, DecorationBinding, kBufferSizeBufferBinding);
15881 resources_in_set[desc_set].push_back(
15882 { &var, to_name(var_id), SPIRType::UInt, get_metal_resource_index(var, SPIRType::UInt), 0 });
15883 }
15884 }
15885 }
15886
15887 // Now add inline uniform blocks.
15888 for (uint32_t var_id : inline_block_vars)
15889 {
15890 auto &var = get<SPIRVariable>(var_id);
15891 uint32_t desc_set = get_decoration(var_id, DecorationDescriptorSet);
15892 add_resource_name(var_id);
15893 resources_in_set[desc_set].push_back(
15894 { &var, to_name(var_id), SPIRType::Struct, get_metal_resource_index(var, SPIRType::Struct), 0 });
15895 }
15896
15897 for (uint32_t desc_set = 0; desc_set < kMaxArgumentBuffers; desc_set++)
15898 {
15899 auto &resources = resources_in_set[desc_set];
15900 if (resources.empty())
15901 continue;
15902
15903 assert(descriptor_set_is_argument_buffer(desc_set));
15904
15905 uint32_t next_id = ir.increase_bound_by(3);
15906 uint32_t type_id = next_id + 1;
15907 uint32_t ptr_type_id = next_id + 2;
15908 argument_buffer_ids[desc_set] = next_id;
15909
15910 auto &buffer_type = set<SPIRType>(type_id);
15911
15912 buffer_type.basetype = SPIRType::Struct;
15913
15914 if ((argument_buffer_device_storage_mask & (1u << desc_set)) != 0)
15915 {
15916 buffer_type.storage = StorageClassStorageBuffer;
15917 // Make sure the argument buffer gets marked as const device.
15918 set_decoration(next_id, DecorationNonWritable);
15919 // Need to mark the type as a Block to enable this.
15920 set_decoration(type_id, DecorationBlock);
15921 }
15922 else
15923 buffer_type.storage = StorageClassUniform;
15924
15925 set_name(type_id, join("spvDescriptorSetBuffer", desc_set));
15926
15927 auto &ptr_type = set<SPIRType>(ptr_type_id);
15928 ptr_type = buffer_type;
15929 ptr_type.pointer = true;
15930 ptr_type.pointer_depth++;
15931 ptr_type.parent_type = type_id;
15932
15933 uint32_t buffer_variable_id = next_id;
15934 set<SPIRVariable>(buffer_variable_id, ptr_type_id, StorageClassUniform);
15935 set_name(buffer_variable_id, join("spvDescriptorSet", desc_set));
15936
15937 // Ids must be emitted in ID order.
15938 sort(begin(resources), end(resources), [&](const Resource &lhs, const Resource &rhs) -> bool {
15939 return tie(lhs.index, lhs.basetype) < tie(rhs.index, rhs.basetype);
15940 });
15941
15942 uint32_t member_index = 0;
15943 uint32_t next_arg_buff_index = 0;
15944 for (auto &resource : resources)
15945 {
15946 auto &var = *resource.var;
15947 auto &type = get_variable_data_type(var);
15948
15949 // If needed, synthesize and add padding members.
15950 // member_index and next_arg_buff_index are incremented when padding members are added.
15951 if (msl_options.pad_argument_buffer_resources)
15952 {
15953 while (resource.index > next_arg_buff_index)
15954 {
15955 auto &rez_bind = get_argument_buffer_resource(desc_set, next_arg_buff_index);
15956 switch (rez_bind.basetype)
15957 {
15958 case SPIRType::Void:
15959 case SPIRType::Boolean:
15960 case SPIRType::SByte:
15961 case SPIRType::UByte:
15962 case SPIRType::Short:
15963 case SPIRType::UShort:
15964 case SPIRType::Int:
15965 case SPIRType::UInt:
15966 case SPIRType::Int64:
15967 case SPIRType::UInt64:
15968 case SPIRType::AtomicCounter:
15969 case SPIRType::Half:
15970 case SPIRType::Float:
15971 case SPIRType::Double:
15972 add_argument_buffer_padding_buffer_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15973 break;
15974 case SPIRType::Image:
15975 add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15976 break;
15977 case SPIRType::Sampler:
15978 add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15979 break;
15980 case SPIRType::SampledImage:
15981 if (next_arg_buff_index == rez_bind.msl_sampler)
15982 add_argument_buffer_padding_sampler_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15983 else
15984 add_argument_buffer_padding_image_type(buffer_type, member_index, next_arg_buff_index, rez_bind);
15985 break;
15986 default:
15987 break;
15988 }
15989 }
15990
15991 // Adjust the number of slots consumed by current member itself.
15992 // If actual member is an array, allow runtime array resolution as well.
15993 uint32_t elem_cnt = type.array.empty() ? 1 : to_array_size_literal(type);
15994 if (elem_cnt == 0)
15995 elem_cnt = get_resource_array_size(var.self);
15996
15997 next_arg_buff_index += elem_cnt;
15998 }
15999
16000 string mbr_name = ensure_valid_name(resource.name, "m");
16001 if (resource.plane > 0)
16002 mbr_name += join(plane_name_suffix, resource.plane);
16003 set_member_name(buffer_type.self, member_index, mbr_name);
16004
16005 if (resource.basetype == SPIRType::Sampler && type.basetype != SPIRType::Sampler)
16006 {
16007 // Have to synthesize a sampler type here.
16008
16009 bool type_is_array = !type.array.empty();
16010 uint32_t sampler_type_id = ir.increase_bound_by(type_is_array ? 2 : 1);
16011 auto &new_sampler_type = set<SPIRType>(sampler_type_id);
16012 new_sampler_type.basetype = SPIRType::Sampler;
16013 new_sampler_type.storage = StorageClassUniformConstant;
16014
16015 if (type_is_array)
16016 {
16017 uint32_t sampler_type_array_id = sampler_type_id + 1;
16018 auto &sampler_type_array = set<SPIRType>(sampler_type_array_id);
16019 sampler_type_array = new_sampler_type;
16020 sampler_type_array.array = type.array;
16021 sampler_type_array.array_size_literal = type.array_size_literal;
16022 sampler_type_array.parent_type = sampler_type_id;
16023 buffer_type.member_types.push_back(sampler_type_array_id);
16024 }
16025 else
16026 buffer_type.member_types.push_back(sampler_type_id);
16027 }
16028 else
16029 {
16030 uint32_t binding = get_decoration(var.self, DecorationBinding);
16031 SetBindingPair pair = { desc_set, binding };
16032
16033 if (resource.basetype == SPIRType::Image || resource.basetype == SPIRType::Sampler ||
16034 resource.basetype == SPIRType::SampledImage)
16035 {
16036 // Drop pointer information when we emit the resources into a struct.
16037 buffer_type.member_types.push_back(get_variable_data_type_id(var));
16038 if (resource.plane == 0)
16039 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
16040 }
16041 else if (buffers_requiring_dynamic_offset.count(pair))
16042 {
16043 // Don't set the qualified name here; we'll define a variable holding the corrected buffer address later.
16044 buffer_type.member_types.push_back(var.basetype);
16045 buffers_requiring_dynamic_offset[pair].second = var.self;
16046 }
16047 else if (inline_uniform_blocks.count(pair))
16048 {
16049 // Put the buffer block itself into the argument buffer.
16050 buffer_type.member_types.push_back(get_variable_data_type_id(var));
16051 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
16052 }
16053 else if (atomic_image_vars.count(var.self))
16054 {
16055 // Emulate texture2D atomic operations.
16056 // Don't set the qualified name: it's already set for this variable,
16057 // and the code that references the buffer manually appends "_atomic"
16058 // to the name.
16059 uint32_t offset = ir.increase_bound_by(2);
16060 uint32_t atomic_type_id = offset;
16061 uint32_t type_ptr_id = offset + 1;
16062
16063 SPIRType atomic_type;
16064 atomic_type.basetype = SPIRType::AtomicCounter;
16065 atomic_type.width = 32;
16066 atomic_type.vecsize = 1;
16067 set<SPIRType>(atomic_type_id, atomic_type);
16068
16069 atomic_type.pointer = true;
16070 atomic_type.pointer_depth++;
16071 atomic_type.parent_type = atomic_type_id;
16072 atomic_type.storage = StorageClassStorageBuffer;
16073 auto &atomic_ptr_type = set<SPIRType>(type_ptr_id, atomic_type);
16074 atomic_ptr_type.self = atomic_type_id;
16075
16076 buffer_type.member_types.push_back(type_ptr_id);
16077 }
16078 else
16079 {
16080 // Resources will be declared as pointers not references, so automatically dereference as appropriate.
16081 buffer_type.member_types.push_back(var.basetype);
16082 if (type.array.empty())
16083 set_qualified_name(var.self, join("(*", to_name(buffer_variable_id), ".", mbr_name, ")"));
16084 else
16085 set_qualified_name(var.self, join(to_name(buffer_variable_id), ".", mbr_name));
16086 }
16087 }
16088
16089 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationResourceIndexPrimary,
16090 resource.index);
16091 set_extended_member_decoration(buffer_type.self, member_index, SPIRVCrossDecorationInterfaceOrigID,
16092 var.self);
16093 member_index++;
16094 }
16095 }
16096}
16097
16098// Return the resource type of the app-provided resources for the descriptor set,
16099// that matches the resource index of the argument buffer index.
16100// This is a two-step lookup, first lookup the resource binding number from the argument buffer index,
16101// then lookup the resource binding using the binding number.
16102MSLResourceBinding &CompilerMSL::get_argument_buffer_resource(uint32_t desc_set, uint32_t arg_idx)
16103{
16104 auto stage = get_entry_point().model;
16105 StageSetBinding arg_idx_tuple = { stage, desc_set, arg_idx };
16106 auto arg_itr = resource_arg_buff_idx_to_binding_number.find(arg_idx_tuple);
16107 if (arg_itr != end(resource_arg_buff_idx_to_binding_number))
16108 {
16109 StageSetBinding bind_tuple = { stage, desc_set, arg_itr->second };
16110 auto bind_itr = resource_bindings.find(bind_tuple);
16111 if (bind_itr != end(resource_bindings))
16112 return bind_itr->second.first;
16113 }
16114 SPIRV_CROSS_THROW("Argument buffer resource base type could not be determined. When padding argument buffer "
16115 "elements, all descriptor set resources must be supplied with a base type by the app.");
16116}
16117
16118// Adds an argument buffer padding argument buffer type as one or more members of the struct type at the member index.
16119// Metal does not support arrays of buffers, so these are emitted as multiple struct members.
16120void CompilerMSL::add_argument_buffer_padding_buffer_type(SPIRType &struct_type, uint32_t &mbr_idx,
16121 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
16122{
16123 if (!argument_buffer_padding_buffer_type_id)
16124 {
16125 uint32_t buff_type_id = ir.increase_bound_by(2);
16126 auto &buff_type = set<SPIRType>(buff_type_id);
16127 buff_type.basetype = rez_bind.basetype;
16128 buff_type.storage = StorageClassUniformConstant;
16129
16130 uint32_t ptr_type_id = buff_type_id + 1;
16131 auto &ptr_type = set<SPIRType>(ptr_type_id);
16132 ptr_type = buff_type;
16133 ptr_type.pointer = true;
16134 ptr_type.pointer_depth++;
16135 ptr_type.parent_type = buff_type_id;
16136
16137 argument_buffer_padding_buffer_type_id = ptr_type_id;
16138 }
16139
16140 for (uint32_t rez_idx = 0; rez_idx < rez_bind.count; rez_idx++)
16141 add_argument_buffer_padding_type(argument_buffer_padding_buffer_type_id, struct_type, mbr_idx, arg_buff_index, 1);
16142}
16143
16144// Adds an argument buffer padding argument image type as a member of the struct type at the member index.
16145void CompilerMSL::add_argument_buffer_padding_image_type(SPIRType &struct_type, uint32_t &mbr_idx,
16146 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
16147{
16148 if (!argument_buffer_padding_image_type_id)
16149 {
16150 uint32_t base_type_id = ir.increase_bound_by(2);
16151 auto &base_type = set<SPIRType>(base_type_id);
16152 base_type.basetype = SPIRType::Float;
16153 base_type.width = 32;
16154
16155 uint32_t img_type_id = base_type_id + 1;
16156 auto &img_type = set<SPIRType>(img_type_id);
16157 img_type.basetype = SPIRType::Image;
16158 img_type.storage = StorageClassUniformConstant;
16159
16160 img_type.image.type = base_type_id;
16161 img_type.image.dim = Dim2D;
16162 img_type.image.depth = false;
16163 img_type.image.arrayed = false;
16164 img_type.image.ms = false;
16165 img_type.image.sampled = 1;
16166 img_type.image.format = ImageFormatUnknown;
16167 img_type.image.access = AccessQualifierMax;
16168
16169 argument_buffer_padding_image_type_id = img_type_id;
16170 }
16171
16172 add_argument_buffer_padding_type(argument_buffer_padding_image_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count);
16173}
16174
16175// Adds an argument buffer padding argument sampler type as a member of the struct type at the member index.
16176void CompilerMSL::add_argument_buffer_padding_sampler_type(SPIRType &struct_type, uint32_t &mbr_idx,
16177 uint32_t &arg_buff_index, MSLResourceBinding &rez_bind)
16178{
16179 if (!argument_buffer_padding_sampler_type_id)
16180 {
16181 uint32_t samp_type_id = ir.increase_bound_by(1);
16182 auto &samp_type = set<SPIRType>(samp_type_id);
16183 samp_type.basetype = SPIRType::Sampler;
16184 samp_type.storage = StorageClassUniformConstant;
16185
16186 argument_buffer_padding_sampler_type_id = samp_type_id;
16187 }
16188
16189 add_argument_buffer_padding_type(argument_buffer_padding_sampler_type_id, struct_type, mbr_idx, arg_buff_index, rez_bind.count);
16190}
16191
16192// Adds the argument buffer padding argument type as a member of the struct type at the member index.
16193// Advances both arg_buff_index and mbr_idx to next argument slots.
16194void CompilerMSL::add_argument_buffer_padding_type(uint32_t mbr_type_id, SPIRType &struct_type, uint32_t &mbr_idx,
16195 uint32_t &arg_buff_index, uint32_t count)
16196{
16197 uint32_t type_id = mbr_type_id;
16198 if (count > 1)
16199 {
16200 uint32_t ary_type_id = ir.increase_bound_by(1);
16201 auto &ary_type = set<SPIRType>(ary_type_id);
16202 ary_type = get<SPIRType>(type_id);
16203 ary_type.array.push_back(count);
16204 ary_type.array_size_literal.push_back(true);
16205 ary_type.parent_type = type_id;
16206 type_id = ary_type_id;
16207 }
16208
16209 set_member_name(struct_type.self, mbr_idx, join("_m", arg_buff_index, "_pad"));
16210 set_extended_member_decoration(struct_type.self, mbr_idx, SPIRVCrossDecorationResourceIndexPrimary, arg_buff_index);
16211 struct_type.member_types.push_back(type_id);
16212
16213 arg_buff_index += count;
16214 mbr_idx++;
16215}
16216
16217void CompilerMSL::activate_argument_buffer_resources()
16218{
16219 // For ABI compatibility, force-enable all resources which are part of argument buffers.
16220 ir.for_each_typed_id<SPIRVariable>([&](uint32_t self, const SPIRVariable &) {
16221 if (!has_decoration(self, DecorationDescriptorSet))
16222 return;
16223
16224 uint32_t desc_set = get_decoration(self, DecorationDescriptorSet);
16225 if (descriptor_set_is_argument_buffer(desc_set))
16226 active_interface_variables.insert(self);
16227 });
16228}
16229
16230bool CompilerMSL::using_builtin_array() const
16231{
16232 return msl_options.force_native_arrays || is_using_builtin_array;
16233}
16234
16235void CompilerMSL::set_combined_sampler_suffix(const char *suffix)
16236{
16237 sampler_name_suffix = suffix;
16238}
16239
16240const char *CompilerMSL::get_combined_sampler_suffix() const
16241{
16242 return sampler_name_suffix.c_str();
16243}
16244
16245void CompilerMSL::emit_block_hints(const SPIRBlock &)
16246{
16247}
16248
16249string CompilerMSL::additional_fixed_sample_mask_str() const
16250{
16251 char print_buffer[32];
16252 sprintf(print_buffer, "0x%x", msl_options.additional_fixed_sample_mask);
16253 return print_buffer;
16254}
16255