1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *gemm_with_post_ops_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2021-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_math_utils.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(#if defined(DST_DT_BF16) )==""\n"
24R"==(#define DST_TO_ACC(x) cvt_bf16_to_f32(x) )==""\n"
25R"==(#else )==""\n"
26R"==(#define DST_TO_ACC(x) (x) )==""\n"
27R"==(#endif )==""\n"
28R"==(#if defined(BIA_DT_BF16) )==""\n"
29R"==(#define BIA_TO_ACC(x) cvt_bf16_to_f32(x) )==""\n"
30R"==(#else )==""\n"
31R"==(#define BIA_TO_ACC(x) (x) )==""\n"
32R"==(#endif )==""\n"
33R"==(#if defined(SRC_DT_BF16) )==""\n"
34R"==(#define SRC_TO_ACC(x) cvt_bf16_to_f32(x) )==""\n"
35R"==(#else )==""\n"
36R"==(#define SRC_TO_ACC(x) (x) )==""\n"
37R"==(#endif )==""\n"
38R"==(#ifndef BIA_D2 )==""\n"
39R"==(#define BIA_D2 1 )==""\n"
40R"==(#endif )==""\n"
41R"==(#ifndef BIA_D3 )==""\n"
42R"==(#define BIA_D3 1 )==""\n"
43R"==(#endif )==""\n"
44R"==(#if BIA_NDIMS == 4 )==""\n"
45R"==(#define BIA_OFF(x0, x1, d, h, w) \ )==""\n"
46R"==((((x0 % BIA_D0) % BIA_B0) * BIA_SB0 + ((x0 % BIA_D0) / BIA_B0) * BIA_S0 \ )==""\n"
47R"==(+ ((x1 % BIA_D1) % BIA_B1) * BIA_SB1 \ )==""\n"
48R"==(+ ((x1 % BIA_D1) / BIA_B1) * BIA_S1 \ )==""\n"
49R"==(+ ((h % BIA_D2) % BIA_B2) * BIA_SB2 \ )==""\n"
50R"==(+ ((h % BIA_D2) / BIA_B2) * BIA_S2 \ )==""\n"
51R"==(+ ((w % BIA_D3) % BIA_B3) * BIA_SB3 \ )==""\n"
52R"==(+ ((w % BIA_D3) / BIA_B3) * BIA_S3) )==""\n"
53R"==(#elif BIA_NDIMS == 3 )==""\n"
54R"==(#define BIA_OFF(x0, x1, d, h, w) \ )==""\n"
55R"==((((x0 % BIA_D0) % BIA_B0) * BIA_SB0 + ((x0 % BIA_D0) / BIA_B0) * BIA_S0 \ )==""\n"
56R"==(+ ((x1 % BIA_D1) % BIA_B1) * BIA_SB1 \ )==""\n"
57R"==(+ ((x1 % BIA_D1) / BIA_B1) * BIA_S1 \ )==""\n"
58R"==(+ ((w % BIA_D2) % BIA_B2) * BIA_SB2 \ )==""\n"
59R"==(+ ((w % BIA_D2) / BIA_B2) * BIA_S2) )==""\n"
60R"==(#elif BIA_NDIMS == 2 )==""\n"
61R"==(#define BIA_OFF(x0, x1, d, h, w) \ )==""\n"
62R"==((((x0 % BIA_D0) % BIA_B0) * BIA_SB0 + ((x0 % BIA_D0) / BIA_B0) * BIA_S0 \ )==""\n"
63R"==(+ ((x1 % BIA_D1) % BIA_B1) * BIA_SB1 \ )==""\n"
64R"==(+ ((x1 % BIA_D1) / BIA_B1) * BIA_S1) )==""\n"
65R"==(#elif BIA_NDIMS == 1 )==""\n"
66R"==(#define BIA_OFF(x1, x0, d, h, w) (x0) )==""\n"
67R"==(#endif )==""\n"
68R"==(__kernel void gemm_post_ops(__global SRC_DATA_T *src, __global BIA_DATA_T *bias, )==""\n"
69R"==(__global DST_DATA_T *dst POST_OP_ARGS, __global SPAD_DATA_T *scratchpad, )==""\n"
70R"==(global float *a_scales, global float *b_scales, global float *c_scales, )==""\n"
71R"==(int scale_stride) { )==""\n"
72R"==(const uint d0 = GWS_GET_D0(); )==""\n"
73R"==(const uint d1 = GWS_GET_D1(); )==""\n"
74R"==(const uint d2 = GWS_GET_D2(); )==""\n"
75R"==(const uint d3 = GWS_GET_D3(); )==""\n"
76R"==(#if NDIMS == 4 )==""\n"
77R"==(size_t data_idx = DST_OFF(d0, d1, 0, d2, d3); )==""\n"
78R"==(#elif NDIMS == 3 )==""\n"
79R"==(size_t data_idx = DST_OFF(d0, d1, 0, 0, d2); )==""\n"
80R"==(#else )==""\n"
81R"==(size_t data_idx = DST_OFF(d0, d1, 0, 0, 0); )==""\n"
82R"==(#endif )==""\n"
83R"==(#if USE_TEMP_DST == 1 )==""\n"
84R"==(ACC_DATA_T acc = SRC_TO_ACC(scratchpad[data_idx]); )==""\n"
85R"==(#else )==""\n"
86R"==(ACC_DATA_T acc = SRC_TO_ACC(src[data_idx]); )==""\n"
87R"==(#endif )==""\n"
88R"==(float accumulator = acc; )==""\n"
89R"==(if ((d0 == D0_WO_PADDING && d1 == D1_WO_PADDING && d2 == D2_WO_PADDING )==""\n"
90R"==(&& d3 == D3_WO_PADDING) )==""\n"
91R"==(|| (d0 < D0_WO_PADDING && d1 < D1_WO_PADDING && d2 < D2_WO_PADDING )==""\n"
92R"==(&& d3 < D3_WO_PADDING)) { )==""\n"
93R"==(#if A_SCALES || B_SCALES )==""\n"
94R"==(#define A_SCALE (A_SCALES ? a_scales[0] : 1) )==""\n"
95R"==(#if NDIMS == 2 )==""\n"
96R"==(const float b_scale = B_SCALES ? b_scales[scale_stride * d1] : 1; )==""\n"
97R"==(#elif NDIMS == 3 )==""\n"
98R"==(const float b_scale = B_SCALES ? b_scales[scale_stride * d2] : 1; )==""\n"
99R"==(#elif NDIMS == 4 )==""\n"
100R"==(const float b_scale = B_SCALES ? b_scales[scale_stride * d3] : 1; )==""\n"
101R"==(#endif )==""\n"
102R"==(acc *= A_SCALE * b_scale; )==""\n"
103R"==(#endif )==""\n"
104R"==(#if WITH_BIAS == 1 )==""\n"
105R"==(#if NDIMS == 4 )==""\n"
106R"==(size_t bia_idx = BIA_OFF(d0, d1, 0, d2, d3); )==""\n"
107R"==(#elif NDIMS == 3 )==""\n"
108R"==(size_t bia_idx = BIA_OFF(d0, d1, 0, 0, d2); )==""\n"
109R"==(#else )==""\n"
110R"==(size_t bia_idx = BIA_OFF(d0, d1, 0, 0, 0); )==""\n"
111R"==(#endif )==""\n"
112R"==(acc += BIA_TO_ACC(bias[bia_idx]); )==""\n"
113R"==(#endif )==""\n"
114R"==(float sum_src = 0.0f; )==""\n"
115R"==(#if WITH_SUM )==""\n"
116R"==(sum_src = DST_TO_ACC(dst[data_idx]); )==""\n"
117R"==(#endif )==""\n"
118R"==(accumulator = acc; )==""\n"
119R"==(#if NDIMS == 2 )==""\n"
120R"==(APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, d0, 1, d1, 1, )==""\n"
121R"==(0, 1, 0, 1, 0, 1, 0, 1); )==""\n"
122R"==(#elif NDIMS == 3 )==""\n"
123R"==(APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, d0, 1, d1, 1, )==""\n"
124R"==(d2, 1, 0, 1, 0, 1, 0, 1); )==""\n"
125R"==(#elif NDIMS == 4 )==""\n"
126R"==(APPLY_POST_OPS_SERIAL(accumulator, float, sum_src, float, d0, 1, d1, 1, )==""\n"
127R"==(d2, 1, d3, 1, 0, 1, 0, 1); )==""\n"
128R"==(#endif )==""\n"
129R"==(} )==""\n"
130R"==(#if C_SCALES )==""\n"
131R"==(accumulator /= c_scales[0]; )==""\n"
132R"==(#endif )==""\n"
133R"==(dst[data_idx] = TO_DST(accumulator); )==""\n"
134R"==(} )==""\n"
135R"==()==";
136}
137}
138}
139}