1/*
2 * Copyright 2016-2021 Robert Konrad
3 * SPDX-License-Identifier: Apache-2.0 OR MIT
4 *
5 * Licensed under the Apache License, Version 2.0 (the "License");
6 * you may not use this file except in compliance with the License.
7 * You may obtain a copy of the License at
8 *
9 * http://www.apache.org/licenses/LICENSE-2.0
10 *
11 * Unless required by applicable law or agreed to in writing, software
12 * distributed under the License is distributed on an "AS IS" BASIS,
13 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14 * See the License for the specific language governing permissions and
15 * limitations under the License.
16 *
17 */
18
19/*
20 * At your option, you may choose to accept this material under either:
21 * 1. The Apache License, Version 2.0, found at <http://www.apache.org/licenses/LICENSE-2.0>, or
22 * 2. The MIT License, found at <http://opensource.org/licenses/MIT>.
23 */
24
25#include "spirv_hlsl.hpp"
26#include "GLSL.std.450.h"
27#include <algorithm>
28#include <assert.h>
29
30using namespace spv;
31using namespace SPIRV_CROSS_NAMESPACE;
32using namespace std;
33
34enum class ImageFormatNormalizedState
35{
36 None = 0,
37 Unorm = 1,
38 Snorm = 2
39};
40
41static ImageFormatNormalizedState image_format_to_normalized_state(ImageFormat fmt)
42{
43 switch (fmt)
44 {
45 case ImageFormatR8:
46 case ImageFormatR16:
47 case ImageFormatRg8:
48 case ImageFormatRg16:
49 case ImageFormatRgba8:
50 case ImageFormatRgba16:
51 case ImageFormatRgb10A2:
52 return ImageFormatNormalizedState::Unorm;
53
54 case ImageFormatR8Snorm:
55 case ImageFormatR16Snorm:
56 case ImageFormatRg8Snorm:
57 case ImageFormatRg16Snorm:
58 case ImageFormatRgba8Snorm:
59 case ImageFormatRgba16Snorm:
60 return ImageFormatNormalizedState::Snorm;
61
62 default:
63 break;
64 }
65
66 return ImageFormatNormalizedState::None;
67}
68
69static unsigned image_format_to_components(ImageFormat fmt)
70{
71 switch (fmt)
72 {
73 case ImageFormatR8:
74 case ImageFormatR16:
75 case ImageFormatR8Snorm:
76 case ImageFormatR16Snorm:
77 case ImageFormatR16f:
78 case ImageFormatR32f:
79 case ImageFormatR8i:
80 case ImageFormatR16i:
81 case ImageFormatR32i:
82 case ImageFormatR8ui:
83 case ImageFormatR16ui:
84 case ImageFormatR32ui:
85 return 1;
86
87 case ImageFormatRg8:
88 case ImageFormatRg16:
89 case ImageFormatRg8Snorm:
90 case ImageFormatRg16Snorm:
91 case ImageFormatRg16f:
92 case ImageFormatRg32f:
93 case ImageFormatRg8i:
94 case ImageFormatRg16i:
95 case ImageFormatRg32i:
96 case ImageFormatRg8ui:
97 case ImageFormatRg16ui:
98 case ImageFormatRg32ui:
99 return 2;
100
101 case ImageFormatR11fG11fB10f:
102 return 3;
103
104 case ImageFormatRgba8:
105 case ImageFormatRgba16:
106 case ImageFormatRgb10A2:
107 case ImageFormatRgba8Snorm:
108 case ImageFormatRgba16Snorm:
109 case ImageFormatRgba16f:
110 case ImageFormatRgba32f:
111 case ImageFormatRgba8i:
112 case ImageFormatRgba16i:
113 case ImageFormatRgba32i:
114 case ImageFormatRgba8ui:
115 case ImageFormatRgba16ui:
116 case ImageFormatRgba32ui:
117 case ImageFormatRgb10a2ui:
118 return 4;
119
120 case ImageFormatUnknown:
121 return 4; // Assume 4.
122
123 default:
124 SPIRV_CROSS_THROW("Unrecognized typed image format.");
125 }
126}
127
128static string image_format_to_type(ImageFormat fmt, SPIRType::BaseType basetype)
129{
130 switch (fmt)
131 {
132 case ImageFormatR8:
133 case ImageFormatR16:
134 if (basetype != SPIRType::Float)
135 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
136 return "unorm float";
137 case ImageFormatRg8:
138 case ImageFormatRg16:
139 if (basetype != SPIRType::Float)
140 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
141 return "unorm float2";
142 case ImageFormatRgba8:
143 case ImageFormatRgba16:
144 if (basetype != SPIRType::Float)
145 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
146 return "unorm float4";
147 case ImageFormatRgb10A2:
148 if (basetype != SPIRType::Float)
149 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
150 return "unorm float4";
151
152 case ImageFormatR8Snorm:
153 case ImageFormatR16Snorm:
154 if (basetype != SPIRType::Float)
155 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
156 return "snorm float";
157 case ImageFormatRg8Snorm:
158 case ImageFormatRg16Snorm:
159 if (basetype != SPIRType::Float)
160 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
161 return "snorm float2";
162 case ImageFormatRgba8Snorm:
163 case ImageFormatRgba16Snorm:
164 if (basetype != SPIRType::Float)
165 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
166 return "snorm float4";
167
168 case ImageFormatR16f:
169 case ImageFormatR32f:
170 if (basetype != SPIRType::Float)
171 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
172 return "float";
173 case ImageFormatRg16f:
174 case ImageFormatRg32f:
175 if (basetype != SPIRType::Float)
176 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
177 return "float2";
178 case ImageFormatRgba16f:
179 case ImageFormatRgba32f:
180 if (basetype != SPIRType::Float)
181 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
182 return "float4";
183
184 case ImageFormatR11fG11fB10f:
185 if (basetype != SPIRType::Float)
186 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
187 return "float3";
188
189 case ImageFormatR8i:
190 case ImageFormatR16i:
191 case ImageFormatR32i:
192 if (basetype != SPIRType::Int)
193 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
194 return "int";
195 case ImageFormatRg8i:
196 case ImageFormatRg16i:
197 case ImageFormatRg32i:
198 if (basetype != SPIRType::Int)
199 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
200 return "int2";
201 case ImageFormatRgba8i:
202 case ImageFormatRgba16i:
203 case ImageFormatRgba32i:
204 if (basetype != SPIRType::Int)
205 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
206 return "int4";
207
208 case ImageFormatR8ui:
209 case ImageFormatR16ui:
210 case ImageFormatR32ui:
211 if (basetype != SPIRType::UInt)
212 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
213 return "uint";
214 case ImageFormatRg8ui:
215 case ImageFormatRg16ui:
216 case ImageFormatRg32ui:
217 if (basetype != SPIRType::UInt)
218 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
219 return "uint2";
220 case ImageFormatRgba8ui:
221 case ImageFormatRgba16ui:
222 case ImageFormatRgba32ui:
223 if (basetype != SPIRType::UInt)
224 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
225 return "uint4";
226 case ImageFormatRgb10a2ui:
227 if (basetype != SPIRType::UInt)
228 SPIRV_CROSS_THROW("Mismatch in image type and base type of image.");
229 return "uint4";
230
231 case ImageFormatUnknown:
232 switch (basetype)
233 {
234 case SPIRType::Float:
235 return "float4";
236 case SPIRType::Int:
237 return "int4";
238 case SPIRType::UInt:
239 return "uint4";
240 default:
241 SPIRV_CROSS_THROW("Unsupported base type for image.");
242 }
243
244 default:
245 SPIRV_CROSS_THROW("Unrecognized typed image format.");
246 }
247}
248
249string CompilerHLSL::image_type_hlsl_modern(const SPIRType &type, uint32_t id)
250{
251 auto &imagetype = get<SPIRType>(type.image.type);
252 const char *dim = nullptr;
253 bool typed_load = false;
254 uint32_t components = 4;
255
256 bool force_image_srv = hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(id, DecorationNonWritable);
257
258 switch (type.image.dim)
259 {
260 case Dim1D:
261 typed_load = type.image.sampled == 2;
262 dim = "1D";
263 break;
264 case Dim2D:
265 typed_load = type.image.sampled == 2;
266 dim = "2D";
267 break;
268 case Dim3D:
269 typed_load = type.image.sampled == 2;
270 dim = "3D";
271 break;
272 case DimCube:
273 if (type.image.sampled == 2)
274 SPIRV_CROSS_THROW("RWTextureCube does not exist in HLSL.");
275 dim = "Cube";
276 break;
277 case DimRect:
278 SPIRV_CROSS_THROW("Rectangle texture support is not yet implemented for HLSL."); // TODO
279 case DimBuffer:
280 if (type.image.sampled == 1)
281 return join("Buffer<", type_to_glsl(imagetype), components, ">");
282 else if (type.image.sampled == 2)
283 {
284 if (interlocked_resources.count(id))
285 return join("RasterizerOrderedBuffer<", image_format_to_type(type.image.format, imagetype.basetype),
286 ">");
287
288 typed_load = !force_image_srv && type.image.sampled == 2;
289
290 const char *rw = force_image_srv ? "" : "RW";
291 return join(rw, "Buffer<",
292 typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
293 join(type_to_glsl(imagetype), components),
294 ">");
295 }
296 else
297 SPIRV_CROSS_THROW("Sampler buffers must be either sampled or unsampled. Cannot deduce in runtime.");
298 case DimSubpassData:
299 dim = "2D";
300 typed_load = false;
301 break;
302 default:
303 SPIRV_CROSS_THROW("Invalid dimension.");
304 }
305 const char *arrayed = type.image.arrayed ? "Array" : "";
306 const char *ms = type.image.ms ? "MS" : "";
307 const char *rw = typed_load && !force_image_srv ? "RW" : "";
308
309 if (force_image_srv)
310 typed_load = false;
311
312 if (typed_load && interlocked_resources.count(id))
313 rw = "RasterizerOrdered";
314
315 return join(rw, "Texture", dim, ms, arrayed, "<",
316 typed_load ? image_format_to_type(type.image.format, imagetype.basetype) :
317 join(type_to_glsl(imagetype), components),
318 ">");
319}
320
321string CompilerHLSL::image_type_hlsl_legacy(const SPIRType &type, uint32_t /*id*/)
322{
323 auto &imagetype = get<SPIRType>(type.image.type);
324 string res;
325
326 switch (imagetype.basetype)
327 {
328 case SPIRType::Int:
329 res = "i";
330 break;
331 case SPIRType::UInt:
332 res = "u";
333 break;
334 default:
335 break;
336 }
337
338 if (type.basetype == SPIRType::Image && type.image.dim == DimSubpassData)
339 return res + "subpassInput" + (type.image.ms ? "MS" : "");
340
341 // If we're emulating subpassInput with samplers, force sampler2D
342 // so we don't have to specify format.
343 if (type.basetype == SPIRType::Image && type.image.dim != DimSubpassData)
344 {
345 // Sampler buffers are always declared as samplerBuffer even though they might be separate images in the SPIR-V.
346 if (type.image.dim == DimBuffer && type.image.sampled == 1)
347 res += "sampler";
348 else
349 res += type.image.sampled == 2 ? "image" : "texture";
350 }
351 else
352 res += "sampler";
353
354 switch (type.image.dim)
355 {
356 case Dim1D:
357 res += "1D";
358 break;
359 case Dim2D:
360 res += "2D";
361 break;
362 case Dim3D:
363 res += "3D";
364 break;
365 case DimCube:
366 res += "CUBE";
367 break;
368
369 case DimBuffer:
370 res += "Buffer";
371 break;
372
373 case DimSubpassData:
374 res += "2D";
375 break;
376 default:
377 SPIRV_CROSS_THROW("Only 1D, 2D, 3D, Buffer, InputTarget and Cube textures supported.");
378 }
379
380 if (type.image.ms)
381 res += "MS";
382 if (type.image.arrayed)
383 res += "Array";
384
385 return res;
386}
387
388string CompilerHLSL::image_type_hlsl(const SPIRType &type, uint32_t id)
389{
390 if (hlsl_options.shader_model <= 30)
391 return image_type_hlsl_legacy(type, id);
392 else
393 return image_type_hlsl_modern(type, id);
394}
395
396// The optional id parameter indicates the object whose type we are trying
397// to find the description for. It is optional. Most type descriptions do not
398// depend on a specific object's use of that type.
399string CompilerHLSL::type_to_glsl(const SPIRType &type, uint32_t id)
400{
401 // Ignore the pointer type since GLSL doesn't have pointers.
402
403 switch (type.basetype)
404 {
405 case SPIRType::Struct:
406 // Need OpName lookup here to get a "sensible" name for a struct.
407 if (backend.explicit_struct_type)
408 return join("struct ", to_name(type.self));
409 else
410 return to_name(type.self);
411
412 case SPIRType::Image:
413 case SPIRType::SampledImage:
414 return image_type_hlsl(type, id);
415
416 case SPIRType::Sampler:
417 return comparison_ids.count(id) ? "SamplerComparisonState" : "SamplerState";
418
419 case SPIRType::Void:
420 return "void";
421
422 default:
423 break;
424 }
425
426 if (type.vecsize == 1 && type.columns == 1) // Scalar builtin
427 {
428 switch (type.basetype)
429 {
430 case SPIRType::Boolean:
431 return "bool";
432 case SPIRType::Int:
433 return backend.basic_int_type;
434 case SPIRType::UInt:
435 return backend.basic_uint_type;
436 case SPIRType::AtomicCounter:
437 return "atomic_uint";
438 case SPIRType::Half:
439 if (hlsl_options.enable_16bit_types)
440 return "half";
441 else
442 return "min16float";
443 case SPIRType::Short:
444 if (hlsl_options.enable_16bit_types)
445 return "int16_t";
446 else
447 return "min16int";
448 case SPIRType::UShort:
449 if (hlsl_options.enable_16bit_types)
450 return "uint16_t";
451 else
452 return "min16uint";
453 case SPIRType::Float:
454 return "float";
455 case SPIRType::Double:
456 return "double";
457 case SPIRType::Int64:
458 if (hlsl_options.shader_model < 60)
459 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
460 return "int64_t";
461 case SPIRType::UInt64:
462 if (hlsl_options.shader_model < 60)
463 SPIRV_CROSS_THROW("64-bit integers only supported in SM 6.0.");
464 return "uint64_t";
465 default:
466 return "???";
467 }
468 }
469 else if (type.vecsize > 1 && type.columns == 1) // Vector builtin
470 {
471 switch (type.basetype)
472 {
473 case SPIRType::Boolean:
474 return join("bool", type.vecsize);
475 case SPIRType::Int:
476 return join("int", type.vecsize);
477 case SPIRType::UInt:
478 return join("uint", type.vecsize);
479 case SPIRType::Half:
480 return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.vecsize);
481 case SPIRType::Short:
482 return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.vecsize);
483 case SPIRType::UShort:
484 return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.vecsize);
485 case SPIRType::Float:
486 return join("float", type.vecsize);
487 case SPIRType::Double:
488 return join("double", type.vecsize);
489 case SPIRType::Int64:
490 return join("i64vec", type.vecsize);
491 case SPIRType::UInt64:
492 return join("u64vec", type.vecsize);
493 default:
494 return "???";
495 }
496 }
497 else
498 {
499 switch (type.basetype)
500 {
501 case SPIRType::Boolean:
502 return join("bool", type.columns, "x", type.vecsize);
503 case SPIRType::Int:
504 return join("int", type.columns, "x", type.vecsize);
505 case SPIRType::UInt:
506 return join("uint", type.columns, "x", type.vecsize);
507 case SPIRType::Half:
508 return join(hlsl_options.enable_16bit_types ? "half" : "min16float", type.columns, "x", type.vecsize);
509 case SPIRType::Short:
510 return join(hlsl_options.enable_16bit_types ? "int16_t" : "min16int", type.columns, "x", type.vecsize);
511 case SPIRType::UShort:
512 return join(hlsl_options.enable_16bit_types ? "uint16_t" : "min16uint", type.columns, "x", type.vecsize);
513 case SPIRType::Float:
514 return join("float", type.columns, "x", type.vecsize);
515 case SPIRType::Double:
516 return join("double", type.columns, "x", type.vecsize);
517 // Matrix types not supported for int64/uint64.
518 default:
519 return "???";
520 }
521 }
522}
523
524void CompilerHLSL::emit_header()
525{
526 for (auto &header : header_lines)
527 statement(header);
528
529 if (header_lines.size() > 0)
530 {
531 statement("");
532 }
533}
534
535void CompilerHLSL::emit_interface_block_globally(const SPIRVariable &var)
536{
537 add_resource_name(var.self);
538
539 // The global copies of I/O variables should not contain interpolation qualifiers.
540 // These are emitted inside the interface structs.
541 auto &flags = ir.meta[var.self].decoration.decoration_flags;
542 auto old_flags = flags;
543 flags.reset();
544 statement("static ", variable_decl(var), ";");
545 flags = old_flags;
546}
547
548const char *CompilerHLSL::to_storage_qualifiers_glsl(const SPIRVariable &var)
549{
550 // Input and output variables are handled specially in HLSL backend.
551 // The variables are declared as global, private variables, and do not need any qualifiers.
552 if (var.storage == StorageClassUniformConstant || var.storage == StorageClassUniform ||
553 var.storage == StorageClassPushConstant)
554 {
555 return "uniform ";
556 }
557
558 return "";
559}
560
561void CompilerHLSL::emit_builtin_outputs_in_struct()
562{
563 auto &execution = get_entry_point();
564
565 bool legacy = hlsl_options.shader_model <= 30;
566 active_output_builtins.for_each_bit([&](uint32_t i) {
567 const char *type = nullptr;
568 const char *semantic = nullptr;
569 auto builtin = static_cast<BuiltIn>(i);
570 switch (builtin)
571 {
572 case BuiltInPosition:
573 type = is_position_invariant() && backend.support_precise_qualifier ? "precise float4" : "float4";
574 semantic = legacy ? "POSITION" : "SV_Position";
575 break;
576
577 case BuiltInSampleMask:
578 if (hlsl_options.shader_model < 41 || execution.model != ExecutionModelFragment)
579 SPIRV_CROSS_THROW("Sample Mask output is only supported in PS 4.1 or higher.");
580 type = "uint";
581 semantic = "SV_Coverage";
582 break;
583
584 case BuiltInFragDepth:
585 type = "float";
586 if (legacy)
587 {
588 semantic = "DEPTH";
589 }
590 else
591 {
592 if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthGreater))
593 semantic = "SV_DepthGreaterEqual";
594 else if (hlsl_options.shader_model >= 50 && execution.flags.get(ExecutionModeDepthLess))
595 semantic = "SV_DepthLessEqual";
596 else
597 semantic = "SV_Depth";
598 }
599 break;
600
601 case BuiltInClipDistance:
602 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
603 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
604 {
605 uint32_t to_declare = clip_distance_count - clip;
606 if (to_declare > 4)
607 to_declare = 4;
608
609 uint32_t semantic_index = clip / 4;
610
611 static const char *types[] = { "float", "float2", "float3", "float4" };
612 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
613 " : SV_ClipDistance", semantic_index, ";");
614 }
615 break;
616
617 case BuiltInCullDistance:
618 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
619 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
620 {
621 uint32_t to_declare = cull_distance_count - cull;
622 if (to_declare > 4)
623 to_declare = 4;
624
625 uint32_t semantic_index = cull / 4;
626
627 static const char *types[] = { "float", "float2", "float3", "float4" };
628 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassOutput), semantic_index,
629 " : SV_CullDistance", semantic_index, ";");
630 }
631 break;
632
633 case BuiltInPointSize:
634 // If point_size_compat is enabled, just ignore PointSize.
635 // PointSize does not exist in HLSL, but some code bases might want to be able to use these shaders,
636 // even if it means working around the missing feature.
637 if (hlsl_options.point_size_compat)
638 break;
639 else
640 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
641
642 default:
643 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
644 }
645
646 if (type && semantic)
647 statement(type, " ", builtin_to_glsl(builtin, StorageClassOutput), " : ", semantic, ";");
648 });
649}
650
651void CompilerHLSL::emit_builtin_inputs_in_struct()
652{
653 bool legacy = hlsl_options.shader_model <= 30;
654 active_input_builtins.for_each_bit([&](uint32_t i) {
655 const char *type = nullptr;
656 const char *semantic = nullptr;
657 auto builtin = static_cast<BuiltIn>(i);
658 switch (builtin)
659 {
660 case BuiltInFragCoord:
661 type = "float4";
662 semantic = legacy ? "VPOS" : "SV_Position";
663 break;
664
665 case BuiltInVertexId:
666 case BuiltInVertexIndex:
667 if (legacy)
668 SPIRV_CROSS_THROW("Vertex index not supported in SM 3.0 or lower.");
669 type = "uint";
670 semantic = "SV_VertexID";
671 break;
672
673 case BuiltInInstanceId:
674 case BuiltInInstanceIndex:
675 if (legacy)
676 SPIRV_CROSS_THROW("Instance index not supported in SM 3.0 or lower.");
677 type = "uint";
678 semantic = "SV_InstanceID";
679 break;
680
681 case BuiltInSampleId:
682 if (legacy)
683 SPIRV_CROSS_THROW("Sample ID not supported in SM 3.0 or lower.");
684 type = "uint";
685 semantic = "SV_SampleIndex";
686 break;
687
688 case BuiltInSampleMask:
689 if (hlsl_options.shader_model < 50 || get_entry_point().model != ExecutionModelFragment)
690 SPIRV_CROSS_THROW("Sample Mask input is only supported in PS 5.0 or higher.");
691 type = "uint";
692 semantic = "SV_Coverage";
693 break;
694
695 case BuiltInGlobalInvocationId:
696 type = "uint3";
697 semantic = "SV_DispatchThreadID";
698 break;
699
700 case BuiltInLocalInvocationId:
701 type = "uint3";
702 semantic = "SV_GroupThreadID";
703 break;
704
705 case BuiltInLocalInvocationIndex:
706 type = "uint";
707 semantic = "SV_GroupIndex";
708 break;
709
710 case BuiltInWorkgroupId:
711 type = "uint3";
712 semantic = "SV_GroupID";
713 break;
714
715 case BuiltInFrontFacing:
716 type = "bool";
717 semantic = "SV_IsFrontFace";
718 break;
719
720 case BuiltInNumWorkgroups:
721 case BuiltInSubgroupSize:
722 case BuiltInSubgroupLocalInvocationId:
723 case BuiltInSubgroupEqMask:
724 case BuiltInSubgroupLtMask:
725 case BuiltInSubgroupLeMask:
726 case BuiltInSubgroupGtMask:
727 case BuiltInSubgroupGeMask:
728 // Handled specially.
729 break;
730
731 case BuiltInClipDistance:
732 // HLSL is a bit weird here, use SV_ClipDistance0, SV_ClipDistance1 and so on with vectors.
733 for (uint32_t clip = 0; clip < clip_distance_count; clip += 4)
734 {
735 uint32_t to_declare = clip_distance_count - clip;
736 if (to_declare > 4)
737 to_declare = 4;
738
739 uint32_t semantic_index = clip / 4;
740
741 static const char *types[] = { "float", "float2", "float3", "float4" };
742 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
743 " : SV_ClipDistance", semantic_index, ";");
744 }
745 break;
746
747 case BuiltInCullDistance:
748 // HLSL is a bit weird here, use SV_CullDistance0, SV_CullDistance1 and so on with vectors.
749 for (uint32_t cull = 0; cull < cull_distance_count; cull += 4)
750 {
751 uint32_t to_declare = cull_distance_count - cull;
752 if (to_declare > 4)
753 to_declare = 4;
754
755 uint32_t semantic_index = cull / 4;
756
757 static const char *types[] = { "float", "float2", "float3", "float4" };
758 statement(types[to_declare - 1], " ", builtin_to_glsl(builtin, StorageClassInput), semantic_index,
759 " : SV_CullDistance", semantic_index, ";");
760 }
761 break;
762
763 case BuiltInPointCoord:
764 // PointCoord is not supported, but provide a way to just ignore that, similar to PointSize.
765 if (hlsl_options.point_coord_compat)
766 break;
767 else
768 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
769
770 default:
771 SPIRV_CROSS_THROW("Unsupported builtin in HLSL.");
772 }
773
774 if (type && semantic)
775 statement(type, " ", builtin_to_glsl(builtin, StorageClassInput), " : ", semantic, ";");
776 });
777}
778
779uint32_t CompilerHLSL::type_to_consumed_locations(const SPIRType &type) const
780{
781 // TODO: Need to verify correctness.
782 uint32_t elements = 0;
783
784 if (type.basetype == SPIRType::Struct)
785 {
786 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
787 elements += type_to_consumed_locations(get<SPIRType>(type.member_types[i]));
788 }
789 else
790 {
791 uint32_t array_multiplier = 1;
792 for (uint32_t i = 0; i < uint32_t(type.array.size()); i++)
793 {
794 if (type.array_size_literal[i])
795 array_multiplier *= type.array[i];
796 else
797 array_multiplier *= evaluate_constant_u32(type.array[i]);
798 }
799 elements += array_multiplier * type.columns;
800 }
801 return elements;
802}
803
804string CompilerHLSL::to_interpolation_qualifiers(const Bitset &flags)
805{
806 string res;
807 //if (flags & (1ull << DecorationSmooth))
808 // res += "linear ";
809 if (flags.get(DecorationFlat))
810 res += "nointerpolation ";
811 if (flags.get(DecorationNoPerspective))
812 res += "noperspective ";
813 if (flags.get(DecorationCentroid))
814 res += "centroid ";
815 if (flags.get(DecorationPatch))
816 res += "patch "; // Seems to be different in actual HLSL.
817 if (flags.get(DecorationSample))
818 res += "sample ";
819 if (flags.get(DecorationInvariant) && backend.support_precise_qualifier)
820 res += "precise "; // Not supported?
821
822 return res;
823}
824
825std::string CompilerHLSL::to_semantic(uint32_t location, ExecutionModel em, StorageClass sc)
826{
827 if (em == ExecutionModelVertex && sc == StorageClassInput)
828 {
829 // We have a vertex attribute - we should look at remapping it if the user provided
830 // vertex attribute hints.
831 for (auto &attribute : remap_vertex_attributes)
832 if (attribute.location == location)
833 return attribute.semantic;
834 }
835
836 // Not a vertex attribute, or no remap_vertex_attributes entry.
837 return join("TEXCOORD", location);
838}
839
840std::string CompilerHLSL::to_initializer_expression(const SPIRVariable &var)
841{
842 // We cannot emit static const initializer for block constants for practical reasons,
843 // so just inline the initializer.
844 // FIXME: There is a theoretical problem here if someone tries to composite extract
845 // into this initializer since we don't declare it properly, but that is somewhat non-sensical.
846 auto &type = get<SPIRType>(var.basetype);
847 bool is_block = has_decoration(type.self, DecorationBlock);
848 auto *c = maybe_get<SPIRConstant>(var.initializer);
849 if (is_block && c)
850 return constant_expression(*c);
851 else
852 return CompilerGLSL::to_initializer_expression(var);
853}
854
855void CompilerHLSL::emit_interface_block_member_in_struct(const SPIRVariable &var, uint32_t member_index,
856 uint32_t location,
857 std::unordered_set<uint32_t> &active_locations)
858{
859 auto &execution = get_entry_point();
860 auto type = get<SPIRType>(var.basetype);
861 auto semantic = to_semantic(location, execution.model, var.storage);
862 auto mbr_name = join(to_name(type.self), "_", to_member_name(type, member_index));
863 auto &mbr_type = get<SPIRType>(type.member_types[member_index]);
864
865 statement(to_interpolation_qualifiers(get_member_decoration_bitset(type.self, member_index)),
866 type_to_glsl(mbr_type),
867 " ", mbr_name, type_to_array_glsl(mbr_type),
868 " : ", semantic, ";");
869
870 // Structs and arrays should consume more locations.
871 uint32_t consumed_locations = type_to_consumed_locations(mbr_type);
872 for (uint32_t i = 0; i < consumed_locations; i++)
873 active_locations.insert(location + i);
874}
875
876void CompilerHLSL::emit_interface_block_in_struct(const SPIRVariable &var, unordered_set<uint32_t> &active_locations)
877{
878 auto &execution = get_entry_point();
879 auto type = get<SPIRType>(var.basetype);
880
881 string binding;
882 bool use_location_number = true;
883 bool legacy = hlsl_options.shader_model <= 30;
884 if (execution.model == ExecutionModelFragment && var.storage == StorageClassOutput)
885 {
886 // Dual-source blending is achieved in HLSL by emitting to SV_Target0 and 1.
887 uint32_t index = get_decoration(var.self, DecorationIndex);
888 uint32_t location = get_decoration(var.self, DecorationLocation);
889
890 if (index != 0 && location != 0)
891 SPIRV_CROSS_THROW("Dual-source blending is only supported on MRT #0 in HLSL.");
892
893 binding = join(legacy ? "COLOR" : "SV_Target", location + index);
894 use_location_number = false;
895 if (legacy) // COLOR must be a four-component vector on legacy shader model targets (HLSL ERR_COLOR_4COMP)
896 type.vecsize = 4;
897 }
898
899 const auto get_vacant_location = [&]() -> uint32_t {
900 for (uint32_t i = 0; i < 64; i++)
901 if (!active_locations.count(i))
902 return i;
903 SPIRV_CROSS_THROW("All locations from 0 to 63 are exhausted.");
904 };
905
906 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
907
908 auto name = to_name(var.self);
909 if (use_location_number)
910 {
911 uint32_t location_number;
912
913 // If an explicit location exists, use it with TEXCOORD[N] semantic.
914 // Otherwise, pick a vacant location.
915 if (has_decoration(var.self, DecorationLocation))
916 location_number = get_decoration(var.self, DecorationLocation);
917 else
918 location_number = get_vacant_location();
919
920 // Allow semantic remap if specified.
921 auto semantic = to_semantic(location_number, execution.model, var.storage);
922
923 if (need_matrix_unroll && type.columns > 1)
924 {
925 if (!type.array.empty())
926 SPIRV_CROSS_THROW("Arrays of matrices used as input/output. This is not supported.");
927
928 // Unroll matrices.
929 for (uint32_t i = 0; i < type.columns; i++)
930 {
931 SPIRType newtype = type;
932 newtype.columns = 1;
933
934 string effective_semantic;
935 if (hlsl_options.flatten_matrix_vertex_input_semantics)
936 effective_semantic = to_semantic(location_number, execution.model, var.storage);
937 else
938 effective_semantic = join(semantic, "_", i);
939
940 statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)),
941 variable_decl(newtype, join(name, "_", i)), " : ", effective_semantic, ";");
942 active_locations.insert(location_number++);
943 }
944 }
945 else
946 {
947 statement(to_interpolation_qualifiers(get_decoration_bitset(var.self)), variable_decl(type, name), " : ",
948 semantic, ";");
949
950 // Structs and arrays should consume more locations.
951 uint32_t consumed_locations = type_to_consumed_locations(type);
952 for (uint32_t i = 0; i < consumed_locations; i++)
953 active_locations.insert(location_number + i);
954 }
955 }
956 else
957 statement(variable_decl(type, name), " : ", binding, ";");
958}
959
960std::string CompilerHLSL::builtin_to_glsl(spv::BuiltIn builtin, spv::StorageClass storage)
961{
962 switch (builtin)
963 {
964 case BuiltInVertexId:
965 return "gl_VertexID";
966 case BuiltInInstanceId:
967 return "gl_InstanceID";
968 case BuiltInNumWorkgroups:
969 {
970 if (!num_workgroups_builtin)
971 SPIRV_CROSS_THROW("NumWorkgroups builtin is used, but remap_num_workgroups_builtin() was not called. "
972 "Cannot emit code for this builtin.");
973
974 auto &var = get<SPIRVariable>(num_workgroups_builtin);
975 auto &type = get<SPIRType>(var.basetype);
976 auto ret = join(to_name(num_workgroups_builtin), "_", get_member_name(type.self, 0));
977 ParsedIR::sanitize_underscores(ret);
978 return ret;
979 }
980 case BuiltInPointCoord:
981 // Crude hack, but there is no real alternative. This path is only enabled if point_coord_compat is set.
982 return "float2(0.5f, 0.5f)";
983 case BuiltInSubgroupLocalInvocationId:
984 return "WaveGetLaneIndex()";
985 case BuiltInSubgroupSize:
986 return "WaveGetLaneCount()";
987
988 default:
989 return CompilerGLSL::builtin_to_glsl(builtin, storage);
990 }
991}
992
993void CompilerHLSL::emit_builtin_variables()
994{
995 Bitset builtins = active_input_builtins;
996 builtins.merge_or(active_output_builtins);
997
998 bool need_base_vertex_info = false;
999
1000 std::unordered_map<uint32_t, ID> builtin_to_initializer;
1001 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1002 if (!is_builtin_variable(var) || var.storage != StorageClassOutput || !var.initializer)
1003 return;
1004
1005 auto *c = this->maybe_get<SPIRConstant>(var.initializer);
1006 if (!c)
1007 return;
1008
1009 auto &type = this->get<SPIRType>(var.basetype);
1010 if (type.basetype == SPIRType::Struct)
1011 {
1012 uint32_t member_count = uint32_t(type.member_types.size());
1013 for (uint32_t i = 0; i < member_count; i++)
1014 {
1015 if (has_member_decoration(type.self, i, DecorationBuiltIn))
1016 {
1017 builtin_to_initializer[get_member_decoration(type.self, i, DecorationBuiltIn)] =
1018 c->subconstants[i];
1019 }
1020 }
1021 }
1022 else if (has_decoration(var.self, DecorationBuiltIn))
1023 builtin_to_initializer[get_decoration(var.self, DecorationBuiltIn)] = var.initializer;
1024 });
1025
1026 // Emit global variables for the interface variables which are statically used by the shader.
1027 builtins.for_each_bit([&](uint32_t i) {
1028 const char *type = nullptr;
1029 auto builtin = static_cast<BuiltIn>(i);
1030 uint32_t array_size = 0;
1031
1032 string init_expr;
1033 auto init_itr = builtin_to_initializer.find(builtin);
1034 if (init_itr != builtin_to_initializer.end())
1035 init_expr = join(" = ", to_expression(init_itr->second));
1036
1037 switch (builtin)
1038 {
1039 case BuiltInFragCoord:
1040 case BuiltInPosition:
1041 type = "float4";
1042 break;
1043
1044 case BuiltInFragDepth:
1045 type = "float";
1046 break;
1047
1048 case BuiltInVertexId:
1049 case BuiltInVertexIndex:
1050 case BuiltInInstanceIndex:
1051 type = "int";
1052 if (hlsl_options.support_nonzero_base_vertex_base_instance)
1053 need_base_vertex_info = true;
1054 break;
1055
1056 case BuiltInInstanceId:
1057 case BuiltInSampleId:
1058 type = "int";
1059 break;
1060
1061 case BuiltInPointSize:
1062 if (hlsl_options.point_size_compat)
1063 {
1064 // Just emit the global variable, it will be ignored.
1065 type = "float";
1066 break;
1067 }
1068 else
1069 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1070
1071 case BuiltInGlobalInvocationId:
1072 case BuiltInLocalInvocationId:
1073 case BuiltInWorkgroupId:
1074 type = "uint3";
1075 break;
1076
1077 case BuiltInLocalInvocationIndex:
1078 type = "uint";
1079 break;
1080
1081 case BuiltInFrontFacing:
1082 type = "bool";
1083 break;
1084
1085 case BuiltInNumWorkgroups:
1086 case BuiltInPointCoord:
1087 // Handled specially.
1088 break;
1089
1090 case BuiltInSubgroupLocalInvocationId:
1091 case BuiltInSubgroupSize:
1092 if (hlsl_options.shader_model < 60)
1093 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1094 break;
1095
1096 case BuiltInSubgroupEqMask:
1097 case BuiltInSubgroupLtMask:
1098 case BuiltInSubgroupLeMask:
1099 case BuiltInSubgroupGtMask:
1100 case BuiltInSubgroupGeMask:
1101 if (hlsl_options.shader_model < 60)
1102 SPIRV_CROSS_THROW("Need SM 6.0 for Wave ops.");
1103 type = "uint4";
1104 break;
1105
1106 case BuiltInClipDistance:
1107 array_size = clip_distance_count;
1108 type = "float";
1109 break;
1110
1111 case BuiltInCullDistance:
1112 array_size = cull_distance_count;
1113 type = "float";
1114 break;
1115
1116 case BuiltInSampleMask:
1117 type = "int";
1118 break;
1119
1120 default:
1121 SPIRV_CROSS_THROW(join("Unsupported builtin in HLSL: ", unsigned(builtin)));
1122 }
1123
1124 StorageClass storage = active_input_builtins.get(i) ? StorageClassInput : StorageClassOutput;
1125
1126 if (type)
1127 {
1128 if (array_size)
1129 statement("static ", type, " ", builtin_to_glsl(builtin, storage), "[", array_size, "]", init_expr, ";");
1130 else
1131 statement("static ", type, " ", builtin_to_glsl(builtin, storage), init_expr, ";");
1132 }
1133
1134 // SampleMask can be both in and out with sample builtin, in this case we have already
1135 // declared the input variable and we need to add the output one now.
1136 if (builtin == BuiltInSampleMask && storage == StorageClassInput && this->active_output_builtins.get(i))
1137 {
1138 statement("static ", type, " ", this->builtin_to_glsl(builtin, StorageClassOutput), init_expr, ";");
1139 }
1140 });
1141
1142 if (need_base_vertex_info)
1143 {
1144 statement("cbuffer SPIRV_Cross_VertexInfo");
1145 begin_scope();
1146 statement("int SPIRV_Cross_BaseVertex;");
1147 statement("int SPIRV_Cross_BaseInstance;");
1148 end_scope_decl();
1149 statement("");
1150 }
1151}
1152
1153void CompilerHLSL::emit_composite_constants()
1154{
1155 // HLSL cannot declare structs or arrays inline, so we must move them out to
1156 // global constants directly.
1157 bool emitted = false;
1158
1159 ir.for_each_typed_id<SPIRConstant>([&](uint32_t, SPIRConstant &c) {
1160 if (c.specialization)
1161 return;
1162
1163 auto &type = this->get<SPIRType>(c.constant_type);
1164
1165 if (type.basetype == SPIRType::Struct && is_builtin_type(type))
1166 return;
1167
1168 if (type.basetype == SPIRType::Struct || !type.array.empty())
1169 {
1170 add_resource_name(c.self);
1171 auto name = to_name(c.self);
1172 statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
1173 emitted = true;
1174 }
1175 });
1176
1177 if (emitted)
1178 statement("");
1179}
1180
1181void CompilerHLSL::emit_specialization_constants_and_structs()
1182{
1183 bool emitted = false;
1184 SpecializationConstant wg_x, wg_y, wg_z;
1185 ID workgroup_size_id = get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
1186
1187 std::unordered_set<TypeID> io_block_types;
1188 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, const SPIRVariable &var) {
1189 auto &type = this->get<SPIRType>(var.basetype);
1190 if ((var.storage == StorageClassInput || var.storage == StorageClassOutput) &&
1191 !var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1192 interface_variable_exists_in_entry_point(var.self) &&
1193 has_decoration(type.self, DecorationBlock))
1194 {
1195 io_block_types.insert(type.self);
1196 }
1197 });
1198
1199 auto loop_lock = ir.create_loop_hard_lock();
1200 for (auto &id_ : ir.ids_for_constant_or_type)
1201 {
1202 auto &id = ir.ids[id_];
1203
1204 if (id.get_type() == TypeConstant)
1205 {
1206 auto &c = id.get<SPIRConstant>();
1207
1208 if (c.self == workgroup_size_id)
1209 {
1210 statement("static const uint3 gl_WorkGroupSize = ",
1211 constant_expression(get<SPIRConstant>(workgroup_size_id)), ";");
1212 emitted = true;
1213 }
1214 else if (c.specialization)
1215 {
1216 auto &type = get<SPIRType>(c.constant_type);
1217 add_resource_name(c.self);
1218 auto name = to_name(c.self);
1219
1220 if (has_decoration(c.self, DecorationSpecId))
1221 {
1222 // HLSL does not support specialization constants, so fallback to macros.
1223 c.specialization_constant_macro_name =
1224 constant_value_macro_name(get_decoration(c.self, DecorationSpecId));
1225
1226 statement("#ifndef ", c.specialization_constant_macro_name);
1227 statement("#define ", c.specialization_constant_macro_name, " ", constant_expression(c));
1228 statement("#endif");
1229 statement("static const ", variable_decl(type, name), " = ", c.specialization_constant_macro_name, ";");
1230 }
1231 else
1232 statement("static const ", variable_decl(type, name), " = ", constant_expression(c), ";");
1233
1234 emitted = true;
1235 }
1236 }
1237 else if (id.get_type() == TypeConstantOp)
1238 {
1239 auto &c = id.get<SPIRConstantOp>();
1240 auto &type = get<SPIRType>(c.basetype);
1241 add_resource_name(c.self);
1242 auto name = to_name(c.self);
1243 statement("static const ", variable_decl(type, name), " = ", constant_op_expression(c), ";");
1244 emitted = true;
1245 }
1246 else if (id.get_type() == TypeType)
1247 {
1248 auto &type = id.get<SPIRType>();
1249 bool is_non_io_block = has_decoration(type.self, DecorationBlock) &&
1250 io_block_types.count(type.self) == 0;
1251 bool is_buffer_block = has_decoration(type.self, DecorationBufferBlock);
1252 if (type.basetype == SPIRType::Struct && type.array.empty() &&
1253 !type.pointer && !is_non_io_block && !is_buffer_block)
1254 {
1255 if (emitted)
1256 statement("");
1257 emitted = false;
1258
1259 emit_struct(type);
1260 }
1261 }
1262 }
1263
1264 if (emitted)
1265 statement("");
1266}
1267
1268void CompilerHLSL::replace_illegal_names()
1269{
1270 static const unordered_set<string> keywords = {
1271 // Additional HLSL specific keywords.
1272 "line", "linear", "matrix", "point", "row_major", "sampler", "vector"
1273 };
1274
1275 CompilerGLSL::replace_illegal_names(keywords);
1276 CompilerGLSL::replace_illegal_names();
1277}
1278
1279void CompilerHLSL::declare_undefined_values()
1280{
1281 bool emitted = false;
1282 ir.for_each_typed_id<SPIRUndef>([&](uint32_t, const SPIRUndef &undef) {
1283 auto &type = this->get<SPIRType>(undef.basetype);
1284 // OpUndef can be void for some reason ...
1285 if (type.basetype == SPIRType::Void)
1286 return;
1287
1288 string initializer;
1289 if (options.force_zero_initialized_variables && type_can_zero_initialize(type))
1290 initializer = join(" = ", to_zero_initialized_expression(undef.basetype));
1291
1292 statement("static ", variable_decl(type, to_name(undef.self), undef.self), initializer, ";");
1293 emitted = true;
1294 });
1295
1296 if (emitted)
1297 statement("");
1298}
1299
1300void CompilerHLSL::emit_resources()
1301{
1302 auto &execution = get_entry_point();
1303
1304 replace_illegal_names();
1305
1306 emit_specialization_constants_and_structs();
1307 emit_composite_constants();
1308
1309 bool emitted = false;
1310
1311 // Output UBOs and SSBOs
1312 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1313 auto &type = this->get<SPIRType>(var.basetype);
1314
1315 bool is_block_storage = type.storage == StorageClassStorageBuffer || type.storage == StorageClassUniform;
1316 bool has_block_flags = ir.meta[type.self].decoration.decoration_flags.get(DecorationBlock) ||
1317 ir.meta[type.self].decoration.decoration_flags.get(DecorationBufferBlock);
1318
1319 if (var.storage != StorageClassFunction && type.pointer && is_block_storage && !is_hidden_variable(var) &&
1320 has_block_flags)
1321 {
1322 emit_buffer_block(var);
1323 emitted = true;
1324 }
1325 });
1326
1327 // Output push constant blocks
1328 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1329 auto &type = this->get<SPIRType>(var.basetype);
1330 if (var.storage != StorageClassFunction && type.pointer && type.storage == StorageClassPushConstant &&
1331 !is_hidden_variable(var))
1332 {
1333 emit_push_constant_block(var);
1334 emitted = true;
1335 }
1336 });
1337
1338 if (execution.model == ExecutionModelVertex && hlsl_options.shader_model <= 30)
1339 {
1340 statement("uniform float4 gl_HalfPixel;");
1341 emitted = true;
1342 }
1343
1344 bool skip_separate_image_sampler = !combined_image_samplers.empty() || hlsl_options.shader_model <= 30;
1345
1346 // Output Uniform Constants (values, samplers, images, etc).
1347 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1348 auto &type = this->get<SPIRType>(var.basetype);
1349
1350 // If we're remapping separate samplers and images, only emit the combined samplers.
1351 if (skip_separate_image_sampler)
1352 {
1353 // Sampler buffers are always used without a sampler, and they will also work in regular D3D.
1354 bool sampler_buffer = type.basetype == SPIRType::Image && type.image.dim == DimBuffer;
1355 bool separate_image = type.basetype == SPIRType::Image && type.image.sampled == 1;
1356 bool separate_sampler = type.basetype == SPIRType::Sampler;
1357 if (!sampler_buffer && (separate_image || separate_sampler))
1358 return;
1359 }
1360
1361 if (var.storage != StorageClassFunction && !is_builtin_variable(var) && !var.remapped_variable &&
1362 type.pointer && (type.storage == StorageClassUniformConstant || type.storage == StorageClassAtomicCounter) &&
1363 !is_hidden_variable(var))
1364 {
1365 emit_uniform(var);
1366 emitted = true;
1367 }
1368 });
1369
1370 if (emitted)
1371 statement("");
1372 emitted = false;
1373
1374 // Emit builtin input and output variables here.
1375 emit_builtin_variables();
1376
1377 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1378 auto &type = this->get<SPIRType>(var.basetype);
1379
1380 if (var.storage != StorageClassFunction && !var.remapped_variable && type.pointer &&
1381 (var.storage == StorageClassInput || var.storage == StorageClassOutput) && !is_builtin_variable(var) &&
1382 interface_variable_exists_in_entry_point(var.self))
1383 {
1384 // Builtin variables are handled separately.
1385 emit_interface_block_globally(var);
1386 emitted = true;
1387 }
1388 });
1389
1390 if (emitted)
1391 statement("");
1392 emitted = false;
1393
1394 require_input = false;
1395 require_output = false;
1396 unordered_set<uint32_t> active_inputs;
1397 unordered_set<uint32_t> active_outputs;
1398
1399 struct IOVariable
1400 {
1401 const SPIRVariable *var;
1402 uint32_t location;
1403 uint32_t block_member_index;
1404 bool block;
1405 };
1406
1407 SmallVector<IOVariable> input_variables;
1408 SmallVector<IOVariable> output_variables;
1409
1410 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
1411 auto &type = this->get<SPIRType>(var.basetype);
1412 bool block = has_decoration(type.self, DecorationBlock);
1413
1414 if (var.storage != StorageClassInput && var.storage != StorageClassOutput)
1415 return;
1416
1417 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
1418 interface_variable_exists_in_entry_point(var.self))
1419 {
1420 if (block)
1421 {
1422 for (uint32_t i = 0; i < uint32_t(type.member_types.size()); i++)
1423 {
1424 uint32_t location = get_declared_member_location(var, i, false);
1425 if (var.storage == StorageClassInput)
1426 input_variables.push_back({ &var, location, i, true });
1427 else
1428 output_variables.push_back({ &var, location, i, true });
1429 }
1430 }
1431 else
1432 {
1433 uint32_t location = get_decoration(var.self, DecorationLocation);
1434 if (var.storage == StorageClassInput)
1435 input_variables.push_back({ &var, location, 0, false });
1436 else
1437 output_variables.push_back({ &var, location, 0, false });
1438 }
1439 }
1440 });
1441
1442 const auto variable_compare = [&](const IOVariable &a, const IOVariable &b) -> bool {
1443 // Sort input and output variables based on, from more robust to less robust:
1444 // - Location
1445 // - Variable has a location
1446 // - Name comparison
1447 // - Variable has a name
1448 // - Fallback: ID
1449 bool has_location_a = a.block || has_decoration(a.var->self, DecorationLocation);
1450 bool has_location_b = b.block || has_decoration(b.var->self, DecorationLocation);
1451
1452 if (has_location_a && has_location_b)
1453 return a.location < b.location;
1454 else if (has_location_a && !has_location_b)
1455 return true;
1456 else if (!has_location_a && has_location_b)
1457 return false;
1458
1459 const auto &name1 = to_name(a.var->self);
1460 const auto &name2 = to_name(b.var->self);
1461
1462 if (name1.empty() && name2.empty())
1463 return a.var->self < b.var->self;
1464 else if (name1.empty())
1465 return true;
1466 else if (name2.empty())
1467 return false;
1468
1469 return name1.compare(name2) < 0;
1470 };
1471
1472 auto input_builtins = active_input_builtins;
1473 input_builtins.clear(BuiltInNumWorkgroups);
1474 input_builtins.clear(BuiltInPointCoord);
1475 input_builtins.clear(BuiltInSubgroupSize);
1476 input_builtins.clear(BuiltInSubgroupLocalInvocationId);
1477 input_builtins.clear(BuiltInSubgroupEqMask);
1478 input_builtins.clear(BuiltInSubgroupLtMask);
1479 input_builtins.clear(BuiltInSubgroupLeMask);
1480 input_builtins.clear(BuiltInSubgroupGtMask);
1481 input_builtins.clear(BuiltInSubgroupGeMask);
1482
1483 if (!input_variables.empty() || !input_builtins.empty())
1484 {
1485 require_input = true;
1486 statement("struct SPIRV_Cross_Input");
1487
1488 begin_scope();
1489 sort(input_variables.begin(), input_variables.end(), variable_compare);
1490 for (auto &var : input_variables)
1491 {
1492 if (var.block)
1493 emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_inputs);
1494 else
1495 emit_interface_block_in_struct(*var.var, active_inputs);
1496 }
1497 emit_builtin_inputs_in_struct();
1498 end_scope_decl();
1499 statement("");
1500 }
1501
1502 if (!output_variables.empty() || !active_output_builtins.empty())
1503 {
1504 require_output = true;
1505 statement("struct SPIRV_Cross_Output");
1506
1507 begin_scope();
1508 sort(output_variables.begin(), output_variables.end(), variable_compare);
1509 for (auto &var : output_variables)
1510 {
1511 if (var.block)
1512 emit_interface_block_member_in_struct(*var.var, var.block_member_index, var.location, active_outputs);
1513 else
1514 emit_interface_block_in_struct(*var.var, active_outputs);
1515 }
1516 emit_builtin_outputs_in_struct();
1517 end_scope_decl();
1518 statement("");
1519 }
1520
1521 // Global variables.
1522 for (auto global : global_variables)
1523 {
1524 auto &var = get<SPIRVariable>(global);
1525 if (is_hidden_variable(var, true))
1526 continue;
1527
1528 if (var.storage != StorageClassOutput)
1529 {
1530 if (!variable_is_lut(var))
1531 {
1532 add_resource_name(var.self);
1533
1534 const char *storage = nullptr;
1535 switch (var.storage)
1536 {
1537 case StorageClassWorkgroup:
1538 storage = "groupshared";
1539 break;
1540
1541 default:
1542 storage = "static";
1543 break;
1544 }
1545
1546 string initializer;
1547 if (options.force_zero_initialized_variables && var.storage == StorageClassPrivate &&
1548 !var.initializer && !var.static_expression && type_can_zero_initialize(get_variable_data_type(var)))
1549 {
1550 initializer = join(" = ", to_zero_initialized_expression(get_variable_data_type_id(var)));
1551 }
1552 statement(storage, " ", variable_decl(var), initializer, ";");
1553
1554 emitted = true;
1555 }
1556 }
1557 }
1558
1559 if (emitted)
1560 statement("");
1561
1562 declare_undefined_values();
1563
1564 if (requires_op_fmod)
1565 {
1566 static const char *types[] = {
1567 "float",
1568 "float2",
1569 "float3",
1570 "float4",
1571 };
1572
1573 for (auto &type : types)
1574 {
1575 statement(type, " mod(", type, " x, ", type, " y)");
1576 begin_scope();
1577 statement("return x - y * floor(x / y);");
1578 end_scope();
1579 statement("");
1580 }
1581 }
1582
1583 emit_texture_size_variants(required_texture_size_variants.srv, "4", false, "");
1584 for (uint32_t norm = 0; norm < 3; norm++)
1585 {
1586 for (uint32_t comp = 0; comp < 4; comp++)
1587 {
1588 static const char *qualifiers[] = { "", "unorm ", "snorm " };
1589 static const char *vecsizes[] = { "", "2", "3", "4" };
1590 emit_texture_size_variants(required_texture_size_variants.uav[norm][comp], vecsizes[comp], true,
1591 qualifiers[norm]);
1592 }
1593 }
1594
1595 if (requires_fp16_packing)
1596 {
1597 // HLSL does not pack into a single word sadly :(
1598 statement("uint spvPackHalf2x16(float2 value)");
1599 begin_scope();
1600 statement("uint2 Packed = f32tof16(value);");
1601 statement("return Packed.x | (Packed.y << 16);");
1602 end_scope();
1603 statement("");
1604
1605 statement("float2 spvUnpackHalf2x16(uint value)");
1606 begin_scope();
1607 statement("return f16tof32(uint2(value & 0xffff, value >> 16));");
1608 end_scope();
1609 statement("");
1610 }
1611
1612 if (requires_uint2_packing)
1613 {
1614 statement("uint64_t spvPackUint2x32(uint2 value)");
1615 begin_scope();
1616 statement("return (uint64_t(value.y) << 32) | uint64_t(value.x);");
1617 end_scope();
1618 statement("");
1619
1620 statement("uint2 spvUnpackUint2x32(uint64_t value)");
1621 begin_scope();
1622 statement("uint2 Unpacked;");
1623 statement("Unpacked.x = uint(value & 0xffffffff);");
1624 statement("Unpacked.y = uint(value >> 32);");
1625 statement("return Unpacked;");
1626 end_scope();
1627 statement("");
1628 }
1629
1630 if (requires_explicit_fp16_packing)
1631 {
1632 // HLSL does not pack into a single word sadly :(
1633 statement("uint spvPackFloat2x16(min16float2 value)");
1634 begin_scope();
1635 statement("uint2 Packed = f32tof16(value);");
1636 statement("return Packed.x | (Packed.y << 16);");
1637 end_scope();
1638 statement("");
1639
1640 statement("min16float2 spvUnpackFloat2x16(uint value)");
1641 begin_scope();
1642 statement("return min16float2(f16tof32(uint2(value & 0xffff, value >> 16)));");
1643 end_scope();
1644 statement("");
1645 }
1646
1647 // HLSL does not seem to have builtins for these operation, so roll them by hand ...
1648 if (requires_unorm8_packing)
1649 {
1650 statement("uint spvPackUnorm4x8(float4 value)");
1651 begin_scope();
1652 statement("uint4 Packed = uint4(round(saturate(value) * 255.0));");
1653 statement("return Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24);");
1654 end_scope();
1655 statement("");
1656
1657 statement("float4 spvUnpackUnorm4x8(uint value)");
1658 begin_scope();
1659 statement("uint4 Packed = uint4(value & 0xff, (value >> 8) & 0xff, (value >> 16) & 0xff, value >> 24);");
1660 statement("return float4(Packed) / 255.0;");
1661 end_scope();
1662 statement("");
1663 }
1664
1665 if (requires_snorm8_packing)
1666 {
1667 statement("uint spvPackSnorm4x8(float4 value)");
1668 begin_scope();
1669 statement("int4 Packed = int4(round(clamp(value, -1.0, 1.0) * 127.0)) & 0xff;");
1670 statement("return uint(Packed.x | (Packed.y << 8) | (Packed.z << 16) | (Packed.w << 24));");
1671 end_scope();
1672 statement("");
1673
1674 statement("float4 spvUnpackSnorm4x8(uint value)");
1675 begin_scope();
1676 statement("int SignedValue = int(value);");
1677 statement("int4 Packed = int4(SignedValue << 24, SignedValue << 16, SignedValue << 8, SignedValue) >> 24;");
1678 statement("return clamp(float4(Packed) / 127.0, -1.0, 1.0);");
1679 end_scope();
1680 statement("");
1681 }
1682
1683 if (requires_unorm16_packing)
1684 {
1685 statement("uint spvPackUnorm2x16(float2 value)");
1686 begin_scope();
1687 statement("uint2 Packed = uint2(round(saturate(value) * 65535.0));");
1688 statement("return Packed.x | (Packed.y << 16);");
1689 end_scope();
1690 statement("");
1691
1692 statement("float2 spvUnpackUnorm2x16(uint value)");
1693 begin_scope();
1694 statement("uint2 Packed = uint2(value & 0xffff, value >> 16);");
1695 statement("return float2(Packed) / 65535.0;");
1696 end_scope();
1697 statement("");
1698 }
1699
1700 if (requires_snorm16_packing)
1701 {
1702 statement("uint spvPackSnorm2x16(float2 value)");
1703 begin_scope();
1704 statement("int2 Packed = int2(round(clamp(value, -1.0, 1.0) * 32767.0)) & 0xffff;");
1705 statement("return uint(Packed.x | (Packed.y << 16));");
1706 end_scope();
1707 statement("");
1708
1709 statement("float2 spvUnpackSnorm2x16(uint value)");
1710 begin_scope();
1711 statement("int SignedValue = int(value);");
1712 statement("int2 Packed = int2(SignedValue << 16, SignedValue) >> 16;");
1713 statement("return clamp(float2(Packed) / 32767.0, -1.0, 1.0);");
1714 end_scope();
1715 statement("");
1716 }
1717
1718 if (requires_bitfield_insert)
1719 {
1720 static const char *types[] = { "uint", "uint2", "uint3", "uint4" };
1721 for (auto &type : types)
1722 {
1723 statement(type, " spvBitfieldInsert(", type, " Base, ", type, " Insert, uint Offset, uint Count)");
1724 begin_scope();
1725 statement("uint Mask = Count == 32 ? 0xffffffff : (((1u << Count) - 1) << (Offset & 31));");
1726 statement("return (Base & ~Mask) | ((Insert << Offset) & Mask);");
1727 end_scope();
1728 statement("");
1729 }
1730 }
1731
1732 if (requires_bitfield_extract)
1733 {
1734 static const char *unsigned_types[] = { "uint", "uint2", "uint3", "uint4" };
1735 for (auto &type : unsigned_types)
1736 {
1737 statement(type, " spvBitfieldUExtract(", type, " Base, uint Offset, uint Count)");
1738 begin_scope();
1739 statement("uint Mask = Count == 32 ? 0xffffffff : ((1 << Count) - 1);");
1740 statement("return (Base >> Offset) & Mask;");
1741 end_scope();
1742 statement("");
1743 }
1744
1745 // In this overload, we will have to do sign-extension, which we will emulate by shifting up and down.
1746 static const char *signed_types[] = { "int", "int2", "int3", "int4" };
1747 for (auto &type : signed_types)
1748 {
1749 statement(type, " spvBitfieldSExtract(", type, " Base, int Offset, int Count)");
1750 begin_scope();
1751 statement("int Mask = Count == 32 ? -1 : ((1 << Count) - 1);");
1752 statement(type, " Masked = (Base >> Offset) & Mask;");
1753 statement("int ExtendShift = (32 - Count) & 31;");
1754 statement("return (Masked << ExtendShift) >> ExtendShift;");
1755 end_scope();
1756 statement("");
1757 }
1758 }
1759
1760 if (requires_inverse_2x2)
1761 {
1762 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1763 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1764 statement("float2x2 spvInverse(float2x2 m)");
1765 begin_scope();
1766 statement("float2x2 adj; // The adjoint matrix (inverse after dividing by determinant)");
1767 statement_no_indent("");
1768 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1769 statement("adj[0][0] = m[1][1];");
1770 statement("adj[0][1] = -m[0][1];");
1771 statement_no_indent("");
1772 statement("adj[1][0] = -m[1][0];");
1773 statement("adj[1][1] = m[0][0];");
1774 statement_no_indent("");
1775 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1776 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]);");
1777 statement_no_indent("");
1778 statement("// Divide the classical adjoint matrix by the determinant.");
1779 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1780 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1781 end_scope();
1782 statement("");
1783 }
1784
1785 if (requires_inverse_3x3)
1786 {
1787 statement("// Returns the determinant of a 2x2 matrix.");
1788 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1789 begin_scope();
1790 statement("return a1 * b2 - b1 * a2;");
1791 end_scope();
1792 statement_no_indent("");
1793 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1794 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1795 statement("float3x3 spvInverse(float3x3 m)");
1796 begin_scope();
1797 statement("float3x3 adj; // The adjoint matrix (inverse after dividing by determinant)");
1798 statement_no_indent("");
1799 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1800 statement("adj[0][0] = spvDet2x2(m[1][1], m[1][2], m[2][1], m[2][2]);");
1801 statement("adj[0][1] = -spvDet2x2(m[0][1], m[0][2], m[2][1], m[2][2]);");
1802 statement("adj[0][2] = spvDet2x2(m[0][1], m[0][2], m[1][1], m[1][2]);");
1803 statement_no_indent("");
1804 statement("adj[1][0] = -spvDet2x2(m[1][0], m[1][2], m[2][0], m[2][2]);");
1805 statement("adj[1][1] = spvDet2x2(m[0][0], m[0][2], m[2][0], m[2][2]);");
1806 statement("adj[1][2] = -spvDet2x2(m[0][0], m[0][2], m[1][0], m[1][2]);");
1807 statement_no_indent("");
1808 statement("adj[2][0] = spvDet2x2(m[1][0], m[1][1], m[2][0], m[2][1]);");
1809 statement("adj[2][1] = -spvDet2x2(m[0][0], m[0][1], m[2][0], m[2][1]);");
1810 statement("adj[2][2] = spvDet2x2(m[0][0], m[0][1], m[1][0], m[1][1]);");
1811 statement_no_indent("");
1812 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1813 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]);");
1814 statement_no_indent("");
1815 statement("// Divide the classical adjoint matrix by the determinant.");
1816 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1817 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1818 end_scope();
1819 statement("");
1820 }
1821
1822 if (requires_inverse_4x4)
1823 {
1824 if (!requires_inverse_3x3)
1825 {
1826 statement("// Returns the determinant of a 2x2 matrix.");
1827 statement("float spvDet2x2(float a1, float a2, float b1, float b2)");
1828 begin_scope();
1829 statement("return a1 * b2 - b1 * a2;");
1830 end_scope();
1831 statement("");
1832 }
1833
1834 statement("// Returns the determinant of a 3x3 matrix.");
1835 statement("float spvDet3x3(float a1, float a2, float a3, float b1, float b2, float b3, float c1, "
1836 "float c2, float c3)");
1837 begin_scope();
1838 statement("return a1 * spvDet2x2(b2, b3, c2, c3) - b1 * spvDet2x2(a2, a3, c2, c3) + c1 * "
1839 "spvDet2x2(a2, a3, "
1840 "b2, b3);");
1841 end_scope();
1842 statement_no_indent("");
1843 statement("// Returns the inverse of a matrix, by using the algorithm of calculating the classical");
1844 statement("// adjoint and dividing by the determinant. The contents of the matrix are changed.");
1845 statement("float4x4 spvInverse(float4x4 m)");
1846 begin_scope();
1847 statement("float4x4 adj; // The adjoint matrix (inverse after dividing by determinant)");
1848 statement_no_indent("");
1849 statement("// Create the transpose of the cofactors, as the classical adjoint of the matrix.");
1850 statement(
1851 "adj[0][0] = spvDet3x3(m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1852 "m[3][3]);");
1853 statement(
1854 "adj[0][1] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[2][1], m[2][2], m[2][3], m[3][1], m[3][2], "
1855 "m[3][3]);");
1856 statement(
1857 "adj[0][2] = spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[3][1], m[3][2], "
1858 "m[3][3]);");
1859 statement(
1860 "adj[0][3] = -spvDet3x3(m[0][1], m[0][2], m[0][3], m[1][1], m[1][2], m[1][3], m[2][1], m[2][2], "
1861 "m[2][3]);");
1862 statement_no_indent("");
1863 statement(
1864 "adj[1][0] = -spvDet3x3(m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1865 "m[3][3]);");
1866 statement(
1867 "adj[1][1] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[2][0], m[2][2], m[2][3], m[3][0], m[3][2], "
1868 "m[3][3]);");
1869 statement(
1870 "adj[1][2] = -spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[3][0], m[3][2], "
1871 "m[3][3]);");
1872 statement(
1873 "adj[1][3] = spvDet3x3(m[0][0], m[0][2], m[0][3], m[1][0], m[1][2], m[1][3], m[2][0], m[2][2], "
1874 "m[2][3]);");
1875 statement_no_indent("");
1876 statement(
1877 "adj[2][0] = spvDet3x3(m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1878 "m[3][3]);");
1879 statement(
1880 "adj[2][1] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[2][0], m[2][1], m[2][3], m[3][0], m[3][1], "
1881 "m[3][3]);");
1882 statement(
1883 "adj[2][2] = spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[3][0], m[3][1], "
1884 "m[3][3]);");
1885 statement(
1886 "adj[2][3] = -spvDet3x3(m[0][0], m[0][1], m[0][3], m[1][0], m[1][1], m[1][3], m[2][0], m[2][1], "
1887 "m[2][3]);");
1888 statement_no_indent("");
1889 statement(
1890 "adj[3][0] = -spvDet3x3(m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1891 "m[3][2]);");
1892 statement(
1893 "adj[3][1] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[2][0], m[2][1], m[2][2], m[3][0], m[3][1], "
1894 "m[3][2]);");
1895 statement(
1896 "adj[3][2] = -spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[3][0], m[3][1], "
1897 "m[3][2]);");
1898 statement(
1899 "adj[3][3] = spvDet3x3(m[0][0], m[0][1], m[0][2], m[1][0], m[1][1], m[1][2], m[2][0], m[2][1], "
1900 "m[2][2]);");
1901 statement_no_indent("");
1902 statement("// Calculate the determinant as a combination of the cofactors of the first row.");
1903 statement("float det = (adj[0][0] * m[0][0]) + (adj[0][1] * m[1][0]) + (adj[0][2] * m[2][0]) + (adj[0][3] "
1904 "* m[3][0]);");
1905 statement_no_indent("");
1906 statement("// Divide the classical adjoint matrix by the determinant.");
1907 statement("// If determinant is zero, matrix is not invertable, so leave it unchanged.");
1908 statement("return (det != 0.0f) ? (adj * (1.0f / det)) : m;");
1909 end_scope();
1910 statement("");
1911 }
1912
1913 if (requires_scalar_reflect)
1914 {
1915 // FP16/FP64? No templates in HLSL.
1916 statement("float spvReflect(float i, float n)");
1917 begin_scope();
1918 statement("return i - 2.0 * dot(n, i) * n;");
1919 end_scope();
1920 statement("");
1921 }
1922
1923 if (requires_scalar_refract)
1924 {
1925 // FP16/FP64? No templates in HLSL.
1926 statement("float spvRefract(float i, float n, float eta)");
1927 begin_scope();
1928 statement("float NoI = n * i;");
1929 statement("float NoI2 = NoI * NoI;");
1930 statement("float k = 1.0 - eta * eta * (1.0 - NoI2);");
1931 statement("if (k < 0.0)");
1932 begin_scope();
1933 statement("return 0.0;");
1934 end_scope();
1935 statement("else");
1936 begin_scope();
1937 statement("return eta * i - (eta * NoI + sqrt(k)) * n;");
1938 end_scope();
1939 end_scope();
1940 statement("");
1941 }
1942
1943 if (requires_scalar_faceforward)
1944 {
1945 // FP16/FP64? No templates in HLSL.
1946 statement("float spvFaceForward(float n, float i, float nref)");
1947 begin_scope();
1948 statement("return i * nref < 0.0 ? n : -n;");
1949 end_scope();
1950 statement("");
1951 }
1952
1953 for (TypeID type_id : composite_selection_workaround_types)
1954 {
1955 // Need out variable since HLSL does not support returning arrays.
1956 auto &type = get<SPIRType>(type_id);
1957 auto type_str = type_to_glsl(type);
1958 auto type_arr_str = type_to_array_glsl(type);
1959 statement("void spvSelectComposite(out ", type_str, " out_value", type_arr_str, ", bool cond, ",
1960 type_str, " true_val", type_arr_str, ", ",
1961 type_str, " false_val", type_arr_str, ")");
1962 begin_scope();
1963 statement("if (cond)");
1964 begin_scope();
1965 statement("out_value = true_val;");
1966 end_scope();
1967 statement("else");
1968 begin_scope();
1969 statement("out_value = false_val;");
1970 end_scope();
1971 end_scope();
1972 statement("");
1973 }
1974}
1975
1976void CompilerHLSL::emit_texture_size_variants(uint64_t variant_mask, const char *vecsize_qualifier, bool uav,
1977 const char *type_qualifier)
1978{
1979 if (variant_mask == 0)
1980 return;
1981
1982 static const char *types[QueryTypeCount] = { "float", "int", "uint" };
1983 static const char *dims[QueryDimCount] = { "Texture1D", "Texture1DArray", "Texture2D", "Texture2DArray",
1984 "Texture3D", "Buffer", "TextureCube", "TextureCubeArray",
1985 "Texture2DMS", "Texture2DMSArray" };
1986
1987 static const bool has_lod[QueryDimCount] = { true, true, true, true, true, false, true, true, false, false };
1988
1989 static const char *ret_types[QueryDimCount] = {
1990 "uint", "uint2", "uint2", "uint3", "uint3", "uint", "uint2", "uint3", "uint2", "uint3",
1991 };
1992
1993 static const uint32_t return_arguments[QueryDimCount] = {
1994 1, 2, 2, 3, 3, 1, 2, 3, 2, 3,
1995 };
1996
1997 for (uint32_t index = 0; index < QueryDimCount; index++)
1998 {
1999 for (uint32_t type_index = 0; type_index < QueryTypeCount; type_index++)
2000 {
2001 uint32_t bit = 16 * type_index + index;
2002 uint64_t mask = 1ull << bit;
2003
2004 if ((variant_mask & mask) == 0)
2005 continue;
2006
2007 statement(ret_types[index], " spv", (uav ? "Image" : "Texture"), "Size(", (uav ? "RW" : ""),
2008 dims[index], "<", type_qualifier, types[type_index], vecsize_qualifier, "> Tex, ",
2009 (uav ? "" : "uint Level, "), "out uint Param)");
2010 begin_scope();
2011 statement(ret_types[index], " ret;");
2012 switch (return_arguments[index])
2013 {
2014 case 1:
2015 if (has_lod[index] && !uav)
2016 statement("Tex.GetDimensions(Level, ret.x, Param);");
2017 else
2018 {
2019 statement("Tex.GetDimensions(ret.x);");
2020 statement("Param = 0u;");
2021 }
2022 break;
2023 case 2:
2024 if (has_lod[index] && !uav)
2025 statement("Tex.GetDimensions(Level, ret.x, ret.y, Param);");
2026 else if (!uav)
2027 statement("Tex.GetDimensions(ret.x, ret.y, Param);");
2028 else
2029 {
2030 statement("Tex.GetDimensions(ret.x, ret.y);");
2031 statement("Param = 0u;");
2032 }
2033 break;
2034 case 3:
2035 if (has_lod[index] && !uav)
2036 statement("Tex.GetDimensions(Level, ret.x, ret.y, ret.z, Param);");
2037 else if (!uav)
2038 statement("Tex.GetDimensions(ret.x, ret.y, ret.z, Param);");
2039 else
2040 {
2041 statement("Tex.GetDimensions(ret.x, ret.y, ret.z);");
2042 statement("Param = 0u;");
2043 }
2044 break;
2045 }
2046
2047 statement("return ret;");
2048 end_scope();
2049 statement("");
2050 }
2051 }
2052}
2053
2054string CompilerHLSL::layout_for_member(const SPIRType &type, uint32_t index)
2055{
2056 auto &flags = get_member_decoration_bitset(type.self, index);
2057
2058 // HLSL can emit row_major or column_major decoration in any struct.
2059 // Do not try to merge combined decorations for children like in GLSL.
2060
2061 // Flip the convention. HLSL is a bit odd in that the memory layout is column major ... but the language API is "row-major".
2062 // The way to deal with this is to multiply everything in inverse order, and reverse the memory layout.
2063 if (flags.get(DecorationColMajor))
2064 return "row_major ";
2065 else if (flags.get(DecorationRowMajor))
2066 return "column_major ";
2067
2068 return "";
2069}
2070
2071void CompilerHLSL::emit_struct_member(const SPIRType &type, uint32_t member_type_id, uint32_t index,
2072 const string &qualifier, uint32_t base_offset)
2073{
2074 auto &membertype = get<SPIRType>(member_type_id);
2075
2076 Bitset memberflags;
2077 auto &memb = ir.meta[type.self].members;
2078 if (index < memb.size())
2079 memberflags = memb[index].decoration_flags;
2080
2081 string packing_offset;
2082 bool is_push_constant = type.storage == StorageClassPushConstant;
2083
2084 if ((has_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset) || is_push_constant) &&
2085 has_member_decoration(type.self, index, DecorationOffset))
2086 {
2087 uint32_t offset = memb[index].offset - base_offset;
2088 if (offset & 3)
2089 SPIRV_CROSS_THROW("Cannot pack on tighter bounds than 4 bytes in HLSL.");
2090
2091 static const char *packing_swizzle[] = { "", ".y", ".z", ".w" };
2092 packing_offset = join(" : packoffset(c", offset / 16, packing_swizzle[(offset & 15) >> 2], ")");
2093 }
2094
2095 statement(layout_for_member(type, index), qualifier,
2096 variable_decl(membertype, to_member_name(type, index)), packing_offset, ";");
2097}
2098
2099void CompilerHLSL::emit_buffer_block(const SPIRVariable &var)
2100{
2101 auto &type = get<SPIRType>(var.basetype);
2102
2103 bool is_uav = var.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock);
2104
2105 if (is_uav)
2106 {
2107 Bitset flags = ir.get_buffer_block_flags(var);
2108 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
2109 bool is_coherent = flags.get(DecorationCoherent) && !is_readonly;
2110 bool is_interlocked = interlocked_resources.count(var.self) > 0;
2111 const char *type_name = "ByteAddressBuffer ";
2112 if (!is_readonly)
2113 type_name = is_interlocked ? "RasterizerOrderedByteAddressBuffer " : "RWByteAddressBuffer ";
2114 add_resource_name(var.self);
2115 statement(is_coherent ? "globallycoherent " : "", type_name, to_name(var.self), type_to_array_glsl(type),
2116 to_resource_binding(var), ";");
2117 }
2118 else
2119 {
2120 if (type.array.empty())
2121 {
2122 // Flatten the top-level struct so we can use packoffset,
2123 // this restriction is similar to GLSL where layout(offset) is not possible on sub-structs.
2124 flattened_structs[var.self] = false;
2125
2126 // Prefer the block name if possible.
2127 auto buffer_name = to_name(type.self, false);
2128 if (ir.meta[type.self].decoration.alias.empty() ||
2129 resource_names.find(buffer_name) != end(resource_names) ||
2130 block_names.find(buffer_name) != end(block_names))
2131 {
2132 buffer_name = get_block_fallback_name(var.self);
2133 }
2134
2135 add_variable(block_names, resource_names, buffer_name);
2136
2137 // If for some reason buffer_name is an illegal name, make a final fallback to a workaround name.
2138 // This cannot conflict with anything else, so we're safe now.
2139 if (buffer_name.empty())
2140 buffer_name = join("_", get<SPIRType>(var.basetype).self, "_", var.self);
2141
2142 uint32_t failed_index = 0;
2143 if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index))
2144 set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2145 else
2146 {
2147 SPIRV_CROSS_THROW(join("cbuffer ID ", var.self, " (name: ", buffer_name, "), member index ",
2148 failed_index, " (name: ", to_member_name(type, failed_index),
2149 ") cannot be expressed with either HLSL packing layout or packoffset."));
2150 }
2151
2152 block_names.insert(buffer_name);
2153
2154 // Save for post-reflection later.
2155 declared_block_names[var.self] = buffer_name;
2156
2157 type.member_name_cache.clear();
2158 // var.self can be used as a backup name for the block name,
2159 // so we need to make sure we don't disturb the name here on a recompile.
2160 // It will need to be reset if we have to recompile.
2161 preserve_alias_on_reset(var.self);
2162 add_resource_name(var.self);
2163 statement("cbuffer ", buffer_name, to_resource_binding(var));
2164 begin_scope();
2165
2166 uint32_t i = 0;
2167 for (auto &member : type.member_types)
2168 {
2169 add_member_name(type, i);
2170 auto backup_name = get_member_name(type.self, i);
2171 auto member_name = to_member_name(type, i);
2172 member_name = join(to_name(var.self), "_", member_name);
2173 ParsedIR::sanitize_underscores(member_name);
2174 set_member_name(type.self, i, member_name);
2175 emit_struct_member(type, member, i, "");
2176 set_member_name(type.self, i, backup_name);
2177 i++;
2178 }
2179
2180 end_scope_decl();
2181 statement("");
2182 }
2183 else
2184 {
2185 if (hlsl_options.shader_model < 51)
2186 SPIRV_CROSS_THROW(
2187 "Need ConstantBuffer<T> to use arrays of UBOs, but this is only supported in SM 5.1.");
2188
2189 add_resource_name(type.self);
2190 add_resource_name(var.self);
2191
2192 // ConstantBuffer<T> does not support packoffset, so it is unuseable unless everything aligns as we expect.
2193 uint32_t failed_index = 0;
2194 if (!buffer_is_packing_standard(type, BufferPackingHLSLCbuffer, &failed_index))
2195 {
2196 SPIRV_CROSS_THROW(join("HLSL ConstantBuffer<T> ID ", var.self, " (name: ", to_name(type.self),
2197 "), member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2198 ") cannot be expressed with normal HLSL packing rules."));
2199 }
2200
2201 emit_struct(get<SPIRType>(type.self));
2202 statement("ConstantBuffer<", to_name(type.self), "> ", to_name(var.self), type_to_array_glsl(type),
2203 to_resource_binding(var), ";");
2204 }
2205 }
2206}
2207
2208void CompilerHLSL::emit_push_constant_block(const SPIRVariable &var)
2209{
2210 if (root_constants_layout.empty())
2211 {
2212 emit_buffer_block(var);
2213 }
2214 else
2215 {
2216 for (const auto &layout : root_constants_layout)
2217 {
2218 auto &type = get<SPIRType>(var.basetype);
2219
2220 uint32_t failed_index = 0;
2221 if (buffer_is_packing_standard(type, BufferPackingHLSLCbufferPackOffset, &failed_index, layout.start,
2222 layout.end))
2223 set_extended_decoration(type.self, SPIRVCrossDecorationExplicitOffset);
2224 else
2225 {
2226 SPIRV_CROSS_THROW(join("Root constant cbuffer ID ", var.self, " (name: ", to_name(type.self), ")",
2227 ", member index ", failed_index, " (name: ", to_member_name(type, failed_index),
2228 ") cannot be expressed with either HLSL packing layout or packoffset."));
2229 }
2230
2231 flattened_structs[var.self] = false;
2232 type.member_name_cache.clear();
2233 add_resource_name(var.self);
2234 auto &memb = ir.meta[type.self].members;
2235
2236 statement("cbuffer SPIRV_CROSS_RootConstant_", to_name(var.self),
2237 to_resource_register(HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT, 'b', layout.binding, layout.space));
2238 begin_scope();
2239
2240 // Index of the next field in the generated root constant constant buffer
2241 auto constant_index = 0u;
2242
2243 // Iterate over all member of the push constant and check which of the fields
2244 // fit into the given root constant layout.
2245 for (auto i = 0u; i < memb.size(); i++)
2246 {
2247 const auto offset = memb[i].offset;
2248 if (layout.start <= offset && offset < layout.end)
2249 {
2250 const auto &member = type.member_types[i];
2251
2252 add_member_name(type, constant_index);
2253 auto backup_name = get_member_name(type.self, i);
2254 auto member_name = to_member_name(type, i);
2255 member_name = join(to_name(var.self), "_", member_name);
2256 ParsedIR::sanitize_underscores(member_name);
2257 set_member_name(type.self, constant_index, member_name);
2258 emit_struct_member(type, member, i, "", layout.start);
2259 set_member_name(type.self, constant_index, backup_name);
2260
2261 constant_index++;
2262 }
2263 }
2264
2265 end_scope_decl();
2266 }
2267 }
2268}
2269
2270string CompilerHLSL::to_sampler_expression(uint32_t id)
2271{
2272 auto expr = join("_", to_non_uniform_aware_expression(id));
2273 auto index = expr.find_first_of('[');
2274 if (index == string::npos)
2275 {
2276 return expr + "_sampler";
2277 }
2278 else
2279 {
2280 // We have an expression like _ident[array], so we cannot tack on _sampler, insert it inside the string instead.
2281 return expr.insert(index, "_sampler");
2282 }
2283}
2284
2285void CompilerHLSL::emit_sampled_image_op(uint32_t result_type, uint32_t result_id, uint32_t image_id, uint32_t samp_id)
2286{
2287 if (hlsl_options.shader_model >= 40 && combined_image_samplers.empty())
2288 {
2289 set<SPIRCombinedImageSampler>(result_id, result_type, image_id, samp_id);
2290 }
2291 else
2292 {
2293 // Make sure to suppress usage tracking. It is illegal to create temporaries of opaque types.
2294 emit_op(result_type, result_id, to_combined_image_sampler(image_id, samp_id), true, true);
2295 }
2296}
2297
2298string CompilerHLSL::to_func_call_arg(const SPIRFunction::Parameter &arg, uint32_t id)
2299{
2300 string arg_str = CompilerGLSL::to_func_call_arg(arg, id);
2301
2302 if (hlsl_options.shader_model <= 30)
2303 return arg_str;
2304
2305 // Manufacture automatic sampler arg if the arg is a SampledImage texture and we're in modern HLSL.
2306 auto &type = expression_type(id);
2307
2308 // We don't have to consider combined image samplers here via OpSampledImage because
2309 // those variables cannot be passed as arguments to functions.
2310 // Only global SampledImage variables may be used as arguments.
2311 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
2312 arg_str += ", " + to_sampler_expression(id);
2313
2314 return arg_str;
2315}
2316
2317void CompilerHLSL::emit_function_prototype(SPIRFunction &func, const Bitset &return_flags)
2318{
2319 if (func.self != ir.default_entry_point)
2320 add_function_overload(func);
2321
2322 auto &execution = get_entry_point();
2323 // Avoid shadow declarations.
2324 local_variable_names = resource_names;
2325
2326 string decl;
2327
2328 auto &type = get<SPIRType>(func.return_type);
2329 if (type.array.empty())
2330 {
2331 decl += flags_to_qualifiers_glsl(type, return_flags);
2332 decl += type_to_glsl(type);
2333 decl += " ";
2334 }
2335 else
2336 {
2337 // We cannot return arrays in HLSL, so "return" through an out variable.
2338 decl = "void ";
2339 }
2340
2341 if (func.self == ir.default_entry_point)
2342 {
2343 if (execution.model == ExecutionModelVertex)
2344 decl += "vert_main";
2345 else if (execution.model == ExecutionModelFragment)
2346 decl += "frag_main";
2347 else if (execution.model == ExecutionModelGLCompute)
2348 decl += "comp_main";
2349 else
2350 SPIRV_CROSS_THROW("Unsupported execution model.");
2351 processing_entry_point = true;
2352 }
2353 else
2354 decl += to_name(func.self);
2355
2356 decl += "(";
2357 SmallVector<string> arglist;
2358
2359 if (!type.array.empty())
2360 {
2361 // Fake array returns by writing to an out array instead.
2362 string out_argument;
2363 out_argument += "out ";
2364 out_argument += type_to_glsl(type);
2365 out_argument += " ";
2366 out_argument += "spvReturnValue";
2367 out_argument += type_to_array_glsl(type);
2368 arglist.push_back(move(out_argument));
2369 }
2370
2371 for (auto &arg : func.arguments)
2372 {
2373 // Do not pass in separate images or samplers if we're remapping
2374 // to combined image samplers.
2375 if (skip_argument(arg.id))
2376 continue;
2377
2378 // Might change the variable name if it already exists in this function.
2379 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2380 // to use same name for variables.
2381 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2382 add_local_variable_name(arg.id);
2383
2384 arglist.push_back(argument_decl(arg));
2385
2386 // Flatten a combined sampler to two separate arguments in modern HLSL.
2387 auto &arg_type = get<SPIRType>(arg.type);
2388 if (hlsl_options.shader_model > 30 && arg_type.basetype == SPIRType::SampledImage &&
2389 arg_type.image.dim != DimBuffer)
2390 {
2391 // Manufacture automatic sampler arg for SampledImage texture
2392 arglist.push_back(join(is_depth_image(arg_type, arg.id) ? "SamplerComparisonState " : "SamplerState ",
2393 to_sampler_expression(arg.id), type_to_array_glsl(arg_type)));
2394 }
2395
2396 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2397 auto *var = maybe_get<SPIRVariable>(arg.id);
2398 if (var)
2399 var->parameter = &arg;
2400 }
2401
2402 for (auto &arg : func.shadow_arguments)
2403 {
2404 // Might change the variable name if it already exists in this function.
2405 // SPIRV OpName doesn't have any semantic effect, so it's valid for an implementation
2406 // to use same name for variables.
2407 // Since we want to make the GLSL debuggable and somewhat sane, use fallback names for variables which are duplicates.
2408 add_local_variable_name(arg.id);
2409
2410 arglist.push_back(argument_decl(arg));
2411
2412 // Hold a pointer to the parameter so we can invalidate the readonly field if needed.
2413 auto *var = maybe_get<SPIRVariable>(arg.id);
2414 if (var)
2415 var->parameter = &arg;
2416 }
2417
2418 decl += merge(arglist);
2419 decl += ")";
2420 statement(decl);
2421}
2422
2423void CompilerHLSL::emit_hlsl_entry_point()
2424{
2425 SmallVector<string> arguments;
2426
2427 if (require_input)
2428 arguments.push_back("SPIRV_Cross_Input stage_input");
2429
2430 auto &execution = get_entry_point();
2431
2432 switch (execution.model)
2433 {
2434 case ExecutionModelGLCompute:
2435 {
2436 SpecializationConstant wg_x, wg_y, wg_z;
2437 get_work_group_size_specialization_constants(wg_x, wg_y, wg_z);
2438
2439 uint32_t x = execution.workgroup_size.x;
2440 uint32_t y = execution.workgroup_size.y;
2441 uint32_t z = execution.workgroup_size.z;
2442
2443 if (!execution.workgroup_size.constant && execution.flags.get(ExecutionModeLocalSizeId))
2444 {
2445 if (execution.workgroup_size.id_x)
2446 x = get<SPIRConstant>(execution.workgroup_size.id_x).scalar();
2447 if (execution.workgroup_size.id_y)
2448 y = get<SPIRConstant>(execution.workgroup_size.id_y).scalar();
2449 if (execution.workgroup_size.id_z)
2450 z = get<SPIRConstant>(execution.workgroup_size.id_z).scalar();
2451 }
2452
2453 auto x_expr = wg_x.id ? get<SPIRConstant>(wg_x.id).specialization_constant_macro_name : to_string(x);
2454 auto y_expr = wg_y.id ? get<SPIRConstant>(wg_y.id).specialization_constant_macro_name : to_string(y);
2455 auto z_expr = wg_z.id ? get<SPIRConstant>(wg_z.id).specialization_constant_macro_name : to_string(z);
2456
2457 statement("[numthreads(", x_expr, ", ", y_expr, ", ", z_expr, ")]");
2458 break;
2459 }
2460 case ExecutionModelFragment:
2461 if (execution.flags.get(ExecutionModeEarlyFragmentTests))
2462 statement("[earlydepthstencil]");
2463 break;
2464 default:
2465 break;
2466 }
2467
2468 statement(require_output ? "SPIRV_Cross_Output " : "void ", "main(", merge(arguments), ")");
2469 begin_scope();
2470 bool legacy = hlsl_options.shader_model <= 30;
2471
2472 // Copy builtins from entry point arguments to globals.
2473 active_input_builtins.for_each_bit([&](uint32_t i) {
2474 auto builtin = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassInput);
2475 switch (static_cast<BuiltIn>(i))
2476 {
2477 case BuiltInFragCoord:
2478 // VPOS in D3D9 is sampled at integer locations, apply half-pixel offset to be consistent.
2479 // TODO: Do we need an option here? Any reason why a D3D9 shader would be used
2480 // on a D3D10+ system with a different rasterization config?
2481 if (legacy)
2482 statement(builtin, " = stage_input.", builtin, " + float4(0.5f, 0.5f, 0.0f, 0.0f);");
2483 else
2484 {
2485 statement(builtin, " = stage_input.", builtin, ";");
2486 // ZW are undefined in D3D9, only do this fixup here.
2487 statement(builtin, ".w = 1.0 / ", builtin, ".w;");
2488 }
2489 break;
2490
2491 case BuiltInVertexId:
2492 case BuiltInVertexIndex:
2493 case BuiltInInstanceIndex:
2494 // D3D semantics are uint, but shader wants int.
2495 if (hlsl_options.support_nonzero_base_vertex_base_instance)
2496 {
2497 if (static_cast<BuiltIn>(i) == BuiltInInstanceIndex)
2498 statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseInstance;");
2499 else
2500 statement(builtin, " = int(stage_input.", builtin, ") + SPIRV_Cross_BaseVertex;");
2501 }
2502 else
2503 statement(builtin, " = int(stage_input.", builtin, ");");
2504 break;
2505
2506 case BuiltInInstanceId:
2507 // D3D semantics are uint, but shader wants int.
2508 statement(builtin, " = int(stage_input.", builtin, ");");
2509 break;
2510
2511 case BuiltInNumWorkgroups:
2512 case BuiltInPointCoord:
2513 case BuiltInSubgroupSize:
2514 case BuiltInSubgroupLocalInvocationId:
2515 break;
2516
2517 case BuiltInSubgroupEqMask:
2518 // Emulate these ...
2519 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2520 statement("gl_SubgroupEqMask = 1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96));");
2521 statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupEqMask.x = 0;");
2522 statement("if (WaveGetLaneIndex() >= 64 || WaveGetLaneIndex() < 32) gl_SubgroupEqMask.y = 0;");
2523 statement("if (WaveGetLaneIndex() >= 96 || WaveGetLaneIndex() < 64) gl_SubgroupEqMask.z = 0;");
2524 statement("if (WaveGetLaneIndex() < 96) gl_SubgroupEqMask.w = 0;");
2525 break;
2526
2527 case BuiltInSubgroupGeMask:
2528 // Emulate these ...
2529 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2530 statement("gl_SubgroupGeMask = ~((1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u);");
2531 statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupGeMask.x = 0u;");
2532 statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupGeMask.y = 0u;");
2533 statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupGeMask.z = 0u;");
2534 statement("if (WaveGetLaneIndex() < 32) gl_SubgroupGeMask.y = ~0u;");
2535 statement("if (WaveGetLaneIndex() < 64) gl_SubgroupGeMask.z = ~0u;");
2536 statement("if (WaveGetLaneIndex() < 96) gl_SubgroupGeMask.w = ~0u;");
2537 break;
2538
2539 case BuiltInSubgroupGtMask:
2540 // Emulate these ...
2541 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2542 statement("uint gt_lane_index = WaveGetLaneIndex() + 1;");
2543 statement("gl_SubgroupGtMask = ~((1u << (gt_lane_index - uint4(0, 32, 64, 96))) - 1u);");
2544 statement("if (gt_lane_index >= 32) gl_SubgroupGtMask.x = 0u;");
2545 statement("if (gt_lane_index >= 64) gl_SubgroupGtMask.y = 0u;");
2546 statement("if (gt_lane_index >= 96) gl_SubgroupGtMask.z = 0u;");
2547 statement("if (gt_lane_index >= 128) gl_SubgroupGtMask.w = 0u;");
2548 statement("if (gt_lane_index < 32) gl_SubgroupGtMask.y = ~0u;");
2549 statement("if (gt_lane_index < 64) gl_SubgroupGtMask.z = ~0u;");
2550 statement("if (gt_lane_index < 96) gl_SubgroupGtMask.w = ~0u;");
2551 break;
2552
2553 case BuiltInSubgroupLeMask:
2554 // Emulate these ...
2555 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2556 statement("uint le_lane_index = WaveGetLaneIndex() + 1;");
2557 statement("gl_SubgroupLeMask = (1u << (le_lane_index - uint4(0, 32, 64, 96))) - 1u;");
2558 statement("if (le_lane_index >= 32) gl_SubgroupLeMask.x = ~0u;");
2559 statement("if (le_lane_index >= 64) gl_SubgroupLeMask.y = ~0u;");
2560 statement("if (le_lane_index >= 96) gl_SubgroupLeMask.z = ~0u;");
2561 statement("if (le_lane_index >= 128) gl_SubgroupLeMask.w = ~0u;");
2562 statement("if (le_lane_index < 32) gl_SubgroupLeMask.y = 0u;");
2563 statement("if (le_lane_index < 64) gl_SubgroupLeMask.z = 0u;");
2564 statement("if (le_lane_index < 96) gl_SubgroupLeMask.w = 0u;");
2565 break;
2566
2567 case BuiltInSubgroupLtMask:
2568 // Emulate these ...
2569 // No 64-bit in HLSL, so have to do it in 32-bit and unroll.
2570 statement("gl_SubgroupLtMask = (1u << (WaveGetLaneIndex() - uint4(0, 32, 64, 96))) - 1u;");
2571 statement("if (WaveGetLaneIndex() >= 32) gl_SubgroupLtMask.x = ~0u;");
2572 statement("if (WaveGetLaneIndex() >= 64) gl_SubgroupLtMask.y = ~0u;");
2573 statement("if (WaveGetLaneIndex() >= 96) gl_SubgroupLtMask.z = ~0u;");
2574 statement("if (WaveGetLaneIndex() < 32) gl_SubgroupLtMask.y = 0u;");
2575 statement("if (WaveGetLaneIndex() < 64) gl_SubgroupLtMask.z = 0u;");
2576 statement("if (WaveGetLaneIndex() < 96) gl_SubgroupLtMask.w = 0u;");
2577 break;
2578
2579 case BuiltInClipDistance:
2580 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2581 statement("gl_ClipDistance[", clip, "] = stage_input.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3],
2582 ";");
2583 break;
2584
2585 case BuiltInCullDistance:
2586 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2587 statement("gl_CullDistance[", cull, "] = stage_input.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3],
2588 ";");
2589 break;
2590
2591 default:
2592 statement(builtin, " = stage_input.", builtin, ";");
2593 break;
2594 }
2595 });
2596
2597 // Copy from stage input struct to globals.
2598 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2599 auto &type = this->get<SPIRType>(var.basetype);
2600 bool block = has_decoration(type.self, DecorationBlock);
2601
2602 if (var.storage != StorageClassInput)
2603 return;
2604
2605 bool need_matrix_unroll = var.storage == StorageClassInput && execution.model == ExecutionModelVertex;
2606
2607 if (!var.remapped_variable && type.pointer && !is_builtin_variable(var) &&
2608 interface_variable_exists_in_entry_point(var.self))
2609 {
2610 if (block)
2611 {
2612 auto type_name = to_name(type.self);
2613 auto var_name = to_name(var.self);
2614 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2615 {
2616 auto mbr_name = to_member_name(type, mbr_idx);
2617 auto flat_name = join(type_name, "_", mbr_name);
2618 statement(var_name, ".", mbr_name, " = stage_input.", flat_name, ";");
2619 }
2620 }
2621 else
2622 {
2623 auto name = to_name(var.self);
2624 auto &mtype = this->get<SPIRType>(var.basetype);
2625 if (need_matrix_unroll && mtype.columns > 1)
2626 {
2627 // Unroll matrices.
2628 for (uint32_t col = 0; col < mtype.columns; col++)
2629 statement(name, "[", col, "] = stage_input.", name, "_", col, ";");
2630 }
2631 else
2632 {
2633 statement(name, " = stage_input.", name, ";");
2634 }
2635 }
2636 }
2637 });
2638
2639 // Run the shader.
2640 if (execution.model == ExecutionModelVertex)
2641 statement("vert_main();");
2642 else if (execution.model == ExecutionModelFragment)
2643 statement("frag_main();");
2644 else if (execution.model == ExecutionModelGLCompute)
2645 statement("comp_main();");
2646 else
2647 SPIRV_CROSS_THROW("Unsupported shader stage.");
2648
2649 // Copy stage outputs.
2650 if (require_output)
2651 {
2652 statement("SPIRV_Cross_Output stage_output;");
2653
2654 // Copy builtins from globals to return struct.
2655 active_output_builtins.for_each_bit([&](uint32_t i) {
2656 // PointSize doesn't exist in HLSL.
2657 if (i == BuiltInPointSize)
2658 return;
2659
2660 switch (static_cast<BuiltIn>(i))
2661 {
2662 case BuiltInClipDistance:
2663 for (uint32_t clip = 0; clip < clip_distance_count; clip++)
2664 statement("stage_output.gl_ClipDistance", clip / 4, ".", "xyzw"[clip & 3], " = gl_ClipDistance[",
2665 clip, "];");
2666 break;
2667
2668 case BuiltInCullDistance:
2669 for (uint32_t cull = 0; cull < cull_distance_count; cull++)
2670 statement("stage_output.gl_CullDistance", cull / 4, ".", "xyzw"[cull & 3], " = gl_CullDistance[",
2671 cull, "];");
2672 break;
2673
2674 default:
2675 {
2676 auto builtin_expr = builtin_to_glsl(static_cast<BuiltIn>(i), StorageClassOutput);
2677 statement("stage_output.", builtin_expr, " = ", builtin_expr, ";");
2678 break;
2679 }
2680 }
2681 });
2682
2683 ir.for_each_typed_id<SPIRVariable>([&](uint32_t, SPIRVariable &var) {
2684 auto &type = this->get<SPIRType>(var.basetype);
2685 bool block = has_decoration(type.self, DecorationBlock);
2686
2687 if (var.storage != StorageClassOutput)
2688 return;
2689
2690 if (!var.remapped_variable && type.pointer &&
2691 !is_builtin_variable(var) &&
2692 interface_variable_exists_in_entry_point(var.self))
2693 {
2694 if (block)
2695 {
2696 // I/O blocks need to flatten output.
2697 auto type_name = to_name(type.self);
2698 auto var_name = to_name(var.self);
2699 for (uint32_t mbr_idx = 0; mbr_idx < uint32_t(type.member_types.size()); mbr_idx++)
2700 {
2701 auto mbr_name = to_member_name(type, mbr_idx);
2702 auto flat_name = join(type_name, "_", mbr_name);
2703 statement("stage_output.", flat_name, " = ", var_name, ".", mbr_name, ";");
2704 }
2705 }
2706 else
2707 {
2708 auto name = to_name(var.self);
2709
2710 if (legacy && execution.model == ExecutionModelFragment)
2711 {
2712 string output_filler;
2713 for (uint32_t size = type.vecsize; size < 4; ++size)
2714 output_filler += ", 0.0";
2715
2716 statement("stage_output.", name, " = float4(", name, output_filler, ");");
2717 }
2718 else
2719 {
2720 statement("stage_output.", name, " = ", name, ";");
2721 }
2722 }
2723 }
2724 });
2725
2726 statement("return stage_output;");
2727 }
2728
2729 end_scope();
2730}
2731
2732void CompilerHLSL::emit_fixup()
2733{
2734 if (is_vertex_like_shader())
2735 {
2736 // Do various mangling on the gl_Position.
2737 if (hlsl_options.shader_model <= 30)
2738 {
2739 statement("gl_Position.x = gl_Position.x - gl_HalfPixel.x * "
2740 "gl_Position.w;");
2741 statement("gl_Position.y = gl_Position.y + gl_HalfPixel.y * "
2742 "gl_Position.w;");
2743 }
2744
2745 if (options.vertex.flip_vert_y)
2746 statement("gl_Position.y = -gl_Position.y;");
2747 if (options.vertex.fixup_clipspace)
2748 statement("gl_Position.z = (gl_Position.z + gl_Position.w) * 0.5;");
2749 }
2750}
2751
2752void CompilerHLSL::emit_texture_op(const Instruction &i, bool sparse)
2753{
2754 if (sparse)
2755 SPIRV_CROSS_THROW("Sparse feedback not yet supported in HLSL.");
2756
2757 auto *ops = stream(i);
2758 auto op = static_cast<Op>(i.op);
2759 uint32_t length = i.length;
2760
2761 SmallVector<uint32_t> inherited_expressions;
2762
2763 uint32_t result_type = ops[0];
2764 uint32_t id = ops[1];
2765 VariableID img = ops[2];
2766 uint32_t coord = ops[3];
2767 uint32_t dref = 0;
2768 uint32_t comp = 0;
2769 bool gather = false;
2770 bool proj = false;
2771 const uint32_t *opt = nullptr;
2772 auto *combined_image = maybe_get<SPIRCombinedImageSampler>(img);
2773
2774 if (combined_image && has_decoration(img, DecorationNonUniform))
2775 {
2776 set_decoration(combined_image->image, DecorationNonUniform);
2777 set_decoration(combined_image->sampler, DecorationNonUniform);
2778 }
2779
2780 auto img_expr = to_non_uniform_aware_expression(combined_image ? combined_image->image : img);
2781
2782 inherited_expressions.push_back(coord);
2783
2784 switch (op)
2785 {
2786 case OpImageSampleDrefImplicitLod:
2787 case OpImageSampleDrefExplicitLod:
2788 dref = ops[4];
2789 opt = &ops[5];
2790 length -= 5;
2791 break;
2792
2793 case OpImageSampleProjDrefImplicitLod:
2794 case OpImageSampleProjDrefExplicitLod:
2795 dref = ops[4];
2796 proj = true;
2797 opt = &ops[5];
2798 length -= 5;
2799 break;
2800
2801 case OpImageDrefGather:
2802 dref = ops[4];
2803 opt = &ops[5];
2804 gather = true;
2805 length -= 5;
2806 break;
2807
2808 case OpImageGather:
2809 comp = ops[4];
2810 opt = &ops[5];
2811 gather = true;
2812 length -= 5;
2813 break;
2814
2815 case OpImageSampleProjImplicitLod:
2816 case OpImageSampleProjExplicitLod:
2817 opt = &ops[4];
2818 length -= 4;
2819 proj = true;
2820 break;
2821
2822 case OpImageQueryLod:
2823 opt = &ops[4];
2824 length -= 4;
2825 break;
2826
2827 default:
2828 opt = &ops[4];
2829 length -= 4;
2830 break;
2831 }
2832
2833 auto &imgtype = expression_type(img);
2834 uint32_t coord_components = 0;
2835 switch (imgtype.image.dim)
2836 {
2837 case spv::Dim1D:
2838 coord_components = 1;
2839 break;
2840 case spv::Dim2D:
2841 coord_components = 2;
2842 break;
2843 case spv::Dim3D:
2844 coord_components = 3;
2845 break;
2846 case spv::DimCube:
2847 coord_components = 3;
2848 break;
2849 case spv::DimBuffer:
2850 coord_components = 1;
2851 break;
2852 default:
2853 coord_components = 2;
2854 break;
2855 }
2856
2857 if (dref)
2858 inherited_expressions.push_back(dref);
2859
2860 if (imgtype.image.arrayed)
2861 coord_components++;
2862
2863 uint32_t bias = 0;
2864 uint32_t lod = 0;
2865 uint32_t grad_x = 0;
2866 uint32_t grad_y = 0;
2867 uint32_t coffset = 0;
2868 uint32_t offset = 0;
2869 uint32_t coffsets = 0;
2870 uint32_t sample = 0;
2871 uint32_t minlod = 0;
2872 uint32_t flags = 0;
2873
2874 if (length)
2875 {
2876 flags = opt[0];
2877 opt++;
2878 length--;
2879 }
2880
2881 auto test = [&](uint32_t &v, uint32_t flag) {
2882 if (length && (flags & flag))
2883 {
2884 v = *opt++;
2885 inherited_expressions.push_back(v);
2886 length--;
2887 }
2888 };
2889
2890 test(bias, ImageOperandsBiasMask);
2891 test(lod, ImageOperandsLodMask);
2892 test(grad_x, ImageOperandsGradMask);
2893 test(grad_y, ImageOperandsGradMask);
2894 test(coffset, ImageOperandsConstOffsetMask);
2895 test(offset, ImageOperandsOffsetMask);
2896 test(coffsets, ImageOperandsConstOffsetsMask);
2897 test(sample, ImageOperandsSampleMask);
2898 test(minlod, ImageOperandsMinLodMask);
2899
2900 string expr;
2901 string texop;
2902
2903 if (minlod != 0)
2904 SPIRV_CROSS_THROW("MinLod texture operand not supported in HLSL.");
2905
2906 if (op == OpImageFetch)
2907 {
2908 if (hlsl_options.shader_model < 40)
2909 {
2910 SPIRV_CROSS_THROW("texelFetch is not supported in HLSL shader model 2/3.");
2911 }
2912 texop += img_expr;
2913 texop += ".Load";
2914 }
2915 else if (op == OpImageQueryLod)
2916 {
2917 texop += img_expr;
2918 texop += ".CalculateLevelOfDetail";
2919 }
2920 else
2921 {
2922 auto &imgformat = get<SPIRType>(imgtype.image.type);
2923 if (imgformat.basetype != SPIRType::Float)
2924 {
2925 SPIRV_CROSS_THROW("Sampling non-float textures is not supported in HLSL.");
2926 }
2927
2928 if (hlsl_options.shader_model >= 40)
2929 {
2930 texop += img_expr;
2931
2932 if (is_depth_image(imgtype, img))
2933 {
2934 if (gather)
2935 {
2936 SPIRV_CROSS_THROW("GatherCmp does not exist in HLSL.");
2937 }
2938 else if (lod || grad_x || grad_y)
2939 {
2940 // Assume we want a fixed level, and the only thing we can get in HLSL is SampleCmpLevelZero.
2941 texop += ".SampleCmpLevelZero";
2942 }
2943 else
2944 texop += ".SampleCmp";
2945 }
2946 else if (gather)
2947 {
2948 uint32_t comp_num = evaluate_constant_u32(comp);
2949 if (hlsl_options.shader_model >= 50)
2950 {
2951 switch (comp_num)
2952 {
2953 case 0:
2954 texop += ".GatherRed";
2955 break;
2956 case 1:
2957 texop += ".GatherGreen";
2958 break;
2959 case 2:
2960 texop += ".GatherBlue";
2961 break;
2962 case 3:
2963 texop += ".GatherAlpha";
2964 break;
2965 default:
2966 SPIRV_CROSS_THROW("Invalid component.");
2967 }
2968 }
2969 else
2970 {
2971 if (comp_num == 0)
2972 texop += ".Gather";
2973 else
2974 SPIRV_CROSS_THROW("HLSL shader model 4 can only gather from the red component.");
2975 }
2976 }
2977 else if (bias)
2978 texop += ".SampleBias";
2979 else if (grad_x || grad_y)
2980 texop += ".SampleGrad";
2981 else if (lod)
2982 texop += ".SampleLevel";
2983 else
2984 texop += ".Sample";
2985 }
2986 else
2987 {
2988 switch (imgtype.image.dim)
2989 {
2990 case Dim1D:
2991 texop += "tex1D";
2992 break;
2993 case Dim2D:
2994 texop += "tex2D";
2995 break;
2996 case Dim3D:
2997 texop += "tex3D";
2998 break;
2999 case DimCube:
3000 texop += "texCUBE";
3001 break;
3002 case DimRect:
3003 case DimBuffer:
3004 case DimSubpassData:
3005 SPIRV_CROSS_THROW("Buffer texture support is not yet implemented for HLSL"); // TODO
3006 default:
3007 SPIRV_CROSS_THROW("Invalid dimension.");
3008 }
3009
3010 if (gather)
3011 SPIRV_CROSS_THROW("textureGather is not supported in HLSL shader model 2/3.");
3012 if (offset || coffset)
3013 SPIRV_CROSS_THROW("textureOffset is not supported in HLSL shader model 2/3.");
3014
3015 if (grad_x || grad_y)
3016 texop += "grad";
3017 else if (lod)
3018 texop += "lod";
3019 else if (bias)
3020 texop += "bias";
3021 else if (proj || dref)
3022 texop += "proj";
3023 }
3024 }
3025
3026 expr += texop;
3027 expr += "(";
3028 if (hlsl_options.shader_model < 40)
3029 {
3030 if (combined_image)
3031 SPIRV_CROSS_THROW("Separate images/samplers are not supported in HLSL shader model 2/3.");
3032 expr += to_expression(img);
3033 }
3034 else if (op != OpImageFetch)
3035 {
3036 string sampler_expr;
3037 if (combined_image)
3038 sampler_expr = to_non_uniform_aware_expression(combined_image->sampler);
3039 else
3040 sampler_expr = to_sampler_expression(img);
3041 expr += sampler_expr;
3042 }
3043
3044 auto swizzle = [](uint32_t comps, uint32_t in_comps) -> const char * {
3045 if (comps == in_comps)
3046 return "";
3047
3048 switch (comps)
3049 {
3050 case 1:
3051 return ".x";
3052 case 2:
3053 return ".xy";
3054 case 3:
3055 return ".xyz";
3056 default:
3057 return "";
3058 }
3059 };
3060
3061 bool forward = should_forward(coord);
3062
3063 // The IR can give us more components than we need, so chop them off as needed.
3064 string coord_expr;
3065 auto &coord_type = expression_type(coord);
3066 if (coord_components != coord_type.vecsize)
3067 coord_expr = to_enclosed_expression(coord) + swizzle(coord_components, expression_type(coord).vecsize);
3068 else
3069 coord_expr = to_expression(coord);
3070
3071 if (proj && hlsl_options.shader_model >= 40) // Legacy HLSL has "proj" operations which do this for us.
3072 coord_expr = coord_expr + " / " + to_extract_component_expression(coord, coord_components);
3073
3074 if (hlsl_options.shader_model < 40)
3075 {
3076 if (dref)
3077 {
3078 if (imgtype.image.dim != spv::Dim1D && imgtype.image.dim != spv::Dim2D)
3079 {
3080 SPIRV_CROSS_THROW(
3081 "Depth comparison is only supported for 1D and 2D textures in HLSL shader model 2/3.");
3082 }
3083
3084 if (grad_x || grad_y)
3085 SPIRV_CROSS_THROW("Depth comparison is not supported for grad sampling in HLSL shader model 2/3.");
3086
3087 for (uint32_t size = coord_components; size < 2; ++size)
3088 coord_expr += ", 0.0";
3089
3090 forward = forward && should_forward(dref);
3091 coord_expr += ", " + to_expression(dref);
3092 }
3093 else if (lod || bias || proj)
3094 {
3095 for (uint32_t size = coord_components; size < 3; ++size)
3096 coord_expr += ", 0.0";
3097 }
3098
3099 if (lod)
3100 {
3101 coord_expr = "float4(" + coord_expr + ", " + to_expression(lod) + ")";
3102 }
3103 else if (bias)
3104 {
3105 coord_expr = "float4(" + coord_expr + ", " + to_expression(bias) + ")";
3106 }
3107 else if (proj)
3108 {
3109 coord_expr = "float4(" + coord_expr + ", " + to_extract_component_expression(coord, coord_components) + ")";
3110 }
3111 else if (dref)
3112 {
3113 // A "normal" sample gets fed into tex2Dproj as well, because the
3114 // regular tex2D accepts only two coordinates.
3115 coord_expr = "float4(" + coord_expr + ", 1.0)";
3116 }
3117
3118 if (!!lod + !!bias + !!proj > 1)
3119 SPIRV_CROSS_THROW("Legacy HLSL can only use one of lod/bias/proj modifiers.");
3120 }
3121
3122 if (op == OpImageFetch)
3123 {
3124 if (imgtype.image.dim != DimBuffer && !imgtype.image.ms)
3125 coord_expr =
3126 join("int", coord_components + 1, "(", coord_expr, ", ", lod ? to_expression(lod) : string("0"), ")");
3127 }
3128 else
3129 expr += ", ";
3130 expr += coord_expr;
3131
3132 if (dref && hlsl_options.shader_model >= 40)
3133 {
3134 forward = forward && should_forward(dref);
3135 expr += ", ";
3136
3137 if (proj)
3138 expr += to_enclosed_expression(dref) + " / " + to_extract_component_expression(coord, coord_components);
3139 else
3140 expr += to_expression(dref);
3141 }
3142
3143 if (!dref && (grad_x || grad_y))
3144 {
3145 forward = forward && should_forward(grad_x);
3146 forward = forward && should_forward(grad_y);
3147 expr += ", ";
3148 expr += to_expression(grad_x);
3149 expr += ", ";
3150 expr += to_expression(grad_y);
3151 }
3152
3153 if (!dref && lod && hlsl_options.shader_model >= 40 && op != OpImageFetch)
3154 {
3155 forward = forward && should_forward(lod);
3156 expr += ", ";
3157 expr += to_expression(lod);
3158 }
3159
3160 if (!dref && bias && hlsl_options.shader_model >= 40)
3161 {
3162 forward = forward && should_forward(bias);
3163 expr += ", ";
3164 expr += to_expression(bias);
3165 }
3166
3167 if (coffset)
3168 {
3169 forward = forward && should_forward(coffset);
3170 expr += ", ";
3171 expr += to_expression(coffset);
3172 }
3173 else if (offset)
3174 {
3175 forward = forward && should_forward(offset);
3176 expr += ", ";
3177 expr += to_expression(offset);
3178 }
3179
3180 if (sample)
3181 {
3182 expr += ", ";
3183 expr += to_expression(sample);
3184 }
3185
3186 expr += ")";
3187
3188 if (dref && hlsl_options.shader_model < 40)
3189 expr += ".x";
3190
3191 if (op == OpImageQueryLod)
3192 {
3193 // This is rather awkward.
3194 // textureQueryLod returns two values, the "accessed level",
3195 // as well as the actual LOD lambda.
3196 // As far as I can tell, there is no way to get the .x component
3197 // according to GLSL spec, and it depends on the sampler itself.
3198 // Just assume X == Y, so we will need to splat the result to a float2.
3199 statement("float _", id, "_tmp = ", expr, ";");
3200 statement("float2 _", id, " = _", id, "_tmp.xx;");
3201 set<SPIRExpression>(id, join("_", id), result_type, true);
3202 }
3203 else
3204 {
3205 emit_op(result_type, id, expr, forward, false);
3206 }
3207
3208 for (auto &inherit : inherited_expressions)
3209 inherit_expression_dependencies(id, inherit);
3210
3211 switch (op)
3212 {
3213 case OpImageSampleDrefImplicitLod:
3214 case OpImageSampleImplicitLod:
3215 case OpImageSampleProjImplicitLod:
3216 case OpImageSampleProjDrefImplicitLod:
3217 register_control_dependent_expression(id);
3218 break;
3219
3220 default:
3221 break;
3222 }
3223}
3224
3225string CompilerHLSL::to_resource_binding(const SPIRVariable &var)
3226{
3227 const auto &type = get<SPIRType>(var.basetype);
3228
3229 // We can remap push constant blocks, even if they don't have any binding decoration.
3230 if (type.storage != StorageClassPushConstant && !has_decoration(var.self, DecorationBinding))
3231 return "";
3232
3233 char space = '\0';
3234
3235 HLSLBindingFlagBits resource_flags = HLSL_BINDING_AUTO_NONE_BIT;
3236
3237 switch (type.basetype)
3238 {
3239 case SPIRType::SampledImage:
3240 space = 't'; // SRV
3241 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3242 break;
3243
3244 case SPIRType::Image:
3245 if (type.image.sampled == 2 && type.image.dim != DimSubpassData)
3246 {
3247 if (has_decoration(var.self, DecorationNonWritable) && hlsl_options.nonwritable_uav_texture_as_srv)
3248 {
3249 space = 't'; // SRV
3250 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3251 }
3252 else
3253 {
3254 space = 'u'; // UAV
3255 resource_flags = HLSL_BINDING_AUTO_UAV_BIT;
3256 }
3257 }
3258 else
3259 {
3260 space = 't'; // SRV
3261 resource_flags = HLSL_BINDING_AUTO_SRV_BIT;
3262 }
3263 break;
3264
3265 case SPIRType::Sampler:
3266 space = 's';
3267 resource_flags = HLSL_BINDING_AUTO_SAMPLER_BIT;
3268 break;
3269
3270 case SPIRType::Struct:
3271 {
3272 auto storage = type.storage;
3273 if (storage == StorageClassUniform)
3274 {
3275 if (has_decoration(type.self, DecorationBufferBlock))
3276 {
3277 Bitset flags = ir.get_buffer_block_flags(var);
3278 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3279 space = is_readonly ? 't' : 'u'; // UAV
3280 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3281 }
3282 else if (has_decoration(type.self, DecorationBlock))
3283 {
3284 space = 'b'; // Constant buffers
3285 resource_flags = HLSL_BINDING_AUTO_CBV_BIT;
3286 }
3287 }
3288 else if (storage == StorageClassPushConstant)
3289 {
3290 space = 'b'; // Constant buffers
3291 resource_flags = HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT;
3292 }
3293 else if (storage == StorageClassStorageBuffer)
3294 {
3295 // UAV or SRV depending on readonly flag.
3296 Bitset flags = ir.get_buffer_block_flags(var);
3297 bool is_readonly = flags.get(DecorationNonWritable) && !is_hlsl_force_storage_buffer_as_uav(var.self);
3298 space = is_readonly ? 't' : 'u';
3299 resource_flags = is_readonly ? HLSL_BINDING_AUTO_SRV_BIT : HLSL_BINDING_AUTO_UAV_BIT;
3300 }
3301
3302 break;
3303 }
3304 default:
3305 break;
3306 }
3307
3308 if (!space)
3309 return "";
3310
3311 uint32_t desc_set =
3312 resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantDescriptorSet : 0u;
3313 uint32_t binding = resource_flags == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT ? ResourceBindingPushConstantBinding : 0u;
3314
3315 if (has_decoration(var.self, DecorationBinding))
3316 binding = get_decoration(var.self, DecorationBinding);
3317 if (has_decoration(var.self, DecorationDescriptorSet))
3318 desc_set = get_decoration(var.self, DecorationDescriptorSet);
3319
3320 return to_resource_register(resource_flags, space, binding, desc_set);
3321}
3322
3323string CompilerHLSL::to_resource_binding_sampler(const SPIRVariable &var)
3324{
3325 // For combined image samplers.
3326 if (!has_decoration(var.self, DecorationBinding))
3327 return "";
3328
3329 return to_resource_register(HLSL_BINDING_AUTO_SAMPLER_BIT, 's', get_decoration(var.self, DecorationBinding),
3330 get_decoration(var.self, DecorationDescriptorSet));
3331}
3332
3333void CompilerHLSL::remap_hlsl_resource_binding(HLSLBindingFlagBits type, uint32_t &desc_set, uint32_t &binding)
3334{
3335 auto itr = resource_bindings.find({ get_execution_model(), desc_set, binding });
3336 if (itr != end(resource_bindings))
3337 {
3338 auto &remap = itr->second;
3339 remap.second = true;
3340
3341 switch (type)
3342 {
3343 case HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT:
3344 case HLSL_BINDING_AUTO_CBV_BIT:
3345 desc_set = remap.first.cbv.register_space;
3346 binding = remap.first.cbv.register_binding;
3347 break;
3348
3349 case HLSL_BINDING_AUTO_SRV_BIT:
3350 desc_set = remap.first.srv.register_space;
3351 binding = remap.first.srv.register_binding;
3352 break;
3353
3354 case HLSL_BINDING_AUTO_SAMPLER_BIT:
3355 desc_set = remap.first.sampler.register_space;
3356 binding = remap.first.sampler.register_binding;
3357 break;
3358
3359 case HLSL_BINDING_AUTO_UAV_BIT:
3360 desc_set = remap.first.uav.register_space;
3361 binding = remap.first.uav.register_binding;
3362 break;
3363
3364 default:
3365 break;
3366 }
3367 }
3368}
3369
3370string CompilerHLSL::to_resource_register(HLSLBindingFlagBits flag, char space, uint32_t binding, uint32_t space_set)
3371{
3372 if ((flag & resource_binding_flags) == 0)
3373 {
3374 remap_hlsl_resource_binding(flag, space_set, binding);
3375
3376 // The push constant block did not have a binding, and there were no remap for it,
3377 // so, declare without register binding.
3378 if (flag == HLSL_BINDING_AUTO_PUSH_CONSTANT_BIT && space_set == ResourceBindingPushConstantDescriptorSet)
3379 return "";
3380
3381 if (hlsl_options.shader_model >= 51)
3382 return join(" : register(", space, binding, ", space", space_set, ")");
3383 else
3384 return join(" : register(", space, binding, ")");
3385 }
3386 else
3387 return "";
3388}
3389
3390void CompilerHLSL::emit_modern_uniform(const SPIRVariable &var)
3391{
3392 auto &type = get<SPIRType>(var.basetype);
3393 switch (type.basetype)
3394 {
3395 case SPIRType::SampledImage:
3396 case SPIRType::Image:
3397 {
3398 bool is_coherent = false;
3399 if (type.basetype == SPIRType::Image && type.image.sampled == 2)
3400 is_coherent = has_decoration(var.self, DecorationCoherent);
3401
3402 statement(is_coherent ? "globallycoherent " : "", image_type_hlsl_modern(type, var.self), " ",
3403 to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3404
3405 if (type.basetype == SPIRType::SampledImage && type.image.dim != DimBuffer)
3406 {
3407 // For combined image samplers, also emit a combined image sampler.
3408 if (is_depth_image(type, var.self))
3409 statement("SamplerComparisonState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3410 to_resource_binding_sampler(var), ";");
3411 else
3412 statement("SamplerState ", to_sampler_expression(var.self), type_to_array_glsl(type),
3413 to_resource_binding_sampler(var), ";");
3414 }
3415 break;
3416 }
3417
3418 case SPIRType::Sampler:
3419 if (comparison_ids.count(var.self))
3420 statement("SamplerComparisonState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var),
3421 ";");
3422 else
3423 statement("SamplerState ", to_name(var.self), type_to_array_glsl(type), to_resource_binding(var), ";");
3424 break;
3425
3426 default:
3427 statement(variable_decl(var), to_resource_binding(var), ";");
3428 break;
3429 }
3430}
3431
3432void CompilerHLSL::emit_legacy_uniform(const SPIRVariable &var)
3433{
3434 auto &type = get<SPIRType>(var.basetype);
3435 switch (type.basetype)
3436 {
3437 case SPIRType::Sampler:
3438 case SPIRType::Image:
3439 SPIRV_CROSS_THROW("Separate image and samplers not supported in legacy HLSL.");
3440
3441 default:
3442 statement(variable_decl(var), ";");
3443 break;
3444 }
3445}
3446
3447void CompilerHLSL::emit_uniform(const SPIRVariable &var)
3448{
3449 add_resource_name(var.self);
3450 if (hlsl_options.shader_model >= 40)
3451 emit_modern_uniform(var);
3452 else
3453 emit_legacy_uniform(var);
3454}
3455
3456bool CompilerHLSL::emit_complex_bitcast(uint32_t, uint32_t, uint32_t)
3457{
3458 return false;
3459}
3460
3461string CompilerHLSL::bitcast_glsl_op(const SPIRType &out_type, const SPIRType &in_type)
3462{
3463 if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Int)
3464 return type_to_glsl(out_type);
3465 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Int64)
3466 return type_to_glsl(out_type);
3467 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Float)
3468 return "asuint";
3469 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::UInt)
3470 return type_to_glsl(out_type);
3471 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::UInt64)
3472 return type_to_glsl(out_type);
3473 else if (out_type.basetype == SPIRType::Int && in_type.basetype == SPIRType::Float)
3474 return "asint";
3475 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::UInt)
3476 return "asfloat";
3477 else if (out_type.basetype == SPIRType::Float && in_type.basetype == SPIRType::Int)
3478 return "asfloat";
3479 else if (out_type.basetype == SPIRType::Int64 && in_type.basetype == SPIRType::Double)
3480 SPIRV_CROSS_THROW("Double to Int64 is not supported in HLSL.");
3481 else if (out_type.basetype == SPIRType::UInt64 && in_type.basetype == SPIRType::Double)
3482 SPIRV_CROSS_THROW("Double to UInt64 is not supported in HLSL.");
3483 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::Int64)
3484 return "asdouble";
3485 else if (out_type.basetype == SPIRType::Double && in_type.basetype == SPIRType::UInt64)
3486 return "asdouble";
3487 else if (out_type.basetype == SPIRType::Half && in_type.basetype == SPIRType::UInt && in_type.vecsize == 1)
3488 {
3489 if (!requires_explicit_fp16_packing)
3490 {
3491 requires_explicit_fp16_packing = true;
3492 force_recompile();
3493 }
3494 return "spvUnpackFloat2x16";
3495 }
3496 else if (out_type.basetype == SPIRType::UInt && in_type.basetype == SPIRType::Half && in_type.vecsize == 2)
3497 {
3498 if (!requires_explicit_fp16_packing)
3499 {
3500 requires_explicit_fp16_packing = true;
3501 force_recompile();
3502 }
3503 return "spvPackFloat2x16";
3504 }
3505 else
3506 return "";
3507}
3508
3509void CompilerHLSL::emit_glsl_op(uint32_t result_type, uint32_t id, uint32_t eop, const uint32_t *args, uint32_t count)
3510{
3511 auto op = static_cast<GLSLstd450>(eop);
3512
3513 // If we need to do implicit bitcasts, make sure we do it with the correct type.
3514 uint32_t integer_width = get_integer_width_for_glsl_instruction(op, args, count);
3515 auto int_type = to_signed_basetype(integer_width);
3516 auto uint_type = to_unsigned_basetype(integer_width);
3517
3518 switch (op)
3519 {
3520 case GLSLstd450InverseSqrt:
3521 emit_unary_func_op(result_type, id, args[0], "rsqrt");
3522 break;
3523
3524 case GLSLstd450Fract:
3525 emit_unary_func_op(result_type, id, args[0], "frac");
3526 break;
3527
3528 case GLSLstd450RoundEven:
3529 if (hlsl_options.shader_model < 40)
3530 SPIRV_CROSS_THROW("roundEven is not supported in HLSL shader model 2/3.");
3531 emit_unary_func_op(result_type, id, args[0], "round");
3532 break;
3533
3534 case GLSLstd450Acosh:
3535 case GLSLstd450Asinh:
3536 case GLSLstd450Atanh:
3537 SPIRV_CROSS_THROW("Inverse hyperbolics are not supported on HLSL.");
3538
3539 case GLSLstd450FMix:
3540 case GLSLstd450IMix:
3541 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "lerp");
3542 break;
3543
3544 case GLSLstd450Atan2:
3545 emit_binary_func_op(result_type, id, args[0], args[1], "atan2");
3546 break;
3547
3548 case GLSLstd450Fma:
3549 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "mad");
3550 break;
3551
3552 case GLSLstd450InterpolateAtCentroid:
3553 emit_unary_func_op(result_type, id, args[0], "EvaluateAttributeAtCentroid");
3554 break;
3555 case GLSLstd450InterpolateAtSample:
3556 emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeAtSample");
3557 break;
3558 case GLSLstd450InterpolateAtOffset:
3559 emit_binary_func_op(result_type, id, args[0], args[1], "EvaluateAttributeSnapped");
3560 break;
3561
3562 case GLSLstd450PackHalf2x16:
3563 if (!requires_fp16_packing)
3564 {
3565 requires_fp16_packing = true;
3566 force_recompile();
3567 }
3568 emit_unary_func_op(result_type, id, args[0], "spvPackHalf2x16");
3569 break;
3570
3571 case GLSLstd450UnpackHalf2x16:
3572 if (!requires_fp16_packing)
3573 {
3574 requires_fp16_packing = true;
3575 force_recompile();
3576 }
3577 emit_unary_func_op(result_type, id, args[0], "spvUnpackHalf2x16");
3578 break;
3579
3580 case GLSLstd450PackSnorm4x8:
3581 if (!requires_snorm8_packing)
3582 {
3583 requires_snorm8_packing = true;
3584 force_recompile();
3585 }
3586 emit_unary_func_op(result_type, id, args[0], "spvPackSnorm4x8");
3587 break;
3588
3589 case GLSLstd450UnpackSnorm4x8:
3590 if (!requires_snorm8_packing)
3591 {
3592 requires_snorm8_packing = true;
3593 force_recompile();
3594 }
3595 emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm4x8");
3596 break;
3597
3598 case GLSLstd450PackUnorm4x8:
3599 if (!requires_unorm8_packing)
3600 {
3601 requires_unorm8_packing = true;
3602 force_recompile();
3603 }
3604 emit_unary_func_op(result_type, id, args[0], "spvPackUnorm4x8");
3605 break;
3606
3607 case GLSLstd450UnpackUnorm4x8:
3608 if (!requires_unorm8_packing)
3609 {
3610 requires_unorm8_packing = true;
3611 force_recompile();
3612 }
3613 emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm4x8");
3614 break;
3615
3616 case GLSLstd450PackSnorm2x16:
3617 if (!requires_snorm16_packing)
3618 {
3619 requires_snorm16_packing = true;
3620 force_recompile();
3621 }
3622 emit_unary_func_op(result_type, id, args[0], "spvPackSnorm2x16");
3623 break;
3624
3625 case GLSLstd450UnpackSnorm2x16:
3626 if (!requires_snorm16_packing)
3627 {
3628 requires_snorm16_packing = true;
3629 force_recompile();
3630 }
3631 emit_unary_func_op(result_type, id, args[0], "spvUnpackSnorm2x16");
3632 break;
3633
3634 case GLSLstd450PackUnorm2x16:
3635 if (!requires_unorm16_packing)
3636 {
3637 requires_unorm16_packing = true;
3638 force_recompile();
3639 }
3640 emit_unary_func_op(result_type, id, args[0], "spvPackUnorm2x16");
3641 break;
3642
3643 case GLSLstd450UnpackUnorm2x16:
3644 if (!requires_unorm16_packing)
3645 {
3646 requires_unorm16_packing = true;
3647 force_recompile();
3648 }
3649 emit_unary_func_op(result_type, id, args[0], "spvUnpackUnorm2x16");
3650 break;
3651
3652 case GLSLstd450PackDouble2x32:
3653 case GLSLstd450UnpackDouble2x32:
3654 SPIRV_CROSS_THROW("packDouble2x32/unpackDouble2x32 not supported in HLSL.");
3655
3656 case GLSLstd450FindILsb:
3657 {
3658 auto basetype = expression_type(args[0]).basetype;
3659 emit_unary_func_op_cast(result_type, id, args[0], "firstbitlow", basetype, basetype);
3660 break;
3661 }
3662
3663 case GLSLstd450FindSMsb:
3664 emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", int_type, int_type);
3665 break;
3666
3667 case GLSLstd450FindUMsb:
3668 emit_unary_func_op_cast(result_type, id, args[0], "firstbithigh", uint_type, uint_type);
3669 break;
3670
3671 case GLSLstd450MatrixInverse:
3672 {
3673 auto &type = get<SPIRType>(result_type);
3674 if (type.vecsize == 2 && type.columns == 2)
3675 {
3676 if (!requires_inverse_2x2)
3677 {
3678 requires_inverse_2x2 = true;
3679 force_recompile();
3680 }
3681 }
3682 else if (type.vecsize == 3 && type.columns == 3)
3683 {
3684 if (!requires_inverse_3x3)
3685 {
3686 requires_inverse_3x3 = true;
3687 force_recompile();
3688 }
3689 }
3690 else if (type.vecsize == 4 && type.columns == 4)
3691 {
3692 if (!requires_inverse_4x4)
3693 {
3694 requires_inverse_4x4 = true;
3695 force_recompile();
3696 }
3697 }
3698 emit_unary_func_op(result_type, id, args[0], "spvInverse");
3699 break;
3700 }
3701
3702 case GLSLstd450Normalize:
3703 // HLSL does not support scalar versions here.
3704 if (expression_type(args[0]).vecsize == 1)
3705 {
3706 // Returns -1 or 1 for valid input, sign() does the job.
3707 emit_unary_func_op(result_type, id, args[0], "sign");
3708 }
3709 else
3710 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3711 break;
3712
3713 case GLSLstd450Reflect:
3714 if (get<SPIRType>(result_type).vecsize == 1)
3715 {
3716 if (!requires_scalar_reflect)
3717 {
3718 requires_scalar_reflect = true;
3719 force_recompile();
3720 }
3721 emit_binary_func_op(result_type, id, args[0], args[1], "spvReflect");
3722 }
3723 else
3724 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3725 break;
3726
3727 case GLSLstd450Refract:
3728 if (get<SPIRType>(result_type).vecsize == 1)
3729 {
3730 if (!requires_scalar_refract)
3731 {
3732 requires_scalar_refract = true;
3733 force_recompile();
3734 }
3735 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvRefract");
3736 }
3737 else
3738 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3739 break;
3740
3741 case GLSLstd450FaceForward:
3742 if (get<SPIRType>(result_type).vecsize == 1)
3743 {
3744 if (!requires_scalar_faceforward)
3745 {
3746 requires_scalar_faceforward = true;
3747 force_recompile();
3748 }
3749 emit_trinary_func_op(result_type, id, args[0], args[1], args[2], "spvFaceForward");
3750 }
3751 else
3752 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3753 break;
3754
3755 default:
3756 CompilerGLSL::emit_glsl_op(result_type, id, eop, args, count);
3757 break;
3758 }
3759}
3760
3761void CompilerHLSL::read_access_chain_array(const string &lhs, const SPIRAccessChain &chain)
3762{
3763 auto &type = get<SPIRType>(chain.basetype);
3764
3765 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
3766 auto ident = get_unique_identifier();
3767
3768 statement("[unroll]");
3769 statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
3770 ident, "++)");
3771 begin_scope();
3772 auto subchain = chain;
3773 subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
3774 subchain.basetype = type.parent_type;
3775 if (!get<SPIRType>(subchain.basetype).array.empty())
3776 subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
3777 read_access_chain(nullptr, join(lhs, "[", ident, "]"), subchain);
3778 end_scope();
3779}
3780
3781void CompilerHLSL::read_access_chain_struct(const string &lhs, const SPIRAccessChain &chain)
3782{
3783 auto &type = get<SPIRType>(chain.basetype);
3784 auto subchain = chain;
3785 uint32_t member_count = uint32_t(type.member_types.size());
3786
3787 for (uint32_t i = 0; i < member_count; i++)
3788 {
3789 uint32_t offset = type_struct_member_offset(type, i);
3790 subchain.static_index = chain.static_index + offset;
3791 subchain.basetype = type.member_types[i];
3792
3793 subchain.matrix_stride = 0;
3794 subchain.array_stride = 0;
3795 subchain.row_major_matrix = false;
3796
3797 auto &member_type = get<SPIRType>(subchain.basetype);
3798 if (member_type.columns > 1)
3799 {
3800 subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
3801 subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
3802 }
3803
3804 if (!member_type.array.empty())
3805 subchain.array_stride = type_struct_member_array_stride(type, i);
3806
3807 read_access_chain(nullptr, join(lhs, ".", to_member_name(type, i)), subchain);
3808 }
3809}
3810
3811void CompilerHLSL::read_access_chain(string *expr, const string &lhs, const SPIRAccessChain &chain)
3812{
3813 auto &type = get<SPIRType>(chain.basetype);
3814
3815 SPIRType target_type;
3816 target_type.basetype = SPIRType::UInt;
3817 target_type.vecsize = type.vecsize;
3818 target_type.columns = type.columns;
3819
3820 if (!type.array.empty())
3821 {
3822 read_access_chain_array(lhs, chain);
3823 return;
3824 }
3825 else if (type.basetype == SPIRType::Struct)
3826 {
3827 read_access_chain_struct(lhs, chain);
3828 return;
3829 }
3830 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
3831 SPIRV_CROSS_THROW("Reading types other than 32-bit from ByteAddressBuffer not yet supported, unless SM 6.2 and "
3832 "native 16-bit types are enabled.");
3833
3834 string base = chain.base;
3835 if (has_decoration(chain.self, DecorationNonUniform))
3836 convert_non_uniform_expression(base, chain.self);
3837
3838 bool templated_load = hlsl_options.shader_model >= 62;
3839 string load_expr;
3840
3841 string template_expr;
3842 if (templated_load)
3843 template_expr = join("<", type_to_glsl(type), ">");
3844
3845 // Load a vector or scalar.
3846 if (type.columns == 1 && !chain.row_major_matrix)
3847 {
3848 const char *load_op = nullptr;
3849 switch (type.vecsize)
3850 {
3851 case 1:
3852 load_op = "Load";
3853 break;
3854 case 2:
3855 load_op = "Load2";
3856 break;
3857 case 3:
3858 load_op = "Load3";
3859 break;
3860 case 4:
3861 load_op = "Load4";
3862 break;
3863 default:
3864 SPIRV_CROSS_THROW("Unknown vector size.");
3865 }
3866
3867 if (templated_load)
3868 load_op = "Load";
3869
3870 load_expr = join(base, ".", load_op, template_expr, "(", chain.dynamic_index, chain.static_index, ")");
3871 }
3872 else if (type.columns == 1)
3873 {
3874 // Strided load since we are loading a column from a row-major matrix.
3875 if (templated_load)
3876 {
3877 auto scalar_type = type;
3878 scalar_type.vecsize = 1;
3879 scalar_type.columns = 1;
3880 template_expr = join("<", type_to_glsl(scalar_type), ">");
3881 if (type.vecsize > 1)
3882 load_expr += type_to_glsl(type) + "(";
3883 }
3884 else if (type.vecsize > 1)
3885 {
3886 load_expr = type_to_glsl(target_type);
3887 load_expr += "(";
3888 }
3889
3890 for (uint32_t r = 0; r < type.vecsize; r++)
3891 {
3892 load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3893 chain.static_index + r * chain.matrix_stride, ")");
3894 if (r + 1 < type.vecsize)
3895 load_expr += ", ";
3896 }
3897
3898 if (type.vecsize > 1)
3899 load_expr += ")";
3900 }
3901 else if (!chain.row_major_matrix)
3902 {
3903 // Load a matrix, column-major, the easy case.
3904 const char *load_op = nullptr;
3905 switch (type.vecsize)
3906 {
3907 case 1:
3908 load_op = "Load";
3909 break;
3910 case 2:
3911 load_op = "Load2";
3912 break;
3913 case 3:
3914 load_op = "Load3";
3915 break;
3916 case 4:
3917 load_op = "Load4";
3918 break;
3919 default:
3920 SPIRV_CROSS_THROW("Unknown vector size.");
3921 }
3922
3923 if (templated_load)
3924 {
3925 auto vector_type = type;
3926 vector_type.columns = 1;
3927 template_expr = join("<", type_to_glsl(vector_type), ">");
3928 load_expr = type_to_glsl(type);
3929 load_op = "Load";
3930 }
3931 else
3932 {
3933 // Note, this loading style in HLSL is *actually* row-major, but we always treat matrices as transposed in this backend,
3934 // so row-major is technically column-major ...
3935 load_expr = type_to_glsl(target_type);
3936 }
3937 load_expr += "(";
3938
3939 for (uint32_t c = 0; c < type.columns; c++)
3940 {
3941 load_expr += join(base, ".", load_op, template_expr, "(", chain.dynamic_index,
3942 chain.static_index + c * chain.matrix_stride, ")");
3943 if (c + 1 < type.columns)
3944 load_expr += ", ";
3945 }
3946 load_expr += ")";
3947 }
3948 else
3949 {
3950 // Pick out elements one by one ... Hopefully compilers are smart enough to recognize this pattern
3951 // considering HLSL is "row-major decl", but "column-major" memory layout (basically implicit transpose model, ugh) ...
3952
3953 if (templated_load)
3954 {
3955 load_expr = type_to_glsl(type);
3956 auto scalar_type = type;
3957 scalar_type.vecsize = 1;
3958 scalar_type.columns = 1;
3959 template_expr = join("<", type_to_glsl(scalar_type), ">");
3960 }
3961 else
3962 load_expr = type_to_glsl(target_type);
3963
3964 load_expr += "(";
3965
3966 for (uint32_t c = 0; c < type.columns; c++)
3967 {
3968 for (uint32_t r = 0; r < type.vecsize; r++)
3969 {
3970 load_expr += join(base, ".Load", template_expr, "(", chain.dynamic_index,
3971 chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ")");
3972
3973 if ((r + 1 < type.vecsize) || (c + 1 < type.columns))
3974 load_expr += ", ";
3975 }
3976 }
3977 load_expr += ")";
3978 }
3979
3980 if (!templated_load)
3981 {
3982 auto bitcast_op = bitcast_glsl_op(type, target_type);
3983 if (!bitcast_op.empty())
3984 load_expr = join(bitcast_op, "(", load_expr, ")");
3985 }
3986
3987 if (lhs.empty())
3988 {
3989 assert(expr);
3990 *expr = move(load_expr);
3991 }
3992 else
3993 statement(lhs, " = ", load_expr, ";");
3994}
3995
3996void CompilerHLSL::emit_load(const Instruction &instruction)
3997{
3998 auto ops = stream(instruction);
3999
4000 auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4001 if (chain)
4002 {
4003 uint32_t result_type = ops[0];
4004 uint32_t id = ops[1];
4005 uint32_t ptr = ops[2];
4006
4007 auto &type = get<SPIRType>(result_type);
4008 bool composite_load = !type.array.empty() || type.basetype == SPIRType::Struct;
4009
4010 if (composite_load)
4011 {
4012 // We cannot make this work in one single expression as we might have nested structures and arrays,
4013 // so unroll the load to an uninitialized temporary.
4014 emit_uninitialized_temporary_expression(result_type, id);
4015 read_access_chain(nullptr, to_expression(id), *chain);
4016 track_expression_read(chain->self);
4017 }
4018 else
4019 {
4020 string load_expr;
4021 read_access_chain(&load_expr, "", *chain);
4022
4023 bool forward = should_forward(ptr) && forced_temporaries.find(id) == end(forced_temporaries);
4024
4025 // If we are forwarding this load,
4026 // don't register the read to access chain here, defer that to when we actually use the expression,
4027 // using the add_implied_read_expression mechanism.
4028 if (!forward)
4029 track_expression_read(chain->self);
4030
4031 // Do not forward complex load sequences like matrices, structs and arrays.
4032 if (type.columns > 1)
4033 forward = false;
4034
4035 auto &e = emit_op(result_type, id, load_expr, forward, true);
4036 e.need_transpose = false;
4037 register_read(id, ptr, forward);
4038 inherit_expression_dependencies(id, ptr);
4039 if (forward)
4040 add_implied_read_expression(e, chain->self);
4041 }
4042 }
4043 else
4044 CompilerGLSL::emit_instruction(instruction);
4045}
4046
4047void CompilerHLSL::write_access_chain_array(const SPIRAccessChain &chain, uint32_t value,
4048 const SmallVector<uint32_t> &composite_chain)
4049{
4050 auto &type = get<SPIRType>(chain.basetype);
4051
4052 // Need to use a reserved identifier here since it might shadow an identifier in the access chain input or other loops.
4053 auto ident = get_unique_identifier();
4054
4055 uint32_t id = ir.increase_bound_by(2);
4056 uint32_t int_type_id = id + 1;
4057 SPIRType int_type;
4058 int_type.basetype = SPIRType::Int;
4059 int_type.width = 32;
4060 set<SPIRType>(int_type_id, int_type);
4061 set<SPIRExpression>(id, ident, int_type_id, true);
4062 set_name(id, ident);
4063 suppressed_usage_tracking.insert(id);
4064
4065 statement("[unroll]");
4066 statement("for (int ", ident, " = 0; ", ident, " < ", to_array_size(type, uint32_t(type.array.size() - 1)), "; ",
4067 ident, "++)");
4068 begin_scope();
4069 auto subchain = chain;
4070 subchain.dynamic_index = join(ident, " * ", chain.array_stride, " + ", chain.dynamic_index);
4071 subchain.basetype = type.parent_type;
4072
4073 // Forcefully allow us to use an ID here by setting MSB.
4074 auto subcomposite_chain = composite_chain;
4075 subcomposite_chain.push_back(0x80000000u | id);
4076
4077 if (!get<SPIRType>(subchain.basetype).array.empty())
4078 subchain.array_stride = get_decoration(subchain.basetype, DecorationArrayStride);
4079
4080 write_access_chain(subchain, value, subcomposite_chain);
4081 end_scope();
4082}
4083
4084void CompilerHLSL::write_access_chain_struct(const SPIRAccessChain &chain, uint32_t value,
4085 const SmallVector<uint32_t> &composite_chain)
4086{
4087 auto &type = get<SPIRType>(chain.basetype);
4088 uint32_t member_count = uint32_t(type.member_types.size());
4089 auto subchain = chain;
4090
4091 auto subcomposite_chain = composite_chain;
4092 subcomposite_chain.push_back(0);
4093
4094 for (uint32_t i = 0; i < member_count; i++)
4095 {
4096 uint32_t offset = type_struct_member_offset(type, i);
4097 subchain.static_index = chain.static_index + offset;
4098 subchain.basetype = type.member_types[i];
4099
4100 subchain.matrix_stride = 0;
4101 subchain.array_stride = 0;
4102 subchain.row_major_matrix = false;
4103
4104 auto &member_type = get<SPIRType>(subchain.basetype);
4105 if (member_type.columns > 1)
4106 {
4107 subchain.matrix_stride = type_struct_member_matrix_stride(type, i);
4108 subchain.row_major_matrix = has_member_decoration(type.self, i, DecorationRowMajor);
4109 }
4110
4111 if (!member_type.array.empty())
4112 subchain.array_stride = type_struct_member_array_stride(type, i);
4113
4114 subcomposite_chain.back() = i;
4115 write_access_chain(subchain, value, subcomposite_chain);
4116 }
4117}
4118
4119string CompilerHLSL::write_access_chain_value(uint32_t value, const SmallVector<uint32_t> &composite_chain,
4120 bool enclose)
4121{
4122 string ret;
4123 if (composite_chain.empty())
4124 ret = to_expression(value);
4125 else
4126 {
4127 AccessChainMeta meta;
4128 ret = access_chain_internal(value, composite_chain.data(), uint32_t(composite_chain.size()),
4129 ACCESS_CHAIN_INDEX_IS_LITERAL_BIT | ACCESS_CHAIN_LITERAL_MSB_FORCE_ID, &meta);
4130 }
4131
4132 if (enclose)
4133 ret = enclose_expression(ret);
4134 return ret;
4135}
4136
4137void CompilerHLSL::write_access_chain(const SPIRAccessChain &chain, uint32_t value,
4138 const SmallVector<uint32_t> &composite_chain)
4139{
4140 auto &type = get<SPIRType>(chain.basetype);
4141
4142 // Make sure we trigger a read of the constituents in the access chain.
4143 track_expression_read(chain.self);
4144
4145 SPIRType target_type;
4146 target_type.basetype = SPIRType::UInt;
4147 target_type.vecsize = type.vecsize;
4148 target_type.columns = type.columns;
4149
4150 if (!type.array.empty())
4151 {
4152 write_access_chain_array(chain, value, composite_chain);
4153 register_write(chain.self);
4154 return;
4155 }
4156 else if (type.basetype == SPIRType::Struct)
4157 {
4158 write_access_chain_struct(chain, value, composite_chain);
4159 register_write(chain.self);
4160 return;
4161 }
4162 else if (type.width != 32 && !hlsl_options.enable_16bit_types)
4163 SPIRV_CROSS_THROW("Writing types other than 32-bit to RWByteAddressBuffer not yet supported, unless SM 6.2 and "
4164 "native 16-bit types are enabled.");
4165
4166 bool templated_store = hlsl_options.shader_model >= 62;
4167
4168 auto base = chain.base;
4169 if (has_decoration(chain.self, DecorationNonUniform))
4170 convert_non_uniform_expression(base, chain.self);
4171
4172 string template_expr;
4173 if (templated_store)
4174 template_expr = join("<", type_to_glsl(type), ">");
4175
4176 if (type.columns == 1 && !chain.row_major_matrix)
4177 {
4178 const char *store_op = nullptr;
4179 switch (type.vecsize)
4180 {
4181 case 1:
4182 store_op = "Store";
4183 break;
4184 case 2:
4185 store_op = "Store2";
4186 break;
4187 case 3:
4188 store_op = "Store3";
4189 break;
4190 case 4:
4191 store_op = "Store4";
4192 break;
4193 default:
4194 SPIRV_CROSS_THROW("Unknown vector size.");
4195 }
4196
4197 auto store_expr = write_access_chain_value(value, composite_chain, false);
4198
4199 if (!templated_store)
4200 {
4201 auto bitcast_op = bitcast_glsl_op(target_type, type);
4202 if (!bitcast_op.empty())
4203 store_expr = join(bitcast_op, "(", store_expr, ")");
4204 }
4205 else
4206 store_op = "Store";
4207 statement(base, ".", store_op, template_expr, "(", chain.dynamic_index, chain.static_index, ", ",
4208 store_expr, ");");
4209 }
4210 else if (type.columns == 1)
4211 {
4212 if (templated_store)
4213 {
4214 auto scalar_type = type;
4215 scalar_type.vecsize = 1;
4216 scalar_type.columns = 1;
4217 template_expr = join("<", type_to_glsl(scalar_type), ">");
4218 }
4219
4220 // Strided store.
4221 for (uint32_t r = 0; r < type.vecsize; r++)
4222 {
4223 auto store_expr = write_access_chain_value(value, composite_chain, true);
4224 if (type.vecsize > 1)
4225 {
4226 store_expr += ".";
4227 store_expr += index_to_swizzle(r);
4228 }
4229 remove_duplicate_swizzle(store_expr);
4230
4231 if (!templated_store)
4232 {
4233 auto bitcast_op = bitcast_glsl_op(target_type, type);
4234 if (!bitcast_op.empty())
4235 store_expr = join(bitcast_op, "(", store_expr, ")");
4236 }
4237
4238 statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4239 chain.static_index + chain.matrix_stride * r, ", ", store_expr, ");");
4240 }
4241 }
4242 else if (!chain.row_major_matrix)
4243 {
4244 const char *store_op = nullptr;
4245 switch (type.vecsize)
4246 {
4247 case 1:
4248 store_op = "Store";
4249 break;
4250 case 2:
4251 store_op = "Store2";
4252 break;
4253 case 3:
4254 store_op = "Store3";
4255 break;
4256 case 4:
4257 store_op = "Store4";
4258 break;
4259 default:
4260 SPIRV_CROSS_THROW("Unknown vector size.");
4261 }
4262
4263 if (templated_store)
4264 {
4265 store_op = "Store";
4266 auto vector_type = type;
4267 vector_type.columns = 1;
4268 template_expr = join("<", type_to_glsl(vector_type), ">");
4269 }
4270
4271 for (uint32_t c = 0; c < type.columns; c++)
4272 {
4273 auto store_expr = join(write_access_chain_value(value, composite_chain, true), "[", c, "]");
4274
4275 if (!templated_store)
4276 {
4277 auto bitcast_op = bitcast_glsl_op(target_type, type);
4278 if (!bitcast_op.empty())
4279 store_expr = join(bitcast_op, "(", store_expr, ")");
4280 }
4281
4282 statement(base, ".", store_op, template_expr, "(", chain.dynamic_index,
4283 chain.static_index + c * chain.matrix_stride, ", ", store_expr, ");");
4284 }
4285 }
4286 else
4287 {
4288 if (templated_store)
4289 {
4290 auto scalar_type = type;
4291 scalar_type.vecsize = 1;
4292 scalar_type.columns = 1;
4293 template_expr = join("<", type_to_glsl(scalar_type), ">");
4294 }
4295
4296 for (uint32_t r = 0; r < type.vecsize; r++)
4297 {
4298 for (uint32_t c = 0; c < type.columns; c++)
4299 {
4300 auto store_expr =
4301 join(write_access_chain_value(value, composite_chain, true), "[", c, "].", index_to_swizzle(r));
4302 remove_duplicate_swizzle(store_expr);
4303 auto bitcast_op = bitcast_glsl_op(target_type, type);
4304 if (!bitcast_op.empty())
4305 store_expr = join(bitcast_op, "(", store_expr, ")");
4306 statement(base, ".Store", template_expr, "(", chain.dynamic_index,
4307 chain.static_index + c * (type.width / 8) + r * chain.matrix_stride, ", ", store_expr, ");");
4308 }
4309 }
4310 }
4311
4312 register_write(chain.self);
4313}
4314
4315void CompilerHLSL::emit_store(const Instruction &instruction)
4316{
4317 auto ops = stream(instruction);
4318 auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4319 if (chain)
4320 write_access_chain(*chain, ops[1], {});
4321 else
4322 CompilerGLSL::emit_instruction(instruction);
4323}
4324
4325void CompilerHLSL::emit_access_chain(const Instruction &instruction)
4326{
4327 auto ops = stream(instruction);
4328 uint32_t length = instruction.length;
4329
4330 bool need_byte_access_chain = false;
4331 auto &type = expression_type(ops[2]);
4332 const auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4333
4334 if (chain)
4335 {
4336 // Keep tacking on an existing access chain.
4337 need_byte_access_chain = true;
4338 }
4339 else if (type.storage == StorageClassStorageBuffer || has_decoration(type.self, DecorationBufferBlock))
4340 {
4341 // If we are starting to poke into an SSBO, we are dealing with ByteAddressBuffers, and we need
4342 // to emit SPIRAccessChain rather than a plain SPIRExpression.
4343 uint32_t chain_arguments = length - 3;
4344 if (chain_arguments > type.array.size())
4345 need_byte_access_chain = true;
4346 }
4347
4348 if (need_byte_access_chain)
4349 {
4350 // If we have a chain variable, we are already inside the SSBO, and any array type will refer to arrays within a block,
4351 // and not array of SSBO.
4352 uint32_t to_plain_buffer_length = chain ? 0u : static_cast<uint32_t>(type.array.size());
4353
4354 auto *backing_variable = maybe_get_backing_variable(ops[2]);
4355
4356 string base;
4357 if (to_plain_buffer_length != 0)
4358 base = access_chain(ops[2], &ops[3], to_plain_buffer_length, get<SPIRType>(ops[0]));
4359 else if (chain)
4360 base = chain->base;
4361 else
4362 base = to_expression(ops[2]);
4363
4364 // Start traversing type hierarchy at the proper non-pointer types.
4365 auto *basetype = &get_pointee_type(type);
4366
4367 // Traverse the type hierarchy down to the actual buffer types.
4368 for (uint32_t i = 0; i < to_plain_buffer_length; i++)
4369 {
4370 assert(basetype->parent_type);
4371 basetype = &get<SPIRType>(basetype->parent_type);
4372 }
4373
4374 uint32_t matrix_stride = 0;
4375 uint32_t array_stride = 0;
4376 bool row_major_matrix = false;
4377
4378 // Inherit matrix information.
4379 if (chain)
4380 {
4381 matrix_stride = chain->matrix_stride;
4382 row_major_matrix = chain->row_major_matrix;
4383 array_stride = chain->array_stride;
4384 }
4385
4386 auto offsets = flattened_access_chain_offset(*basetype, &ops[3 + to_plain_buffer_length],
4387 length - 3 - to_plain_buffer_length, 0, 1, &row_major_matrix,
4388 &matrix_stride, &array_stride);
4389
4390 auto &e = set<SPIRAccessChain>(ops[1], ops[0], type.storage, base, offsets.first, offsets.second);
4391 e.row_major_matrix = row_major_matrix;
4392 e.matrix_stride = matrix_stride;
4393 e.array_stride = array_stride;
4394 e.immutable = should_forward(ops[2]);
4395 e.loaded_from = backing_variable ? backing_variable->self : ID(0);
4396
4397 if (chain)
4398 {
4399 e.dynamic_index += chain->dynamic_index;
4400 e.static_index += chain->static_index;
4401 }
4402
4403 for (uint32_t i = 2; i < length; i++)
4404 {
4405 inherit_expression_dependencies(ops[1], ops[i]);
4406 add_implied_read_expression(e, ops[i]);
4407 }
4408 }
4409 else
4410 {
4411 CompilerGLSL::emit_instruction(instruction);
4412 }
4413}
4414
4415void CompilerHLSL::emit_atomic(const uint32_t *ops, uint32_t length, spv::Op op)
4416{
4417 const char *atomic_op = nullptr;
4418
4419 string value_expr;
4420 if (op != OpAtomicIDecrement && op != OpAtomicIIncrement && op != OpAtomicLoad && op != OpAtomicStore)
4421 value_expr = to_expression(ops[op == OpAtomicCompareExchange ? 6 : 5]);
4422
4423 bool is_atomic_store = false;
4424
4425 switch (op)
4426 {
4427 case OpAtomicIIncrement:
4428 atomic_op = "InterlockedAdd";
4429 value_expr = "1";
4430 break;
4431
4432 case OpAtomicIDecrement:
4433 atomic_op = "InterlockedAdd";
4434 value_expr = "-1";
4435 break;
4436
4437 case OpAtomicLoad:
4438 atomic_op = "InterlockedAdd";
4439 value_expr = "0";
4440 break;
4441
4442 case OpAtomicISub:
4443 atomic_op = "InterlockedAdd";
4444 value_expr = join("-", enclose_expression(value_expr));
4445 break;
4446
4447 case OpAtomicSMin:
4448 case OpAtomicUMin:
4449 atomic_op = "InterlockedMin";
4450 break;
4451
4452 case OpAtomicSMax:
4453 case OpAtomicUMax:
4454 atomic_op = "InterlockedMax";
4455 break;
4456
4457 case OpAtomicAnd:
4458 atomic_op = "InterlockedAnd";
4459 break;
4460
4461 case OpAtomicOr:
4462 atomic_op = "InterlockedOr";
4463 break;
4464
4465 case OpAtomicXor:
4466 atomic_op = "InterlockedXor";
4467 break;
4468
4469 case OpAtomicIAdd:
4470 atomic_op = "InterlockedAdd";
4471 break;
4472
4473 case OpAtomicExchange:
4474 atomic_op = "InterlockedExchange";
4475 break;
4476
4477 case OpAtomicStore:
4478 atomic_op = "InterlockedExchange";
4479 is_atomic_store = true;
4480 break;
4481
4482 case OpAtomicCompareExchange:
4483 if (length < 8)
4484 SPIRV_CROSS_THROW("Not enough data for opcode.");
4485 atomic_op = "InterlockedCompareExchange";
4486 value_expr = join(to_expression(ops[7]), ", ", value_expr);
4487 break;
4488
4489 default:
4490 SPIRV_CROSS_THROW("Unknown atomic opcode.");
4491 }
4492
4493 if (is_atomic_store)
4494 {
4495 auto &data_type = expression_type(ops[0]);
4496 auto *chain = maybe_get<SPIRAccessChain>(ops[0]);
4497
4498 auto &tmp_id = extra_sub_expressions[ops[0]];
4499 if (!tmp_id)
4500 {
4501 tmp_id = ir.increase_bound_by(1);
4502 emit_uninitialized_temporary_expression(get_pointee_type(data_type).self, tmp_id);
4503 }
4504
4505 if (data_type.storage == StorageClassImage || !chain)
4506 {
4507 statement(atomic_op, "(", to_non_uniform_aware_expression(ops[0]), ", ",
4508 to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4509 }
4510 else
4511 {
4512 string base = chain->base;
4513 if (has_decoration(chain->self, DecorationNonUniform))
4514 convert_non_uniform_expression(base, chain->self);
4515 // RWByteAddress buffer is always uint in its underlying type.
4516 statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ",
4517 to_expression(ops[3]), ", ", to_expression(tmp_id), ");");
4518 }
4519 }
4520 else
4521 {
4522 uint32_t result_type = ops[0];
4523 uint32_t id = ops[1];
4524 forced_temporaries.insert(ops[1]);
4525
4526 auto &type = get<SPIRType>(result_type);
4527 statement(variable_decl(type, to_name(id)), ";");
4528
4529 auto &data_type = expression_type(ops[2]);
4530 auto *chain = maybe_get<SPIRAccessChain>(ops[2]);
4531 SPIRType::BaseType expr_type;
4532 if (data_type.storage == StorageClassImage || !chain)
4533 {
4534 statement(atomic_op, "(", to_non_uniform_aware_expression(ops[2]), ", ", value_expr, ", ", to_name(id), ");");
4535 expr_type = data_type.basetype;
4536 }
4537 else
4538 {
4539 // RWByteAddress buffer is always uint in its underlying type.
4540 string base = chain->base;
4541 if (has_decoration(chain->self, DecorationNonUniform))
4542 convert_non_uniform_expression(base, chain->self);
4543 expr_type = SPIRType::UInt;
4544 statement(base, ".", atomic_op, "(", chain->dynamic_index, chain->static_index, ", ", value_expr,
4545 ", ", to_name(id), ");");
4546 }
4547
4548 auto expr = bitcast_expression(type, expr_type, to_name(id));
4549 set<SPIRExpression>(id, expr, result_type, true);
4550 }
4551 flush_all_atomic_capable_variables();
4552}
4553
4554void CompilerHLSL::emit_subgroup_op(const Instruction &i)
4555{
4556 if (hlsl_options.shader_model < 60)
4557 SPIRV_CROSS_THROW("Wave ops requires SM 6.0 or higher.");
4558
4559 const uint32_t *ops = stream(i);
4560 auto op = static_cast<Op>(i.op);
4561
4562 uint32_t result_type = ops[0];
4563 uint32_t id = ops[1];
4564
4565 auto scope = static_cast<Scope>(evaluate_constant_u32(ops[2]));
4566 if (scope != ScopeSubgroup)
4567 SPIRV_CROSS_THROW("Only subgroup scope is supported.");
4568
4569 const auto make_inclusive_Sum = [&](const string &expr) -> string {
4570 return join(expr, " + ", to_expression(ops[4]));
4571 };
4572
4573 const auto make_inclusive_Product = [&](const string &expr) -> string {
4574 return join(expr, " * ", to_expression(ops[4]));
4575 };
4576
4577 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4578 uint32_t integer_width = get_integer_width_for_instruction(i);
4579 auto int_type = to_signed_basetype(integer_width);
4580 auto uint_type = to_unsigned_basetype(integer_width);
4581
4582#define make_inclusive_BitAnd(expr) ""
4583#define make_inclusive_BitOr(expr) ""
4584#define make_inclusive_BitXor(expr) ""
4585#define make_inclusive_Min(expr) ""
4586#define make_inclusive_Max(expr) ""
4587
4588 switch (op)
4589 {
4590 case OpGroupNonUniformElect:
4591 emit_op(result_type, id, "WaveIsFirstLane()", true);
4592 break;
4593
4594 case OpGroupNonUniformBroadcast:
4595 emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4596 break;
4597
4598 case OpGroupNonUniformBroadcastFirst:
4599 emit_unary_func_op(result_type, id, ops[3], "WaveReadLaneFirst");
4600 break;
4601
4602 case OpGroupNonUniformBallot:
4603 emit_unary_func_op(result_type, id, ops[3], "WaveActiveBallot");
4604 break;
4605
4606 case OpGroupNonUniformInverseBallot:
4607 SPIRV_CROSS_THROW("Cannot trivially implement InverseBallot in HLSL.");
4608
4609 case OpGroupNonUniformBallotBitExtract:
4610 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitExtract in HLSL.");
4611
4612 case OpGroupNonUniformBallotFindLSB:
4613 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindLSB in HLSL.");
4614
4615 case OpGroupNonUniformBallotFindMSB:
4616 SPIRV_CROSS_THROW("Cannot trivially implement BallotFindMSB in HLSL.");
4617
4618 case OpGroupNonUniformBallotBitCount:
4619 {
4620 auto operation = static_cast<GroupOperation>(ops[3]);
4621 if (operation == GroupOperationReduce)
4622 {
4623 bool forward = should_forward(ops[4]);
4624 auto left = join("countbits(", to_enclosed_expression(ops[4]), ".x) + countbits(",
4625 to_enclosed_expression(ops[4]), ".y)");
4626 auto right = join("countbits(", to_enclosed_expression(ops[4]), ".z) + countbits(",
4627 to_enclosed_expression(ops[4]), ".w)");
4628 emit_op(result_type, id, join(left, " + ", right), forward);
4629 inherit_expression_dependencies(id, ops[4]);
4630 }
4631 else if (operation == GroupOperationInclusiveScan)
4632 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Inclusive Scan in HLSL.");
4633 else if (operation == GroupOperationExclusiveScan)
4634 SPIRV_CROSS_THROW("Cannot trivially implement BallotBitCount Exclusive Scan in HLSL.");
4635 else
4636 SPIRV_CROSS_THROW("Invalid BitCount operation.");
4637 break;
4638 }
4639
4640 case OpGroupNonUniformShuffle:
4641 emit_binary_func_op(result_type, id, ops[3], ops[4], "WaveReadLaneAt");
4642 break;
4643 case OpGroupNonUniformShuffleXor:
4644 {
4645 bool forward = should_forward(ops[3]);
4646 emit_op(ops[0], ops[1],
4647 join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4648 "WaveGetLaneIndex() ^ ", to_enclosed_expression(ops[4]), ")"), forward);
4649 inherit_expression_dependencies(ops[1], ops[3]);
4650 break;
4651 }
4652 case OpGroupNonUniformShuffleUp:
4653 {
4654 bool forward = should_forward(ops[3]);
4655 emit_op(ops[0], ops[1],
4656 join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4657 "WaveGetLaneIndex() - ", to_enclosed_expression(ops[4]), ")"), forward);
4658 inherit_expression_dependencies(ops[1], ops[3]);
4659 break;
4660 }
4661 case OpGroupNonUniformShuffleDown:
4662 {
4663 bool forward = should_forward(ops[3]);
4664 emit_op(ops[0], ops[1],
4665 join("WaveReadLaneAt(", to_unpacked_expression(ops[3]), ", ",
4666 "WaveGetLaneIndex() + ", to_enclosed_expression(ops[4]), ")"), forward);
4667 inherit_expression_dependencies(ops[1], ops[3]);
4668 break;
4669 }
4670
4671 case OpGroupNonUniformAll:
4672 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllTrue");
4673 break;
4674
4675 case OpGroupNonUniformAny:
4676 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAnyTrue");
4677 break;
4678
4679 case OpGroupNonUniformAllEqual:
4680 emit_unary_func_op(result_type, id, ops[3], "WaveActiveAllEqual");
4681 break;
4682
4683 // clang-format off
4684#define HLSL_GROUP_OP(op, hlsl_op, supports_scan) \
4685case OpGroupNonUniform##op: \
4686 { \
4687 auto operation = static_cast<GroupOperation>(ops[3]); \
4688 if (operation == GroupOperationReduce) \
4689 emit_unary_func_op(result_type, id, ops[4], "WaveActive" #hlsl_op); \
4690 else if (operation == GroupOperationInclusiveScan && supports_scan) \
4691 { \
4692 bool forward = should_forward(ops[4]); \
4693 emit_op(result_type, id, make_inclusive_##hlsl_op (join("WavePrefix" #hlsl_op, "(", to_expression(ops[4]), ")")), forward); \
4694 inherit_expression_dependencies(id, ops[4]); \
4695 } \
4696 else if (operation == GroupOperationExclusiveScan && supports_scan) \
4697 emit_unary_func_op(result_type, id, ops[4], "WavePrefix" #hlsl_op); \
4698 else if (operation == GroupOperationClusteredReduce) \
4699 SPIRV_CROSS_THROW("Cannot trivially implement ClusteredReduce in HLSL."); \
4700 else \
4701 SPIRV_CROSS_THROW("Invalid group operation."); \
4702 break; \
4703 }
4704
4705#define HLSL_GROUP_OP_CAST(op, hlsl_op, type) \
4706case OpGroupNonUniform##op: \
4707 { \
4708 auto operation = static_cast<GroupOperation>(ops[3]); \
4709 if (operation == GroupOperationReduce) \
4710 emit_unary_func_op_cast(result_type, id, ops[4], "WaveActive" #hlsl_op, type, type); \
4711 else \
4712 SPIRV_CROSS_THROW("Invalid group operation."); \
4713 break; \
4714 }
4715
4716 HLSL_GROUP_OP(FAdd, Sum, true)
4717 HLSL_GROUP_OP(FMul, Product, true)
4718 HLSL_GROUP_OP(FMin, Min, false)
4719 HLSL_GROUP_OP(FMax, Max, false)
4720 HLSL_GROUP_OP(IAdd, Sum, true)
4721 HLSL_GROUP_OP(IMul, Product, true)
4722 HLSL_GROUP_OP_CAST(SMin, Min, int_type)
4723 HLSL_GROUP_OP_CAST(SMax, Max, int_type)
4724 HLSL_GROUP_OP_CAST(UMin, Min, uint_type)
4725 HLSL_GROUP_OP_CAST(UMax, Max, uint_type)
4726 HLSL_GROUP_OP(BitwiseAnd, BitAnd, false)
4727 HLSL_GROUP_OP(BitwiseOr, BitOr, false)
4728 HLSL_GROUP_OP(BitwiseXor, BitXor, false)
4729 HLSL_GROUP_OP_CAST(LogicalAnd, BitAnd, uint_type)
4730 HLSL_GROUP_OP_CAST(LogicalOr, BitOr, uint_type)
4731 HLSL_GROUP_OP_CAST(LogicalXor, BitXor, uint_type)
4732
4733#undef HLSL_GROUP_OP
4734#undef HLSL_GROUP_OP_CAST
4735 // clang-format on
4736
4737 case OpGroupNonUniformQuadSwap:
4738 {
4739 uint32_t direction = evaluate_constant_u32(ops[4]);
4740 if (direction == 0)
4741 emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossX");
4742 else if (direction == 1)
4743 emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossY");
4744 else if (direction == 2)
4745 emit_unary_func_op(result_type, id, ops[3], "QuadReadAcrossDiagonal");
4746 else
4747 SPIRV_CROSS_THROW("Invalid quad swap direction.");
4748 break;
4749 }
4750
4751 case OpGroupNonUniformQuadBroadcast:
4752 {
4753 emit_binary_func_op(result_type, id, ops[3], ops[4], "QuadReadLaneAt");
4754 break;
4755 }
4756
4757 default:
4758 SPIRV_CROSS_THROW("Invalid opcode for subgroup.");
4759 }
4760
4761 register_control_dependent_expression(id);
4762}
4763
4764void CompilerHLSL::emit_instruction(const Instruction &instruction)
4765{
4766 auto ops = stream(instruction);
4767 auto opcode = static_cast<Op>(instruction.op);
4768
4769#define HLSL_BOP(op) emit_binary_op(ops[0], ops[1], ops[2], ops[3], #op)
4770#define HLSL_BOP_CAST(op, type) \
4771 emit_binary_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4772#define HLSL_UOP(op) emit_unary_op(ops[0], ops[1], ops[2], #op)
4773#define HLSL_QFOP(op) emit_quaternary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], ops[5], #op)
4774#define HLSL_TFOP(op) emit_trinary_func_op(ops[0], ops[1], ops[2], ops[3], ops[4], #op)
4775#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4776#define HLSL_BFOP_CAST(op, type) \
4777 emit_binary_func_op_cast(ops[0], ops[1], ops[2], ops[3], #op, type, opcode_is_sign_invariant(opcode))
4778#define HLSL_BFOP(op) emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], #op)
4779#define HLSL_UFOP(op) emit_unary_func_op(ops[0], ops[1], ops[2], #op)
4780
4781 // If we need to do implicit bitcasts, make sure we do it with the correct type.
4782 uint32_t integer_width = get_integer_width_for_instruction(instruction);
4783 auto int_type = to_signed_basetype(integer_width);
4784 auto uint_type = to_unsigned_basetype(integer_width);
4785
4786 switch (opcode)
4787 {
4788 case OpAccessChain:
4789 case OpInBoundsAccessChain:
4790 {
4791 emit_access_chain(instruction);
4792 break;
4793 }
4794 case OpBitcast:
4795 {
4796 auto bitcast_type = get_bitcast_type(ops[0], ops[2]);
4797 if (bitcast_type == CompilerHLSL::TypeNormal)
4798 CompilerGLSL::emit_instruction(instruction);
4799 else
4800 {
4801 if (!requires_uint2_packing)
4802 {
4803 requires_uint2_packing = true;
4804 force_recompile();
4805 }
4806
4807 if (bitcast_type == CompilerHLSL::TypePackUint2x32)
4808 emit_unary_func_op(ops[0], ops[1], ops[2], "spvPackUint2x32");
4809 else
4810 emit_unary_func_op(ops[0], ops[1], ops[2], "spvUnpackUint2x32");
4811 }
4812
4813 break;
4814 }
4815
4816 case OpSelect:
4817 {
4818 auto &value_type = expression_type(ops[3]);
4819 if (value_type.basetype == SPIRType::Struct || is_array(value_type))
4820 {
4821 // HLSL does not support ternary expressions on composites.
4822 // Cannot use branches, since we might be in a continue block
4823 // where explicit control flow is prohibited.
4824 // Emit a helper function where we can use control flow.
4825 TypeID value_type_id = expression_type_id(ops[3]);
4826 auto itr = std::find(composite_selection_workaround_types.begin(),
4827 composite_selection_workaround_types.end(),
4828 value_type_id);
4829 if (itr == composite_selection_workaround_types.end())
4830 {
4831 composite_selection_workaround_types.push_back(value_type_id);
4832 force_recompile();
4833 }
4834 emit_uninitialized_temporary_expression(ops[0], ops[1]);
4835 statement("spvSelectComposite(",
4836 to_expression(ops[1]), ", ", to_expression(ops[2]), ", ",
4837 to_expression(ops[3]), ", ", to_expression(ops[4]), ");");
4838 }
4839 else
4840 CompilerGLSL::emit_instruction(instruction);
4841 break;
4842 }
4843
4844 case OpStore:
4845 {
4846 emit_store(instruction);
4847 break;
4848 }
4849
4850 case OpLoad:
4851 {
4852 emit_load(instruction);
4853 break;
4854 }
4855
4856 case OpMatrixTimesVector:
4857 {
4858 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4859 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4860 break;
4861 }
4862
4863 case OpVectorTimesMatrix:
4864 {
4865 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4866 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4867 break;
4868 }
4869
4870 case OpMatrixTimesMatrix:
4871 {
4872 // Matrices are kept in a transposed state all the time, flip multiplication order always.
4873 emit_binary_func_op(ops[0], ops[1], ops[3], ops[2], "mul");
4874 break;
4875 }
4876
4877 case OpOuterProduct:
4878 {
4879 uint32_t result_type = ops[0];
4880 uint32_t id = ops[1];
4881 uint32_t a = ops[2];
4882 uint32_t b = ops[3];
4883
4884 auto &type = get<SPIRType>(result_type);
4885 string expr = type_to_glsl_constructor(type);
4886 expr += "(";
4887 for (uint32_t col = 0; col < type.columns; col++)
4888 {
4889 expr += to_enclosed_expression(a);
4890 expr += " * ";
4891 expr += to_extract_component_expression(b, col);
4892 if (col + 1 < type.columns)
4893 expr += ", ";
4894 }
4895 expr += ")";
4896 emit_op(result_type, id, expr, should_forward(a) && should_forward(b));
4897 inherit_expression_dependencies(id, a);
4898 inherit_expression_dependencies(id, b);
4899 break;
4900 }
4901
4902 case OpFMod:
4903 {
4904 if (!requires_op_fmod)
4905 {
4906 requires_op_fmod = true;
4907 force_recompile();
4908 }
4909 CompilerGLSL::emit_instruction(instruction);
4910 break;
4911 }
4912
4913 case OpFRem:
4914 emit_binary_func_op(ops[0], ops[1], ops[2], ops[3], "fmod");
4915 break;
4916
4917 case OpImage:
4918 {
4919 uint32_t result_type = ops[0];
4920 uint32_t id = ops[1];
4921 auto *combined = maybe_get<SPIRCombinedImageSampler>(ops[2]);
4922
4923 if (combined)
4924 {
4925 auto &e = emit_op(result_type, id, to_expression(combined->image), true, true);
4926 auto *var = maybe_get_backing_variable(combined->image);
4927 if (var)
4928 e.loaded_from = var->self;
4929 }
4930 else
4931 {
4932 auto &e = emit_op(result_type, id, to_expression(ops[2]), true, true);
4933 auto *var = maybe_get_backing_variable(ops[2]);
4934 if (var)
4935 e.loaded_from = var->self;
4936 }
4937 break;
4938 }
4939
4940 case OpDPdx:
4941 HLSL_UFOP(ddx);
4942 register_control_dependent_expression(ops[1]);
4943 break;
4944
4945 case OpDPdy:
4946 HLSL_UFOP(ddy);
4947 register_control_dependent_expression(ops[1]);
4948 break;
4949
4950 case OpDPdxFine:
4951 HLSL_UFOP(ddx_fine);
4952 register_control_dependent_expression(ops[1]);
4953 break;
4954
4955 case OpDPdyFine:
4956 HLSL_UFOP(ddy_fine);
4957 register_control_dependent_expression(ops[1]);
4958 break;
4959
4960 case OpDPdxCoarse:
4961 HLSL_UFOP(ddx_coarse);
4962 register_control_dependent_expression(ops[1]);
4963 break;
4964
4965 case OpDPdyCoarse:
4966 HLSL_UFOP(ddy_coarse);
4967 register_control_dependent_expression(ops[1]);
4968 break;
4969
4970 case OpFwidth:
4971 case OpFwidthCoarse:
4972 case OpFwidthFine:
4973 HLSL_UFOP(fwidth);
4974 register_control_dependent_expression(ops[1]);
4975 break;
4976
4977 case OpLogicalNot:
4978 {
4979 auto result_type = ops[0];
4980 auto id = ops[1];
4981 auto &type = get<SPIRType>(result_type);
4982
4983 if (type.vecsize > 1)
4984 emit_unrolled_unary_op(result_type, id, ops[2], "!");
4985 else
4986 HLSL_UOP(!);
4987 break;
4988 }
4989
4990 case OpIEqual:
4991 {
4992 auto result_type = ops[0];
4993 auto id = ops[1];
4994
4995 if (expression_type(ops[2]).vecsize > 1)
4996 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
4997 else
4998 HLSL_BOP_CAST(==, int_type);
4999 break;
5000 }
5001
5002 case OpLogicalEqual:
5003 case OpFOrdEqual:
5004 case OpFUnordEqual:
5005 {
5006 // HLSL != operator is unordered.
5007 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5008 // isnan() is apparently implemented as x != x as well.
5009 // We cannot implement UnordEqual as !(OrdNotEqual), as HLSL cannot express OrdNotEqual.
5010 // HACK: FUnordEqual will be implemented as FOrdEqual.
5011
5012 auto result_type = ops[0];
5013 auto id = ops[1];
5014
5015 if (expression_type(ops[2]).vecsize > 1)
5016 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "==", false, SPIRType::Unknown);
5017 else
5018 HLSL_BOP(==);
5019 break;
5020 }
5021
5022 case OpINotEqual:
5023 {
5024 auto result_type = ops[0];
5025 auto id = ops[1];
5026
5027 if (expression_type(ops[2]).vecsize > 1)
5028 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
5029 else
5030 HLSL_BOP_CAST(!=, int_type);
5031 break;
5032 }
5033
5034 case OpLogicalNotEqual:
5035 case OpFOrdNotEqual:
5036 case OpFUnordNotEqual:
5037 {
5038 // HLSL != operator is unordered.
5039 // https://docs.microsoft.com/en-us/windows/win32/direct3d10/d3d10-graphics-programming-guide-resources-float-rules.
5040 // isnan() is apparently implemented as x != x as well.
5041
5042 // FIXME: FOrdNotEqual cannot be implemented in a crisp and simple way here.
5043 // We would need to do something like not(UnordEqual), but that cannot be expressed either.
5044 // Adding a lot of NaN checks would be a breaking change from perspective of performance.
5045 // SPIR-V will generally use isnan() checks when this even matters.
5046 // HACK: FOrdNotEqual will be implemented as FUnordEqual.
5047
5048 auto result_type = ops[0];
5049 auto id = ops[1];
5050
5051 if (expression_type(ops[2]).vecsize > 1)
5052 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "!=", false, SPIRType::Unknown);
5053 else
5054 HLSL_BOP(!=);
5055 break;
5056 }
5057
5058 case OpUGreaterThan:
5059 case OpSGreaterThan:
5060 {
5061 auto result_type = ops[0];
5062 auto id = ops[1];
5063 auto type = opcode == OpUGreaterThan ? uint_type : int_type;
5064
5065 if (expression_type(ops[2]).vecsize > 1)
5066 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, type);
5067 else
5068 HLSL_BOP_CAST(>, type);
5069 break;
5070 }
5071
5072 case OpFOrdGreaterThan:
5073 {
5074 auto result_type = ops[0];
5075 auto id = ops[1];
5076
5077 if (expression_type(ops[2]).vecsize > 1)
5078 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", false, SPIRType::Unknown);
5079 else
5080 HLSL_BOP(>);
5081 break;
5082 }
5083
5084 case OpFUnordGreaterThan:
5085 {
5086 auto result_type = ops[0];
5087 auto id = ops[1];
5088
5089 if (expression_type(ops[2]).vecsize > 1)
5090 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", true, SPIRType::Unknown);
5091 else
5092 CompilerGLSL::emit_instruction(instruction);
5093 break;
5094 }
5095
5096 case OpUGreaterThanEqual:
5097 case OpSGreaterThanEqual:
5098 {
5099 auto result_type = ops[0];
5100 auto id = ops[1];
5101
5102 auto type = opcode == OpUGreaterThanEqual ? uint_type : int_type;
5103 if (expression_type(ops[2]).vecsize > 1)
5104 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, type);
5105 else
5106 HLSL_BOP_CAST(>=, type);
5107 break;
5108 }
5109
5110 case OpFOrdGreaterThanEqual:
5111 {
5112 auto result_type = ops[0];
5113 auto id = ops[1];
5114
5115 if (expression_type(ops[2]).vecsize > 1)
5116 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", false, SPIRType::Unknown);
5117 else
5118 HLSL_BOP(>=);
5119 break;
5120 }
5121
5122 case OpFUnordGreaterThanEqual:
5123 {
5124 auto result_type = ops[0];
5125 auto id = ops[1];
5126
5127 if (expression_type(ops[2]).vecsize > 1)
5128 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", true, SPIRType::Unknown);
5129 else
5130 CompilerGLSL::emit_instruction(instruction);
5131 break;
5132 }
5133
5134 case OpULessThan:
5135 case OpSLessThan:
5136 {
5137 auto result_type = ops[0];
5138 auto id = ops[1];
5139
5140 auto type = opcode == OpULessThan ? uint_type : int_type;
5141 if (expression_type(ops[2]).vecsize > 1)
5142 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, type);
5143 else
5144 HLSL_BOP_CAST(<, type);
5145 break;
5146 }
5147
5148 case OpFOrdLessThan:
5149 {
5150 auto result_type = ops[0];
5151 auto id = ops[1];
5152
5153 if (expression_type(ops[2]).vecsize > 1)
5154 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<", false, SPIRType::Unknown);
5155 else
5156 HLSL_BOP(<);
5157 break;
5158 }
5159
5160 case OpFUnordLessThan:
5161 {
5162 auto result_type = ops[0];
5163 auto id = ops[1];
5164
5165 if (expression_type(ops[2]).vecsize > 1)
5166 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">=", true, SPIRType::Unknown);
5167 else
5168 CompilerGLSL::emit_instruction(instruction);
5169 break;
5170 }
5171
5172 case OpULessThanEqual:
5173 case OpSLessThanEqual:
5174 {
5175 auto result_type = ops[0];
5176 auto id = ops[1];
5177
5178 auto type = opcode == OpULessThanEqual ? uint_type : int_type;
5179 if (expression_type(ops[2]).vecsize > 1)
5180 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, type);
5181 else
5182 HLSL_BOP_CAST(<=, type);
5183 break;
5184 }
5185
5186 case OpFOrdLessThanEqual:
5187 {
5188 auto result_type = ops[0];
5189 auto id = ops[1];
5190
5191 if (expression_type(ops[2]).vecsize > 1)
5192 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], "<=", false, SPIRType::Unknown);
5193 else
5194 HLSL_BOP(<=);
5195 break;
5196 }
5197
5198 case OpFUnordLessThanEqual:
5199 {
5200 auto result_type = ops[0];
5201 auto id = ops[1];
5202
5203 if (expression_type(ops[2]).vecsize > 1)
5204 emit_unrolled_binary_op(result_type, id, ops[2], ops[3], ">", true, SPIRType::Unknown);
5205 else
5206 CompilerGLSL::emit_instruction(instruction);
5207 break;
5208 }
5209
5210 case OpImageQueryLod:
5211 emit_texture_op(instruction, false);
5212 break;
5213
5214 case OpImageQuerySizeLod:
5215 {
5216 auto result_type = ops[0];
5217 auto id = ops[1];
5218
5219 require_texture_query_variant(ops[2]);
5220 auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5221 statement("uint ", dummy_samples_levels, ";");
5222
5223 auto expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", ",
5224 bitcast_expression(SPIRType::UInt, ops[3]), ", ", dummy_samples_levels, ")");
5225
5226 auto &restype = get<SPIRType>(ops[0]);
5227 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5228 emit_op(result_type, id, expr, true);
5229 break;
5230 }
5231
5232 case OpImageQuerySize:
5233 {
5234 auto result_type = ops[0];
5235 auto id = ops[1];
5236
5237 require_texture_query_variant(ops[2]);
5238 bool uav = expression_type(ops[2]).image.sampled == 2;
5239
5240 if (const auto *var = maybe_get_backing_variable(ops[2]))
5241 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5242 uav = false;
5243
5244 auto dummy_samples_levels = join(get_fallback_name(id), "_dummy_parameter");
5245 statement("uint ", dummy_samples_levels, ";");
5246
5247 string expr;
5248 if (uav)
5249 expr = join("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", dummy_samples_levels, ")");
5250 else
5251 expr = join("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", dummy_samples_levels, ")");
5252
5253 auto &restype = get<SPIRType>(ops[0]);
5254 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5255 emit_op(result_type, id, expr, true);
5256 break;
5257 }
5258
5259 case OpImageQuerySamples:
5260 case OpImageQueryLevels:
5261 {
5262 auto result_type = ops[0];
5263 auto id = ops[1];
5264
5265 require_texture_query_variant(ops[2]);
5266 bool uav = expression_type(ops[2]).image.sampled == 2;
5267 if (opcode == OpImageQueryLevels && uav)
5268 SPIRV_CROSS_THROW("Cannot query levels for UAV images.");
5269
5270 if (const auto *var = maybe_get_backing_variable(ops[2]))
5271 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var->self, DecorationNonWritable))
5272 uav = false;
5273
5274 // Keep it simple and do not emit special variants to make this look nicer ...
5275 // This stuff is barely, if ever, used.
5276 forced_temporaries.insert(id);
5277 auto &type = get<SPIRType>(result_type);
5278 statement(variable_decl(type, to_name(id)), ";");
5279
5280 if (uav)
5281 statement("spvImageSize(", to_non_uniform_aware_expression(ops[2]), ", ", to_name(id), ");");
5282 else
5283 statement("spvTextureSize(", to_non_uniform_aware_expression(ops[2]), ", 0u, ", to_name(id), ");");
5284
5285 auto &restype = get<SPIRType>(ops[0]);
5286 auto expr = bitcast_expression(restype, SPIRType::UInt, to_name(id));
5287 set<SPIRExpression>(id, expr, result_type, true);
5288 break;
5289 }
5290
5291 case OpImageRead:
5292 {
5293 uint32_t result_type = ops[0];
5294 uint32_t id = ops[1];
5295 auto *var = maybe_get_backing_variable(ops[2]);
5296 auto &type = expression_type(ops[2]);
5297 bool subpass_data = type.image.dim == DimSubpassData;
5298 bool pure = false;
5299
5300 string imgexpr;
5301
5302 if (subpass_data)
5303 {
5304 if (hlsl_options.shader_model < 40)
5305 SPIRV_CROSS_THROW("Subpass loads are not supported in HLSL shader model 2/3.");
5306
5307 // Similar to GLSL, implement subpass loads using texelFetch.
5308 if (type.image.ms)
5309 {
5310 uint32_t operands = ops[4];
5311 if (operands != ImageOperandsSampleMask || instruction.length != 6)
5312 SPIRV_CROSS_THROW("Multisampled image used in OpImageRead, but unexpected operand mask was used.");
5313 uint32_t sample = ops[5];
5314 imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int2(gl_FragCoord.xy), ", to_expression(sample), ")");
5315 }
5316 else
5317 imgexpr = join(to_non_uniform_aware_expression(ops[2]), ".Load(int3(int2(gl_FragCoord.xy), 0))");
5318
5319 pure = true;
5320 }
5321 else
5322 {
5323 imgexpr = join(to_non_uniform_aware_expression(ops[2]), "[", to_expression(ops[3]), "]");
5324 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5325 // except that the underlying type changes how the data is interpreted.
5326
5327 bool force_srv =
5328 hlsl_options.nonwritable_uav_texture_as_srv && var && has_decoration(var->self, DecorationNonWritable);
5329 pure = force_srv;
5330
5331 if (var && !subpass_data && !force_srv)
5332 imgexpr = remap_swizzle(get<SPIRType>(result_type),
5333 image_format_to_components(get<SPIRType>(var->basetype).image.format), imgexpr);
5334 }
5335
5336 if (var && var->forwardable)
5337 {
5338 bool forward = forced_temporaries.find(id) == end(forced_temporaries);
5339 auto &e = emit_op(result_type, id, imgexpr, forward);
5340
5341 if (!pure)
5342 {
5343 e.loaded_from = var->self;
5344 if (forward)
5345 var->dependees.push_back(id);
5346 }
5347 }
5348 else
5349 emit_op(result_type, id, imgexpr, false);
5350
5351 inherit_expression_dependencies(id, ops[2]);
5352 if (type.image.ms)
5353 inherit_expression_dependencies(id, ops[5]);
5354 break;
5355 }
5356
5357 case OpImageWrite:
5358 {
5359 auto *var = maybe_get_backing_variable(ops[0]);
5360
5361 // The underlying image type in HLSL depends on the image format, unlike GLSL, where all images are "vec4",
5362 // except that the underlying type changes how the data is interpreted.
5363 auto value_expr = to_expression(ops[2]);
5364 if (var)
5365 {
5366 auto &type = get<SPIRType>(var->basetype);
5367 auto narrowed_type = get<SPIRType>(type.image.type);
5368 narrowed_type.vecsize = image_format_to_components(type.image.format);
5369 value_expr = remap_swizzle(narrowed_type, expression_type(ops[2]).vecsize, value_expr);
5370 }
5371
5372 statement(to_non_uniform_aware_expression(ops[0]), "[", to_expression(ops[1]), "] = ", value_expr, ";");
5373 if (var && variable_storage_is_aliased(*var))
5374 flush_all_aliased_variables();
5375 break;
5376 }
5377
5378 case OpImageTexelPointer:
5379 {
5380 uint32_t result_type = ops[0];
5381 uint32_t id = ops[1];
5382
5383 auto expr = to_expression(ops[2]);
5384 expr += join("[", to_expression(ops[3]), "]");
5385 auto &e = set<SPIRExpression>(id, expr, result_type, true);
5386
5387 // When using the pointer, we need to know which variable it is actually loaded from.
5388 auto *var = maybe_get_backing_variable(ops[2]);
5389 e.loaded_from = var ? var->self : ID(0);
5390 inherit_expression_dependencies(id, ops[3]);
5391 break;
5392 }
5393
5394 case OpAtomicCompareExchange:
5395 case OpAtomicExchange:
5396 case OpAtomicISub:
5397 case OpAtomicSMin:
5398 case OpAtomicUMin:
5399 case OpAtomicSMax:
5400 case OpAtomicUMax:
5401 case OpAtomicAnd:
5402 case OpAtomicOr:
5403 case OpAtomicXor:
5404 case OpAtomicIAdd:
5405 case OpAtomicIIncrement:
5406 case OpAtomicIDecrement:
5407 case OpAtomicLoad:
5408 case OpAtomicStore:
5409 {
5410 emit_atomic(ops, instruction.length, opcode);
5411 break;
5412 }
5413
5414 case OpControlBarrier:
5415 case OpMemoryBarrier:
5416 {
5417 uint32_t memory;
5418 uint32_t semantics;
5419
5420 if (opcode == OpMemoryBarrier)
5421 {
5422 memory = evaluate_constant_u32(ops[0]);
5423 semantics = evaluate_constant_u32(ops[1]);
5424 }
5425 else
5426 {
5427 memory = evaluate_constant_u32(ops[1]);
5428 semantics = evaluate_constant_u32(ops[2]);
5429 }
5430
5431 if (memory == ScopeSubgroup)
5432 {
5433 // No Wave-barriers in HLSL.
5434 break;
5435 }
5436
5437 // We only care about these flags, acquire/release and friends are not relevant to GLSL.
5438 semantics = mask_relevant_memory_semantics(semantics);
5439
5440 if (opcode == OpMemoryBarrier)
5441 {
5442 // If we are a memory barrier, and the next instruction is a control barrier, check if that memory barrier
5443 // does what we need, so we avoid redundant barriers.
5444 const Instruction *next = get_next_instruction_in_block(instruction);
5445 if (next && next->op == OpControlBarrier)
5446 {
5447 auto *next_ops = stream(*next);
5448 uint32_t next_memory = evaluate_constant_u32(next_ops[1]);
5449 uint32_t next_semantics = evaluate_constant_u32(next_ops[2]);
5450 next_semantics = mask_relevant_memory_semantics(next_semantics);
5451
5452 // There is no "just execution barrier" in HLSL.
5453 // If there are no memory semantics for next instruction, we will imply group shared memory is synced.
5454 if (next_semantics == 0)
5455 next_semantics = MemorySemanticsWorkgroupMemoryMask;
5456
5457 bool memory_scope_covered = false;
5458 if (next_memory == memory)
5459 memory_scope_covered = true;
5460 else if (next_semantics == MemorySemanticsWorkgroupMemoryMask)
5461 {
5462 // If we only care about workgroup memory, either Device or Workgroup scope is fine,
5463 // scope does not have to match.
5464 if ((next_memory == ScopeDevice || next_memory == ScopeWorkgroup) &&
5465 (memory == ScopeDevice || memory == ScopeWorkgroup))
5466 {
5467 memory_scope_covered = true;
5468 }
5469 }
5470 else if (memory == ScopeWorkgroup && next_memory == ScopeDevice)
5471 {
5472 // The control barrier has device scope, but the memory barrier just has workgroup scope.
5473 memory_scope_covered = true;
5474 }
5475
5476 // If we have the same memory scope, and all memory types are covered, we're good.
5477 if (memory_scope_covered && (semantics & next_semantics) == semantics)
5478 break;
5479 }
5480 }
5481
5482 // We are synchronizing some memory or syncing execution,
5483 // so we cannot forward any loads beyond the memory barrier.
5484 if (semantics || opcode == OpControlBarrier)
5485 {
5486 assert(current_emitting_block);
5487 flush_control_dependent_expressions(current_emitting_block->self);
5488 flush_all_active_variables();
5489 }
5490
5491 if (opcode == OpControlBarrier)
5492 {
5493 // We cannot emit just execution barrier, for no memory semantics pick the cheapest option.
5494 if (semantics == MemorySemanticsWorkgroupMemoryMask || semantics == 0)
5495 statement("GroupMemoryBarrierWithGroupSync();");
5496 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5497 statement("DeviceMemoryBarrierWithGroupSync();");
5498 else
5499 statement("AllMemoryBarrierWithGroupSync();");
5500 }
5501 else
5502 {
5503 if (semantics == MemorySemanticsWorkgroupMemoryMask)
5504 statement("GroupMemoryBarrier();");
5505 else if (semantics != 0 && (semantics & MemorySemanticsWorkgroupMemoryMask) == 0)
5506 statement("DeviceMemoryBarrier();");
5507 else
5508 statement("AllMemoryBarrier();");
5509 }
5510 break;
5511 }
5512
5513 case OpBitFieldInsert:
5514 {
5515 if (!requires_bitfield_insert)
5516 {
5517 requires_bitfield_insert = true;
5518 force_recompile();
5519 }
5520
5521 auto expr = join("spvBitfieldInsert(", to_expression(ops[2]), ", ", to_expression(ops[3]), ", ",
5522 to_expression(ops[4]), ", ", to_expression(ops[5]), ")");
5523
5524 bool forward =
5525 should_forward(ops[2]) && should_forward(ops[3]) && should_forward(ops[4]) && should_forward(ops[5]);
5526
5527 auto &restype = get<SPIRType>(ops[0]);
5528 expr = bitcast_expression(restype, SPIRType::UInt, expr);
5529 emit_op(ops[0], ops[1], expr, forward);
5530 break;
5531 }
5532
5533 case OpBitFieldSExtract:
5534 case OpBitFieldUExtract:
5535 {
5536 if (!requires_bitfield_extract)
5537 {
5538 requires_bitfield_extract = true;
5539 force_recompile();
5540 }
5541
5542 if (opcode == OpBitFieldSExtract)
5543 HLSL_TFOP(spvBitfieldSExtract);
5544 else
5545 HLSL_TFOP(spvBitfieldUExtract);
5546 break;
5547 }
5548
5549 case OpBitCount:
5550 {
5551 auto basetype = expression_type(ops[2]).basetype;
5552 emit_unary_func_op_cast(ops[0], ops[1], ops[2], "countbits", basetype, basetype);
5553 break;
5554 }
5555
5556 case OpBitReverse:
5557 HLSL_UFOP(reversebits);
5558 break;
5559
5560 case OpArrayLength:
5561 {
5562 auto *var = maybe_get_backing_variable(ops[2]);
5563 if (!var)
5564 SPIRV_CROSS_THROW("Array length must point directly to an SSBO block.");
5565
5566 auto &type = get<SPIRType>(var->basetype);
5567 if (!has_decoration(type.self, DecorationBlock) && !has_decoration(type.self, DecorationBufferBlock))
5568 SPIRV_CROSS_THROW("Array length expression must point to a block type.");
5569
5570 // This must be 32-bit uint, so we're good to go.
5571 emit_uninitialized_temporary_expression(ops[0], ops[1]);
5572 statement(to_non_uniform_aware_expression(ops[2]), ".GetDimensions(", to_expression(ops[1]), ");");
5573 uint32_t offset = type_struct_member_offset(type, ops[3]);
5574 uint32_t stride = type_struct_member_array_stride(type, ops[3]);
5575 statement(to_expression(ops[1]), " = (", to_expression(ops[1]), " - ", offset, ") / ", stride, ";");
5576 break;
5577 }
5578
5579 case OpIsHelperInvocationEXT:
5580 SPIRV_CROSS_THROW("helperInvocationEXT() is not supported in HLSL.");
5581
5582 case OpBeginInvocationInterlockEXT:
5583 case OpEndInvocationInterlockEXT:
5584 if (hlsl_options.shader_model < 51)
5585 SPIRV_CROSS_THROW("Rasterizer order views require Shader Model 5.1.");
5586 break; // Nothing to do in the body
5587
5588 default:
5589 CompilerGLSL::emit_instruction(instruction);
5590 break;
5591 }
5592}
5593
5594void CompilerHLSL::require_texture_query_variant(uint32_t var_id)
5595{
5596 if (const auto *var = maybe_get_backing_variable(var_id))
5597 var_id = var->self;
5598
5599 auto &type = expression_type(var_id);
5600 bool uav = type.image.sampled == 2;
5601 if (hlsl_options.nonwritable_uav_texture_as_srv && has_decoration(var_id, DecorationNonWritable))
5602 uav = false;
5603
5604 uint32_t bit = 0;
5605 switch (type.image.dim)
5606 {
5607 case Dim1D:
5608 bit = type.image.arrayed ? Query1DArray : Query1D;
5609 break;
5610
5611 case Dim2D:
5612 if (type.image.ms)
5613 bit = type.image.arrayed ? Query2DMSArray : Query2DMS;
5614 else
5615 bit = type.image.arrayed ? Query2DArray : Query2D;
5616 break;
5617
5618 case Dim3D:
5619 bit = Query3D;
5620 break;
5621
5622 case DimCube:
5623 bit = type.image.arrayed ? QueryCubeArray : QueryCube;
5624 break;
5625
5626 case DimBuffer:
5627 bit = QueryBuffer;
5628 break;
5629
5630 default:
5631 SPIRV_CROSS_THROW("Unsupported query type.");
5632 }
5633
5634 switch (get<SPIRType>(type.image.type).basetype)
5635 {
5636 case SPIRType::Float:
5637 bit += QueryTypeFloat;
5638 break;
5639
5640 case SPIRType::Int:
5641 bit += QueryTypeInt;
5642 break;
5643
5644 case SPIRType::UInt:
5645 bit += QueryTypeUInt;
5646 break;
5647
5648 default:
5649 SPIRV_CROSS_THROW("Unsupported query type.");
5650 }
5651
5652 auto norm_state = image_format_to_normalized_state(type.image.format);
5653 auto &variant = uav ? required_texture_size_variants
5654 .uav[uint32_t(norm_state)][image_format_to_components(type.image.format) - 1] :
5655 required_texture_size_variants.srv;
5656
5657 uint64_t mask = 1ull << bit;
5658 if ((variant & mask) == 0)
5659 {
5660 force_recompile();
5661 variant |= mask;
5662 }
5663}
5664
5665void CompilerHLSL::set_root_constant_layouts(std::vector<RootConstants> layout)
5666{
5667 root_constants_layout = move(layout);
5668}
5669
5670void CompilerHLSL::add_vertex_attribute_remap(const HLSLVertexAttributeRemap &vertex_attributes)
5671{
5672 remap_vertex_attributes.push_back(vertex_attributes);
5673}
5674
5675VariableID CompilerHLSL::remap_num_workgroups_builtin()
5676{
5677 update_active_builtins();
5678
5679 if (!active_input_builtins.get(BuiltInNumWorkgroups))
5680 return 0;
5681
5682 // Create a new, fake UBO.
5683 uint32_t offset = ir.increase_bound_by(4);
5684
5685 uint32_t uint_type_id = offset;
5686 uint32_t block_type_id = offset + 1;
5687 uint32_t block_pointer_type_id = offset + 2;
5688 uint32_t variable_id = offset + 3;
5689
5690 SPIRType uint_type;
5691 uint_type.basetype = SPIRType::UInt;
5692 uint_type.width = 32;
5693 uint_type.vecsize = 3;
5694 uint_type.columns = 1;
5695 set<SPIRType>(uint_type_id, uint_type);
5696
5697 SPIRType block_type;
5698 block_type.basetype = SPIRType::Struct;
5699 block_type.member_types.push_back(uint_type_id);
5700 set<SPIRType>(block_type_id, block_type);
5701 set_decoration(block_type_id, DecorationBlock);
5702 set_member_name(block_type_id, 0, "count");
5703 set_member_decoration(block_type_id, 0, DecorationOffset, 0);
5704
5705 SPIRType block_pointer_type = block_type;
5706 block_pointer_type.pointer = true;
5707 block_pointer_type.storage = StorageClassUniform;
5708 block_pointer_type.parent_type = block_type_id;
5709 auto &ptr_type = set<SPIRType>(block_pointer_type_id, block_pointer_type);
5710
5711 // Preserve self.
5712 ptr_type.self = block_type_id;
5713
5714 set<SPIRVariable>(variable_id, block_pointer_type_id, StorageClassUniform);
5715 ir.meta[variable_id].decoration.alias = "SPIRV_Cross_NumWorkgroups";
5716
5717 num_workgroups_builtin = variable_id;
5718 get_entry_point().interface_variables.push_back(num_workgroups_builtin);
5719 return variable_id;
5720}
5721
5722void CompilerHLSL::set_resource_binding_flags(HLSLBindingFlags flags)
5723{
5724 resource_binding_flags = flags;
5725}
5726
5727void CompilerHLSL::validate_shader_model()
5728{
5729 // Check for nonuniform qualifier.
5730 // Instead of looping over all decorations to find this, just look at capabilities.
5731 for (auto &cap : ir.declared_capabilities)
5732 {
5733 switch (cap)
5734 {
5735 case CapabilityShaderNonUniformEXT:
5736 case CapabilityRuntimeDescriptorArrayEXT:
5737 if (hlsl_options.shader_model < 51)
5738 SPIRV_CROSS_THROW(
5739 "Shader model 5.1 or higher is required to use bindless resources or NonUniformResourceIndex.");
5740 break;
5741
5742 case CapabilityVariablePointers:
5743 case CapabilityVariablePointersStorageBuffer:
5744 SPIRV_CROSS_THROW("VariablePointers capability is not supported in HLSL.");
5745
5746 default:
5747 break;
5748 }
5749 }
5750
5751 if (ir.addressing_model != AddressingModelLogical)
5752 SPIRV_CROSS_THROW("Only Logical addressing model can be used with HLSL.");
5753
5754 if (hlsl_options.enable_16bit_types && hlsl_options.shader_model < 62)
5755 SPIRV_CROSS_THROW("Need at least shader model 6.2 when enabling native 16-bit type support.");
5756}
5757
5758string CompilerHLSL::compile()
5759{
5760 ir.fixup_reserved_names();
5761
5762 // Do not deal with ES-isms like precision, older extensions and such.
5763 options.es = false;
5764 options.version = 450;
5765 options.vulkan_semantics = true;
5766 backend.float_literal_suffix = true;
5767 backend.double_literal_suffix = false;
5768 backend.long_long_literal_suffix = true;
5769 backend.uint32_t_literal_suffix = true;
5770 backend.int16_t_literal_suffix = "";
5771 backend.uint16_t_literal_suffix = "u";
5772 backend.basic_int_type = "int";
5773 backend.basic_uint_type = "uint";
5774 backend.demote_literal = "discard";
5775 backend.boolean_mix_function = "";
5776 backend.swizzle_is_function = false;
5777 backend.shared_is_implied = true;
5778 backend.unsized_array_supported = true;
5779 backend.explicit_struct_type = false;
5780 backend.use_initializer_list = true;
5781 backend.use_constructor_splatting = false;
5782 backend.can_swizzle_scalar = true;
5783 backend.can_declare_struct_inline = false;
5784 backend.can_declare_arrays_inline = false;
5785 backend.can_return_array = false;
5786 backend.nonuniform_qualifier = "NonUniformResourceIndex";
5787 backend.support_case_fallthrough = false;
5788
5789 // SM 4.1 does not support precise for some reason.
5790 backend.support_precise_qualifier = hlsl_options.shader_model >= 50 || hlsl_options.shader_model == 40;
5791
5792 fixup_type_alias();
5793 reorder_type_alias();
5794 build_function_control_flow_graphs_and_analyze();
5795 validate_shader_model();
5796 update_active_builtins();
5797 analyze_image_and_sampler_usage();
5798 analyze_interlocked_resource_usage();
5799
5800 // Subpass input needs SV_Position.
5801 if (need_subpass_input)
5802 active_input_builtins.set(BuiltInFragCoord);
5803
5804 uint32_t pass_count = 0;
5805 do
5806 {
5807 reset(pass_count);
5808
5809 // Move constructor for this type is broken on GCC 4.9 ...
5810 buffer.reset();
5811
5812 emit_header();
5813 emit_resources();
5814
5815 emit_function(get<SPIRFunction>(ir.default_entry_point), Bitset());
5816 emit_hlsl_entry_point();
5817
5818 pass_count++;
5819 } while (is_forcing_recompilation());
5820
5821 // Entry point in HLSL is always main() for the time being.
5822 get_entry_point().name = "main";
5823
5824 return buffer.str();
5825}
5826
5827void CompilerHLSL::emit_block_hints(const SPIRBlock &block)
5828{
5829 switch (block.hint)
5830 {
5831 case SPIRBlock::HintFlatten:
5832 statement("[flatten]");
5833 break;
5834 case SPIRBlock::HintDontFlatten:
5835 statement("[branch]");
5836 break;
5837 case SPIRBlock::HintUnroll:
5838 statement("[unroll]");
5839 break;
5840 case SPIRBlock::HintDontUnroll:
5841 statement("[loop]");
5842 break;
5843 default:
5844 break;
5845 }
5846}
5847
5848string CompilerHLSL::get_unique_identifier()
5849{
5850 return join("_", unique_identifier_count++, "ident");
5851}
5852
5853void CompilerHLSL::add_hlsl_resource_binding(const HLSLResourceBinding &binding)
5854{
5855 StageSetBinding tuple = { binding.stage, binding.desc_set, binding.binding };
5856 resource_bindings[tuple] = { binding, false };
5857}
5858
5859bool CompilerHLSL::is_hlsl_resource_binding_used(ExecutionModel model, uint32_t desc_set, uint32_t binding) const
5860{
5861 StageSetBinding tuple = { model, desc_set, binding };
5862 auto itr = resource_bindings.find(tuple);
5863 return itr != end(resource_bindings) && itr->second.second;
5864}
5865
5866CompilerHLSL::BitcastType CompilerHLSL::get_bitcast_type(uint32_t result_type, uint32_t op0)
5867{
5868 auto &rslt_type = get<SPIRType>(result_type);
5869 auto &expr_type = expression_type(op0);
5870
5871 if (rslt_type.basetype == SPIRType::BaseType::UInt64 && expr_type.basetype == SPIRType::BaseType::UInt &&
5872 expr_type.vecsize == 2)
5873 return BitcastType::TypePackUint2x32;
5874 else if (rslt_type.basetype == SPIRType::BaseType::UInt && rslt_type.vecsize == 2 &&
5875 expr_type.basetype == SPIRType::BaseType::UInt64)
5876 return BitcastType::TypeUnpackUint64;
5877
5878 return BitcastType::TypeNormal;
5879}
5880
5881bool CompilerHLSL::is_hlsl_force_storage_buffer_as_uav(ID id) const
5882{
5883 if (hlsl_options.force_storage_buffer_as_uav)
5884 {
5885 return true;
5886 }
5887
5888 const uint32_t desc_set = get_decoration(id, spv::DecorationDescriptorSet);
5889 const uint32_t binding = get_decoration(id, spv::DecorationBinding);
5890
5891 return (force_uav_buffer_bindings.find({ desc_set, binding }) != force_uav_buffer_bindings.end());
5892}
5893
5894void CompilerHLSL::set_hlsl_force_storage_buffer_as_uav(uint32_t desc_set, uint32_t binding)
5895{
5896 SetBindingPair pair = { desc_set, binding };
5897 force_uav_buffer_bindings.insert(pair);
5898}
5899
5900bool CompilerHLSL::builtin_translates_to_nonarray(spv::BuiltIn builtin) const
5901{
5902 return (builtin == BuiltInSampleMask);
5903}
5904