1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_inner_product_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#if IS_FWD == 1 )==""\n"
23R"==(KERNEL_ATTR )==""\n"
24R"==(__kernel void ref_inner_product_fwd(__global SRC_DATA_T *src, )==""\n"
25R"==(__global WEI_DATA_T *wei, __global BIA_DATA_T *bias, )==""\n"
26R"==(__global DST_DATA_T *dst POST_OP_ARGS, __global float *src_scales, )==""\n"
27R"==(__global float *wei_scales, __global float *dst_scales) { )==""\n"
28R"==(const int mb = GWS_GET_MB(); )==""\n"
29R"==(const int oc = GWS_GET_OC(); )==""\n"
30R"==(ACC_DATA_T d = 0; )==""\n"
31R"==(#if HAS_SPATIAL == 1 )==""\n"
32R"==(for (int ic = 0; ic < IC; ++ic) )==""\n"
33R"==(for (int kd = 0; kd < KD; ++kd) )==""\n"
34R"==(for (int kh = 0; kh < KH; ++kh) )==""\n"
35R"==(for (int kw = 0; kw < KW; ++kw) { )==""\n"
36R"==(const uint src_off = SRC_OFF(mb, ic, kd, kh, kw); )==""\n"
37R"==(const uint wei_off = WEI_OFF(0, oc, ic, kd, kh, kw); )==""\n"
38R"==(#else )==""\n"
39R"==(for (int ic = 0; ic < IC_TOTAL; ++ic) { )==""\n"
40R"==(const uint src_off = mb * IC_TOTAL + ic; )==""\n"
41R"==(const uint wei_off = oc * IC_TOTAL + ic; )==""\n"
42R"==(#endif )==""\n"
43R"==(d += SRC_TO_REF(src[src_off]) * WEI_TO_REF(wei[wei_off]); )==""\n"
44R"==(} )==""\n"
45R"==(DATA_T tmp = d; )==""\n"
46R"==(#if WITH_SRC_SCALES )==""\n"
47R"==(tmp *= src_scales[0]; )==""\n"
48R"==(#endif )==""\n"
49R"==(#if WITH_WEI_SCALES )==""\n"
50R"==(#if WEI_SCALES_MASK == 0 )==""\n"
51R"==(tmp *= wei_scales[0]; )==""\n"
52R"==(#else )==""\n"
53R"==(tmp *= wei_scales[oc]; )==""\n"
54R"==(#endif )==""\n"
55R"==(#endif )==""\n"
56R"==(#if WITH_BIAS )==""\n"
57R"==(tmp += BIA_TO_REF(bias[oc]); )==""\n"
58R"==(#endif )==""\n"
59R"==(float dest_data; )==""\n"
60R"==(#if WITH_SUM )==""\n"
61R"==(dest_data = DST_TO_REF(dst[mb * OC + oc]); )==""\n"
62R"==(#endif )==""\n"
63R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==""\n"
64R"==(tmp, DATA_T, dest_data, float, mb, 1, oc, 1); )==""\n"
65R"==(#if WITH_DST_SCALES )==""\n"
66R"==(tmp /= dst_scales[0]; )==""\n"
67R"==(#endif )==""\n"
68R"==(dst[mb * OC + oc] = TO_DST(tmp); )==""\n"
69R"==(} )==""\n"
70R"==(#endif )==""\n"
71R"==(#if IS_BWD_D == 1 )==""\n"
72R"==(KERNEL_ATTR )==""\n"
73R"==(__kernel void ref_inner_product_bwd_data(__global SRC_DATA_T *diff_src, )==""\n"
74R"==(__global WEI_DATA_T *wei, __global DST_DATA_T *diff_dst) { )==""\n"
75R"==(const int mb = GWS_GET_MB_IC() / IC; )==""\n"
76R"==(const int ic = GWS_GET_MB_IC() % IC; )==""\n"
77R"==(const int kd = GWS_GET_KD(); )==""\n"
78R"==(const int kh = GWS_GET_KH(); )==""\n"
79R"==(const int kw = GWS_GET_KW(); )==""\n"
80R"==(float ds = 0.0f; )==""\n"
81R"==(for (int oc = 0; oc < OC; ++oc) { )==""\n"
82R"==(const uint diff_dst_off = DST_OFF(mb, oc, 0, 0, 0); )==""\n"
83R"==(const uint wei_off = WEI_OFF(0, oc, ic, kd, kh, kw); )==""\n"
84R"==(ds += DST_TO_REF(diff_dst[diff_dst_off]) * WEI_TO_REF(wei[wei_off]); )==""\n"
85R"==(} )==""\n"
86R"==(const uint diff_src_off = SRC_OFF(mb, ic, kd, kh, kw); )==""\n"
87R"==(diff_src[diff_src_off] = REF_TO_SRC(ds); )==""\n"
88R"==(} )==""\n"
89R"==(#endif )==""\n"
90R"==(#if IS_BWD_W == 1 )==""\n"
91R"==(KERNEL_ATTR )==""\n"
92R"==(__kernel void ref_inner_product_bwd_weights(__global SRC_DATA_T *src, )==""\n"
93R"==(__global WEI_DATA_T *diff_wei, __global BIA_DATA_T *diff_bias, )==""\n"
94R"==(__global DST_DATA_T *diff_dst) { )==""\n"
95R"==(const int oc = GWS_GET_OC(); )==""\n"
96R"==(const int ic = GWS_GET_IC(); )==""\n"
97R"==(const int kd = GWS_GET_KD(); )==""\n"
98R"==(const int kh = GWS_GET_KH(); )==""\n"
99R"==(const int kw = GWS_GET_KW(); )==""\n"
100R"==(float ds = 0.0f; )==""\n"
101R"==(for (int mb = 0; mb < MB; ++mb) { )==""\n"
102R"==(const uint diff_dst_off = DST_OFF(mb, oc, 0, 0, 0); )==""\n"
103R"==(const uint src_off = SRC_OFF(mb, ic, kd, kh, kw); )==""\n"
104R"==(ds += DST_TO_REF(diff_dst[diff_dst_off]) * SRC_TO_REF(src[src_off]); )==""\n"
105R"==(} )==""\n"
106R"==(const uint diff_wei_off = WEI_OFF(0, oc, ic, kd, kh, kw); )==""\n"
107R"==(diff_wei[diff_wei_off] = REF_TO_WEI(ds); )==""\n"
108R"==(#if WITH_BIAS == 1 )==""\n"
109R"==(if (ic == 0) { )==""\n"
110R"==(float db = 0.0f; )==""\n"
111R"==(for (int mb = 0; mb < MB; ++mb) { )==""\n"
112R"==(const uint diff_dst_off = DST_OFF(mb, oc, 0, 0, 0); )==""\n"
113R"==(db += DST_TO_REF(diff_dst[diff_dst_off]); )==""\n"
114R"==(} )==""\n"
115R"==(diff_bias[oc] = REF_TO_BIA(db); )==""\n"
116R"==(} )==""\n"
117R"==(#endif )==""\n"
118R"==(} )==""\n"
119R"==(#endif )==""\n"
120R"==()==";
121}
122}
123}
124}