1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *ref_binary_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2019-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#undef DST_OFF )==""\n"
23R"==(#define SRC0_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(SRC0, x0, x1, x2, x3, x4, x5) )==""\n"
24R"==(#define SRC1_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(SRC1, x0, x1, x2, x3, x4, x5) )==""\n"
25R"==(#define DST_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(DST, x0, x1, x2, x3, x4, x5) )==""\n"
26R"==(float binary_op(float src0, float src1) { )==""\n"
27R"==(#if IS_ADD )==""\n"
28R"==(return src0 + src1; )==""\n"
29R"==(#elif IS_MUL )==""\n"
30R"==(return src0 * src1; )==""\n"
31R"==(#elif IS_MAX )==""\n"
32R"==(return max(src0, src1); )==""\n"
33R"==(#elif IS_MIN )==""\n"
34R"==(return min(src0, src1); )==""\n"
35R"==(#elif IS_DIV )==""\n"
36R"==(return src0 / src1; )==""\n"
37R"==(#elif IS_SUB )==""\n"
38R"==(return src0 - src1; )==""\n"
39R"==(#elif IS_GE )==""\n"
40R"==(return src0 >= src1; )==""\n"
41R"==(#elif IS_GT )==""\n"
42R"==(return src0 > src1; )==""\n"
43R"==(#elif IS_LE )==""\n"
44R"==(return src0 <= src1; )==""\n"
45R"==(#elif IS_LT )==""\n"
46R"==(return src0 < src1; )==""\n"
47R"==(#elif IS_EQ )==""\n"
48R"==(return src0 == src1; )==""\n"
49R"==(#elif IS_NE )==""\n"
50R"==(return src0 != src1; )==""\n"
51R"==(#endif )==""\n"
52R"==(} )==""\n"
53R"==(#if IS_TENSOR_OP && IS_DENSE && IS_SAME_MD && !WITH_BINARY_POST_OP )==""\n"
54R"==(KERNEL_ATTR )==""\n"
55R"==(__kernel void ref_binary(__global DATA_T *src0, __global DATA_T *src1, )==""\n"
56R"==(__global DST_DATA_T *dst POST_OP_ARGS, __global float *src0_scale, )==""\n"
57R"==(__global float *src1_scale) { )==""\n"
58R"==(int off = GWS_GET_IDX(); )==""\n"
59R"==(float tmp_src0 = CONVERT_FLOAT_T(src0[off]); )==""\n"
60R"==(float tmp_src1 = CONVERT_FLOAT_T(src1[off]); )==""\n"
61R"==(float d = 0; )==""\n"
62R"==(#if WITH_SRC0_SCALE )==""\n"
63R"==(tmp_src0 = tmp_src0 * (*src0_scale); )==""\n"
64R"==(#endif )==""\n"
65R"==(#if WITH_SRC1_SCALE )==""\n"
66R"==(tmp_src1 = tmp_src1 * (*src1_scale); )==""\n"
67R"==(#endif )==""\n"
68R"==(d = binary_op(tmp_src0, tmp_src1); )==""\n"
69R"==(float dst_data; )==""\n"
70R"==(#if WITH_SUM )==""\n"
71R"==(dst_data = CONVERT_FLOAT_T(dst[off]); )==""\n"
72R"==(#endif )==""\n"
73R"==(APPLY_POST_OPS_SERIAL( )==""\n"
74R"==(d, float, dst_data, float, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); )==""\n"
75R"==(dst[off] = TO_DST(d); )==""\n"
76R"==(} )==""\n"
77R"==(#else )==""\n"
78R"==(KERNEL_ATTR )==""\n"
79R"==(__kernel void ref_binary(__global SRC0_DATA_T *src0, __global SRC1_DATA_T *src1, )==""\n"
80R"==(__global DST_DATA_T *dst POST_OP_ARGS, __global float *src0_scale, )==""\n"
81R"==(__global float *src1_scale) { )==""\n"
82R"==(int dims0[6] = {0}; )==""\n"
83R"==(dims0[0] = GWS_GET_D0(); )==""\n"
84R"==(dims0[1] = GWS_GET_D1(); )==""\n"
85R"==(dims0[2] = GWS_GET_D2(); )==""\n"
86R"==(dims0[3] = GWS_GET_D3(); )==""\n"
87R"==(dims0[4] = GWS_GET_D4(); )==""\n"
88R"==(dims0[5] = GWS_GET_D5(); )==""\n"
89R"==(int d1_block = GWS_GET_D1_BLOCK(); )==""\n"
90R"==(int dims0_po[6] )==""\n"
91R"==(= {dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]}; )==""\n"
92R"==(int d1_init = GWS_GET_D1(); )==""\n"
93R"==(int dst_off = DST_OFF( )==""\n"
94R"==(dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]); )==""\n"
95R"==(#if TENSOR_OP )==""\n"
96R"==(int src0_off = SRC0_OFF( )==""\n"
97R"==(dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]); )==""\n"
98R"==(int src1_off = SRC1_OFF( )==""\n"
99R"==(dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]); )==""\n"
100R"==(#else )==""\n"
101R"==(int src0_off )==""\n"
102R"==(= SRC0_OFF(dims0[0] * !SRC0_BCAST_DIM0, dims0[1] * !SRC0_BCAST_DIM1, )==""\n"
103R"==(dims0[2] * !SRC0_BCAST_DIM2, dims0[3] * !SRC0_BCAST_DIM3, )==""\n"
104R"==(dims0[4] * !SRC0_BCAST_DIM4, dims0[5] * !SRC0_BCAST_DIM5); )==""\n"
105R"==(int src1_off )==""\n"
106R"==(= SRC1_OFF(dims0[0] * !SRC1_BCAST_DIM0, dims0[1] * !SRC1_BCAST_DIM1, )==""\n"
107R"==(dims0[2] * !SRC1_BCAST_DIM2, dims0[3] * !SRC1_BCAST_DIM3, )==""\n"
108R"==(dims0[4] * !SRC1_BCAST_DIM4, dims0[5] * !SRC1_BCAST_DIM5); )==""\n"
109R"==(#endif )==""\n"
110R"==(int block_size = d1_block; )==""\n"
111R"==(if (dims0[0] >= DST_D0) { )==""\n"
112R"==(for (int ic = 0; ic < block_size; ++ic) { )==""\n"
113R"==(dst[dst_off] = DATA_ZERO; )==""\n"
114R"==(dst_off++; )==""\n"
115R"==(} )==""\n"
116R"==(return; )==""\n"
117R"==(} )==""\n"
118R"==(if (d1_init + block_size <= DST_D1) { )==""\n"
119R"==(for (int ic = 0; ic < block_size; ++ic) { )==""\n"
120R"==(float tmp_src0 = CONVERT_FLOAT_T(src0[src0_off]); )==""\n"
121R"==(float tmp_src1 = CONVERT_FLOAT_T(src1[src1_off]); )==""\n"
122R"==(float d = 0; )==""\n"
123R"==(#if WITH_SRC0_SCALE )==""\n"
124R"==(tmp_src0 = tmp_src0 * (*src0_scale); )==""\n"
125R"==(#endif )==""\n"
126R"==(#if WITH_SRC1_SCALE )==""\n"
127R"==(tmp_src1 = tmp_src1 * (*src1_scale); )==""\n"
128R"==(#endif )==""\n"
129R"==(d = binary_op(tmp_src0, tmp_src1); )==""\n"
130R"==(float dst_data; )==""\n"
131R"==(#if WITH_SUM )==""\n"
132R"==(dst_data = CONVERT_FLOAT_T(dst[dst_off]); )==""\n"
133R"==(#endif )==""\n"
134R"==(APPLY_POST_OPS_SERIAL(d, float, dst_data, float, dims0_po[0], 1, )==""\n"
135R"==(dims0_po[1], 1, dims0_po[2], 1, dims0_po[3], 1, dims0_po[4], )==""\n"
136R"==(1, dims0_po[5], 1); )==""\n"
137R"==(dst[dst_off] = TO_DST(d); )==""\n"
138R"==(#if USE_UNROLL_16B || SRC0_UNROLL_16B )==""\n"
139R"==(src0_off++; )==""\n"
140R"==(dst_off++; )==""\n"
141R"==(++dims0_po[1]; )==""\n"
142R"==(if (USE_UNROLL_16B && (SRC1_D1 > 1)) { )==""\n"
143R"==(src1_off++; )==""\n"
144R"==(} else if (SRC0_UNROLL_16B && (SRC1_D1 > 1)) { )==""\n"
145R"==(src1_off += SRC1_S1_0; )==""\n"
146R"==(} )==""\n"
147R"==(#endif )==""\n"
148R"==(} )==""\n"
149R"==(} else { )==""\n"
150R"==(for (int ic = 0; ic < DST_D1 - d1_init; ic++) { )==""\n"
151R"==(float tmp_src0 = CONVERT_FLOAT_T(src0[src0_off]); )==""\n"
152R"==(float tmp_src1 = CONVERT_FLOAT_T(src1[src1_off]); )==""\n"
153R"==(float d = 0; )==""\n"
154R"==(#if WITH_SRC0_SCALE )==""\n"
155R"==(tmp_src0 = tmp_src0 * (*src0_scale); )==""\n"
156R"==(#endif )==""\n"
157R"==(#if WITH_SRC1_SCALE )==""\n"
158R"==(tmp_src1 = tmp_src1 * (*src1_scale); )==""\n"
159R"==(#endif )==""\n"
160R"==(d = binary_op(tmp_src0, tmp_src1); )==""\n"
161R"==(float dst_data; )==""\n"
162R"==(#if WITH_SUM )==""\n"
163R"==(dst_data = CONVERT_FLOAT_T(dst[dst_off]); )==""\n"
164R"==(#endif )==""\n"
165R"==(APPLY_POST_OPS_SERIAL(d, float, dst_data, float, dims0_po[0], 1, )==""\n"
166R"==(dims0_po[1], 1, dims0_po[2], 1, dims0_po[3], 1, dims0_po[4], )==""\n"
167R"==(1, dims0_po[5], 1); )==""\n"
168R"==(dst[dst_off] = TO_DST(d); )==""\n"
169R"==(#if USE_UNROLL_16B || SRC0_UNROLL_16B )==""\n"
170R"==(src0_off++; )==""\n"
171R"==(dst_off++; )==""\n"
172R"==(++dims0_po[1]; )==""\n"
173R"==(if (USE_UNROLL_16B && (SRC1_D1 > 1)) { )==""\n"
174R"==(src1_off++; )==""\n"
175R"==(} else if (SRC0_UNROLL_16B && (SRC1_D1 > 1)) { )==""\n"
176R"==(src1_off += SRC1_S1_0; )==""\n"
177R"==(} )==""\n"
178R"==(#endif )==""\n"
179R"==(} )==""\n"
180R"==(#if DST_D1 != DST_PD1 )==""\n"
181R"==(for (int ic = 0; ic < min(DST_PD1 - DST_D1, block_size); ic++) { )==""\n"
182R"==(dst[dst_off] = DATA_ZERO; )==""\n"
183R"==(dst_off++; )==""\n"
184R"==(} )==""\n"
185R"==(#endif )==""\n"
186R"==(} )==""\n"
187R"==(} )==""\n"
188R"==(#endif )==""\n"
189R"==()==";
190}
191}
192}
193}