1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "gpu/jit/jit_eltwise_injector.hpp" |
18 | #include "common/impl_registration.hpp" |
19 | |
20 | #include <limits> |
21 | |
22 | namespace dnnl { |
23 | namespace impl { |
24 | namespace gpu { |
25 | namespace jit { |
26 | |
27 | using namespace ngen; |
28 | |
29 | template <gpu_gen_t hw> |
30 | int jit_eltwise_injector_f32<hw>::min_scratch_regs() { |
31 | using namespace alg_kind; |
32 | if (is_fwd_) { |
33 | switch (alg_) { |
34 | case eltwise_elu: |
35 | case eltwise_elu_use_dst_for_bwd: return 1; |
36 | case eltwise_exp: |
37 | case eltwise_exp_use_dst_for_bwd: return 0; |
38 | case eltwise_gelu_erf: return 4; |
39 | case eltwise_hardsigmoid: return 0; |
40 | case eltwise_hardswish: return 1; |
41 | case eltwise_log: return 0; |
42 | case eltwise_mish: return 4; |
43 | case eltwise_pow: return 1; |
44 | case eltwise_relu: |
45 | case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0.f) ? 0 : 1; |
46 | case eltwise_abs: return 0; |
47 | case eltwise_soft_relu: return 1; |
48 | case eltwise_sqrt: |
49 | case eltwise_sqrt_use_dst_for_bwd: return 0; |
50 | case eltwise_square: return 0; |
51 | case eltwise_swish: return 1; |
52 | case eltwise_tanh: |
53 | case eltwise_tanh_use_dst_for_bwd: return 2; |
54 | case eltwise_round: return 0; |
55 | case eltwise_linear: return 0; |
56 | case eltwise_clip: |
57 | case eltwise_clip_v2: |
58 | case eltwise_clip_v2_use_dst_for_bwd: return 0; |
59 | case eltwise_gelu_tanh: return 2; |
60 | case eltwise_logistic: |
61 | case eltwise_logistic_use_dst_for_bwd: return 0; |
62 | default: assert(!"unsupported eltwise algorithm" ); |
63 | } |
64 | } else { |
65 | switch (alg_) { |
66 | case eltwise_relu: return 1; |
67 | case eltwise_abs: return 1; |
68 | case eltwise_square: return 0; |
69 | case eltwise_linear: return 0; |
70 | case eltwise_clip: return 1; |
71 | case eltwise_gelu_tanh: return 2; |
72 | default: assert(!"unsupported eltwise algorithm" ); |
73 | } |
74 | } |
75 | return 0; |
76 | } |
77 | |
78 | template <gpu_gen_t hw> |
79 | int jit_eltwise_injector_f32<hw>::preferred_scratch_regs() { |
80 | using namespace alg_kind; |
81 | if (is_fwd_) { |
82 | switch (alg_) { |
83 | case eltwise_elu: |
84 | case eltwise_elu_use_dst_for_bwd: return 8; |
85 | case eltwise_gelu_erf: return 8; |
86 | case eltwise_hardswish: return 8; |
87 | case eltwise_mish: return 8; |
88 | case eltwise_relu: |
89 | case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0.f) ? 0 : 8; |
90 | case eltwise_tanh: return 8; |
91 | case eltwise_gelu_tanh: return 8; |
92 | case eltwise_soft_relu: return 8; |
93 | case eltwise_swish: return 8; |
94 | default: break; |
95 | } |
96 | } else { |
97 | switch (alg_) { |
98 | case eltwise_gelu_tanh: return 8; |
99 | default: break; |
100 | } |
101 | } |
102 | return min_scratch_regs(); |
103 | } |
104 | |
105 | template <gpu_gen_t hw> |
106 | int jit_eltwise_injector_f32<hw>::max_batch_size() { |
107 | using namespace alg_kind; |
108 | auto ss = scratch_.getLen(); |
109 | |
110 | if (is_fwd_) { |
111 | switch (alg_) { |
112 | case eltwise_relu: |
113 | case eltwise_relu_use_dst_for_bwd: |
114 | if (alpha_ == 0.) |
115 | break; |
116 | else |
117 | return ss; |
118 | case eltwise_elu: |
119 | case eltwise_elu_use_dst_for_bwd: |
120 | case eltwise_hardswish: |
121 | case eltwise_pow: |
122 | case eltwise_soft_relu: |
123 | case eltwise_swish: return ss; |
124 | case eltwise_tanh: |
125 | case eltwise_mish: |
126 | case eltwise_gelu_erf: return ss / min_scratch_regs(); |
127 | case eltwise_gelu_tanh: return ss & ~1; |
128 | default: break; |
129 | } |
130 | } else { |
131 | switch (alg_) { |
132 | case eltwise_gelu_tanh: return ss / 2; |
133 | default: break; |
134 | } |
135 | } |
136 | |
137 | return 128; |
138 | } |
139 | |
140 | template <gpu_gen_t hw> |
141 | int jit_eltwise_injector_f32<hw>::phase_count(alg_kind_t alg) { |
142 | using namespace alg_kind; |
143 | |
144 | if (is_fwd_) { |
145 | switch (alg) { |
146 | case eltwise_elu: |
147 | case eltwise_elu_use_dst_for_bwd: return 5; |
148 | case eltwise_exp: |
149 | case eltwise_exp_use_dst_for_bwd: return 2; |
150 | case eltwise_gelu_erf: return 25; |
151 | case eltwise_hardsigmoid: return 4; |
152 | case eltwise_hardswish: return 5; |
153 | case eltwise_log: return 2; |
154 | case eltwise_mish: |
155 | return phase_count(alg_kind::eltwise_soft_relu) |
156 | + phase_count(alg_kind::eltwise_tanh) + 1; |
157 | case eltwise_pow: return 6; |
158 | case eltwise_relu: |
159 | case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0) ? 1 : 2; |
160 | case eltwise_soft_relu: return 10; |
161 | case eltwise_swish: return 5; |
162 | case eltwise_tanh: |
163 | case eltwise_tanh_use_dst_for_bwd: |
164 | return (use_tanh_compat()) ? 9 : 6; |
165 | case eltwise_linear: return 2; |
166 | case eltwise_clip: |
167 | case eltwise_clip_v2: |
168 | case eltwise_clip_v2_use_dst_for_bwd: return 2; |
169 | case eltwise_gelu_tanh: return 8; |
170 | case eltwise_logistic: |
171 | case eltwise_logistic_use_dst_for_bwd: return 4; |
172 | default: break; |
173 | } |
174 | } else { |
175 | switch (alg) { |
176 | case eltwise_abs: return 2; |
177 | case eltwise_clip: return 4; |
178 | case eltwise_gelu_tanh: return 14; |
179 | default: break; |
180 | } |
181 | } |
182 | |
183 | return 1; |
184 | } |
185 | |
186 | template <gpu_gen_t hw> |
187 | void jit_eltwise_injector_f32<hw>::relu_zero_ns_compute_fwd( |
188 | int simd, const ngen::GRF &r) { |
189 | h->max_(simd, r, r, 0.f); |
190 | } |
191 | |
192 | template <gpu_gen_t hw> |
193 | void jit_eltwise_injector_f32<hw>::relu_compute_fwd( |
194 | int simd, const ngen::GRF &r, int phase, int off) { |
195 | auto temp = scratch_[off].f(); |
196 | switch (phase) { |
197 | case 0: h->mul(simd, temp, r, alpha_); break; |
198 | case 1: h->csel(simd | le | f0[0], r, temp, r, r); break; |
199 | default: assert(!"invalid phase" ); |
200 | } |
201 | } |
202 | |
203 | template <gpu_gen_t hw> |
204 | void jit_eltwise_injector_f32<hw>::abs_compute_fwd( |
205 | int simd, const ngen::GRF &r) { |
206 | h->mov(simd, r, abs(r)); |
207 | } |
208 | |
209 | template <gpu_gen_t hw> |
210 | void jit_eltwise_injector_f32<hw>::soft_relu_compute_fwd_inner(int simd, |
211 | const ngen::GRF &input, const ngen::GRF &temp, const ngen::GRF &dest, |
212 | int phase, int off, float alpha) { |
213 | const float exp_overflow_bound = 88.72283172607421875; |
214 | const float log2e = 1.44269502162933349609375f; |
215 | const float reciproc_log2e = 1.f / log2e; // 1 / log_2(e) |
216 | switch (phase) { |
217 | case 0: h->mul(simd, temp, input, alpha); break; |
218 | case 1: h->add(simd, dest, input, -exp_overflow_bound); break; |
219 | case 2: h->csel(simd | le | f0[0], dest, dest, temp, dest); break; |
220 | case 3: h->mul(simd, temp, temp, log2e); break; |
221 | case 4: h->eexp(simd, temp, temp); break; |
222 | case 5: h->add(simd, temp, temp, 1.f); break; |
223 | case 6: h->log(simd, temp, temp); break; |
224 | case 7: h->mul(simd, temp, temp, reciproc_log2e); break; |
225 | case 8: h->csel(simd | le | f0[0], temp, temp, dest, dest); break; |
226 | case 9: h->mul(simd, dest, temp, 1.f / alpha); break; |
227 | default: assert(!"invalid phase" ); |
228 | } |
229 | } |
230 | |
231 | template <gpu_gen_t hw> |
232 | void jit_eltwise_injector_f32<hw>::soft_relu_compute_fwd( |
233 | int simd, const ngen::GRF &r, int phase, int off) { |
234 | auto temp = scratch_[off].f(); |
235 | soft_relu_compute_fwd_inner(simd, r, temp, r, phase, off, alpha_); |
236 | } |
237 | |
238 | template <gpu_gen_t hw> |
239 | void jit_eltwise_injector_f32<hw>::sqrt_compute_fwd( |
240 | int simd, const ngen::GRF &r) { |
241 | h->sqt(simd, r, r); |
242 | } |
243 | |
244 | template <gpu_gen_t hw> |
245 | void jit_eltwise_injector_f32<hw>::square_compute_fwd( |
246 | int simd, const ngen::GRF &r) { |
247 | h->mul(simd, r, r, r); |
248 | } |
249 | |
250 | template <gpu_gen_t hw> |
251 | void jit_eltwise_injector_f32<hw>::tanh_compute_fwd( |
252 | int simd, const ngen::GRF &r, int phase, int off, int batch) { |
253 | const float log2e = 1.44269502162933349609375f; // log_2(e) |
254 | auto one_half = scratch_[0].f(7); |
255 | auto a = scratch_[off + batch].f(); |
256 | switch (phase) { |
257 | case 0: h->mul(simd, a, abs(r), 2.f * log2e); break; |
258 | case 1: h->exp(simd, a, a); break; |
259 | case 2: h->mad(simd, a, one_half, a, one_half); break; |
260 | case 3: h->inv(simd, a, a); break; |
261 | case 4: h->add(simd, a, -a, 1.f); break; |
262 | case 5: h->csel(simd | ge | f0[0], r, a, -a, r); break; |
263 | default: assert(!"invalid phase" ); |
264 | } |
265 | } |
266 | |
267 | template <gpu_gen_t hw> |
268 | void jit_eltwise_injector_f32<hw>::tanh_compute_fwd_compat( |
269 | int simd, const ngen::GRF &r, int phase, int off, int batch) { |
270 | // This approximation of tanh(x) does not use the math.exp instruction |
271 | // that seems to be faulty on DG2-128; the exact formula is as follows: |
272 | // R = max(min(0.0519867*x*((x^2 + k)^2 + l)/((x^2 + m)^2 + n), 1), -1) |
273 | // Both absolute and relative errors are <7*10^-5 \forall x \in \mathbb R |
274 | auto k = scratch_[0].f(4); |
275 | auto l = scratch_[0].f(5); |
276 | auto m = scratch_[0].f(6); |
277 | auto n = scratch_[0].f(7); |
278 | auto a = scratch_[off + batch].f(); |
279 | switch (phase) { |
280 | case 0: h->mad(simd, a, m, r, r); break; |
281 | case 1: h->mad(simd, a, n, a, a); break; |
282 | case 2: h->inv(simd, a, a); break; |
283 | case 3: h->mul(simd, a, a, r); break; |
284 | case 4: h->mad(simd, r, k, r, r); break; |
285 | case 5: h->mad(simd, r, l, r, r); break; |
286 | case 6: h->mul(simd, r, r, 0.0519867f); break; // 0.051986694f |
287 | case 7: h->mul(simd | sat, r, r, abs(a)); break; |
288 | case 8: h->csel(simd | ge | f0[0], r, r, -r, a); break; |
289 | default: assert(!"invalid phase" ); |
290 | } |
291 | } |
292 | |
293 | template <gpu_gen_t hw> |
294 | void jit_eltwise_injector_f32<hw>::round_compute_fwd( |
295 | int simd, const ngen::GRF &r) { |
296 | h->rnde(simd, r, r); |
297 | } |
298 | |
299 | template <gpu_gen_t hw> |
300 | void jit_eltwise_injector_f32<hw>::swish_compute_fwd( |
301 | int simd, const ngen::GRF &r, int phase, int off) { |
302 | const float log2e = 1.442695f; // log_2(e) |
303 | auto temp = scratch_[off].f(); |
304 | switch (phase) { |
305 | case 0: h->mul(simd, temp, r, -1.f * log2e * alpha_); break; |
306 | case 1: h->exp(simd, temp, temp); break; |
307 | case 2: h->add(simd, temp, temp, 1.f); break; |
308 | case 3: h->inv(simd, temp, temp); break; |
309 | case 4: h->mul(simd, r, r, temp); break; |
310 | default: assert(!"invalid phase" ); |
311 | } |
312 | } |
313 | |
314 | template <gpu_gen_t hw> |
315 | void jit_eltwise_injector_f32<hw>::linear_compute_fwd( |
316 | int simd, const ngen::GRF &r, int phase) { |
317 | switch (phase) { |
318 | case 0: h->mul(simd, r, r, alpha_); break; |
319 | case 1: h->add(simd, r, r, beta_); break; |
320 | default: assert(!"invalid phase" ); |
321 | } |
322 | } |
323 | |
324 | template <gpu_gen_t hw> |
325 | void jit_eltwise_injector_f32<hw>::clip_compute_fwd( |
326 | int simd, const ngen::GRF &r, int phase, float alpha, float beta) { |
327 | switch (phase) { |
328 | case 0: h->max_(simd, r, r, alpha); break; |
329 | case 1: h->min_(simd, r, r, beta); break; |
330 | default: assert(!"invalid phase" ); |
331 | } |
332 | } |
333 | |
334 | template <gpu_gen_t hw> |
335 | void jit_eltwise_injector_f32<hw>::gelu_tanh_compute_fwd( |
336 | int simd, const ngen::GRF &r, int phase, int off) { |
337 | |
338 | const float k = 0.044715f; |
339 | const float sqrt_2_over_pi = 0.7978845f; // sqrt(2/pi) |
340 | const float log2e = 1.442695f; // log_2(e) |
341 | |
342 | int msimd = simd; |
343 | if (hw == gpu_xe_hp) |
344 | msimd = 16; // workaround for intermittent hang with DPAS+EM |
345 | |
346 | auto a = scratch_[off].f(); |
347 | switch (phase) { |
348 | case 0: h->mul(simd, a, r, r); break; |
349 | case 1: h->mul(simd, a, a, k); break; |
350 | case 2: h->mad(simd, a, r, a, r); break; |
351 | case 3: h->mul(simd, a, a, -2 * sqrt_2_over_pi * log2e); break; |
352 | case 4: h->exp(msimd, a, a); break; |
353 | case 5: h->add(simd, a, a, 1.0f); break; |
354 | case 6: h->inv(msimd, a, a); break; |
355 | case 7: h->mul(simd, r, a, r); break; |
356 | default: assert(!"invalid phase" ); |
357 | } |
358 | } |
359 | |
360 | template <gpu_gen_t hw> |
361 | void jit_eltwise_injector_f32<hw>::logistic_compute_fwd( |
362 | int simd, const ngen::GRF &r, int phase) { |
363 | const float log2e = 1.442695f; // log_2(e) |
364 | switch (phase) { |
365 | case 0: h->mul(simd, r, r, -1.f * log2e); break; |
366 | case 1: h->exp(simd, r, r); break; |
367 | case 2: h->add(simd, r, r, 1.f); break; |
368 | case 3: h->inv(simd, r, r); break; |
369 | default: assert(!"invalid phase" ); |
370 | } |
371 | } |
372 | |
373 | template <gpu_gen_t hw> |
374 | void jit_eltwise_injector_f32<hw>::relu_prepare_bwd() { |
375 | auto neg_slope = scratch_[0].f(0); |
376 | auto pos_slope = scratch_[0].f(4); |
377 | h->mov(1, neg_slope, alpha_); |
378 | h->mov(1, pos_slope, 1.f); |
379 | } |
380 | |
381 | template <gpu_gen_t hw> |
382 | void jit_eltwise_injector_f32<hw>::relu_compute_bwd( |
383 | int simd, const ngen::GRF &r) { |
384 | auto neg_slope = scratch_[0].f(0); |
385 | auto pos_slope = scratch_[0].f(4); |
386 | h->csel(simd | le | f0[0], r, neg_slope, pos_slope, r); |
387 | } |
388 | |
389 | template <gpu_gen_t hw> |
390 | void jit_eltwise_injector_f32<hw>::abs_prepare_bwd() { |
391 | auto neg_one = scratch_[0].f(0); |
392 | auto pos_one = scratch_[0].f(4); |
393 | h->mov(1, neg_one, -1.f); |
394 | h->mov(1, pos_one, 1.f); |
395 | } |
396 | |
397 | template <gpu_gen_t hw> |
398 | void jit_eltwise_injector_f32<hw>::clip_prepare_bwd() { |
399 | auto pos_inf_imm = Immediate(std::numeric_limits<float>::infinity()); |
400 | auto zero = scratch_[0].f(0); |
401 | auto one = scratch_[0].f(1); |
402 | auto pos_inf = scratch_[0].f(2); |
403 | h->mov(1, zero, 0.f); |
404 | h->mov(1, one, 1.f); |
405 | h->mov(1, pos_inf, pos_inf_imm); |
406 | } |
407 | |
408 | template <gpu_gen_t hw> |
409 | void jit_eltwise_injector_f32<hw>::tanh_prepare_fwd() { |
410 | auto one_half = scratch_[0].f(7); |
411 | h->mov(1, one_half, 0.5f); |
412 | } |
413 | |
414 | template <gpu_gen_t hw> |
415 | void jit_eltwise_injector_f32<hw>::tanh_prepare_fwd_compat() { |
416 | auto k = scratch_[0].f(4); |
417 | auto l = scratch_[0].f(5); |
418 | auto m = scratch_[0].f(6); |
419 | auto n = scratch_[0].f(7); |
420 | h->mov(1, k, 77.0954f); // 77.095392909578f |
421 | h->mov(1, l, -4435.55f); // -4435.54623970169f |
422 | h->mov(1, m, 17.06396f); // 17.06396485f |
423 | h->mov(1, n, -212.7724f); // -212.772646402036f |
424 | } |
425 | |
426 | template <gpu_gen_t hw> |
427 | void jit_eltwise_injector_f32<hw>::abs_compute_bwd( |
428 | int simd, const ngen::GRF &r, int phase) { |
429 | auto neg_one = scratch_[0].f(0); |
430 | auto pos_one = scratch_[0].f(4); |
431 | switch (phase) { |
432 | case 0: h->csel(simd | lt | f0[0], r, neg_one, r, r); break; |
433 | case 1: h->csel(simd | gt | f0[0], r, pos_one, r, r); break; |
434 | default: break; |
435 | } |
436 | } |
437 | |
438 | template <gpu_gen_t hw> |
439 | void jit_eltwise_injector_f32<hw>::square_compute_bwd( |
440 | int simd, const ngen::GRF &r) { |
441 | h->add(simd, r, r, r); |
442 | } |
443 | |
444 | template <gpu_gen_t hw> |
445 | void jit_eltwise_injector_f32<hw>::linear_compute_bwd( |
446 | int simd, const ngen::GRF &r) { |
447 | h->mov(simd, r, alpha_); |
448 | } |
449 | |
450 | template <gpu_gen_t hw> |
451 | void jit_eltwise_injector_f32<hw>::clip_compute_bwd( |
452 | int simd, const ngen::GRF &r, int phase, float alpha, float beta) { |
453 | auto zero = scratch_[0].f(0); |
454 | auto one = scratch_[0].f(1); |
455 | auto pos_inf = scratch_[0].f(2); |
456 | switch (phase) { |
457 | // r[i] = r[i] - alpha |
458 | case 0: h->add(simd, r, r, -alpha); break; |
459 | // r[i] <= 0 => r[i] = infinity |
460 | case 1: h->csel(simd | le | f0[0], r, pos_inf, r, r); break; |
461 | // r[i] = (r[i] + alpha) - beta |
462 | case 2: h->add(simd, r, r, alpha - beta); break; |
463 | // r[i] = (r[i] <= 0 ? 1 : 0) |
464 | case 3: h->csel(simd | le | f0[0], r, one, zero, r); break; |
465 | default: assert(!"invalid phase" ); |
466 | } |
467 | } |
468 | |
469 | template <gpu_gen_t hw> |
470 | void jit_eltwise_injector_f32<hw>::gelu_tanh_compute_bwd( |
471 | int simd, const ngen::GRF &r, int phase, int off, int batch) { |
472 | |
473 | const float k = 0.044715f; |
474 | const float sqrt_2_over_pi = 0.7978845f; // sqrt(2/pi) |
475 | const float log2e = 1.442695f; // log_2(e) |
476 | |
477 | int msimd = simd; |
478 | if (hw == gpu_xe_hp) msimd = 16; |
479 | |
480 | auto a = scratch_[off].f(); |
481 | auto b = scratch_[off + batch].f(); |
482 | switch (phase) { |
483 | case 0: h->mul(simd, a, r, r); break; |
484 | case 1: h->mul(simd, b, a, 3.0f * k); break; |
485 | case 2: h->mul(simd, a, a, k); break; |
486 | case 3: h->mad(simd, a, r, a, r); break; |
487 | case 4: h->mad(simd, b, r, b, r); break; |
488 | case 5: h->mul(simd, a, a, -2 * sqrt_2_over_pi * log2e); break; |
489 | case 6: h->mul(simd, b, b, 2 * sqrt_2_over_pi); break; |
490 | case 7: h->exp(msimd, a, a); break; |
491 | case 8: h->add(simd, r, a, 1.0f); break; |
492 | case 9: h->inv(msimd, r, r); break; |
493 | case 10: h->mul(simd, a, a, r); break; |
494 | case 11: h->mul(simd, a, a, b); break; |
495 | case 12: h->add(simd, a, a, 1.0f); break; |
496 | case 13: h->mul(simd, r, r, a); break; |
497 | default: assert(!"invalid phase" ); |
498 | } |
499 | } |
500 | |
501 | template <gpu_gen_t hw> |
502 | void jit_eltwise_injector_f32<hw>::elu_compute_fwd( |
503 | int simd, const ngen::GRF &r, int phase, int off) { |
504 | auto temp = scratch_[off].f(); |
505 | const float log2e = 1.442695f; // log_2(e) |
506 | switch (phase) { |
507 | case 0: h->mul(simd, temp, r, log2e); break; |
508 | case 1: h->exp(simd, temp, temp); break; |
509 | case 2: h->add(simd, temp, temp, -1.f); break; |
510 | case 3: h->mul(simd, temp, temp, alpha_); break; |
511 | case 4: h->csel(simd | le | f0[0], r, temp, r, r); break; |
512 | default: assert(!"invalid phase" ); |
513 | } |
514 | } |
515 | |
516 | template <gpu_gen_t hw> |
517 | void jit_eltwise_injector_f32<hw>::exp_compute_fwd( |
518 | int simd, const ngen::GRF &r, int phase) { |
519 | const float log2e = 1.442695f; // log_2(e) |
520 | switch (phase) { |
521 | case 0: h->mul(simd, r, r, log2e); break; |
522 | case 1: h->exp(simd, r, r); break; |
523 | default: assert(!"invalid phase" ); |
524 | } |
525 | } |
526 | |
527 | template <gpu_gen_t hw> |
528 | void jit_eltwise_injector_f32<hw>::gelu_erf_compute_fwd( |
529 | int simd, const ngen::GRF &r, int phase, int off, int batch) { |
530 | auto temp = scratch_[off].f(); |
531 | auto at_accum = scratch_[off + batch].f(); |
532 | auto tpow = scratch_[off + 2 * batch].f(); |
533 | auto temp2 = scratch_[off + 3 * batch].f(); |
534 | const float log2e = 1.442695f; // log_2(e) |
535 | const float reciproc_sqrt_2 = 0.707106769084930419921875f; // 1/sqrt(2) |
536 | const float p = 0.3275911f; |
537 | const float a1 = 0.254829592f; |
538 | const float a2 = -0.284496736f; |
539 | const float a3 = 1.421413741f; |
540 | const float a4 = -1.453152027f; |
541 | const float a5 = 1.061405429f; |
542 | switch (phase) { |
543 | case 0: h->mul(simd, temp, abs(r), reciproc_sqrt_2); break; |
544 | case 1: h->mul(simd, temp, temp, p); break; |
545 | case 2: h->add(simd, temp, temp, 1.f); break; |
546 | case 3: h->inv(simd, temp, temp); break; |
547 | case 4: h->mul(simd, at_accum, temp, a1); break; |
548 | case 5: h->mul(simd, tpow, temp, temp); break; |
549 | case 6: h->mul(simd, temp2, tpow, a2); break; |
550 | case 7: h->add(simd, at_accum, temp2, at_accum); break; |
551 | case 8: h->mul(simd, tpow, tpow, temp); break; |
552 | case 9: h->mul(simd, temp2, tpow, a3); break; |
553 | case 10: h->add(simd, at_accum, temp2, at_accum); break; |
554 | case 11: h->mul(simd, tpow, tpow, temp); break; |
555 | case 12: h->mul(simd, temp2, tpow, a4); break; |
556 | case 13: h->add(simd, at_accum, temp2, at_accum); break; |
557 | case 14: h->mul(simd, tpow, tpow, temp); break; |
558 | case 15: h->mul(simd, temp2, tpow, a5); break; |
559 | case 16: h->add(simd, at_accum, temp2, at_accum); break; |
560 | case 17: h->mul(simd, temp, r, r); break; |
561 | case 18: h->mul(simd, temp, temp, -log2e * 0.5f); break; |
562 | case 19: h->exp(simd, temp, temp); break; |
563 | case 20: h->mul(simd, temp, temp, at_accum); break; |
564 | case 21: h->mul(simd, temp, temp, r); break; |
565 | case 22: h->mul(simd, temp, temp, 0.5f); break; |
566 | case 23: h->add(simd, temp2, r, -temp); break; |
567 | case 24: h->csel(simd | le | f0[0], r, temp, temp2, r); break; |
568 | default: assert(!"invalid phase" ); |
569 | } |
570 | } |
571 | |
572 | template <gpu_gen_t hw> |
573 | void jit_eltwise_injector_f32<hw>::hardsigmoid_compute_fwd( |
574 | int simd, const ngen::GRF &r, int phase, int off) { |
575 | switch (phase) { |
576 | case 0: h->mul(simd, r, r, alpha_); break; |
577 | case 1: h->add(simd, r, r, beta_); break; |
578 | case 2: h->min_(simd, r, r, 1.f); break; |
579 | case 3: h->max_(simd, r, r, 0.f); break; |
580 | default: assert(!"invalid phase" ); |
581 | } |
582 | } |
583 | |
584 | template <gpu_gen_t hw> |
585 | void jit_eltwise_injector_f32<hw>::hardswish_compute_fwd( |
586 | int simd, const ngen::GRF &r, int phase, int off) { |
587 | auto temp = scratch_[off].f(); |
588 | switch (phase) { |
589 | case 0: h->mul(simd, temp, r, alpha_); break; |
590 | case 1: h->add(simd, temp, temp, beta_); break; |
591 | case 2: h->min_(simd, temp, temp, 1.f); break; |
592 | case 3: h->max_(simd, temp, temp, 0.f); break; |
593 | case 4: h->mul(simd, r, r, temp); break; |
594 | default: assert(!"invalid phase" ); |
595 | } |
596 | } |
597 | |
598 | template <gpu_gen_t hw> |
599 | void jit_eltwise_injector_f32<hw>::log_compute_fwd( |
600 | int simd, const ngen::GRF &r, int phase) { |
601 | const float reciproc_log2e = 1.f / 1.442695f; // 1 / log_2(e) |
602 | switch (phase) { |
603 | case 0: h->log(simd, r, r); break; |
604 | case 1: h->mul(simd, r, r, reciproc_log2e); break; |
605 | default: assert(!"invalid phase" ); |
606 | } |
607 | } |
608 | |
609 | template <gpu_gen_t hw> |
610 | void jit_eltwise_injector_f32<hw>::mish_compute_fwd( |
611 | int simd, const ngen::GRF &r, int phase, int off, int batch) { |
612 | auto temp = scratch_[off + batch].f(); |
613 | auto temp2 = scratch_[off + 2 * batch].f(); |
614 | const int srelu_phases = phase_count(alg_kind::eltwise_soft_relu); |
615 | const int tanh_phases = phase_count(alg_kind::eltwise_tanh); |
616 | // note tanh_compute_fwd_* clobbers scratch_[off] and scratch_[off + batch] |
617 | if (phase < srelu_phases) |
618 | soft_relu_compute_fwd_inner(simd, r, temp, temp2, phase, off, 1.f); |
619 | if (phase >= srelu_phases && phase < srelu_phases + tanh_phases) { |
620 | if (use_tanh_compat()) |
621 | tanh_compute_fwd_compat( |
622 | simd, temp2, phase - srelu_phases, off, batch); |
623 | else |
624 | tanh_compute_fwd(simd, temp2, phase - srelu_phases, off, batch); |
625 | } |
626 | if (phase == srelu_phases + tanh_phases) h->mul(simd, r, r, temp2); |
627 | if (phase > srelu_phases + tanh_phases) assert(!"invalid phase" ); |
628 | } |
629 | |
630 | template <gpu_gen_t hw> |
631 | void jit_eltwise_injector_f32<hw>::pow_compute_fwd( |
632 | int simd, const ngen::GRF &r, int phase, int off) { |
633 | auto temp = scratch_[off].f(); |
634 | switch (phase) { |
635 | case 0: |
636 | if ((long long int)beta_ == beta_) { |
637 | h->mov(simd, temp, abs(r)); |
638 | } else { |
639 | h->mov(simd, temp, r); |
640 | } |
641 | break; |
642 | case 1: h->log(simd, temp, temp); break; |
643 | case 2: h->mul(simd, temp, temp, beta_); break; |
644 | case 3: h->exp(simd, temp, temp); break; |
645 | case 4: |
646 | if (((long long int)beta_) & 0x1) |
647 | h->csel(simd | lt | f0[0], temp, -temp, temp, r); |
648 | break; |
649 | case 5: h->mul(simd, r, temp, alpha_); break; |
650 | default: assert(!"invalid phase" ); |
651 | } |
652 | } |
653 | |
654 | template <gpu_gen_t hw> |
655 | void jit_eltwise_injector_f32<hw>::compute(const ngen::GRFRange ®s) { |
656 | using namespace alg_kind; |
657 | |
658 | auto bmax = max_batch_size(); |
659 | auto phases = phase_count(alg_); |
660 | |
661 | for (int idx0 = 0; idx0 < regs.getLen(); idx0 += bmax) { |
662 | auto batch = nstl::min(regs.getLen() - idx0, bmax); |
663 | |
664 | for (int phase = 0; phase < phases; phase++) { |
665 | for (int ii = 0; ii < batch; ii += 2) { |
666 | int nreg = nstl::min(2, batch - ii); |
667 | int simd = nreg * GRF::bytes(hw) / sizeof(float); |
668 | auto base = regs[idx0 + ii].f(); |
669 | |
670 | if (is_fwd_) { |
671 | switch (alg_) { |
672 | case eltwise_elu: |
673 | case eltwise_elu_use_dst_for_bwd: |
674 | elu_compute_fwd(simd, base, phase, ii); |
675 | break; |
676 | case eltwise_exp: |
677 | case eltwise_exp_use_dst_for_bwd: |
678 | exp_compute_fwd(simd, base, phase); |
679 | break; |
680 | case eltwise_gelu_erf: |
681 | gelu_erf_compute_fwd(simd, base, phase, ii, batch); |
682 | break; |
683 | case eltwise_hardsigmoid: |
684 | hardsigmoid_compute_fwd(simd, base, phase, ii); |
685 | break; |
686 | case eltwise_hardswish: |
687 | hardswish_compute_fwd(simd, base, phase, ii); |
688 | break; |
689 | case eltwise_log: |
690 | log_compute_fwd(simd, base, phase); |
691 | break; |
692 | case eltwise_mish: |
693 | mish_compute_fwd(simd, base, phase, ii, batch); |
694 | break; |
695 | case eltwise_pow: |
696 | pow_compute_fwd(simd, base, phase, ii); |
697 | break; |
698 | case eltwise_relu: |
699 | case eltwise_relu_use_dst_for_bwd: |
700 | if (alpha_ == 0.f) |
701 | relu_zero_ns_compute_fwd(simd, base); |
702 | else |
703 | relu_compute_fwd(simd, base, phase, ii); |
704 | break; |
705 | case eltwise_abs: abs_compute_fwd(simd, base); break; |
706 | case eltwise_soft_relu: |
707 | soft_relu_compute_fwd(simd, base, phase, ii); |
708 | break; |
709 | case eltwise_sqrt: |
710 | case eltwise_sqrt_use_dst_for_bwd: |
711 | sqrt_compute_fwd(simd, base); |
712 | break; |
713 | case eltwise_square: |
714 | square_compute_fwd(simd, base); |
715 | break; |
716 | case eltwise_tanh: |
717 | case eltwise_tanh_use_dst_for_bwd: |
718 | if (use_tanh_compat()) |
719 | tanh_compute_fwd_compat( |
720 | simd, base, phase, ii, batch); |
721 | else |
722 | tanh_compute_fwd(simd, base, phase, ii, batch); |
723 | break; |
724 | case eltwise_round: |
725 | round_compute_fwd(simd, base); |
726 | break; |
727 | case eltwise_swish: |
728 | swish_compute_fwd(simd, base, phase, ii); |
729 | break; |
730 | case eltwise_linear: |
731 | linear_compute_fwd(simd, base, phase); |
732 | break; |
733 | case eltwise_clip: |
734 | case eltwise_clip_v2: |
735 | case eltwise_clip_v2_use_dst_for_bwd: |
736 | clip_compute_fwd(simd, base, phase, alpha_, beta_); |
737 | break; |
738 | case eltwise_gelu_tanh: |
739 | gelu_tanh_compute_fwd(simd, base, phase, ii); |
740 | break; |
741 | case eltwise_logistic: |
742 | case eltwise_logistic_use_dst_for_bwd: |
743 | logistic_compute_fwd(simd, base, phase); |
744 | break; |
745 | default: assert(!"unsupported eltwise algorithm" ); |
746 | } |
747 | } else { |
748 | switch (alg_) { |
749 | case eltwise_relu: relu_compute_bwd(simd, base); break; |
750 | case eltwise_abs: |
751 | abs_compute_bwd(simd, base, phase); |
752 | break; |
753 | case eltwise_square: |
754 | square_compute_bwd(simd, base); |
755 | break; |
756 | case eltwise_linear: |
757 | linear_compute_bwd(simd, base); |
758 | break; |
759 | case eltwise_clip: |
760 | clip_compute_bwd(simd, base, phase, alpha_, beta_); |
761 | break; |
762 | case eltwise_gelu_tanh: |
763 | gelu_tanh_compute_bwd(simd, base, phase, ii, batch); |
764 | break; |
765 | default: assert(!"unsupported eltwise algorithm" ); |
766 | } |
767 | } |
768 | // Apply scale. |
769 | if (phase == phases - 1 && scale_ != 1.f) { |
770 | h->mul(simd, base, base, scale_); |
771 | } |
772 | } |
773 | } |
774 | } |
775 | } |
776 | |
777 | template <gpu_gen_t hw> |
778 | void jit_eltwise_injector_f32<hw>::prepare() { |
779 | using namespace alg_kind; |
780 | |
781 | assert(scratch_.getLen() >= min_scratch_regs()); |
782 | |
783 | if (is_fwd_) { |
784 | switch (alg_) { |
785 | case eltwise_mish: |
786 | case eltwise_tanh: |
787 | if (use_tanh_compat()) |
788 | tanh_prepare_fwd_compat(); |
789 | else |
790 | tanh_prepare_fwd(); |
791 | break; |
792 | default: break; |
793 | } |
794 | } else { |
795 | switch (alg_) { |
796 | case eltwise_relu: relu_prepare_bwd(); break; |
797 | case eltwise_abs: abs_prepare_bwd(); break; |
798 | case eltwise_clip: clip_prepare_bwd(); break; |
799 | default: break; |
800 | } |
801 | } |
802 | } |
803 | |
804 | REG_GEN9_ISA(template struct jit_eltwise_injector_f32<gpu_gen9>); |
805 | REG_GEN11_ISA(template struct jit_eltwise_injector_f32<gpu_gen11>); |
806 | REG_XELP_ISA(template struct jit_eltwise_injector_f32<gpu_xe_lp>); |
807 | REG_XEHP_ISA(template struct jit_eltwise_injector_f32<gpu_xe_hp>); |
808 | REG_XEHPG_ISA(template struct jit_eltwise_injector_f32<gpu_xe_hpg>); |
809 | REG_XEHPC_ISA(template struct jit_eltwise_injector_f32<gpu_xe_hpc>); |
810 | |
811 | } // namespace jit |
812 | } // namespace gpu |
813 | } // namespace impl |
814 | } // namespace dnnl |
815 | |