1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_softmax_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#define CONCAt2(a, b) a##b )==""\n"
22R"==(#define CONCAT2(a, b) CONCAt2(a, b) )==""\n"
23R"==(#define IS_HALF_half 1 )==""\n"
24R"==(#define IS_HALF(dt) CONCAT2(IS_HALF_, dt) )==""\n"
25R"==(#if IS_FWD )==""\n"
26R"==(#define DD(i) CONCAt2(DST_D, i) )==""\n"
27R"==(#elif IS_BWD )==""\n"
28R"==(#define DD(i) CONCAt2(SRC_D, i) )==""\n"
29R"==(#else )==""\n"
30R"==(#error unsupported data parameter )==""\n"
31R"==(#endif )==""\n"
32R"==(#define OFF(dim, idx) \ )==""\n"
33R"==((dim % CONCAT2(DATA_B, idx)) * CONCAT2(DATA_SB, idx) \ )==""\n"
34R"==(+ (dim / CONCAT2(DATA_B, idx)) * CONCAT2(DATA_S, idx) )==""\n"
35R"==(#if SOFTMAX_AXIS_IDX == 0 )==""\n"
36R"==(#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
37R"==(OFF(softmax_dim, 0) + OFF(dim0, 1) + OFF(dim1, 2) + OFF(dim2, 3) \ )==""\n"
38R"==(+ OFF(dim3, 4) + OFF(dim4, 5) )==""\n"
39R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
40R"==(softmax_dim >= DD(0) || dim0 >= DD(1) || dim1 >= DD(2) || dim2 >= DD(3) \ )==""\n"
41R"==(|| dim3 >= DD(4) || dim4 >= DD(5) )==""\n"
42R"==(#elif SOFTMAX_AXIS_IDX == 1 )==""\n"
43R"==(#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
44R"==(OFF(dim0, 0) + OFF(softmax_dim, 1) + OFF(dim1, 2) + OFF(dim2, 3) \ )==""\n"
45R"==(+ OFF(dim3, 4) + OFF(dim4, 5) )==""\n"
46R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
47R"==(dim0 >= DD(0) || softmax_dim >= DD(1) || dim1 >= DD(2) || dim2 >= DD(3) \ )==""\n"
48R"==(|| dim3 >= DD(4) || dim4 >= DD(5) )==""\n"
49R"==(#elif SOFTMAX_AXIS_IDX == 2 )==""\n"
50R"==(#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
51R"==(OFF(dim0, 0) + OFF(dim1, 1) + OFF(softmax_dim, 2) + OFF(dim2, 3) \ )==""\n"
52R"==(+ OFF(dim3, 4) + OFF(dim4, 5) )==""\n"
53R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
54R"==(dim0 >= DD(0) || dim1 >= DD(1) || softmax_dim >= DD(2) || dim2 >= DD(3) \ )==""\n"
55R"==(|| dim3 >= DD(4) || dim4 >= DD(5) )==""\n"
56R"==(#elif SOFTMAX_AXIS_IDX == 3 )==""\n"
57R"==(#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
58R"==(OFF(dim0, 0) + OFF(dim1, 1) + OFF(dim2, 2) + OFF(softmax_dim, 3) \ )==""\n"
59R"==(+ OFF(dim3, 4) + OFF(dim4, 5) )==""\n"
60R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
61R"==(dim0 >= DD(0) || dim1 >= DD(1) || dim2 >= DD(2) || softmax_dim >= DD(3) \ )==""\n"
62R"==(|| dim3 >= DD(4) || dim4 >= DD(5) )==""\n"
63R"==(#elif SOFTMAX_AXIS_IDX == 4 )==""\n"
64R"==(#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
65R"==(OFF(dim0, 0) + OFF(dim1, 1) + OFF(dim2, 2) + OFF(dim3, 3) \ )==""\n"
66R"==(+ OFF(softmax_dim, 4) + OFF(dim4, 5) )==""\n"
67R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
68R"==(dim0 >= DD(0) || dim1 >= DD(1) || dim2 >= DD(2) || dim3 >= DD(3) \ )==""\n"
69R"==(|| softmax_dim >= DD(4) || dim4 >= DD(5) )==""\n"
70R"==(#elif SOFTMAX_AXIS_IDX == 5 )==""\n"
71R"==(#define DATA_OFF(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
72R"==(OFF(dim0, 0) + OFF(dim1, 1) + OFF(dim2, 2) + OFF(dim3, 3) + OFF(dim4, 4) \ )==""\n"
73R"==(+ OFF(softmax_dim, 5) )==""\n"
74R"==(#define NEEDS_PADDING(dim0, dim1, dim2, dim3, dim4, softmax_dim) \ )==""\n"
75R"==(dim0 >= DD(0) || dim1 >= DD(1) || dim2 >= DD(2) || dim3 >= DD(3) \ )==""\n"
76R"==(|| dim4 >= DD(4) || softmax_dim >= DD(5) )==""\n"
77R"==(#else )==""\n"
78R"==(#error unsupported softmax dimension )==""\n"
79R"==(#endif )==""\n"
80R"==(#if IS_FWD )==""\n"
81R"==(#if SUB_GROUP_SIZE == 16 )==""\n"
82R"==(__attribute__((reqd_work_group_size(GROUP_SIZE, 1, 1))) )==""\n"
83R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
84R"==(#endif )==""\n"
85R"==(__kernel void )==""\n"
86R"==(ref_softmax_fwd_generic(__global SRC_DATA_T *src, __global DATA_T *dst, )==""\n"
87R"==(__global float *src_scale, __global float *dst_scale) { )==""\n"
88R"==(const int dim[] = { )==""\n"
89R"==((get_global_id(0) / GROUP_SIZE) % BLOCK_0, )==""\n"
90R"==(get_global_id(1) % BLOCK_1, )==""\n"
91R"==(get_global_id(2) % BLOCK_2, )==""\n"
92R"==((get_global_id(0) / GROUP_SIZE) / BLOCK_0, )==""\n"
93R"==(get_global_id(1) / BLOCK_1, )==""\n"
94R"==(get_global_id(2) / BLOCK_2, )==""\n"
95R"==(}; )==""\n"
96R"==(float scale = 1.0f; )==""\n"
97R"==(#if WITH_SRC_SCALES )==""\n"
98R"==(scale *= src_scale[0]; )==""\n"
99R"==(#endif )==""\n"
100R"==(#if WITH_DST_SCALES )==""\n"
101R"==(scale /= dst_scale[0]; )==""\n"
102R"==(#endif )==""\n"
103R"==(#if SUB_GROUP_SIZE == 16 )==""\n"
104R"==(const int local_id = get_local_id(0); )==""\n"
105R"==(const int begin = local_id * (SOFTMAX_AXIS / GROUP_SIZE); )==""\n"
106R"==(const int end = (local_id == GROUP_SIZE - 1) )==""\n"
107R"==(? SOFTMAX_AXIS )==""\n"
108R"==(: (local_id + 1) * (SOFTMAX_AXIS / GROUP_SIZE); )==""\n"
109R"==(#if SOFTMAX_AXIS - (GROUP_SIZE - 1) * (SOFTMAX_AXIS / GROUP_SIZE) \ )==""\n"
110R"==(> SOFTMAX_AXIS / GROUP_SIZE )==""\n"
111R"==(const int buf_size )==""\n"
112R"==(= SOFTMAX_AXIS - (GROUP_SIZE - 1) * (SOFTMAX_AXIS / GROUP_SIZE); )==""\n"
113R"==(#else )==""\n"
114R"==(const int buf_size = SOFTMAX_AXIS / GROUP_SIZE; )==""\n"
115R"==(#endif )==""\n"
116R"==(#else )==""\n"
117R"==(const int begin = 0; )==""\n"
118R"==(const int end = SOFTMAX_AXIS; )==""\n"
119R"==(const int buf_size = SOFTMAX_AXIS; )==""\n"
120R"==(#endif )==""\n"
121R"==(#if IS_HALF(SRC_DATA_T) == 1 && IS_HALF(DST_DATA_TYPE) == 1 )==""\n"
122R"==(typedef half acc_t; )==""\n"
123R"==(const acc_t acc_max = HALF_MAX; )==""\n"
124R"==(const acc_t acc_zero = 0.h; )==""\n"
125R"==(#else )==""\n"
126R"==(typedef float acc_t; )==""\n"
127R"==(const acc_t acc_max = FLT_MAX; )==""\n"
128R"==(const acc_t acc_zero = 0.f; )==""\n"
129R"==(#endif )==""\n"
130R"==(acc_t d[buf_size]; )==""\n"
131R"==(acc_t max_ = -acc_max; )==""\n"
132R"==(acc_t denom_ = acc_zero; )==""\n"
133R"==(if (!(NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], begin))) { )==""\n"
134R"==(for (int i = begin; i < end && i < DD(SOFTMAX_AXIS_IDX); ++i) { )==""\n"
135R"==(size_t data_off )==""\n"
136R"==(= DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); )==""\n"
137R"==(d[i - begin] = SRC_TO_REF(src[data_off]); )==""\n"
138R"==(max_ = max(max_, d[i - begin]); )==""\n"
139R"==(} )==""\n"
140R"==(} )==""\n"
141R"==(#if SUB_GROUP_SIZE == 16 )==""\n"
142R"==(#if GROUP_SIZE == SUB_GROUP_SIZE )==""\n"
143R"==(max_ = sub_group_reduce_max(max_); )==""\n"
144R"==(#else )==""\n"
145R"==(max_ = work_group_reduce_max(max_); )==""\n"
146R"==(#endif )==""\n"
147R"==(#endif )==""\n"
148R"==(if (!(NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], begin))) { )==""\n"
149R"==(for (int i = begin; i < end && i < DD(SOFTMAX_AXIS_IDX); ++i) { )==""\n"
150R"==(#if LOGSOFTMAX )==""\n"
151R"==(denom_ += exp(d[i - begin] - max_); )==""\n"
152R"==(#else )==""\n"
153R"==(d[i - begin] = exp(d[i - begin] - max_); )==""\n"
154R"==(denom_ += d[i - begin]; )==""\n"
155R"==(#endif )==""\n"
156R"==(} )==""\n"
157R"==(} )==""\n"
158R"==(#if SUB_GROUP_SIZE == 16 )==""\n"
159R"==(#if GROUP_SIZE == SUB_GROUP_SIZE )==""\n"
160R"==(denom_ = sub_group_reduce_add(denom_); )==""\n"
161R"==(#else )==""\n"
162R"==(denom_ = work_group_reduce_add(denom_); )==""\n"
163R"==(#endif )==""\n"
164R"==(#endif )==""\n"
165R"==(#if LOGSOFTMAX )==""\n"
166R"==(denom_ = log(denom_); )==""\n"
167R"==(#else )==""\n"
168R"==(denom_ = 1.0 / denom_; )==""\n"
169R"==(#endif )==""\n"
170R"==(for (int i = begin; i < end; ++i) { )==""\n"
171R"==(size_t data_off = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); )==""\n"
172R"==(if (NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], i)) { )==""\n"
173R"==(dst[data_off] = REF_TO_DST(acc_zero); )==""\n"
174R"==(} else { )==""\n"
175R"==(float unscaled; )==""\n"
176R"==(#if LOGSOFTMAX )==""\n"
177R"==(unscaled = d[i - begin] - max_ - denom_; )==""\n"
178R"==(#else )==""\n"
179R"==(unscaled = d[i - begin] * denom_; )==""\n"
180R"==(#endif )==""\n"
181R"==(#if DT_S8 == 1 || DT_U8 == 1 )==""\n"
182R"==(dst[data_off] = REF_TO_DST(round(scale * unscaled)); )==""\n"
183R"==(#else )==""\n"
184R"==(dst[data_off] = REF_TO_DST(scale * unscaled); )==""\n"
185R"==(#endif )==""\n"
186R"==(} )==""\n"
187R"==(} )==""\n"
188R"==(} )==""\n"
189R"==(#endif )==""\n"
190R"==(#if IS_BWD )==""\n"
191R"==(__attribute__((reqd_work_group_size(GROUP_SIZE, 1, 1))) )==""\n"
192R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) )==""\n"
193R"==(__kernel void )==""\n"
194R"==(ref_softmax_bwd_generic(__global DST_DATA_T *dst, __global SRC_DATA_T *diff_src, )==""\n"
195R"==(__global DST_DATA_T *diff_dst) { )==""\n"
196R"==(const int dim[] = { )==""\n"
197R"==((get_global_id(0) / GROUP_SIZE) % BLOCK_0, )==""\n"
198R"==(get_global_id(1) % BLOCK_1, )==""\n"
199R"==(get_global_id(2) % BLOCK_2, )==""\n"
200R"==((get_global_id(0) / GROUP_SIZE) / BLOCK_0, )==""\n"
201R"==(get_global_id(1) / BLOCK_1, )==""\n"
202R"==(get_global_id(2) / BLOCK_2, )==""\n"
203R"==(}; )==""\n"
204R"==(const int local_id = get_local_id(0); )==""\n"
205R"==(const int begin = local_id * (SOFTMAX_AXIS / GROUP_SIZE); )==""\n"
206R"==(const int end = (local_id == GROUP_SIZE - 1) )==""\n"
207R"==(? SOFTMAX_AXIS )==""\n"
208R"==(: (local_id + 1) * (SOFTMAX_AXIS / GROUP_SIZE); )==""\n"
209R"==(#if SOFTMAX_AXIS - (GROUP_SIZE - 1) * (SOFTMAX_AXIS / GROUP_SIZE) \ )==""\n"
210R"==(> SOFTMAX_AXIS / GROUP_SIZE )==""\n"
211R"==(const int buf_size )==""\n"
212R"==(= SOFTMAX_AXIS - (GROUP_SIZE - 1) * (SOFTMAX_AXIS / GROUP_SIZE); )==""\n"
213R"==(#else )==""\n"
214R"==(const int buf_size = SOFTMAX_AXIS / GROUP_SIZE; )==""\n"
215R"==(#endif )==""\n"
216R"==(#if IS_HALF(SRC_DATA_T) == 1 && IS_HALF(DST_DATA_TYPE) == 1 )==""\n"
217R"==(typedef half acc_t; )==""\n"
218R"==(const acc_t acc_zero = 0.h; )==""\n"
219R"==(#else )==""\n"
220R"==(typedef float acc_t; )==""\n"
221R"==(const acc_t acc_zero = 0.f; )==""\n"
222R"==(#endif )==""\n"
223R"==(acc_t diff_d[buf_size]; )==""\n"
224R"==(acc_t d[buf_size]; )==""\n"
225R"==(acc_t sbr = acc_zero; )==""\n"
226R"==(if (!(NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], begin))) { )==""\n"
227R"==(for (int i = begin; i < end && i < DD(SOFTMAX_AXIS_IDX); ++i) { )==""\n"
228R"==(size_t idx = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); )==""\n"
229R"==(diff_d[i - begin] = DST_TO_REF(diff_dst[idx]); )==""\n"
230R"==(d[i - begin] = DST_TO_REF(dst[idx]); )==""\n"
231R"==(#if LOGSOFTMAX )==""\n"
232R"==(sbr += diff_d[i - begin]; )==""\n"
233R"==(#else )==""\n"
234R"==(sbr += diff_d[i - begin] * d[i - begin]; )==""\n"
235R"==(#endif )==""\n"
236R"==(} )==""\n"
237R"==(} )==""\n"
238R"==(#if GROUP_SIZE == SUB_GROUP_SIZE )==""\n"
239R"==(sbr = sub_group_reduce_add(sbr); )==""\n"
240R"==(#else )==""\n"
241R"==(sbr = work_group_reduce_add(sbr); )==""\n"
242R"==(#endif )==""\n"
243R"==(for (int i = begin; i < end; ++i) { )==""\n"
244R"==(size_t idx = DATA_OFF(dim[0], dim[1], dim[2], dim[3], dim[4], i); )==""\n"
245R"==(if (NEEDS_PADDING(dim[0], dim[1], dim[2], dim[3], dim[4], i)) { )==""\n"
246R"==(diff_src[idx] = REF_TO_SRC(acc_zero); )==""\n"
247R"==(} else { )==""\n"
248R"==(#if LOGSOFTMAX )==""\n"
249R"==(diff_src[idx] )==""\n"
250R"==(= REF_TO_SRC(diff_d[i - begin] - exp(d[i - begin]) * sbr); )==""\n"
251R"==(#else )==""\n"
252R"==(acc_t inner_data = diff_d[i - begin] - sbr; )==""\n"
253R"==(diff_src[idx] = REF_TO_SRC(d[i - begin] * inner_data); )==""\n"
254R"==(#endif )==""\n"
255R"==(} )==""\n"
256R"==(} )==""\n"
257R"==(} )==""\n"
258R"==(#endif )==""\n"
259R"==()==";
260}
261}
262}
263}