1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_JIT_ELTWISE_INJECTOR_HPP |
18 | #define GPU_JIT_JIT_ELTWISE_INJECTOR_HPP |
19 | |
20 | #include <assert.h> |
21 | |
22 | #include "common/c_types_map.hpp" |
23 | #include "common/utils.hpp" |
24 | #include "gpu/jit/jit_generator.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | namespace gpu { |
29 | namespace jit { |
30 | |
31 | inline bool jit_eltwise_injector_f32_is_supported(alg_kind_t alg) { |
32 | using namespace alg_kind; |
33 | return utils::one_of(alg, eltwise_elu, eltwise_elu_use_dst_for_bwd, |
34 | eltwise_exp, eltwise_exp_use_dst_for_bwd, eltwise_gelu_tanh, |
35 | eltwise_gelu_erf, eltwise_hardsigmoid, eltwise_hardswish, |
36 | eltwise_log, eltwise_mish, eltwise_pow, eltwise_relu, |
37 | eltwise_relu_use_dst_for_bwd, eltwise_soft_relu, eltwise_sqrt, |
38 | eltwise_sqrt_use_dst_for_bwd, eltwise_square, eltwise_swish, |
39 | eltwise_tanh, eltwise_tanh_use_dst_for_bwd, eltwise_abs, |
40 | eltwise_round, eltwise_linear, eltwise_clip, eltwise_clip_v2, |
41 | eltwise_clip_v2_use_dst_for_bwd, eltwise_logistic, |
42 | eltwise_logistic_use_dst_for_bwd); |
43 | } |
44 | |
45 | template <gpu_gen_t hw> |
46 | struct jit_eltwise_injector_f32 { |
47 | jit_eltwise_injector_f32(jit_generator<hw> *host, alg_kind_t alg, |
48 | float alpha, float beta, float scale, int eu_count, |
49 | const ngen::GRFRange &scratch = ngen::GRFRange(), |
50 | bool is_fwd = true) |
51 | : alg_(alg) |
52 | , alpha_(alpha) |
53 | , beta_(beta) |
54 | , scale_(scale) |
55 | , is_fwd_(is_fwd) |
56 | , eu_count_(eu_count) |
57 | , h(host) |
58 | , scratch_(scratch) { |
59 | |
60 | assert(jit_eltwise_injector_f32_is_supported(alg_)); |
61 | assert(scratch_.isEmpty() || (scratch_.getLen() >= min_scratch_regs())); |
62 | } |
63 | |
64 | int min_scratch_regs(); |
65 | int preferred_scratch_regs(); |
66 | void set_scratch(const ngen::GRFRange &scratch) { scratch_ = scratch; } |
67 | |
68 | void prepare(); |
69 | void compute(const ngen::GRF ®) { compute(reg - reg); } |
70 | void compute(const ngen::GRFRange ®s); |
71 | |
72 | private: |
73 | const alg_kind_t alg_; |
74 | const float alpha_; |
75 | const float beta_; |
76 | const float scale_; |
77 | const bool is_fwd_; |
78 | |
79 | const int eu_count_; |
80 | |
81 | jit_generator<hw> *h; |
82 | |
83 | ngen::GRFRange scratch_; |
84 | |
85 | bool is_gpu(ngen::HW arg_hw, int arg_eu_count) const { |
86 | return (hw == arg_hw) && (eu_count_ == arg_eu_count); |
87 | } |
88 | bool use_tanh_compat() const { return false; } |
89 | |
90 | int max_batch_size(); |
91 | int phase_count(alg_kind_t alg); |
92 | |
93 | void relu_prepare_bwd(); |
94 | void abs_prepare_bwd(); |
95 | void clip_prepare_bwd(); |
96 | void tanh_prepare_fwd(); |
97 | void tanh_prepare_fwd_compat(); |
98 | |
99 | void relu_zero_ns_compute_fwd(int simd, const ngen::GRF &r); |
100 | void relu_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); |
101 | void abs_compute_fwd(int simd, const ngen::GRF &r); |
102 | void exp_compute_fwd(int simd, const ngen::GRF &r, int phase); |
103 | void elu_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); |
104 | void gelu_erf_compute_fwd( |
105 | int simd, const ngen::GRF &r, int phase, int off, int batch); |
106 | void hardsigmoid_compute_fwd( |
107 | int simd, const ngen::GRF &r, int phase, int off); |
108 | void hardswish_compute_fwd( |
109 | int simd, const ngen::GRF &r, int phase, int off); |
110 | void log_compute_fwd(int simd, const ngen::GRF &r, int phase); |
111 | void mish_compute_fwd( |
112 | int simd, const ngen::GRF &r, int phase, int off, int batch); |
113 | void pow_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); |
114 | void soft_relu_compute_fwd_inner(int simd, const ngen::GRF &input, |
115 | const ngen::GRF &temp, const ngen::GRF &dest, int phase, int off, |
116 | float alpha); |
117 | void soft_relu_compute_fwd( |
118 | int simd, const ngen::GRF &r, int phase, int off); |
119 | void sqrt_compute_fwd(int simd, const ngen::GRF &r); |
120 | void square_compute_fwd(int simd, const ngen::GRF &r); |
121 | void round_compute_fwd(int simd, const ngen::GRF &r); |
122 | void swish_compute_fwd(int simd, const ngen::GRF &r, int phase, int off); |
123 | void tanh_compute_fwd( |
124 | int simd, const ngen::GRF &r, int phase, int off, int batch); |
125 | void tanh_compute_fwd_compat( |
126 | int simd, const ngen::GRF &r, int phase, int off, int batch); |
127 | void linear_compute_fwd(int simd, const ngen::GRF &r, int phase); |
128 | void clip_compute_fwd( |
129 | int simd, const ngen::GRF &r, int phase, float alpha, float beta); |
130 | void gelu_tanh_compute_fwd( |
131 | int simd, const ngen::GRF &r, int phase, int off); |
132 | void logistic_compute_fwd(int simd, const ngen::GRF &r, int phase); |
133 | |
134 | void relu_compute_bwd(int simd, const ngen::GRF &r); |
135 | void abs_compute_bwd(int simd, const ngen::GRF &r, int phase); |
136 | void square_compute_bwd(int simd, const ngen::GRF &r); |
137 | void linear_compute_bwd(int simd, const ngen::GRF &r); |
138 | void clip_compute_bwd( |
139 | int simd, const ngen::GRF &r, int phase, float alpha, float beta); |
140 | void gelu_tanh_compute_bwd( |
141 | int simd, const ngen::GRF &r, int phase, int off, int batch); |
142 | |
143 | const ngen::InstructionModifier le = jit_generator<hw>::le; |
144 | const ngen::InstructionModifier lt = jit_generator<hw>::lt; |
145 | const ngen::InstructionModifier ge = jit_generator<hw>::ge; |
146 | const ngen::InstructionModifier gt = jit_generator<hw>::gt; |
147 | const ngen::InstructionModifier eq = jit_generator<hw>::eq; |
148 | const ngen::InstructionModifier sat = jit_generator<hw>::sat; |
149 | const ngen::FlagRegister f0 = jit_generator<hw>::f0; |
150 | }; |
151 | |
152 | } // namespace jit |
153 | } // namespace gpu |
154 | } // namespace impl |
155 | } // namespace dnnl |
156 | |
157 | #endif // GPU_JIT_JIT_ELTWISE_INJECTOR_HPP |
158 | |