1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "gpu/jit/jit_eltwise_injector.hpp"
18#include "common/impl_registration.hpp"
19
20#include <limits>
21
22namespace dnnl {
23namespace impl {
24namespace gpu {
25namespace jit {
26
27using namespace ngen;
28
29template <gpu_gen_t hw>
30int jit_eltwise_injector_f32<hw>::min_scratch_regs() {
31 using namespace alg_kind;
32 if (is_fwd_) {
33 switch (alg_) {
34 case eltwise_elu:
35 case eltwise_elu_use_dst_for_bwd: return 1;
36 case eltwise_exp:
37 case eltwise_exp_use_dst_for_bwd: return 0;
38 case eltwise_gelu_erf: return 4;
39 case eltwise_hardsigmoid: return 0;
40 case eltwise_hardswish: return 1;
41 case eltwise_log: return 0;
42 case eltwise_mish: return 4;
43 case eltwise_pow: return 1;
44 case eltwise_relu:
45 case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0.f) ? 0 : 1;
46 case eltwise_abs: return 0;
47 case eltwise_soft_relu: return 1;
48 case eltwise_sqrt:
49 case eltwise_sqrt_use_dst_for_bwd: return 0;
50 case eltwise_square: return 0;
51 case eltwise_swish: return 1;
52 case eltwise_tanh:
53 case eltwise_tanh_use_dst_for_bwd: return 2;
54 case eltwise_round: return 0;
55 case eltwise_linear: return 0;
56 case eltwise_clip:
57 case eltwise_clip_v2:
58 case eltwise_clip_v2_use_dst_for_bwd: return 0;
59 case eltwise_gelu_tanh: return 2;
60 case eltwise_logistic:
61 case eltwise_logistic_use_dst_for_bwd: return 0;
62 default: assert(!"unsupported eltwise algorithm");
63 }
64 } else {
65 switch (alg_) {
66 case eltwise_relu: return 1;
67 case eltwise_abs: return 1;
68 case eltwise_square: return 0;
69 case eltwise_linear: return 0;
70 case eltwise_clip: return 1;
71 case eltwise_gelu_tanh: return 2;
72 default: assert(!"unsupported eltwise algorithm");
73 }
74 }
75 return 0;
76}
77
78template <gpu_gen_t hw>
79int jit_eltwise_injector_f32<hw>::preferred_scratch_regs() {
80 using namespace alg_kind;
81 if (is_fwd_) {
82 switch (alg_) {
83 case eltwise_elu:
84 case eltwise_elu_use_dst_for_bwd: return 8;
85 case eltwise_gelu_erf: return 8;
86 case eltwise_hardswish: return 8;
87 case eltwise_mish: return 8;
88 case eltwise_relu:
89 case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0.f) ? 0 : 8;
90 case eltwise_tanh: return 8;
91 case eltwise_gelu_tanh: return 8;
92 case eltwise_soft_relu: return 8;
93 case eltwise_swish: return 8;
94 default: break;
95 }
96 } else {
97 switch (alg_) {
98 case eltwise_gelu_tanh: return 8;
99 default: break;
100 }
101 }
102 return min_scratch_regs();
103}
104
105template <gpu_gen_t hw>
106int jit_eltwise_injector_f32<hw>::max_batch_size() {
107 using namespace alg_kind;
108 auto ss = scratch_.getLen();
109
110 if (is_fwd_) {
111 switch (alg_) {
112 case eltwise_relu:
113 case eltwise_relu_use_dst_for_bwd:
114 if (alpha_ == 0.)
115 break;
116 else
117 return ss;
118 case eltwise_elu:
119 case eltwise_elu_use_dst_for_bwd:
120 case eltwise_hardswish:
121 case eltwise_pow:
122 case eltwise_soft_relu:
123 case eltwise_swish: return ss;
124 case eltwise_tanh:
125 case eltwise_mish:
126 case eltwise_gelu_erf: return ss / min_scratch_regs();
127 case eltwise_gelu_tanh: return ss & ~1;
128 default: break;
129 }
130 } else {
131 switch (alg_) {
132 case eltwise_gelu_tanh: return ss / 2;
133 default: break;
134 }
135 }
136
137 return 128;
138}
139
140template <gpu_gen_t hw>
141int jit_eltwise_injector_f32<hw>::phase_count(alg_kind_t alg) {
142 using namespace alg_kind;
143
144 if (is_fwd_) {
145 switch (alg) {
146 case eltwise_elu:
147 case eltwise_elu_use_dst_for_bwd: return 5;
148 case eltwise_exp:
149 case eltwise_exp_use_dst_for_bwd: return 2;
150 case eltwise_gelu_erf: return 25;
151 case eltwise_hardsigmoid: return 4;
152 case eltwise_hardswish: return 5;
153 case eltwise_log: return 2;
154 case eltwise_mish:
155 return phase_count(alg_kind::eltwise_soft_relu)
156 + phase_count(alg_kind::eltwise_tanh) + 1;
157 case eltwise_pow: return 6;
158 case eltwise_relu:
159 case eltwise_relu_use_dst_for_bwd: return (alpha_ == 0) ? 1 : 2;
160 case eltwise_soft_relu: return 10;
161 case eltwise_swish: return 5;
162 case eltwise_tanh:
163 case eltwise_tanh_use_dst_for_bwd:
164 return (use_tanh_compat()) ? 9 : 6;
165 case eltwise_linear: return 2;
166 case eltwise_clip:
167 case eltwise_clip_v2:
168 case eltwise_clip_v2_use_dst_for_bwd: return 2;
169 case eltwise_gelu_tanh: return 8;
170 case eltwise_logistic:
171 case eltwise_logistic_use_dst_for_bwd: return 4;
172 default: break;
173 }
174 } else {
175 switch (alg) {
176 case eltwise_abs: return 2;
177 case eltwise_clip: return 4;
178 case eltwise_gelu_tanh: return 14;
179 default: break;
180 }
181 }
182
183 return 1;
184}
185
186template <gpu_gen_t hw>
187void jit_eltwise_injector_f32<hw>::relu_zero_ns_compute_fwd(
188 int simd, const ngen::GRF &r) {
189 h->max_(simd, r, r, 0.f);
190}
191
192template <gpu_gen_t hw>
193void jit_eltwise_injector_f32<hw>::relu_compute_fwd(
194 int simd, const ngen::GRF &r, int phase, int off) {
195 auto temp = scratch_[off].f();
196 switch (phase) {
197 case 0: h->mul(simd, temp, r, alpha_); break;
198 case 1: h->csel(simd | le | f0[0], r, temp, r, r); break;
199 default: assert(!"invalid phase");
200 }
201}
202
203template <gpu_gen_t hw>
204void jit_eltwise_injector_f32<hw>::abs_compute_fwd(
205 int simd, const ngen::GRF &r) {
206 h->mov(simd, r, abs(r));
207}
208
209template <gpu_gen_t hw>
210void jit_eltwise_injector_f32<hw>::soft_relu_compute_fwd_inner(int simd,
211 const ngen::GRF &input, const ngen::GRF &temp, const ngen::GRF &dest,
212 int phase, int off, float alpha) {
213 const float exp_overflow_bound = 88.72283172607421875;
214 const float log2e = 1.44269502162933349609375f;
215 const float reciproc_log2e = 1.f / log2e; // 1 / log_2(e)
216 switch (phase) {
217 case 0: h->mul(simd, temp, input, alpha); break;
218 case 1: h->add(simd, dest, input, -exp_overflow_bound); break;
219 case 2: h->csel(simd | le | f0[0], dest, dest, temp, dest); break;
220 case 3: h->mul(simd, temp, temp, log2e); break;
221 case 4: h->eexp(simd, temp, temp); break;
222 case 5: h->add(simd, temp, temp, 1.f); break;
223 case 6: h->log(simd, temp, temp); break;
224 case 7: h->mul(simd, temp, temp, reciproc_log2e); break;
225 case 8: h->csel(simd | le | f0[0], temp, temp, dest, dest); break;
226 case 9: h->mul(simd, dest, temp, 1.f / alpha); break;
227 default: assert(!"invalid phase");
228 }
229}
230
231template <gpu_gen_t hw>
232void jit_eltwise_injector_f32<hw>::soft_relu_compute_fwd(
233 int simd, const ngen::GRF &r, int phase, int off) {
234 auto temp = scratch_[off].f();
235 soft_relu_compute_fwd_inner(simd, r, temp, r, phase, off, alpha_);
236}
237
238template <gpu_gen_t hw>
239void jit_eltwise_injector_f32<hw>::sqrt_compute_fwd(
240 int simd, const ngen::GRF &r) {
241 h->sqt(simd, r, r);
242}
243
244template <gpu_gen_t hw>
245void jit_eltwise_injector_f32<hw>::square_compute_fwd(
246 int simd, const ngen::GRF &r) {
247 h->mul(simd, r, r, r);
248}
249
250template <gpu_gen_t hw>
251void jit_eltwise_injector_f32<hw>::tanh_compute_fwd(
252 int simd, const ngen::GRF &r, int phase, int off, int batch) {
253 const float log2e = 1.44269502162933349609375f; // log_2(e)
254 auto one_half = scratch_[0].f(7);
255 auto a = scratch_[off + batch].f();
256 switch (phase) {
257 case 0: h->mul(simd, a, abs(r), 2.f * log2e); break;
258 case 1: h->exp(simd, a, a); break;
259 case 2: h->mad(simd, a, one_half, a, one_half); break;
260 case 3: h->inv(simd, a, a); break;
261 case 4: h->add(simd, a, -a, 1.f); break;
262 case 5: h->csel(simd | ge | f0[0], r, a, -a, r); break;
263 default: assert(!"invalid phase");
264 }
265}
266
267template <gpu_gen_t hw>
268void jit_eltwise_injector_f32<hw>::tanh_compute_fwd_compat(
269 int simd, const ngen::GRF &r, int phase, int off, int batch) {
270 // This approximation of tanh(x) does not use the math.exp instruction
271 // that seems to be faulty on DG2-128; the exact formula is as follows:
272 // R = max(min(0.0519867*x*((x^2 + k)^2 + l)/((x^2 + m)^2 + n), 1), -1)
273 // Both absolute and relative errors are <7*10^-5 \forall x \in \mathbb R
274 auto k = scratch_[0].f(4);
275 auto l = scratch_[0].f(5);
276 auto m = scratch_[0].f(6);
277 auto n = scratch_[0].f(7);
278 auto a = scratch_[off + batch].f();
279 switch (phase) {
280 case 0: h->mad(simd, a, m, r, r); break;
281 case 1: h->mad(simd, a, n, a, a); break;
282 case 2: h->inv(simd, a, a); break;
283 case 3: h->mul(simd, a, a, r); break;
284 case 4: h->mad(simd, r, k, r, r); break;
285 case 5: h->mad(simd, r, l, r, r); break;
286 case 6: h->mul(simd, r, r, 0.0519867f); break; // 0.051986694f
287 case 7: h->mul(simd | sat, r, r, abs(a)); break;
288 case 8: h->csel(simd | ge | f0[0], r, r, -r, a); break;
289 default: assert(!"invalid phase");
290 }
291}
292
293template <gpu_gen_t hw>
294void jit_eltwise_injector_f32<hw>::round_compute_fwd(
295 int simd, const ngen::GRF &r) {
296 h->rnde(simd, r, r);
297}
298
299template <gpu_gen_t hw>
300void jit_eltwise_injector_f32<hw>::swish_compute_fwd(
301 int simd, const ngen::GRF &r, int phase, int off) {
302 const float log2e = 1.442695f; // log_2(e)
303 auto temp = scratch_[off].f();
304 switch (phase) {
305 case 0: h->mul(simd, temp, r, -1.f * log2e * alpha_); break;
306 case 1: h->exp(simd, temp, temp); break;
307 case 2: h->add(simd, temp, temp, 1.f); break;
308 case 3: h->inv(simd, temp, temp); break;
309 case 4: h->mul(simd, r, r, temp); break;
310 default: assert(!"invalid phase");
311 }
312}
313
314template <gpu_gen_t hw>
315void jit_eltwise_injector_f32<hw>::linear_compute_fwd(
316 int simd, const ngen::GRF &r, int phase) {
317 switch (phase) {
318 case 0: h->mul(simd, r, r, alpha_); break;
319 case 1: h->add(simd, r, r, beta_); break;
320 default: assert(!"invalid phase");
321 }
322}
323
324template <gpu_gen_t hw>
325void jit_eltwise_injector_f32<hw>::clip_compute_fwd(
326 int simd, const ngen::GRF &r, int phase, float alpha, float beta) {
327 switch (phase) {
328 case 0: h->max_(simd, r, r, alpha); break;
329 case 1: h->min_(simd, r, r, beta); break;
330 default: assert(!"invalid phase");
331 }
332}
333
334template <gpu_gen_t hw>
335void jit_eltwise_injector_f32<hw>::gelu_tanh_compute_fwd(
336 int simd, const ngen::GRF &r, int phase, int off) {
337
338 const float k = 0.044715f;
339 const float sqrt_2_over_pi = 0.7978845f; // sqrt(2/pi)
340 const float log2e = 1.442695f; // log_2(e)
341
342 int msimd = simd;
343 if (hw == gpu_xe_hp)
344 msimd = 16; // workaround for intermittent hang with DPAS+EM
345
346 auto a = scratch_[off].f();
347 switch (phase) {
348 case 0: h->mul(simd, a, r, r); break;
349 case 1: h->mul(simd, a, a, k); break;
350 case 2: h->mad(simd, a, r, a, r); break;
351 case 3: h->mul(simd, a, a, -2 * sqrt_2_over_pi * log2e); break;
352 case 4: h->exp(msimd, a, a); break;
353 case 5: h->add(simd, a, a, 1.0f); break;
354 case 6: h->inv(msimd, a, a); break;
355 case 7: h->mul(simd, r, a, r); break;
356 default: assert(!"invalid phase");
357 }
358}
359
360template <gpu_gen_t hw>
361void jit_eltwise_injector_f32<hw>::logistic_compute_fwd(
362 int simd, const ngen::GRF &r, int phase) {
363 const float log2e = 1.442695f; // log_2(e)
364 switch (phase) {
365 case 0: h->mul(simd, r, r, -1.f * log2e); break;
366 case 1: h->exp(simd, r, r); break;
367 case 2: h->add(simd, r, r, 1.f); break;
368 case 3: h->inv(simd, r, r); break;
369 default: assert(!"invalid phase");
370 }
371}
372
373template <gpu_gen_t hw>
374void jit_eltwise_injector_f32<hw>::relu_prepare_bwd() {
375 auto neg_slope = scratch_[0].f(0);
376 auto pos_slope = scratch_[0].f(4);
377 h->mov(1, neg_slope, alpha_);
378 h->mov(1, pos_slope, 1.f);
379}
380
381template <gpu_gen_t hw>
382void jit_eltwise_injector_f32<hw>::relu_compute_bwd(
383 int simd, const ngen::GRF &r) {
384 auto neg_slope = scratch_[0].f(0);
385 auto pos_slope = scratch_[0].f(4);
386 h->csel(simd | le | f0[0], r, neg_slope, pos_slope, r);
387}
388
389template <gpu_gen_t hw>
390void jit_eltwise_injector_f32<hw>::abs_prepare_bwd() {
391 auto neg_one = scratch_[0].f(0);
392 auto pos_one = scratch_[0].f(4);
393 h->mov(1, neg_one, -1.f);
394 h->mov(1, pos_one, 1.f);
395}
396
397template <gpu_gen_t hw>
398void jit_eltwise_injector_f32<hw>::clip_prepare_bwd() {
399 auto pos_inf_imm = Immediate(std::numeric_limits<float>::infinity());
400 auto zero = scratch_[0].f(0);
401 auto one = scratch_[0].f(1);
402 auto pos_inf = scratch_[0].f(2);
403 h->mov(1, zero, 0.f);
404 h->mov(1, one, 1.f);
405 h->mov(1, pos_inf, pos_inf_imm);
406}
407
408template <gpu_gen_t hw>
409void jit_eltwise_injector_f32<hw>::tanh_prepare_fwd() {
410 auto one_half = scratch_[0].f(7);
411 h->mov(1, one_half, 0.5f);
412}
413
414template <gpu_gen_t hw>
415void jit_eltwise_injector_f32<hw>::tanh_prepare_fwd_compat() {
416 auto k = scratch_[0].f(4);
417 auto l = scratch_[0].f(5);
418 auto m = scratch_[0].f(6);
419 auto n = scratch_[0].f(7);
420 h->mov(1, k, 77.0954f); // 77.095392909578f
421 h->mov(1, l, -4435.55f); // -4435.54623970169f
422 h->mov(1, m, 17.06396f); // 17.06396485f
423 h->mov(1, n, -212.7724f); // -212.772646402036f
424}
425
426template <gpu_gen_t hw>
427void jit_eltwise_injector_f32<hw>::abs_compute_bwd(
428 int simd, const ngen::GRF &r, int phase) {
429 auto neg_one = scratch_[0].f(0);
430 auto pos_one = scratch_[0].f(4);
431 switch (phase) {
432 case 0: h->csel(simd | lt | f0[0], r, neg_one, r, r); break;
433 case 1: h->csel(simd | gt | f0[0], r, pos_one, r, r); break;
434 default: break;
435 }
436}
437
438template <gpu_gen_t hw>
439void jit_eltwise_injector_f32<hw>::square_compute_bwd(
440 int simd, const ngen::GRF &r) {
441 h->add(simd, r, r, r);
442}
443
444template <gpu_gen_t hw>
445void jit_eltwise_injector_f32<hw>::linear_compute_bwd(
446 int simd, const ngen::GRF &r) {
447 h->mov(simd, r, alpha_);
448}
449
450template <gpu_gen_t hw>
451void jit_eltwise_injector_f32<hw>::clip_compute_bwd(
452 int simd, const ngen::GRF &r, int phase, float alpha, float beta) {
453 auto zero = scratch_[0].f(0);
454 auto one = scratch_[0].f(1);
455 auto pos_inf = scratch_[0].f(2);
456 switch (phase) {
457 // r[i] = r[i] - alpha
458 case 0: h->add(simd, r, r, -alpha); break;
459 // r[i] <= 0 => r[i] = infinity
460 case 1: h->csel(simd | le | f0[0], r, pos_inf, r, r); break;
461 // r[i] = (r[i] + alpha) - beta
462 case 2: h->add(simd, r, r, alpha - beta); break;
463 // r[i] = (r[i] <= 0 ? 1 : 0)
464 case 3: h->csel(simd | le | f0[0], r, one, zero, r); break;
465 default: assert(!"invalid phase");
466 }
467}
468
469template <gpu_gen_t hw>
470void jit_eltwise_injector_f32<hw>::gelu_tanh_compute_bwd(
471 int simd, const ngen::GRF &r, int phase, int off, int batch) {
472
473 const float k = 0.044715f;
474 const float sqrt_2_over_pi = 0.7978845f; // sqrt(2/pi)
475 const float log2e = 1.442695f; // log_2(e)
476
477 int msimd = simd;
478 if (hw == gpu_xe_hp) msimd = 16;
479
480 auto a = scratch_[off].f();
481 auto b = scratch_[off + batch].f();
482 switch (phase) {
483 case 0: h->mul(simd, a, r, r); break;
484 case 1: h->mul(simd, b, a, 3.0f * k); break;
485 case 2: h->mul(simd, a, a, k); break;
486 case 3: h->mad(simd, a, r, a, r); break;
487 case 4: h->mad(simd, b, r, b, r); break;
488 case 5: h->mul(simd, a, a, -2 * sqrt_2_over_pi * log2e); break;
489 case 6: h->mul(simd, b, b, 2 * sqrt_2_over_pi); break;
490 case 7: h->exp(msimd, a, a); break;
491 case 8: h->add(simd, r, a, 1.0f); break;
492 case 9: h->inv(msimd, r, r); break;
493 case 10: h->mul(simd, a, a, r); break;
494 case 11: h->mul(simd, a, a, b); break;
495 case 12: h->add(simd, a, a, 1.0f); break;
496 case 13: h->mul(simd, r, r, a); break;
497 default: assert(!"invalid phase");
498 }
499}
500
501template <gpu_gen_t hw>
502void jit_eltwise_injector_f32<hw>::elu_compute_fwd(
503 int simd, const ngen::GRF &r, int phase, int off) {
504 auto temp = scratch_[off].f();
505 const float log2e = 1.442695f; // log_2(e)
506 switch (phase) {
507 case 0: h->mul(simd, temp, r, log2e); break;
508 case 1: h->exp(simd, temp, temp); break;
509 case 2: h->add(simd, temp, temp, -1.f); break;
510 case 3: h->mul(simd, temp, temp, alpha_); break;
511 case 4: h->csel(simd | le | f0[0], r, temp, r, r); break;
512 default: assert(!"invalid phase");
513 }
514}
515
516template <gpu_gen_t hw>
517void jit_eltwise_injector_f32<hw>::exp_compute_fwd(
518 int simd, const ngen::GRF &r, int phase) {
519 const float log2e = 1.442695f; // log_2(e)
520 switch (phase) {
521 case 0: h->mul(simd, r, r, log2e); break;
522 case 1: h->exp(simd, r, r); break;
523 default: assert(!"invalid phase");
524 }
525}
526
527template <gpu_gen_t hw>
528void jit_eltwise_injector_f32<hw>::gelu_erf_compute_fwd(
529 int simd, const ngen::GRF &r, int phase, int off, int batch) {
530 auto temp = scratch_[off].f();
531 auto at_accum = scratch_[off + batch].f();
532 auto tpow = scratch_[off + 2 * batch].f();
533 auto temp2 = scratch_[off + 3 * batch].f();
534 const float log2e = 1.442695f; // log_2(e)
535 const float reciproc_sqrt_2 = 0.707106769084930419921875f; // 1/sqrt(2)
536 const float p = 0.3275911f;
537 const float a1 = 0.254829592f;
538 const float a2 = -0.284496736f;
539 const float a3 = 1.421413741f;
540 const float a4 = -1.453152027f;
541 const float a5 = 1.061405429f;
542 switch (phase) {
543 case 0: h->mul(simd, temp, abs(r), reciproc_sqrt_2); break;
544 case 1: h->mul(simd, temp, temp, p); break;
545 case 2: h->add(simd, temp, temp, 1.f); break;
546 case 3: h->inv(simd, temp, temp); break;
547 case 4: h->mul(simd, at_accum, temp, a1); break;
548 case 5: h->mul(simd, tpow, temp, temp); break;
549 case 6: h->mul(simd, temp2, tpow, a2); break;
550 case 7: h->add(simd, at_accum, temp2, at_accum); break;
551 case 8: h->mul(simd, tpow, tpow, temp); break;
552 case 9: h->mul(simd, temp2, tpow, a3); break;
553 case 10: h->add(simd, at_accum, temp2, at_accum); break;
554 case 11: h->mul(simd, tpow, tpow, temp); break;
555 case 12: h->mul(simd, temp2, tpow, a4); break;
556 case 13: h->add(simd, at_accum, temp2, at_accum); break;
557 case 14: h->mul(simd, tpow, tpow, temp); break;
558 case 15: h->mul(simd, temp2, tpow, a5); break;
559 case 16: h->add(simd, at_accum, temp2, at_accum); break;
560 case 17: h->mul(simd, temp, r, r); break;
561 case 18: h->mul(simd, temp, temp, -log2e * 0.5f); break;
562 case 19: h->exp(simd, temp, temp); break;
563 case 20: h->mul(simd, temp, temp, at_accum); break;
564 case 21: h->mul(simd, temp, temp, r); break;
565 case 22: h->mul(simd, temp, temp, 0.5f); break;
566 case 23: h->add(simd, temp2, r, -temp); break;
567 case 24: h->csel(simd | le | f0[0], r, temp, temp2, r); break;
568 default: assert(!"invalid phase");
569 }
570}
571
572template <gpu_gen_t hw>
573void jit_eltwise_injector_f32<hw>::hardsigmoid_compute_fwd(
574 int simd, const ngen::GRF &r, int phase, int off) {
575 switch (phase) {
576 case 0: h->mul(simd, r, r, alpha_); break;
577 case 1: h->add(simd, r, r, beta_); break;
578 case 2: h->min_(simd, r, r, 1.f); break;
579 case 3: h->max_(simd, r, r, 0.f); break;
580 default: assert(!"invalid phase");
581 }
582}
583
584template <gpu_gen_t hw>
585void jit_eltwise_injector_f32<hw>::hardswish_compute_fwd(
586 int simd, const ngen::GRF &r, int phase, int off) {
587 auto temp = scratch_[off].f();
588 switch (phase) {
589 case 0: h->mul(simd, temp, r, alpha_); break;
590 case 1: h->add(simd, temp, temp, beta_); break;
591 case 2: h->min_(simd, temp, temp, 1.f); break;
592 case 3: h->max_(simd, temp, temp, 0.f); break;
593 case 4: h->mul(simd, r, r, temp); break;
594 default: assert(!"invalid phase");
595 }
596}
597
598template <gpu_gen_t hw>
599void jit_eltwise_injector_f32<hw>::log_compute_fwd(
600 int simd, const ngen::GRF &r, int phase) {
601 const float reciproc_log2e = 1.f / 1.442695f; // 1 / log_2(e)
602 switch (phase) {
603 case 0: h->log(simd, r, r); break;
604 case 1: h->mul(simd, r, r, reciproc_log2e); break;
605 default: assert(!"invalid phase");
606 }
607}
608
609template <gpu_gen_t hw>
610void jit_eltwise_injector_f32<hw>::mish_compute_fwd(
611 int simd, const ngen::GRF &r, int phase, int off, int batch) {
612 auto temp = scratch_[off + batch].f();
613 auto temp2 = scratch_[off + 2 * batch].f();
614 const int srelu_phases = phase_count(alg_kind::eltwise_soft_relu);
615 const int tanh_phases = phase_count(alg_kind::eltwise_tanh);
616 // note tanh_compute_fwd_* clobbers scratch_[off] and scratch_[off + batch]
617 if (phase < srelu_phases)
618 soft_relu_compute_fwd_inner(simd, r, temp, temp2, phase, off, 1.f);
619 if (phase >= srelu_phases && phase < srelu_phases + tanh_phases) {
620 if (use_tanh_compat())
621 tanh_compute_fwd_compat(
622 simd, temp2, phase - srelu_phases, off, batch);
623 else
624 tanh_compute_fwd(simd, temp2, phase - srelu_phases, off, batch);
625 }
626 if (phase == srelu_phases + tanh_phases) h->mul(simd, r, r, temp2);
627 if (phase > srelu_phases + tanh_phases) assert(!"invalid phase");
628}
629
630template <gpu_gen_t hw>
631void jit_eltwise_injector_f32<hw>::pow_compute_fwd(
632 int simd, const ngen::GRF &r, int phase, int off) {
633 auto temp = scratch_[off].f();
634 switch (phase) {
635 case 0:
636 if ((long long int)beta_ == beta_) {
637 h->mov(simd, temp, abs(r));
638 } else {
639 h->mov(simd, temp, r);
640 }
641 break;
642 case 1: h->log(simd, temp, temp); break;
643 case 2: h->mul(simd, temp, temp, beta_); break;
644 case 3: h->exp(simd, temp, temp); break;
645 case 4:
646 if (((long long int)beta_) & 0x1)
647 h->csel(simd | lt | f0[0], temp, -temp, temp, r);
648 break;
649 case 5: h->mul(simd, r, temp, alpha_); break;
650 default: assert(!"invalid phase");
651 }
652}
653
654template <gpu_gen_t hw>
655void jit_eltwise_injector_f32<hw>::compute(const ngen::GRFRange &regs) {
656 using namespace alg_kind;
657
658 auto bmax = max_batch_size();
659 auto phases = phase_count(alg_);
660
661 for (int idx0 = 0; idx0 < regs.getLen(); idx0 += bmax) {
662 auto batch = nstl::min(regs.getLen() - idx0, bmax);
663
664 for (int phase = 0; phase < phases; phase++) {
665 for (int ii = 0; ii < batch; ii += 2) {
666 int nreg = nstl::min(2, batch - ii);
667 int simd = nreg * GRF::bytes(hw) / sizeof(float);
668 auto base = regs[idx0 + ii].f();
669
670 if (is_fwd_) {
671 switch (alg_) {
672 case eltwise_elu:
673 case eltwise_elu_use_dst_for_bwd:
674 elu_compute_fwd(simd, base, phase, ii);
675 break;
676 case eltwise_exp:
677 case eltwise_exp_use_dst_for_bwd:
678 exp_compute_fwd(simd, base, phase);
679 break;
680 case eltwise_gelu_erf:
681 gelu_erf_compute_fwd(simd, base, phase, ii, batch);
682 break;
683 case eltwise_hardsigmoid:
684 hardsigmoid_compute_fwd(simd, base, phase, ii);
685 break;
686 case eltwise_hardswish:
687 hardswish_compute_fwd(simd, base, phase, ii);
688 break;
689 case eltwise_log:
690 log_compute_fwd(simd, base, phase);
691 break;
692 case eltwise_mish:
693 mish_compute_fwd(simd, base, phase, ii, batch);
694 break;
695 case eltwise_pow:
696 pow_compute_fwd(simd, base, phase, ii);
697 break;
698 case eltwise_relu:
699 case eltwise_relu_use_dst_for_bwd:
700 if (alpha_ == 0.f)
701 relu_zero_ns_compute_fwd(simd, base);
702 else
703 relu_compute_fwd(simd, base, phase, ii);
704 break;
705 case eltwise_abs: abs_compute_fwd(simd, base); break;
706 case eltwise_soft_relu:
707 soft_relu_compute_fwd(simd, base, phase, ii);
708 break;
709 case eltwise_sqrt:
710 case eltwise_sqrt_use_dst_for_bwd:
711 sqrt_compute_fwd(simd, base);
712 break;
713 case eltwise_square:
714 square_compute_fwd(simd, base);
715 break;
716 case eltwise_tanh:
717 case eltwise_tanh_use_dst_for_bwd:
718 if (use_tanh_compat())
719 tanh_compute_fwd_compat(
720 simd, base, phase, ii, batch);
721 else
722 tanh_compute_fwd(simd, base, phase, ii, batch);
723 break;
724 case eltwise_round:
725 round_compute_fwd(simd, base);
726 break;
727 case eltwise_swish:
728 swish_compute_fwd(simd, base, phase, ii);
729 break;
730 case eltwise_linear:
731 linear_compute_fwd(simd, base, phase);
732 break;
733 case eltwise_clip:
734 case eltwise_clip_v2:
735 case eltwise_clip_v2_use_dst_for_bwd:
736 clip_compute_fwd(simd, base, phase, alpha_, beta_);
737 break;
738 case eltwise_gelu_tanh:
739 gelu_tanh_compute_fwd(simd, base, phase, ii);
740 break;
741 case eltwise_logistic:
742 case eltwise_logistic_use_dst_for_bwd:
743 logistic_compute_fwd(simd, base, phase);
744 break;
745 default: assert(!"unsupported eltwise algorithm");
746 }
747 } else {
748 switch (alg_) {
749 case eltwise_relu: relu_compute_bwd(simd, base); break;
750 case eltwise_abs:
751 abs_compute_bwd(simd, base, phase);
752 break;
753 case eltwise_square:
754 square_compute_bwd(simd, base);
755 break;
756 case eltwise_linear:
757 linear_compute_bwd(simd, base);
758 break;
759 case eltwise_clip:
760 clip_compute_bwd(simd, base, phase, alpha_, beta_);
761 break;
762 case eltwise_gelu_tanh:
763 gelu_tanh_compute_bwd(simd, base, phase, ii, batch);
764 break;
765 default: assert(!"unsupported eltwise algorithm");
766 }
767 }
768 // Apply scale.
769 if (phase == phases - 1 && scale_ != 1.f) {
770 h->mul(simd, base, base, scale_);
771 }
772 }
773 }
774 }
775}
776
777template <gpu_gen_t hw>
778void jit_eltwise_injector_f32<hw>::prepare() {
779 using namespace alg_kind;
780
781 assert(scratch_.getLen() >= min_scratch_regs());
782
783 if (is_fwd_) {
784 switch (alg_) {
785 case eltwise_mish:
786 case eltwise_tanh:
787 if (use_tanh_compat())
788 tanh_prepare_fwd_compat();
789 else
790 tanh_prepare_fwd();
791 break;
792 default: break;
793 }
794 } else {
795 switch (alg_) {
796 case eltwise_relu: relu_prepare_bwd(); break;
797 case eltwise_abs: abs_prepare_bwd(); break;
798 case eltwise_clip: clip_prepare_bwd(); break;
799 default: break;
800 }
801 }
802}
803
804REG_GEN9_ISA(template struct jit_eltwise_injector_f32<gpu_gen9>);
805REG_GEN11_ISA(template struct jit_eltwise_injector_f32<gpu_gen11>);
806REG_XELP_ISA(template struct jit_eltwise_injector_f32<gpu_xe_lp>);
807REG_XEHP_ISA(template struct jit_eltwise_injector_f32<gpu_xe_hp>);
808REG_XEHPG_ISA(template struct jit_eltwise_injector_f32<gpu_xe_hpg>);
809REG_XEHPC_ISA(template struct jit_eltwise_injector_f32<gpu_xe_hpc>);
810
811} // namespace jit
812} // namespace gpu
813} // namespace impl
814} // namespace dnnl
815