1 | namespace dnnl { |
2 | namespace impl { |
3 | namespace gpu { |
4 | namespace ocl { |
5 | const char * = R"==(/******************************************************************************* )==" "\n" |
6 | R"==(* Copyright 2019-2022 Intel Corporation )==" "\n" |
7 | R"==(* )==" "\n" |
8 | R"==(* Licensed under the Apache License, Version 2.0 (the "License"); )==" "\n" |
9 | R"==(* you may not use this file except in compliance with the License. )==" "\n" |
10 | R"==(* You may obtain a copy of the License at )==" "\n" |
11 | R"==(* )==" "\n" |
12 | R"==(* http: )==" "\n" |
13 | R"==(* )==" "\n" |
14 | R"==(* Unless required by applicable law or agreed to in writing, software )==" "\n" |
15 | R"==(* distributed under the License is distributed on an "AS IS" BASIS, )==" "\n" |
16 | R"==(* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. )==" "\n" |
17 | R"==(* See the License for the specific language governing permissions and )==" "\n" |
18 | R"==(* limitations under the License. )==" "\n" |
19 | R"==(*******************************************************************************/ )==" "\n" |
20 | R"==(#ifndef GPU_OCL_RNN_RNN_TYPES_H )==" "\n" |
21 | R"==(#define GPU_OCL_RNN_RNN_TYPES_H )==" "\n" |
22 | R"==(#include "gpu/ocl/ocl_types.h" )==" "\n" |
23 | R"==(#if OUTPUT_DT_U8 )==" "\n" |
24 | R"==(#define TO_OUTPUT(x) convert_uchar_sat_rte(x) )==" "\n" |
25 | R"==(#elif OUTPUT_DT_S8 )==" "\n" |
26 | R"==(#define TO_OUTPUT(x) convert_char_sat_rte(x) )==" "\n" |
27 | R"==(#elif OUTPUT_DT_S32 )==" "\n" |
28 | R"==(#define TO_OUTPUT(x) convert_int_sat_rte(x) )==" "\n" |
29 | R"==(#else )==" "\n" |
30 | R"==(#define TO_OUTPUT(x) (x) )==" "\n" |
31 | R"==(#endif )==" "\n" |
32 | R"==(#if INPUT_DT_BF16 )==" "\n" |
33 | R"==(#define TO_INPUT(x) cvt_f32_to_bf16(x) )==" "\n" |
34 | R"==(#define TO_REF(x) cvt_bf16_to_f32(x) )==" "\n" |
35 | R"==(#else )==" "\n" |
36 | R"==(#define TO_INPUT(x) (x) )==" "\n" |
37 | R"==(#define TO_REF(x) (float)(x) )==" "\n" |
38 | R"==(#endif )==" "\n" |
39 | R"==(#if DT_F16 && !IS_FWD )==" "\n" |
40 | R"==(#error "FP16 is not supported for BWD" )==" "\n" |
41 | R"==(#endif )==" "\n" |
42 | R"==(#define OFFTYPE ulong )==" "\n" |
43 | R"==(#define TO_WS_STATE(x) TO_SRC(x) )==" "\n" |
44 | R"==(#define OFF6(i0, D0, i1, D1, i2, D2, i3, D3, i4, D4, i5, D5) \ )==" "\n" |
45 | R"==(((((((i0) * (D1) + (i1)) * (D2) + (i2)) * (D3) + (i3)) * (D4) + (i4)) \ )==" "\n" |
46 | R"==(* (D5) \ )==" "\n" |
47 | R"==(+ (i5)) )==" "\n" |
48 | R"==(#define OFF5(i0, D0, i1, D1, i2, D2, i3, D3, i4, D4) \ )==" "\n" |
49 | R"==((((((i0) * (D1) + (i1)) * (D2) + (i2)) * (D3) + (i3)) * (D4) + (i4)) )==" "\n" |
50 | R"==(#define OFF4(i0, D0, i1, D1, i2, D2, i3, D3) \ )==" "\n" |
51 | R"==(((((i0) * (D1) + (i1)) * (D2) + (i2)) * (D3) + (i3)) )==" "\n" |
52 | R"==(#define OFF3(i0, D0, i1, D1, i2, D2) (((i0) * (D1) + (i1)) * (D2) + (i2)) )==" "\n" |
53 | R"==(#define OFF2(i0, D0, i1, D1) ((i0) * (D1) + (i1)) )==" "\n" |
54 | R"==(#define OFF_WS_STATE(i0, i1, i2, i3, i4) \ )==" "\n" |
55 | R"==(OFF5((i0), N_LAYER + 1, (i1), N_DIR, (i2), N_ITER + 1, (i3), BATCH, (i4), \ )==" "\n" |
56 | R"==(STATES_WS_LD) )==" "\n" |
57 | R"==(#define OFF_SCRATCH_DIFF_STATES(i0, i1, i2, i3, i4, i5) \ )==" "\n" |
58 | R"==(OFF6((i0), N_LAYER + 1, (i1), N_DIR, (i2), N_STATES + 1, (i3), N_ITER + 1, \ )==" "\n" |
59 | R"==((i4), BATCH, (i5), SCRATCH_DIFF_STATES_LD) )==" "\n" |
60 | R"==(#define OFF_WS_GATES(i0, i1, i2, i3, i4, i5) \ )==" "\n" |
61 | R"==((i0) * N_DIR *N_ITER *BATCH *GATES_WS_LD + (i1)*N_ITER *BATCH *GATES_WS_LD \ )==" "\n" |
62 | R"==(+ (i2)*BATCH *GATES_WS_LD + (i3)*GATES_WS_LD + (i4)*DHC + (i5) )==" "\n" |
63 | R"==(#define OFF_WS_GRID_OFFSET(i0, i1, i2, i3, i4) \ )==" "\n" |
64 | R"==(OFF5((i0), N_LAYER + 1, (i1), N_DIR, (i2), N_ITER + 1, (i3), BATCH, (i4), \ )==" "\n" |
65 | R"==(DHC) )==" "\n" |
66 | R"==(#if N_ITER_SCRATCH_GATES == 1 )==" "\n" |
67 | R"==(#define OFF_SCRATCH_MEM(i0, i1, i2, i3) \ )==" "\n" |
68 | R"==((i1) * SCRATCH_GATES_LD + (i2)*DHC + (i3) )==" "\n" |
69 | R"==(#else )==" "\n" |
70 | R"==(#define OFF_SCRATCH_MEM(i0, i1, i2, i3) \ )==" "\n" |
71 | R"==((i0) * BATCH *SCRATCH_GATES_LD + (i1)*SCRATCH_GATES_LD + (i2)*DHC + (i3) )==" "\n" |
72 | R"==(#endif )==" "\n" |
73 | R"==(#define OFF_WS_BIAS(i0, i1, i2, i3) \ )==" "\n" |
74 | R"==(OFF4((i0), N_LAYER, (i1), N_DIR, (i2), N_BIAS, (i3), DHC) )==" "\n" |
75 | R"==(#define CELL_WS_GATES(i3, i4, i5) OFF_WS_GATES(0, 0, 0, i3, i4, i5) )==" "\n" |
76 | R"==(#define CELL_WS_STATE(i4, i5) OFF_WS_STATE(0, 0, 0, i4, i5) )==" "\n" |
77 | R"==(#define CELL_SCRATCH_MEM(i1, i2, i3) OFF_SCRATCH_MEM(0, i1, i2, i3) )==" "\n" |
78 | R"==(#define CELL_SCRATCH_DIFF_STATES(i2, i4, i5) \ )==" "\n" |
79 | R"==(OFF_SCRATCH_DIFF_STATES(0, 0, i2, 0, i4, i5) )==" "\n" |
80 | R"==(#define CELL_WS_GRID_COMP(i3, i4) OFF_WS_GRID_OFFSET(0, 0, 0, i3, i4) )==" "\n" |
81 | R"==(#define OFF_KER_BIAS(i0, i1) OFF2((i0), N_GATES, (i1), DHC) )==" "\n" |
82 | R"==(#define OFF_SCRATCH_DHG1(i0, i1) OFF2((i0), BATCH, (i1), SCRATCH_DIFF_STATES_LD) )==" "\n" |
83 | R"==(#define OFF_SCRATCH_CELL(i0, i1) OFF2((i0), BATCH, (i1), STATES_WS_LD) )==" "\n" |
84 | R"==(#define SRC_L_OFF(x0, x1, x2) \ )==" "\n" |
85 | R"==((((x0) % SRC_L_B0) * SRC_L_SB0 + ((x0) / SRC_L_B0) * SRC_L_S0 \ )==" "\n" |
86 | R"==(+ ((x1) % SRC_L_B1) * SRC_L_SB1 + ((x1) / SRC_L_B1) * SRC_L_S1 \ )==" "\n" |
87 | R"==(+ ((x2) % SRC_L_B2) * SRC_L_SB2 + ((x2) / SRC_L_B2) * SRC_L_S2) )==" "\n" |
88 | R"==(#define SRC_I_OFF(x0, x1, x2, x3) \ )==" "\n" |
89 | R"==((((x0) % SRC_I_B0) * SRC_I_SB0 + ((x0) / SRC_I_B0) * SRC_I_S0 \ )==" "\n" |
90 | R"==(+ ((x1) % SRC_I_B1) * SRC_I_SB1 + ((x1) / SRC_I_B1) * SRC_I_S1 \ )==" "\n" |
91 | R"==(+ ((x2) % SRC_I_B2) * SRC_I_SB2 + ((x2) / SRC_I_B2) * SRC_I_S2 \ )==" "\n" |
92 | R"==(+ ((x3) % SRC_I_B3) * SRC_I_SB3 + ((x3) / SRC_I_B3) * SRC_I_S3) )==" "\n" |
93 | R"==(#define SRC_I_C_OFF(x0, x1, x2, x3) \ )==" "\n" |
94 | R"==((((x0) % SRC_I_C_B0) * SRC_I_C_SB0 + ((x0) / SRC_I_C_B0) * SRC_I_C_S0 \ )==" "\n" |
95 | R"==(+ ((x1) % SRC_I_C_B1) * SRC_I_C_SB1 \ )==" "\n" |
96 | R"==(+ ((x1) / SRC_I_C_B1) * SRC_I_C_S1 \ )==" "\n" |
97 | R"==(+ ((x2) % SRC_I_C_B2) * SRC_I_C_SB2 \ )==" "\n" |
98 | R"==(+ ((x2) / SRC_I_C_B2) * SRC_I_C_S2 \ )==" "\n" |
99 | R"==(+ ((x3) % SRC_I_C_B3) * SRC_I_C_SB3 \ )==" "\n" |
100 | R"==(+ ((x3) / SRC_I_C_B3) * SRC_I_C_S3) )==" "\n" |
101 | R"==(#define DST_L_OFF(x0, x1, x2) \ )==" "\n" |
102 | R"==((((x0) % DST_L_B0) * DST_L_SB0 + ((x0) / DST_L_B0) * DST_L_S0 \ )==" "\n" |
103 | R"==(+ ((x1) % DST_L_B1) * DST_L_SB1 + ((x1) / DST_L_B1) * DST_L_S1 \ )==" "\n" |
104 | R"==(+ ((x2) % DST_L_B2) * DST_L_SB2 + ((x2) / DST_L_B2) * DST_L_S2) )==" "\n" |
105 | R"==(#define DST_I_OFF(x0, x1, x2, x3) \ )==" "\n" |
106 | R"==((((x0) % DST_I_B0) * DST_I_SB0 + ((x0) / DST_I_B0) * DST_I_S0 \ )==" "\n" |
107 | R"==(+ ((x1) % DST_I_B1) * DST_I_SB1 + ((x1) / DST_I_B1) * DST_I_S1 \ )==" "\n" |
108 | R"==(+ ((x2) % DST_I_B2) * DST_I_SB2 + ((x2) / DST_I_B2) * DST_I_S2 \ )==" "\n" |
109 | R"==(+ ((x3) % DST_I_B3) * DST_I_SB3 + ((x3) / DST_I_B3) * DST_I_S3) )==" "\n" |
110 | R"==(#define DST_I_C_OFF(x0, x1, x2, x3) \ )==" "\n" |
111 | R"==((((x0) % DST_I_C_B0) * DST_I_C_SB0 + ((x0) / DST_I_C_B0) * DST_I_C_S0 \ )==" "\n" |
112 | R"==(+ ((x1) % DST_I_C_B1) * DST_I_C_SB1 \ )==" "\n" |
113 | R"==(+ ((x1) / DST_I_C_B1) * DST_I_C_S1 \ )==" "\n" |
114 | R"==(+ ((x2) % DST_I_C_B2) * DST_I_C_SB2 \ )==" "\n" |
115 | R"==(+ ((x2) / DST_I_C_B2) * DST_I_C_S2 \ )==" "\n" |
116 | R"==(+ ((x3) % DST_I_C_B3) * DST_I_C_SB3 \ )==" "\n" |
117 | R"==(+ ((x3) / DST_I_C_B3) * DST_I_C_S3) )==" "\n" |
118 | R"==(#define BIAS_OFF(x0, x1, x2, x3) \ )==" "\n" |
119 | R"==((((x0) % BIAS_B0) * BIAS_SB0 + ((x0) / BIAS_B0) * BIAS_S0 \ )==" "\n" |
120 | R"==(+ ((x1) % BIAS_B1) * BIAS_SB1 + ((x1) / BIAS_B1) * BIAS_S1 \ )==" "\n" |
121 | R"==(+ ((x2) % BIAS_B2) * BIAS_SB2 + ((x2) / BIAS_B2) * BIAS_S2 \ )==" "\n" |
122 | R"==(+ ((x3) % BIAS_B3) * BIAS_SB3 + ((x3) / BIAS_B3) * BIAS_S3) )==" "\n" |
123 | R"==(#define DIFF_SRC_L_OFF(x0, x1, x2) \ )==" "\n" |
124 | R"==((((x0) % DIFF_SRC_L_B0) * DIFF_SRC_L_SB0 \ )==" "\n" |
125 | R"==(+ ((x0) / DIFF_SRC_L_B0) * DIFF_SRC_L_S0 \ )==" "\n" |
126 | R"==(+ ((x1) % DIFF_SRC_L_B1) * DIFF_SRC_L_SB1 \ )==" "\n" |
127 | R"==(+ ((x1) / DIFF_SRC_L_B1) * DIFF_SRC_L_S1 \ )==" "\n" |
128 | R"==(+ ((x2) % DIFF_SRC_L_B2) * DIFF_SRC_L_SB2 \ )==" "\n" |
129 | R"==(+ ((x2) / DIFF_SRC_L_B2) * DIFF_SRC_L_S2) )==" "\n" |
130 | R"==(#define DIFF_DST_L_OFF(x0, x1, x2) \ )==" "\n" |
131 | R"==((((x0) % DIFF_DST_L_B0) * DIFF_DST_L_SB0 \ )==" "\n" |
132 | R"==(+ ((x0) / DIFF_DST_L_B0) * DIFF_DST_L_S0 \ )==" "\n" |
133 | R"==(+ ((x1) % DIFF_DST_L_B1) * DIFF_DST_L_SB1 \ )==" "\n" |
134 | R"==(+ ((x1) / DIFF_DST_L_B1) * DIFF_DST_L_S1 \ )==" "\n" |
135 | R"==(+ ((x2) % DIFF_DST_L_B2) * DIFF_DST_L_SB2 \ )==" "\n" |
136 | R"==(+ ((x2) / DIFF_DST_L_B2) * DIFF_DST_L_S2) )==" "\n" |
137 | R"==(#define DIFF_SRC_I_OFF(x0, x1, x2, x3) \ )==" "\n" |
138 | R"==((((x0) % DIFF_SRC_I_B0) * DIFF_SRC_I_SB0 \ )==" "\n" |
139 | R"==(+ ((x0) / DIFF_SRC_I_B0) * DIFF_SRC_I_S0 \ )==" "\n" |
140 | R"==(+ ((x1) % DIFF_SRC_I_B1) * DIFF_SRC_I_SB1 \ )==" "\n" |
141 | R"==(+ ((x1) / DIFF_SRC_I_B1) * DIFF_SRC_I_S1 \ )==" "\n" |
142 | R"==(+ ((x2) % DIFF_SRC_I_B2) * DIFF_SRC_I_SB2 \ )==" "\n" |
143 | R"==(+ ((x2) / DIFF_SRC_I_B2) * DIFF_SRC_I_S2 \ )==" "\n" |
144 | R"==(+ ((x3) % DIFF_SRC_I_B3) * DIFF_SRC_I_SB3 \ )==" "\n" |
145 | R"==(+ ((x3) / DIFF_SRC_I_B3) * DIFF_SRC_I_S3) )==" "\n" |
146 | R"==(#define DIFF_DST_I_OFF(x0, x1, x2, x3) \ )==" "\n" |
147 | R"==((((x0) % DIFF_DST_I_B0) * DIFF_DST_I_SB0 \ )==" "\n" |
148 | R"==(+ ((x0) / DIFF_DST_I_B0) * DIFF_DST_I_S0 \ )==" "\n" |
149 | R"==(+ ((x1) % DIFF_DST_I_B1) * DIFF_DST_I_SB1 \ )==" "\n" |
150 | R"==(+ ((x1) / DIFF_DST_I_B1) * DIFF_DST_I_S1 \ )==" "\n" |
151 | R"==(+ ((x2) % DIFF_DST_I_B2) * DIFF_DST_I_SB2 \ )==" "\n" |
152 | R"==(+ ((x2) / DIFF_DST_I_B2) * DIFF_DST_I_S2 \ )==" "\n" |
153 | R"==(+ ((x3) % DIFF_DST_I_B3) * DIFF_DST_I_SB3 \ )==" "\n" |
154 | R"==(+ ((x3) / DIFF_DST_I_B3) * DIFF_DST_I_S3) )==" "\n" |
155 | R"==(#define DIFF_SRC_I_C_OFF(x0, x1, x2, x3) \ )==" "\n" |
156 | R"==((((x0) % DIFF_SRC_I_C_B0) * DIFF_SRC_I_C_SB0 \ )==" "\n" |
157 | R"==(+ ((x0) / DIFF_SRC_I_C_B0) * DIFF_SRC_I_C_S0 \ )==" "\n" |
158 | R"==(+ ((x1) % DIFF_SRC_I_C_B1) * DIFF_SRC_I_C_SB1 \ )==" "\n" |
159 | R"==(+ ((x1) / DIFF_SRC_I_C_B1) * DIFF_SRC_I_C_S1 \ )==" "\n" |
160 | R"==(+ ((x2) % DIFF_SRC_I_C_B2) * DIFF_SRC_I_C_SB2 \ )==" "\n" |
161 | R"==(+ ((x2) / DIFF_SRC_I_C_B2) * DIFF_SRC_I_C_S2 \ )==" "\n" |
162 | R"==(+ ((x3) % DIFF_SRC_I_C_B3) * DIFF_SRC_I_C_SB3 \ )==" "\n" |
163 | R"==(+ ((x3) / DIFF_SRC_I_C_B3) * DIFF_SRC_I_C_S3) )==" "\n" |
164 | R"==(#define DIFF_DST_I_C_OFF(x0, x1, x2, x3) \ )==" "\n" |
165 | R"==((((x0) % DIFF_DST_I_C_B0) * DIFF_DST_I_C_SB0 \ )==" "\n" |
166 | R"==(+ ((x0) / DIFF_DST_I_C_B0) * DIFF_DST_I_C_S0 \ )==" "\n" |
167 | R"==(+ ((x1) % DIFF_DST_I_C_B1) * DIFF_DST_I_C_SB1 \ )==" "\n" |
168 | R"==(+ ((x1) / DIFF_DST_I_C_B1) * DIFF_DST_I_C_S1 \ )==" "\n" |
169 | R"==(+ ((x2) % DIFF_DST_I_C_B2) * DIFF_DST_I_C_SB2 \ )==" "\n" |
170 | R"==(+ ((x2) / DIFF_DST_I_C_B2) * DIFF_DST_I_C_S2 \ )==" "\n" |
171 | R"==(+ ((x3) % DIFF_DST_I_C_B3) * DIFF_DST_I_C_SB3 \ )==" "\n" |
172 | R"==(+ ((x3) / DIFF_DST_I_C_B3) * DIFF_DST_I_C_S3) )==" "\n" |
173 | R"==(#define DIFF_BIAS_OFF(x0, x1, x2, x3) \ )==" "\n" |
174 | R"==((((x0) % DIFF_BIAS_B0) * DIFF_BIAS_SB0 \ )==" "\n" |
175 | R"==(+ ((x0) / DIFF_BIAS_B0) * DIFF_BIAS_S0 \ )==" "\n" |
176 | R"==(+ ((x1) % DIFF_BIAS_B1) * DIFF_BIAS_SB1 \ )==" "\n" |
177 | R"==(+ ((x1) / DIFF_BIAS_B1) * DIFF_BIAS_S1 \ )==" "\n" |
178 | R"==(+ ((x2) % DIFF_BIAS_B2) * DIFF_BIAS_SB2 \ )==" "\n" |
179 | R"==(+ ((x2) / DIFF_BIAS_B2) * DIFF_BIAS_S2 \ )==" "\n" |
180 | R"==(+ ((x3) % DIFF_BIAS_B3) * DIFF_BIAS_SB3 \ )==" "\n" |
181 | R"==(+ ((x3) / DIFF_BIAS_B3) * DIFF_BIAS_S3) )==" "\n" |
182 | R"==(#endif )==" "\n" |
183 | R"==()==" ; |
184 | } |
185 | } |
186 | } |
187 | } |