1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *gen9_eltwise_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2020-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_eltwise.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_post_ops.h" )==" "\n" |
22 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
23 | R"==(#define SIMD GWS_SGS_DEFAULT )==" "\n" |
24 | R"==(KERNEL_ATTR )==" "\n" |
25 | R"==(__kernel void gen9_eltwise_fwd(__global DATA_T *src, __global DATA_T *dst, )==" "\n" |
26 | R"==(int nelems, float alpha, float beta) { )==" "\n" |
27 | R"==(const uint grsize = get_local_size(0); )==" "\n" |
28 | R"==(const uint grid = get_group_id(0); )==" "\n" |
29 | R"==(const uint sgid = get_sub_group_id(); )==" "\n" |
30 | R"==(const uint lid = get_sub_group_local_id(); )==" "\n" |
31 | R"==(const uint gid = get_global_id(0); )==" "\n" |
32 | R"==(ptrdiff_t offset )==" "\n" |
33 | R"==(= (grid * grsize + sgid * get_max_sub_group_size()) * VECT_DT_N; )==" "\n" |
34 | R"==(__global BLOCK_DATA_T *read_pos = (__global BLOCK_DATA_T *)src + offset; )==" "\n" |
35 | R"==(__global BLOCK_DATA_T *write_pos = (__global BLOCK_DATA_T *)dst + offset; )==" "\n" |
36 | R"==(VECT_DATA_T val; )==" "\n" |
37 | R"==(const uint nel_per_read = SIMD * VECT_DT_N; )==" "\n" |
38 | R"==(if (!NELEMS_OVERFLOW || offset + nel_per_read < nelems) { )==" "\n" |
39 | R"==(val = AS_VECT_DATA_T(VECT_BLOCK_READ(read_pos)); )==" "\n" |
40 | R"==(} else { )==" "\n" |
41 | R"==(uint pos = offset + lid; )==" "\n" |
42 | R"==(for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) { )==" "\n" |
43 | R"==(val[i] = src[pos]; )==" "\n" |
44 | R"==(pos += SIMD; )==" "\n" |
45 | R"==(} )==" "\n" |
46 | R"==(} )==" "\n" |
47 | R"==(for (int i = 0; i < VECT_DT_N; ++i) { )==" "\n" |
48 | R"==(val[i] = CONVERT_DATA_T( )==" "\n" |
49 | R"==(fwd_eltwise(DATA_TO_REF(val[i]), alpha, beta, 1.0f)); )==" "\n" |
50 | R"==(} )==" "\n" |
51 | R"==(if (!NELEMS_OVERFLOW || offset + nel_per_read < nelems) { )==" "\n" |
52 | R"==(VECT_BLOCK_WRITE(write_pos, AS_VECT_BLOCK_DATA_T(val)); )==" "\n" |
53 | R"==(} else { )==" "\n" |
54 | R"==(uint pos = offset + lid; )==" "\n" |
55 | R"==(for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) { )==" "\n" |
56 | R"==(dst[pos] = val[i]; )==" "\n" |
57 | R"==(pos += SIMD; )==" "\n" |
58 | R"==(} )==" "\n" |
59 | R"==(} )==" "\n" |
60 | R"==(} )==" "\n" |
61 | R"==(KERNEL_ATTR )==" "\n" |
62 | R"==(__kernel void gen9_eltwise_bwd(__global DATA_T *src, __global DATA_T *diff_src, )==" "\n" |
63 | R"==(__global DATA_T *diff_dst, int nelems, float alpha, float beta) { )==" "\n" |
64 | R"==(const uint grsize = get_local_size(0); )==" "\n" |
65 | R"==(const uint grid = get_group_id(0); )==" "\n" |
66 | R"==(const uint sgid = get_sub_group_id(); )==" "\n" |
67 | R"==(const uint lid = get_sub_group_local_id(); )==" "\n" |
68 | R"==(ptrdiff_t offset = (grid * grsize + sgid * SIMD) * VECT_DT_N; )==" "\n" |
69 | R"==(__global BLOCK_DATA_T *src_pos = (__global BLOCK_DATA_T *)src + offset; )==" "\n" |
70 | R"==(__global BLOCK_DATA_T *diff_pos )==" "\n" |
71 | R"==(= (__global BLOCK_DATA_T *)diff_dst + offset; )==" "\n" |
72 | R"==(__global BLOCK_DATA_T *write_pos )==" "\n" |
73 | R"==(= (__global BLOCK_DATA_T *)diff_src + offset; )==" "\n" |
74 | R"==(VECT_DATA_T val_dd; )==" "\n" |
75 | R"==(VECT_DATA_T val_src; )==" "\n" |
76 | R"==(const uint nel_per_read = SIMD * VECT_DT_N; )==" "\n" |
77 | R"==(if (!NELEMS_OVERFLOW || offset + nel_per_read < nelems) { )==" "\n" |
78 | R"==(val_src = AS_VECT_DATA_T(VECT_BLOCK_READ(src_pos)); )==" "\n" |
79 | R"==(val_dd = AS_VECT_DATA_T(VECT_BLOCK_READ(diff_pos)); )==" "\n" |
80 | R"==(} else { )==" "\n" |
81 | R"==(uint pos = offset + lid; )==" "\n" |
82 | R"==(for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) { )==" "\n" |
83 | R"==(val_dd[i] = diff_dst[pos]; )==" "\n" |
84 | R"==(val_src[i] = src[pos]; )==" "\n" |
85 | R"==(pos += SIMD; )==" "\n" |
86 | R"==(} )==" "\n" |
87 | R"==(} )==" "\n" |
88 | R"==(for (int i = 0; i < VECT_DT_N; ++i) { )==" "\n" |
89 | R"==(val_dd[i] = CONVERT_DATA_T(bwd_eltwise( )==" "\n" |
90 | R"==(DATA_TO_REF(val_dd[i]), DATA_TO_REF(val_src[i]), alpha, beta)); )==" "\n" |
91 | R"==(} )==" "\n" |
92 | R"==(if (!NELEMS_OVERFLOW || offset + nel_per_read < nelems) { )==" "\n" |
93 | R"==(VECT_BLOCK_WRITE(write_pos, AS_VECT_BLOCK_DATA_T(val_dd)); )==" "\n" |
94 | R"==(} else { )==" "\n" |
95 | R"==(uint pos = offset + lid; )==" "\n" |
96 | R"==(for (int i = 0; i < VECT_DT_N && pos < nelems; ++i) { )==" "\n" |
97 | R"==(diff_src[pos] = val_dd[i]; )==" "\n" |
98 | R"==(pos += SIMD; )==" "\n" |
99 | R"==(} )==" "\n" |
100 | R"==(} )==" "\n" |
101 | R"==(} )==" "\n" |
102 | R"==()==" ; |
103 | } |
104 | } |
105 | } |
106 | } |