1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char *gemm_post_ops_inner_product_kernel = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2020 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#include "gpu/ocl/ocl_post_ops.h" )==" "\n" |
21 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
22 | R"==(#ifdef DST_DT_F32 )==" "\n" |
23 | R"==(#define DST_TO_ACC(x) (x) )==" "\n" |
24 | R"==(#else )==" "\n" |
25 | R"==(#define DST_TO_ACC(x) TO_DEF_ACC_DATA_T(x) )==" "\n" |
26 | R"==(#endif )==" "\n" |
27 | R"==(#ifdef BIAS_DT_F32 )==" "\n" |
28 | R"==(#define BIAS_TO_ACC(x) (x) )==" "\n" |
29 | R"==(#else )==" "\n" |
30 | R"==(#define BIAS_TO_ACC(x) TO_DEF_ACC_DATA_T(x) )==" "\n" |
31 | R"==(#endif )==" "\n" |
32 | R"==(#ifdef SRC_DT_F32 )==" "\n" |
33 | R"==(#define SRC_TO_ACC(x) (x) )==" "\n" |
34 | R"==(#else )==" "\n" |
35 | R"==(#define SRC_TO_ACC(x) TO_DEF_ACC_DATA_T(x) )==" "\n" |
36 | R"==(#endif )==" "\n" |
37 | R"==(__kernel void gemm_post_ops_inner_product(__global SRC_DATA_T *src, )==" "\n" |
38 | R"==(__global BIAS_DATA_T *bias, __global DST_DATA_T *dst POST_OP_ARGS, )==" "\n" |
39 | R"==(__global SPAD_DATA_T *scratchpad, global float *scales) { )==" "\n" |
40 | R"==(const size_t mb = get_global_id(0) / OC; )==" "\n" |
41 | R"==(const size_t oc = get_global_id(0) % OC; )==" "\n" |
42 | R"==(const size_t data_idx = mb * OC + oc; )==" "\n" |
43 | R"==(#if USE_TEMP_DST == 1 )==" "\n" |
44 | R"==(ACC_DATA_T acc = SRC_TO_ACC(scratchpad[data_idx]); )==" "\n" |
45 | R"==(#else )==" "\n" |
46 | R"==(ACC_DATA_T acc = SRC_TO_ACC(src[data_idx]); )==" "\n" |
47 | R"==(#endif )==" "\n" |
48 | R"==(#if WITH_BIAS == 1 )==" "\n" |
49 | R"==(acc += BIAS_TO_ACC(bias[oc]); )==" "\n" |
50 | R"==(#endif )==" "\n" |
51 | R"==(#if WITH_SCALES )==" "\n" |
52 | R"==(#if SCALES_COMMON )==" "\n" |
53 | R"==(const float scale = scales[0]; )==" "\n" |
54 | R"==(#elif SCALES_PER_OC )==" "\n" |
55 | R"==(const float scale = scales[oc]; )==" "\n" |
56 | R"==(#else )==" "\n" |
57 | R"==(#error "Unsupported scale type" )==" "\n" |
58 | R"==(#endif )==" "\n" |
59 | R"==(acc *= scale; )==" "\n" |
60 | R"==(#endif )==" "\n" |
61 | R"==(float sum_src; )==" "\n" |
62 | R"==(#if WITH_SUM )==" "\n" |
63 | R"==(sum_src = DST_TO_ACC(dst[data_idx]); )==" "\n" |
64 | R"==(#endif )==" "\n" |
65 | R"==(float accumulator = acc; )==" "\n" |
66 | R"==(APPLY_POST_OPS_SERIAL_BINARY_2D( )==" "\n" |
67 | R"==(accumulator, float, sum_src, float, mb, 1, oc, 1); )==" "\n" |
68 | R"==(dst[data_idx] = TO_DST(accumulator); )==" "\n" |
69 | R"==(} )==" "\n" |
70 | R"==()==" ; |
71 | } |
72 | } |
73 | } |
74 | } |