1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef GPU_JIT_JIT_ELTWISE_INJECTOR_HPP
18#define GPU_JIT_JIT_ELTWISE_INJECTOR_HPP
19
20#include <assert.h>
21
22#include "common/c_types_map.hpp"
23#include "common/utils.hpp"
24#include "gpu/jit/jit_generator.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace gpu {
29namespace jit {
30
31inline bool jit_eltwise_injector_f32_is_supported(alg_kind_t alg) {
32 using namespace alg_kind;
33 return utils::one_of(alg, eltwise_elu, eltwise_elu_use_dst_for_bwd,
34 eltwise_exp, eltwise_exp_use_dst_for_bwd, eltwise_gelu_tanh,
35 eltwise_gelu_erf, eltwise_hardsigmoid, eltwise_hardswish,
36 eltwise_log, eltwise_mish, eltwise_pow, eltwise_relu,
37 eltwise_relu_use_dst_for_bwd, eltwise_soft_relu, eltwise_sqrt,
38 eltwise_sqrt_use_dst_for_bwd, eltwise_square, eltwise_swish,
39 eltwise_tanh, eltwise_tanh_use_dst_for_bwd, eltwise_abs,
40 eltwise_round, eltwise_linear, eltwise_clip, eltwise_clip_v2,
41 eltwise_clip_v2_use_dst_for_bwd, eltwise_logistic,
42 eltwise_logistic_use_dst_for_bwd);
43}
44
45template <gpu_gen_t hw>
46struct jit_eltwise_injector_f32 {
47 jit_eltwise_injector_f32(jit_generator<hw> *host, alg_kind_t alg,
48 float alpha, float beta, float scale, int eu_count,
49 const ngen::GRFRange &scratch = ngen::GRFRange(),
50 bool is_fwd = true)
51 : alg_(alg)
52 , alpha_(alpha)
53 , beta_(beta)
54 , scale_(scale)
55 , is_fwd_(is_fwd)
56 , eu_count_(eu_count)
57 , h(host)
58 , scratch_(scratch) {
59
60 assert(jit_eltwise_injector_f32_is_supported(alg_));
61 assert(scratch_.isEmpty() || (scratch_.getLen() >= min_scratch_regs()));
62 }
63
64 int min_scratch_regs();
65 int preferred_scratch_regs();
66 void set_scratch(const ngen::GRFRange &scratch) { scratch_ = scratch; }
67
68 void prepare();
69 void compute(const ngen::GRF &reg) { compute(reg - reg); }
70 void compute(const ngen::GRFRange &regs);
71
72private:
73 const alg_kind_t alg_;
74 const float alpha_;
75 const float beta_;
76 const float scale_;
77 const bool is_fwd_;
78
79 const int eu_count_;
80
81 jit_generator<hw> *h;
82
83 ngen::GRFRange scratch_;
84
85 bool is_gpu(ngen::HW arg_hw, int arg_eu_count) const {
86 return (hw == arg_hw) && (eu_count_ == arg_eu_count);
87 }
88 bool use_tanh_compat() const { return false; }
89
90 int max_batch_size();
91 int phase_count(alg_kind_t alg);
92
93 void relu_prepare_bwd();
94 void abs_prepare_bwd();
95 void clip_prepare_bwd();
96 void tanh_prepare_fwd();
97 void tanh_prepare_fwd_compat();
98
99 void relu_zero_ns_compute_fwd(int simd, const ngen::GRF &r);
100 void relu_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
101 void abs_compute_fwd(int simd, const ngen::GRF &r);
102 void exp_compute_fwd(int simd, const ngen::GRF &r, int phase);
103 void elu_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
104 void gelu_erf_compute_fwd(
105 int simd, const ngen::GRF &r, int phase, int off, int batch);
106 void hardsigmoid_compute_fwd(
107 int simd, const ngen::GRF &r, int phase, int off);
108 void hardswish_compute_fwd(
109 int simd, const ngen::GRF &r, int phase, int off);
110 void log_compute_fwd(int simd, const ngen::GRF &r, int phase);
111 void mish_compute_fwd(
112 int simd, const ngen::GRF &r, int phase, int off, int batch);
113 void pow_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
114 void soft_relu_compute_fwd_inner(int simd, const ngen::GRF &input,
115 const ngen::GRF &temp, const ngen::GRF &dest, int phase, int off,
116 float alpha);
117 void soft_relu_compute_fwd(
118 int simd, const ngen::GRF &r, int phase, int off);
119 void sqrt_compute_fwd(int simd, const ngen::GRF &r);
120 void square_compute_fwd(int simd, const ngen::GRF &r);
121 void round_compute_fwd(int simd, const ngen::GRF &r);
122 void swish_compute_fwd(int simd, const ngen::GRF &r, int phase, int off);
123 void tanh_compute_fwd(
124 int simd, const ngen::GRF &r, int phase, int off, int batch);
125 void tanh_compute_fwd_compat(
126 int simd, const ngen::GRF &r, int phase, int off, int batch);
127 void linear_compute_fwd(int simd, const ngen::GRF &r, int phase);
128 void clip_compute_fwd(
129 int simd, const ngen::GRF &r, int phase, float alpha, float beta);
130 void gelu_tanh_compute_fwd(
131 int simd, const ngen::GRF &r, int phase, int off);
132 void logistic_compute_fwd(int simd, const ngen::GRF &r, int phase);
133
134 void relu_compute_bwd(int simd, const ngen::GRF &r);
135 void abs_compute_bwd(int simd, const ngen::GRF &r, int phase);
136 void square_compute_bwd(int simd, const ngen::GRF &r);
137 void linear_compute_bwd(int simd, const ngen::GRF &r);
138 void clip_compute_bwd(
139 int simd, const ngen::GRF &r, int phase, float alpha, float beta);
140 void gelu_tanh_compute_bwd(
141 int simd, const ngen::GRF &r, int phase, int off, int batch);
142
143 const ngen::InstructionModifier le = jit_generator<hw>::le;
144 const ngen::InstructionModifier lt = jit_generator<hw>::lt;
145 const ngen::InstructionModifier ge = jit_generator<hw>::ge;
146 const ngen::InstructionModifier gt = jit_generator<hw>::gt;
147 const ngen::InstructionModifier eq = jit_generator<hw>::eq;
148 const ngen::InstructionModifier sat = jit_generator<hw>::sat;
149 const ngen::FlagRegister f0 = jit_generator<hw>::f0;
150};
151
152} // namespace jit
153} // namespace gpu
154} // namespace impl
155} // namespace dnnl
156
157#endif // GPU_JIT_JIT_ELTWISE_INJECTOR_HPP
158