1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *ref_binary_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_post_ops.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
22 | R"==(#undef DST_OFF )==" "\n" |
23 | R"==(#define SRC0_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(SRC0, x0, x1, x2, x3, x4, x5) )==" "\n" |
24 | R"==(#define SRC1_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(SRC1, x0, x1, x2, x3, x4, x5) )==" "\n" |
25 | R"==(#define DST_OFF(x0, x1, x2, x3, x4, x5) OFF_MD(DST, x0, x1, x2, x3, x4, x5) )==" "\n" |
26 | R"==(float binary_op(float src0, float src1) { )==" "\n" |
27 | R"==(#if IS_ADD )==" "\n" |
28 | R"==(return src0 + src1; )==" "\n" |
29 | R"==(#elif IS_MUL )==" "\n" |
30 | R"==(return src0 * src1; )==" "\n" |
31 | R"==(#elif IS_MAX )==" "\n" |
32 | R"==(return max(src0, src1); )==" "\n" |
33 | R"==(#elif IS_MIN )==" "\n" |
34 | R"==(return min(src0, src1); )==" "\n" |
35 | R"==(#elif IS_DIV )==" "\n" |
36 | R"==(return src0 / src1; )==" "\n" |
37 | R"==(#elif IS_SUB )==" "\n" |
38 | R"==(return src0 - src1; )==" "\n" |
39 | R"==(#elif IS_GE )==" "\n" |
40 | R"==(return src0 >= src1; )==" "\n" |
41 | R"==(#elif IS_GT )==" "\n" |
42 | R"==(return src0 > src1; )==" "\n" |
43 | R"==(#elif IS_LE )==" "\n" |
44 | R"==(return src0 <= src1; )==" "\n" |
45 | R"==(#elif IS_LT )==" "\n" |
46 | R"==(return src0 < src1; )==" "\n" |
47 | R"==(#elif IS_EQ )==" "\n" |
48 | R"==(return src0 == src1; )==" "\n" |
49 | R"==(#elif IS_NE )==" "\n" |
50 | R"==(return src0 != src1; )==" "\n" |
51 | R"==(#endif )==" "\n" |
52 | R"==(} )==" "\n" |
53 | R"==(#if IS_TENSOR_OP && IS_DENSE && IS_SAME_MD && !WITH_BINARY_POST_OP )==" "\n" |
54 | R"==(KERNEL_ATTR )==" "\n" |
55 | R"==(__kernel void ref_binary(__global DATA_T *src0, __global DATA_T *src1, )==" "\n" |
56 | R"==(__global DST_DATA_T *dst POST_OP_ARGS, __global float *src0_scale, )==" "\n" |
57 | R"==(__global float *src1_scale) { )==" "\n" |
58 | R"==(int off = GWS_GET_IDX(); )==" "\n" |
59 | R"==(float tmp_src0 = CONVERT_FLOAT_T(src0[off]); )==" "\n" |
60 | R"==(float tmp_src1 = CONVERT_FLOAT_T(src1[off]); )==" "\n" |
61 | R"==(float d = 0; )==" "\n" |
62 | R"==(#if WITH_SRC0_SCALE )==" "\n" |
63 | R"==(tmp_src0 = tmp_src0 * (*src0_scale); )==" "\n" |
64 | R"==(#endif )==" "\n" |
65 | R"==(#if WITH_SRC1_SCALE )==" "\n" |
66 | R"==(tmp_src1 = tmp_src1 * (*src1_scale); )==" "\n" |
67 | R"==(#endif )==" "\n" |
68 | R"==(d = binary_op(tmp_src0, tmp_src1); )==" "\n" |
69 | R"==(float dst_data; )==" "\n" |
70 | R"==(#if WITH_SUM )==" "\n" |
71 | R"==(dst_data = CONVERT_FLOAT_T(dst[off]); )==" "\n" |
72 | R"==(#endif )==" "\n" |
73 | R"==(APPLY_POST_OPS_SERIAL( )==" "\n" |
74 | R"==(d, float, dst_data, float, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0); )==" "\n" |
75 | R"==(dst[off] = TO_DST(d); )==" "\n" |
76 | R"==(} )==" "\n" |
77 | R"==(#else )==" "\n" |
78 | R"==(KERNEL_ATTR )==" "\n" |
79 | R"==(__kernel void ref_binary(__global SRC0_DATA_T *src0, __global SRC1_DATA_T *src1, )==" "\n" |
80 | R"==(__global DST_DATA_T *dst POST_OP_ARGS, __global float *src0_scale, )==" "\n" |
81 | R"==(__global float *src1_scale) { )==" "\n" |
82 | R"==(int dims0[6] = {0}; )==" "\n" |
83 | R"==(dims0[0] = GWS_GET_D0(); )==" "\n" |
84 | R"==(dims0[1] = GWS_GET_D1(); )==" "\n" |
85 | R"==(dims0[2] = GWS_GET_D2(); )==" "\n" |
86 | R"==(dims0[3] = GWS_GET_D3(); )==" "\n" |
87 | R"==(dims0[4] = GWS_GET_D4(); )==" "\n" |
88 | R"==(dims0[5] = GWS_GET_D5(); )==" "\n" |
89 | R"==(int d1_block = GWS_GET_D1_BLOCK(); )==" "\n" |
90 | R"==(int dims0_po[6] )==" "\n" |
91 | R"==(= {dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]}; )==" "\n" |
92 | R"==(int d1_init = GWS_GET_D1(); )==" "\n" |
93 | R"==(int dst_off = DST_OFF( )==" "\n" |
94 | R"==(dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]); )==" "\n" |
95 | R"==(#if TENSOR_OP )==" "\n" |
96 | R"==(int src0_off = SRC0_OFF( )==" "\n" |
97 | R"==(dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]); )==" "\n" |
98 | R"==(int src1_off = SRC1_OFF( )==" "\n" |
99 | R"==(dims0[0], dims0[1], dims0[2], dims0[3], dims0[4], dims0[5]); )==" "\n" |
100 | R"==(#else )==" "\n" |
101 | R"==(int src0_off )==" "\n" |
102 | R"==(= SRC0_OFF(dims0[0] * !SRC0_BCAST_DIM0, dims0[1] * !SRC0_BCAST_DIM1, )==" "\n" |
103 | R"==(dims0[2] * !SRC0_BCAST_DIM2, dims0[3] * !SRC0_BCAST_DIM3, )==" "\n" |
104 | R"==(dims0[4] * !SRC0_BCAST_DIM4, dims0[5] * !SRC0_BCAST_DIM5); )==" "\n" |
105 | R"==(int src1_off )==" "\n" |
106 | R"==(= SRC1_OFF(dims0[0] * !SRC1_BCAST_DIM0, dims0[1] * !SRC1_BCAST_DIM1, )==" "\n" |
107 | R"==(dims0[2] * !SRC1_BCAST_DIM2, dims0[3] * !SRC1_BCAST_DIM3, )==" "\n" |
108 | R"==(dims0[4] * !SRC1_BCAST_DIM4, dims0[5] * !SRC1_BCAST_DIM5); )==" "\n" |
109 | R"==(#endif )==" "\n" |
110 | R"==(int block_size = d1_block; )==" "\n" |
111 | R"==(if (dims0[0] >= DST_D0) { )==" "\n" |
112 | R"==(for (int ic = 0; ic < block_size; ++ic) { )==" "\n" |
113 | R"==(dst[dst_off] = DATA_ZERO; )==" "\n" |
114 | R"==(dst_off++; )==" "\n" |
115 | R"==(} )==" "\n" |
116 | R"==(return; )==" "\n" |
117 | R"==(} )==" "\n" |
118 | R"==(if (d1_init + block_size <= DST_D1) { )==" "\n" |
119 | R"==(for (int ic = 0; ic < block_size; ++ic) { )==" "\n" |
120 | R"==(float tmp_src0 = CONVERT_FLOAT_T(src0[src0_off]); )==" "\n" |
121 | R"==(float tmp_src1 = CONVERT_FLOAT_T(src1[src1_off]); )==" "\n" |
122 | R"==(float d = 0; )==" "\n" |
123 | R"==(#if WITH_SRC0_SCALE )==" "\n" |
124 | R"==(tmp_src0 = tmp_src0 * (*src0_scale); )==" "\n" |
125 | R"==(#endif )==" "\n" |
126 | R"==(#if WITH_SRC1_SCALE )==" "\n" |
127 | R"==(tmp_src1 = tmp_src1 * (*src1_scale); )==" "\n" |
128 | R"==(#endif )==" "\n" |
129 | R"==(d = binary_op(tmp_src0, tmp_src1); )==" "\n" |
130 | R"==(float dst_data; )==" "\n" |
131 | R"==(#if WITH_SUM )==" "\n" |
132 | R"==(dst_data = CONVERT_FLOAT_T(dst[dst_off]); )==" "\n" |
133 | R"==(#endif )==" "\n" |
134 | R"==(APPLY_POST_OPS_SERIAL(d, float, dst_data, float, dims0_po[0], 1, )==" "\n" |
135 | R"==(dims0_po[1], 1, dims0_po[2], 1, dims0_po[3], 1, dims0_po[4], )==" "\n" |
136 | R"==(1, dims0_po[5], 1); )==" "\n" |
137 | R"==(dst[dst_off] = TO_DST(d); )==" "\n" |
138 | R"==(#if USE_UNROLL_16B || SRC0_UNROLL_16B )==" "\n" |
139 | R"==(src0_off++; )==" "\n" |
140 | R"==(dst_off++; )==" "\n" |
141 | R"==(++dims0_po[1]; )==" "\n" |
142 | R"==(if (USE_UNROLL_16B && (SRC1_D1 > 1)) { )==" "\n" |
143 | R"==(src1_off++; )==" "\n" |
144 | R"==(} else if (SRC0_UNROLL_16B && (SRC1_D1 > 1)) { )==" "\n" |
145 | R"==(src1_off += SRC1_S1_0; )==" "\n" |
146 | R"==(} )==" "\n" |
147 | R"==(#endif )==" "\n" |
148 | R"==(} )==" "\n" |
149 | R"==(} else { )==" "\n" |
150 | R"==(for (int ic = 0; ic < DST_D1 - d1_init; ic++) { )==" "\n" |
151 | R"==(float tmp_src0 = CONVERT_FLOAT_T(src0[src0_off]); )==" "\n" |
152 | R"==(float tmp_src1 = CONVERT_FLOAT_T(src1[src1_off]); )==" "\n" |
153 | R"==(float d = 0; )==" "\n" |
154 | R"==(#if WITH_SRC0_SCALE )==" "\n" |
155 | R"==(tmp_src0 = tmp_src0 * (*src0_scale); )==" "\n" |
156 | R"==(#endif )==" "\n" |
157 | R"==(#if WITH_SRC1_SCALE )==" "\n" |
158 | R"==(tmp_src1 = tmp_src1 * (*src1_scale); )==" "\n" |
159 | R"==(#endif )==" "\n" |
160 | R"==(d = binary_op(tmp_src0, tmp_src1); )==" "\n" |
161 | R"==(float dst_data; )==" "\n" |
162 | R"==(#if WITH_SUM )==" "\n" |
163 | R"==(dst_data = CONVERT_FLOAT_T(dst[dst_off]); )==" "\n" |
164 | R"==(#endif )==" "\n" |
165 | R"==(APPLY_POST_OPS_SERIAL(d, float, dst_data, float, dims0_po[0], 1, )==" "\n" |
166 | R"==(dims0_po[1], 1, dims0_po[2], 1, dims0_po[3], 1, dims0_po[4], )==" "\n" |
167 | R"==(1, dims0_po[5], 1); )==" "\n" |
168 | R"==(dst[dst_off] = TO_DST(d); )==" "\n" |
169 | R"==(#if USE_UNROLL_16B || SRC0_UNROLL_16B )==" "\n" |
170 | R"==(src0_off++; )==" "\n" |
171 | R"==(dst_off++; )==" "\n" |
172 | R"==(++dims0_po[1]; )==" "\n" |
173 | R"==(if (USE_UNROLL_16B && (SRC1_D1 > 1)) { )==" "\n" |
174 | R"==(src1_off++; )==" "\n" |
175 | R"==(} else if (SRC0_UNROLL_16B && (SRC1_D1 > 1)) { )==" "\n" |
176 | R"==(src1_off += SRC1_S1_0; )==" "\n" |
177 | R"==(} )==" "\n" |
178 | R"==(#endif )==" "\n" |
179 | R"==(} )==" "\n" |
180 | R"==(#if DST_D1 != DST_PD1 )==" "\n" |
181 | R"==(for (int ic = 0; ic < min(DST_PD1 - DST_D1, block_size); ic++) { )==" "\n" |
182 | R"==(dst[dst_off] = DATA_ZERO; )==" "\n" |
183 | R"==(dst_off++; )==" "\n" |
184 | R"==(} )==" "\n" |
185 | R"==(#endif )==" "\n" |
186 | R"==(} )==" "\n" |
187 | R"==(} )==" "\n" |
188 | R"==(#endif )==" "\n" |
189 | R"==()==" ; |
190 | } |
191 | } |
192 | } |
193 | } |