1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_matmul_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#define offset6D(d0, d1, d2, d3, d4, d5, s0, s1, s2, s3, s4, s5) \ )==""\n"
23R"==(((d0) * (s0) + (d1) * (s1) + (d2) * (s2) + (d3) * (s3) + (d4) * (s4) \ )==""\n"
24R"==(+ (d5) * (s5)) )==""\n"
25R"==(__kernel void ref_matmul(__global SRC_DATA_T *A, __global WEI_DATA_T *B, )==""\n"
26R"==(__global DST_DATA_T *C, __global BIA_DATA_T *bia, __global int *a0, )==""\n"
27R"==(__global int *b0, __global int *c0, __global float *src_scales, )==""\n"
28R"==(__global float *wei_scales, long wei_scale_stride, )==""\n"
29R"==(__global float *dst_scales, long K, long N, long M, long D0, long D1, )==""\n"
30R"==(long D2, long bia_stride_d3, long bia_stride_d2, long bia_stride_d1, )==""\n"
31R"==(long bia_stride_d0, long bia_stride_m, long bia_stride_n, )==""\n"
32R"==(long a_stride_d3, long a_stride_d2, long a_stride_d1, long a_stride_d0, )==""\n"
33R"==(long a_stride_m, long a_stride_k, long b_stride_d3, long b_stride_d2, )==""\n"
34R"==(long b_stride_d1, long b_stride_d0, long b_stride_k, long b_stride_n, )==""\n"
35R"==(long c_stride_d3, long c_stride_d2, long c_stride_d1, long c_stride_d0, )==""\n"
36R"==(long c_stride_m, long c_stride_n POST_OP_ARGS) { )==""\n"
37R"==(int n = get_global_id(1); )==""\n"
38R"==(int mb = get_global_id(2); )==""\n"
39R"==(#if WITH_SRC_ZPOINTS )==""\n"
40R"==(int src_zp = a0[0]; )==""\n"
41R"==(#else )==""\n"
42R"==(int src_zp = 0; )==""\n"
43R"==(#endif )==""\n"
44R"==(#if WITH_WEI_ZPOINTS )==""\n"
45R"==(int wei_zp = b0[0]; )==""\n"
46R"==(#else )==""\n"
47R"==(int wei_zp = 0; )==""\n"
48R"==(#endif )==""\n"
49R"==(#if WITH_DST_ZPOINTS )==""\n"
50R"==(int dst_zp = c0[0]; )==""\n"
51R"==(#else )==""\n"
52R"==(int dst_zp = 0; )==""\n"
53R"==(#endif )==""\n"
54R"==(long d3 = mb / D0 / D1 / D2; )==""\n"
55R"==(long d2 = (mb / D0 / D1) % D2; )==""\n"
56R"==(long d1 = (mb / D0) % D1; )==""\n"
57R"==(long d0 = mb % D0; )==""\n"
58R"==(for (long m = 0; m < M; ++m) { )==""\n"
59R"==(ACC_DATA_T acc = 0; )==""\n"
60R"==(for (long k = 0; k < K; ++k) { )==""\n"
61R"==(long src_off )==""\n"
62R"==(= offset6D(m, k, d0, d1, d2, d3, a_stride_m, a_stride_k, )==""\n"
63R"==(a_stride_d0, a_stride_d1, a_stride_d2, a_stride_d3); )==""\n"
64R"==(long wei_off )==""\n"
65R"==(= offset6D(k, n, d0, d1, d2, d3, b_stride_k, b_stride_n, )==""\n"
66R"==(b_stride_d0, b_stride_d1, b_stride_d2, b_stride_d3); )==""\n"
67R"==(acc += TO_ACC(SRC_TO_REF(A[src_off]) - src_zp) )==""\n"
68R"==(* TO_ACC(WEI_TO_REF(B[wei_off]) - wei_zp); )==""\n"
69R"==(} )==""\n"
70R"==(long dst_off = offset6D(m, n, d0, d1, d2, d3, c_stride_m, c_stride_n, )==""\n"
71R"==(c_stride_d0, c_stride_d1, c_stride_d2, c_stride_d3); )==""\n"
72R"==(#if WITH_BIAS || NON_DEFAULT_ATTRS )==""\n"
73R"==(POST_OP_DATA_T temp = (POST_OP_DATA_T)acc; )==""\n"
74R"==(#if WITH_SRC_SCALES )==""\n"
75R"==(temp *= src_scales[0]; )==""\n"
76R"==(#endif )==""\n"
77R"==(#if WITH_WEI_SCALES )==""\n"
78R"==(temp *= wei_scales[wei_scale_stride * n]; )==""\n"
79R"==(#endif )==""\n"
80R"==(#if WITH_BIAS )==""\n"
81R"==(long bia_off = offset6D(m, n, d0, d1, d2, d3, bia_stride_m, )==""\n"
82R"==(bia_stride_n, bia_stride_d0, bia_stride_d1, bia_stride_d2, )==""\n"
83R"==(bia_stride_d3); )==""\n"
84R"==(temp += bia[bia_off]; )==""\n"
85R"==(#endif )==""\n"
86R"==(float dst_data; )==""\n"
87R"==(#if WITH_SUM )==""\n"
88R"==(dst_data = convert_float(DATA_TO_REF(C[dst_off])); )==""\n"
89R"==(#endif )==""\n"
90R"==(float po_acc = convert_float(temp); )==""\n"
91R"==(if (DST_NDIMS == 2) )==""\n"
92R"==(APPLY_POST_OPS_SERIAL(po_acc, float, dst_data, float, m, 1, n, 1, 0, )==""\n"
93R"==(1, 0, 1, 0, 1, 0, 1); )==""\n"
94R"==(if (DST_NDIMS == 3) )==""\n"
95R"==(APPLY_POST_OPS_SERIAL(po_acc, float, dst_data, float, d0, 1, m, 1, )==""\n"
96R"==(n, 1, 0, 1, 0, 1, 0, 1); )==""\n"
97R"==(if (DST_NDIMS == 4) )==""\n"
98R"==(APPLY_POST_OPS_SERIAL(po_acc, float, dst_data, float, d1, 1, d0, 1, )==""\n"
99R"==(m, 1, n, 1, 0, 1, 0, 1); )==""\n"
100R"==(if (DST_NDIMS == 5) )==""\n"
101R"==(APPLY_POST_OPS_SERIAL(po_acc, float, dst_data, float, d2, 1, d1, 1, )==""\n"
102R"==(d0, 1, m, 1, n, 1, 0, 1); )==""\n"
103R"==(if (DST_NDIMS == 6) )==""\n"
104R"==(APPLY_POST_OPS_SERIAL(po_acc, float, dst_data, float, d3, 1, d2, 1, )==""\n"
105R"==(d1, 1, d0, 1, m, 1, n, 1); )==""\n"
106R"==(#if WITH_DST_SCALES )==""\n"
107R"==(po_acc /= dst_scales[0]; )==""\n"
108R"==(#endif )==""\n"
109R"==(po_acc += dst_zp; )==""\n"
110R"==(C[dst_off] = TO_DST(po_acc); )==""\n"
111R"==(#else )==""\n"
112R"==(C[dst_off] = TO_DST(acc); )==""\n"
113R"==(#endif )==""\n"
114R"==(} )==""\n"
115R"==(} )==""\n"
116R"==()==";
117}
118}
119}
120}