1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gen9_softmax_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2020-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
21R"==(#define LOAD_FLOAT8(prefix, ptr) \ )==""\n"
22R"==(DATA_TO_FLOAT8(prefix, \ )==""\n"
23R"==(BLOCK_TO_DATA8(prefix, \ )==""\n"
24R"==(READ_BLOCK8(prefix, \ )==""\n"
25R"==((__global BLOCK_T(ALIAS(prefix)) *)(ptr)))) )==""\n"
26R"==(#define STORE_FLOAT8(prefix, ptr, val) \ )==""\n"
27R"==(WRITE_BLOCK8(prefix, (__global BLOCK_T(ALIAS(prefix)) *)(ptr), \ )==""\n"
28R"==(DATA_TO_BLOCK8(prefix, FLOAT_TO_DATA8(prefix, val))) )==""\n"
29R"==(#define VECT_SIZE 8 )==""\n"
30R"==(#define NUM_BUF (SOFTMAX_AXIS_SIZE / SUB_GROUP_SIZE / VECT_SIZE) )==""\n"
31R"==(#if IS_FWD )==""\n"
32R"==(__attribute__((reqd_work_group_size(GROUP_SIZE, 1, 1))) )==""\n"
33R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __kernel void )==""\n"
34R"==(gen9_softmax_fwd(__global SRC_DATA_T *src, __global DST_DATA_T *dst, )==""\n"
35R"==(__global float *src_scale, __global float *dst_scale) { )==""\n"
36R"==(float scale = 1.0f; )==""\n"
37R"==(#if WITH_SRC_SCALES )==""\n"
38R"==(scale *= src_scale[0]; )==""\n"
39R"==(#endif )==""\n"
40R"==(#if WITH_DST_SCALES )==""\n"
41R"==(scale /= dst_scale[0]; )==""\n"
42R"==(#endif )==""\n"
43R"==(#if IS_NHWC || IS_BLOCKED )==""\n"
44R"==(const int group = get_global_id(0) / GROUP_SIZE; )==""\n"
45R"==(const int mb = group / OC_PADDED; )==""\n"
46R"==(const int local_oc = group % OC_PADDED; )==""\n"
47R"==(const int local_id = get_local_id(0); )==""\n"
48R"==(const int axis_chunk = (local_id / SUB_GROUP_SIZE) * SOFTMAX_BUF; )==""\n"
49R"==(const int subgroup_id = get_sub_group_local_id(); )==""\n"
50R"==(const int oc_chunk = OC * VECT_SIZE * subgroup_id; )==""\n"
51R"==(#if IS_BLOCKED )==""\n"
52R"==(const int oc_block_id = local_oc / OC; )==""\n"
53R"==(const int oc_in_block = local_oc % OC; )==""\n"
54R"==(int data_off = (MB * oc_block_id + mb) * OC * SOFTMAX_AXIS_SIZE + oc_chunk )==""\n"
55R"==(+ oc_in_block; )==""\n"
56R"==(#else )==""\n"
57R"==(int data_off = mb * OC_PADDED * SOFTMAX_AXIS_SIZE + oc_chunk + local_oc; )==""\n"
58R"==(#endif )==""\n"
59R"==(float d[VECT_SIZE]; )==""\n"
60R"==(float max_ = -FLT_MAX; )==""\n"
61R"==(float denom_ = 0.f; )==""\n"
62R"==(src += data_off; )==""\n"
63R"==(for (int k = 0, axis_channel_id = OC * axis_chunk; k < VECT_SIZE; )==""\n"
64R"==(++k, axis_channel_id += OC) { )==""\n"
65R"==(d[k] = DATA_TO_FLOAT(SRC, src[axis_channel_id]); )==""\n"
66R"==(max_ = max(d[k], max_); )==""\n"
67R"==(} )==""\n"
68R"==(#if GROUP_SIZE == SUB_GROUP_SIZE )==""\n"
69R"==(max_ = sub_group_reduce_max(max_); )==""\n"
70R"==(#else )==""\n"
71R"==(max_ = work_group_reduce_max(max_); )==""\n"
72R"==(#endif )==""\n"
73R"==(for (int k = 0; k < VECT_SIZE; ++k) { )==""\n"
74R"==(#if LOGSOFTMAX )==""\n"
75R"==(denom_ += exp(d[k] - max_); )==""\n"
76R"==(#else )==""\n"
77R"==(d[k] = exp(d[k] - max_); )==""\n"
78R"==(denom_ += d[k]; )==""\n"
79R"==(#endif )==""\n"
80R"==(} )==""\n"
81R"==(#if GROUP_SIZE == SUB_GROUP_SIZE )==""\n"
82R"==(denom_ = sub_group_reduce_add(denom_); )==""\n"
83R"==(#else )==""\n"
84R"==(denom_ = work_group_reduce_add(denom_); )==""\n"
85R"==(#endif )==""\n"
86R"==(#if LOGSOFTMAX )==""\n"
87R"==(denom_ = log(denom_); )==""\n"
88R"==(#else )==""\n"
89R"==(denom_ = 1.0 / denom_; )==""\n"
90R"==(#endif )==""\n"
91R"==(dst += data_off; )==""\n"
92R"==(for (int k = 0, axis_channel_id = OC * axis_chunk; k < VECT_SIZE; )==""\n"
93R"==(++k, axis_channel_id += OC) { )==""\n"
94R"==(#if LOGSOFTMAX )==""\n"
95R"==(d[k] = d[k] - max_ - denom_; )==""\n"
96R"==(#else )==""\n"
97R"==(d[k] = d[k] * denom_; )==""\n"
98R"==(#endif )==""\n"
99R"==(dst[axis_channel_id] = FLOAT_TO_DATA(DST, d[k] * scale); )==""\n"
100R"==(} )==""\n"
101R"==(#else )==""\n"
102R"==(const int data_off = (get_global_id(0) / GROUP_SIZE) * SOFTMAX_AXIS_SIZE; )==""\n"
103R"==(float8 d[NUM_BUF]; )==""\n"
104R"==(float max_ = -FLT_MAX; )==""\n"
105R"==(float denom_ = 0.f; )==""\n"
106R"==(src += data_off; )==""\n"
107R"==(for (int k = 0; k < NUM_BUF; ++k) { )==""\n"
108R"==(d[k] = LOAD_FLOAT8(SRC, &src[k * VECT_SIZE * SUB_GROUP_SIZE]); )==""\n"
109R"==(for (int i = 0; i < VECT_SIZE; ++i) { )==""\n"
110R"==(max_ = max(d[k][i], max_); )==""\n"
111R"==(} )==""\n"
112R"==(} )==""\n"
113R"==(max_ = sub_group_reduce_max(max_); )==""\n"
114R"==(for (int k = 0; k < NUM_BUF; ++k) { )==""\n"
115R"==(#if LOGSOFTMAX )==""\n"
116R"==(for (int i = 0; i < VECT_SIZE; ++i) )==""\n"
117R"==(denom_ += exp(d[k][i] - max_); )==""\n"
118R"==(#else )==""\n"
119R"==(d[k] = exp(d[k] - max_); )==""\n"
120R"==(for (int i = 0; i < VECT_SIZE; ++i) )==""\n"
121R"==(denom_ += d[k][i]; )==""\n"
122R"==(#endif )==""\n"
123R"==(} )==""\n"
124R"==(denom_ = sub_group_reduce_add(denom_); )==""\n"
125R"==(#if LOGSOFTMAX )==""\n"
126R"==(denom_ = log(denom_); )==""\n"
127R"==(#else )==""\n"
128R"==(denom_ = 1.0 / denom_; )==""\n"
129R"==(#endif )==""\n"
130R"==(dst += data_off; )==""\n"
131R"==(for (int k = 0; k < NUM_BUF; ++k) { )==""\n"
132R"==(#if LOGSOFTMAX )==""\n"
133R"==(d[k] = d[k] - max_ - denom_; )==""\n"
134R"==(#else )==""\n"
135R"==(d[k] = d[k] * denom_; )==""\n"
136R"==(#endif )==""\n"
137R"==(STORE_FLOAT8(DST, &dst[k * VECT_SIZE * SUB_GROUP_SIZE], scale * d[k]); )==""\n"
138R"==(} )==""\n"
139R"==(#endif )==""\n"
140R"==(} )==""\n"
141R"==(#endif )==""\n"
142R"==(#if IS_BWD )==""\n"
143R"==(__attribute__((reqd_work_group_size(GROUP_SIZE, 1, 1))) )==""\n"
144R"==(__attribute__((intel_reqd_sub_group_size(SUB_GROUP_SIZE))) __kernel void )==""\n"
145R"==(gen9_softmax_bwd(__global DST_DATA_T *dst, __global SRC_DATA_T *diff_src, )==""\n"
146R"==(__global DST_DATA_T *diff_dst) { )==""\n"
147R"==(#if IS_NHWC || IS_16C )==""\n"
148R"==(const int groups = get_global_id(0) / GROUP_SIZE; )==""\n"
149R"==(const int batch = groups / IC_PADDED; )==""\n"
150R"==(const int ic = groups % IC_PADDED; )==""\n"
151R"==(const int sub_grp_id = get_sub_group_local_id(); )==""\n"
152R"==(const int local_id = get_local_id(0); )==""\n"
153R"==(const int slice = local_id / SUB_GROUP_SIZE; )==""\n"
154R"==(const int ic_buff = IC * VECT_SIZE; )==""\n"
155R"==(#if IS_16C )==""\n"
156R"==(const int ic_blk = ic / IC; )==""\n"
157R"==(const int ic_in_blk = ic % IC; )==""\n"
158R"==(int data_off = BATCH * IC * SOFTMAX_AXIS_SIZE * ic_blk )==""\n"
159R"==(+ batch * IC * SOFTMAX_AXIS_SIZE + ic_buff * sub_grp_id + ic_in_blk; )==""\n"
160R"==(#else )==""\n"
161R"==(int data_off = batch * IC * SOFTMAX_AXIS_SIZE + ic_buff * sub_grp_id + ic; )==""\n"
162R"==(#endif )==""\n"
163R"==(float sbr = 0.f; )==""\n"
164R"==(float diff_d[VECT_SIZE]; )==""\n"
165R"==(float dst_[VECT_SIZE]; )==""\n"
166R"==(diff_dst += data_off; )==""\n"
167R"==(dst += data_off; )==""\n"
168R"==(for (int i = 0, idx = IC * slice * SOFTMAX_BUF; i < VECT_SIZE; )==""\n"
169R"==(++i, idx += IC) { )==""\n"
170R"==(diff_d[i] = DATA_TO_FLOAT(DST, diff_dst[idx]); )==""\n"
171R"==(dst_[i] = DATA_TO_FLOAT(DST, dst[idx]); )==""\n"
172R"==(#if LOGSOFTMAX )==""\n"
173R"==(sbr += diff_d[i]; )==""\n"
174R"==(#else )==""\n"
175R"==(sbr += dst_[i] * diff_d[i]; )==""\n"
176R"==(#endif )==""\n"
177R"==(} )==""\n"
178R"==(#if GROUP_SIZE == SUB_GROUP_SIZE )==""\n"
179R"==(sbr = sub_group_reduce_add(sbr); )==""\n"
180R"==(#else )==""\n"
181R"==(sbr = work_group_reduce_add(sbr); )==""\n"
182R"==(#endif )==""\n"
183R"==(diff_src += data_off; )==""\n"
184R"==(for (int i = 0, idx = IC * slice * SOFTMAX_BUF; i < VECT_SIZE; )==""\n"
185R"==(++i, idx += IC) { )==""\n"
186R"==(#if LOGSOFTMAX )==""\n"
187R"==(diff_d[i] = diff_d[i] - exp(dst_[i]) * sbr; )==""\n"
188R"==(#else )==""\n"
189R"==(diff_d[i] = (diff_d[i] - sbr) * dst_[i]; )==""\n"
190R"==(#endif )==""\n"
191R"==(diff_src[idx] = FLOAT_TO_DATA(SRC, diff_d[i]); )==""\n"
192R"==(} )==""\n"
193R"==(#else )==""\n"
194R"==(const int data_off = (get_global_id(0) / GROUP_SIZE) * SOFTMAX_AXIS_SIZE; )==""\n"
195R"==(float sbr = 0.f; )==""\n"
196R"==(float8 diff_d[NUM_BUF]; )==""\n"
197R"==(float8 dst_[NUM_BUF]; )==""\n"
198R"==(diff_dst += data_off; )==""\n"
199R"==(dst += data_off; )==""\n"
200R"==(for (int k = 0; k < NUM_BUF; ++k) { )==""\n"
201R"==(diff_d[k] = LOAD_FLOAT8(DST, &diff_dst[k * VECT_SIZE * SUB_GROUP_SIZE]); )==""\n"
202R"==(dst_[k] = LOAD_FLOAT8(DST, &dst[k * VECT_SIZE * SUB_GROUP_SIZE]); )==""\n"
203R"==(for (int i = 0; i < VECT_SIZE; ++i) { )==""\n"
204R"==(#if LOGSOFTMAX )==""\n"
205R"==(sbr += diff_d[k][i]; )==""\n"
206R"==(#else )==""\n"
207R"==(sbr += dst_[k][i] * diff_d[k][i]; )==""\n"
208R"==(#endif )==""\n"
209R"==(} )==""\n"
210R"==(} )==""\n"
211R"==(sbr = sub_group_reduce_add(sbr); )==""\n"
212R"==(diff_src += data_off; )==""\n"
213R"==(for (int k = 0; k < NUM_BUF; ++k) { )==""\n"
214R"==(#if LOGSOFTMAX )==""\n"
215R"==(diff_d[k] = diff_d[k] - exp(dst_[k]) * sbr; )==""\n"
216R"==(#else )==""\n"
217R"==(diff_d[k] = (diff_d[k] - sbr) * dst_[k]; )==""\n"
218R"==(#endif )==""\n"
219R"==(STORE_FLOAT8(SRC, &diff_src[k * VECT_SIZE * SUB_GROUP_SIZE], diff_d[k]); )==""\n"
220R"==(} )==""\n"
221R"==(#endif )==""\n"
222R"==(} )==""\n"
223R"==(#endif )==""\n"
224R"==()==";
225}
226}
227}
228}