1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_gemm_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/gemm/ocl_gemm_attrs.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
22R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
23R"==(void get_strides(int mask, long dim0, long dim1, long dim2, long *str0, )==""\n"
24R"==(long *str1, long *str2) { )==""\n"
25R"==(int is_3d = dim0 > 1; )==""\n"
26R"==(long dims[3]; )==""\n"
27R"==(dims[0] = (is_3d && mask & (1 << 0)) ? dim0 : 1; )==""\n"
28R"==(dims[1] = mask & (1 << 1) ? dim1 : 1; )==""\n"
29R"==(dims[2] = mask & (1 << 2) ? dim2 : 1; )==""\n"
30R"==(*str0 = dims[0] == 1 ? 0 : dims[1] * dims[2]; )==""\n"
31R"==(*str1 = dims[1] == 1 ? 0 : 1; )==""\n"
32R"==(*str2 = dims[2] == 1 ? 0 : dims[1]; )==""\n"
33R"==(} )==""\n"
34R"==(__kernel void ref_gemm(__global A_DATA_T *a, __global B_DATA_T *b, )==""\n"
35R"==(__global C_DATA_T *c, __global BIA_DATA_T *bias, long offset_a0, )==""\n"
36R"==(long offset_b0, long offset_c0, long offset_bias0, int transa, )==""\n"
37R"==(int transb, long MB, long M, long N, long K, long stride_a, )==""\n"
38R"==(long stride_b, long stride_c, long lda, long ldb, long ldc, )==""\n"
39R"==(float eltwise_alpha, float eltwise_beta, float eltwise_scale, )==""\n"
40R"==(int bias_mask, __global int *ao, __global int *bo, __global int *c0, )==""\n"
41R"==(int c0_mask, __global float *scales, long scale_stride, float beta) { )==""\n"
42R"==(int n = get_global_id(1); )==""\n"
43R"==(int mb = get_global_id(2); )==""\n"
44R"==(#if WITH_BIAS )==""\n"
45R"==(bias += offset_bias0; )==""\n"
46R"==(long b_strides[3]; )==""\n"
47R"==(get_strides( )==""\n"
48R"==(bias_mask, MB, M, N, &b_strides[0], &b_strides[1], &b_strides[2]); )==""\n"
49R"==(#endif )==""\n"
50R"==(a += offset_a0; )==""\n"
51R"==(b += offset_b0; )==""\n"
52R"==(c += offset_c0; )==""\n"
53R"==(#if WITH_DST_ZPOINTS )==""\n"
54R"==(long c0_strides[3]; )==""\n"
55R"==(get_strides( )==""\n"
56R"==(c0_mask, MB, M, N, &c0_strides[0], &c0_strides[1], &c0_strides[2]); )==""\n"
57R"==(#endif )==""\n"
58R"==(for (long m = 0; m < M; ++m) { )==""\n"
59R"==(ACC_DATA_T acc = 0; )==""\n"
60R"==(for (long k = 0; k < K; ++k) { )==""\n"
61R"==(long off_a = mb * stride_a + (transa ? m * lda + k : k * lda + m); )==""\n"
62R"==(long off_b = mb * stride_b + (transb ? k * ldb + n : n * ldb + k); )==""\n"
63R"==(acc += TO_ACC(A_TO_REF(a[off_a]) - ATTR_A0) )==""\n"
64R"==(* TO_ACC(B_TO_REF(b[off_b]) - ATTR_B0); )==""\n"
65R"==(} )==""\n"
66R"==(long off_c = mb * stride_c + n * ldc + m; )==""\n"
67R"==(#if WITH_BIAS || NON_DEFAULT_ATTRS )==""\n"
68R"==(POST_OP_DATA_T temp = (POST_OP_DATA_T)acc; )==""\n"
69R"==(#if WITH_BIAS )==""\n"
70R"==(long off_bias = mb * b_strides[0] + m * b_strides[1] + n * b_strides[2]; )==""\n"
71R"==(temp += BIA_TO_REF(bias[off_bias]); )==""\n"
72R"==(#endif )==""\n"
73R"==(#if WITH_SCALES )==""\n"
74R"==(temp *= scales[scale_stride * n]; )==""\n"
75R"==(#endif )==""\n"
76R"==(#if WITH_SUM )==""\n"
77R"==(temp += (POST_OP_DATA_T)(beta * C_TO_REF(c[off_c])); )==""\n"
78R"==(#endif )==""\n"
79R"==(#if WITH_ELTWISE )==""\n"
80R"==(temp = fwd_eltwise(temp, eltwise_alpha, eltwise_beta, eltwise_scale); )==""\n"
81R"==(#endif )==""\n"
82R"==(#if WITH_DST_ZPOINTS )==""\n"
83R"==(long off_c0 )==""\n"
84R"==(= mb * c0_strides[0] + m * c0_strides[1] + n * c0_strides[2]; )==""\n"
85R"==(temp += c0[off_c0]; )==""\n"
86R"==(#endif )==""\n"
87R"==(c[off_c] = TO_C(temp); )==""\n"
88R"==(#else )==""\n"
89R"==(c[off_c] = TO_C(acc); )==""\n"
90R"==(#endif )==""\n"
91R"==(} )==""\n"
92R"==(} )==""\n"
93R"==()==";
94}
95}
96}
97}