1namespace dnnl {
2namespace impl {
3namespace gpu {
4namespace ocl {
5const char *combined_reduction_kernel = R"==(/******************************************************************************* )==""\n"
6R"==(* Copyright 2021-2022 Intel Corporation )==""\n"
7R"==(* )==""\n"
8R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==""\n"
9R"==(* you may not use this file except in compliance with the License. )==""\n"
10R"==(* You may obtain a copy of the License at )==""\n"
11R"==(* )==""\n"
12R"==(* http: )==""\n"
13R"==(* )==""\n"
14R"==(* Unless required by applicable law or agreed to in writing, software )==""\n"
15R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==""\n"
16R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==""\n"
17R"==(* See the License for the specific language governing permissions and )==""\n"
18R"==(* limitations under the License. )==""\n"
19R"==(*******************************************************************************/ )==""\n"
20R"==(#include "gpu/ocl/ocl_post_ops.h" )==""\n"
21R"==(#include "gpu/ocl/ocl_types.h" )==""\n"
22R"==(#if defined(IS_MAX) )==""\n"
23R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_MIN) )==""\n"
24R"==(#elif defined(IS_MIN) )==""\n"
25R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_MAX) )==""\n"
26R"==(#elif defined(IS_MUL) )==""\n"
27R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_ONE) )==""\n"
28R"==(#else )==""\n"
29R"==(#define INIT_ACC TO_DEF_ACC_DATA_T(DATA_ZERO) )==""\n"
30R"==(#endif )==""\n"
31R"==(#if defined(SRC_DT_S8) || defined(SRC_DT_U8) || defined(SRC_DT_S32) )==""\n"
32R"==(#define MAX_FUNC max )==""\n"
33R"==(#define MIN_FUNC min )==""\n"
34R"==(#else )==""\n"
35R"==(#define MAX_FUNC fmax )==""\n"
36R"==(#define MIN_FUNC fmin )==""\n"
37R"==(#endif )==""\n"
38R"==(#if defined(IS_MAX) )==""\n"
39R"==(#define ACCUMULATE_INITIAL(x, y) MAX_FUNC(x, y) )==""\n"
40R"==(#elif defined(IS_MIN) )==""\n"
41R"==(#define ACCUMULATE_INITIAL(x, y) MIN_FUNC(x, y) )==""\n"
42R"==(#elif defined(IS_MEAN) || defined(IS_SUM) )==""\n"
43R"==(#define ACCUMULATE_INITIAL(x, y) (x + y) )==""\n"
44R"==(#elif defined(IS_MUL) )==""\n"
45R"==(#define ACCUMULATE_INITIAL(x, y) (x * y) )==""\n"
46R"==(#else )==""\n"
47R"==(#define ACCUMULATE_INITIAL(x, y) (x + pow(fabs(y), POWER)) )==""\n"
48R"==(#endif )==""\n"
49R"==(#if defined(IS_MAX) || defined(IS_MIN) || defined(IS_MEAN) || defined(IS_MUL) )==""\n"
50R"==(#define ACCUMULATE_FURTHER ACCUMULATE_INITIAL )==""\n"
51R"==(#else )==""\n"
52R"==(#define ACCUMULATE_FURTHER(x, y) (x + y) )==""\n"
53R"==(#endif )==""\n"
54R"==(#if IS_FIRST )==""\n"
55R"==(#define ACCUMULATE ACCUMULATE_INITIAL )==""\n"
56R"==(#else )==""\n"
57R"==(#define ACCUMULATE ACCUMULATE_FURTHER )==""\n"
58R"==(#endif )==""\n"
59R"==(#if defined(IS_MEAN) )==""\n"
60R"==(#define FINALIZE(x) (x / DIV) )==""\n"
61R"==(#elif defined(IS_LP_MAX) )==""\n"
62R"==(#define FINALIZE(x) rootn(fmax(x, EPS), POWER) )==""\n"
63R"==(#elif defined(IS_LP_SUM) )==""\n"
64R"==(#define FINALIZE(x) rootn(x + EPS, POWER) )==""\n"
65R"==(#elif defined(IS_P_MAX) )==""\n"
66R"==(#define FINALIZE(x) fmax(x, EPS) )==""\n"
67R"==(#elif defined(IS_P_SUM) )==""\n"
68R"==(#define FINALIZE(x) (x + EPS) )==""\n"
69R"==(#else )==""\n"
70R"==(#define FINALIZE(x) (x) )==""\n"
71R"==(#endif )==""\n"
72R"==(#define _SRC_OFF(outer, reduction, inner) \ )==""\n"
73R"==((outer) * REDUCTION_SIZE *INNER_DIM_SIZE + (reduction)*INNER_DIM_SIZE \ )==""\n"
74R"==(+ (inner) )==""\n"
75R"==(#define _DST_OFF(outer, reduction_chunk, inner) \ )==""\n"
76R"==((outer) * INNER_DIM_SIZE *REDUCTION_END_SIZE \ )==""\n"
77R"==(+ (reduction_chunk)*INNER_DIM_SIZE + inner )==""\n"
78R"==(KERNEL_ATTR )==""\n"
79R"==(__kernel void combined_reduce( )==""\n"
80R"==(__global SRC_DATA_T *src, __global DST_DATA_T *dst) { )==""\n"
81R"==(const int outer_idx = (get_global_id(0) / OUTER_DIM_STRIDE); )==""\n"
82R"==(const int reduction_chunk )==""\n"
83R"==(= (get_global_id(0) / PADDED_INNER_DIM_SIZE) % REDUCTION_CHUNK_SIZE; )==""\n"
84R"==(const int inner = (get_global_id(0) % PADDED_INNER_DIM_SIZE); )==""\n"
85R"==(if (outer_idx >= OUTER_DIM_SIZE) return; )==""\n"
86R"==(const int inner_idx = inner % INNER_DIM_SIZE; )==""\n"
87R"==(const int red_off = inner / INNER_DIM_SIZE; )==""\n"
88R"==(const int reduction_idx )==""\n"
89R"==(= reduction_chunk * REDUCTIONS_PER_WI * INNER_DIMS_PER_WI + red_off; )==""\n"
90R"==(const int reduction_idx_start = reduction_idx - red_off; )==""\n"
91R"==(const int inner_idx_start = inner - get_sub_group_local_id(); )==""\n"
92R"==(if (inner / INNER_DIMS_PER_WI >= INNER_DIM_SIZE) { return; } )==""\n"
93R"==(const int dst_off )==""\n"
94R"==(= _DST_OFF(outer_idx, reduction_chunk + red_off, inner_idx); )==""\n"
95R"==(DEF_ACC_DATA_T acc = INIT_ACC; )==""\n"
96R"==(for (int off = 0; off < REDUCTIONS_PER_WI )==""\n"
97R"==(&& off * INNER_DIMS_PER_WI + reduction_idx < REDUCTION_SIZE; )==""\n"
98R"==(off++) { )==""\n"
99R"==(#if WITH_BLOCK_READ )==""\n"
100R"==(const int src_off = _SRC_OFF(outer_idx, )==""\n"
101R"==(off * INNER_DIMS_PER_WI + reduction_idx_start, inner_idx_start); )==""\n"
102R"==(const SRC_DATA_T src_val = AS_DATA_T( )==""\n"
103R"==(BLOCK_READ((const __global BLOCK_DATA_T *)&src[src_off])); )==""\n"
104R"==(#else )==""\n"
105R"==(const int src_off = _SRC_OFF( )==""\n"
106R"==(outer_idx, off * INNER_DIMS_PER_WI + reduction_idx, inner_idx); )==""\n"
107R"==(const SRC_DATA_T src_val = src[src_off]; )==""\n"
108R"==(#endif )==""\n"
109R"==(const DEF_ACC_DATA_T prev = acc; )==""\n"
110R"==(acc = ACCUMULATE(acc, TO_DEF_ACC_DATA_T(src_val)); )==""\n"
111R"==(} )==""\n"
112R"==(for (int i = 1; i < INNER_DIMS_PER_WI; i++) { )==""\n"
113R"==(const DEF_ACC_DATA_T other )==""\n"
114R"==(= intel_sub_group_shuffle_down(acc, INIT_ACC, INNER_DIM_SIZE); )==""\n"
115R"==(if (get_sub_group_local_id() < INNER_DIM_SIZE) { )==""\n"
116R"==(acc = ACCUMULATE_FURTHER(acc, other); )==""\n"
117R"==(} else { )==""\n"
118R"==(acc = other; )==""\n"
119R"==(} )==""\n"
120R"==(} )==""\n"
121R"==(if (get_sub_group_local_id() < INNER_DIM_SIZE) { )==""\n"
122R"==(#if IS_FINAL )==""\n"
123R"==(float res = convert_float(acc); )==""\n"
124R"==(res = FINALIZE(res); )==""\n"
125R"==(dst[dst_off] = TO_DST(res); )==""\n"
126R"==(#else )==""\n"
127R"==(dst[dst_off] = acc; )==""\n"
128R"==(#endif )==""\n"
129R"==(} )==""\n"
130R"==(} )==""\n"
131R"==()==";
132}
133}
134}
135}