1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef GPU_JIT_JIT_POST_OP_INJECTOR_HPP |
18 | #define GPU_JIT_JIT_POST_OP_INJECTOR_HPP |
19 | |
20 | #include "common/primitive_attr.hpp" |
21 | #include "gpu/jit/jit_eltwise_injector.hpp" |
22 | #include "gpu/jit/jit_generator.hpp" |
23 | |
24 | namespace dnnl { |
25 | namespace impl { |
26 | namespace gpu { |
27 | namespace jit { |
28 | |
29 | inline bool jit_post_op_injector_is_supported( |
30 | const post_ops_t &post_ops, bool skip_sum) { |
31 | bool is_supported = true; |
32 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
33 | const auto &po = post_ops.entry_[idx]; |
34 | if (po.is_binary()) |
35 | is_supported &= false; |
36 | else if (po.is_convolution()) |
37 | is_supported &= false; |
38 | else if (po.is_eltwise()) |
39 | is_supported |
40 | &= jit_eltwise_injector_f32_is_supported(po.eltwise.alg); |
41 | else if (po.is_sum(false, false)) |
42 | is_supported &= skip_sum; |
43 | } |
44 | return is_supported; |
45 | } |
46 | |
47 | template <gpu_gen_t hw> |
48 | struct jit_post_op_injector { |
49 | jit_post_op_injector(jit_generator<hw> *host, data_type_t accumulator_type, |
50 | const post_ops_t &post_ops, int eu_count, |
51 | const ngen::GRFRange &scratch = ngen::GRFRange(), |
52 | bool is_fwd = true) |
53 | : post_ops_(post_ops), is_fwd_(is_fwd), scratch_(scratch) { |
54 | assert(accumulator_type == data_type_t::dnnl_f32); |
55 | workers_.reserve(post_ops.len()); |
56 | for (int idx = 0; idx < post_ops.len(); ++idx) { |
57 | const auto &po = post_ops.entry_[idx]; |
58 | if (po.is_eltwise()) |
59 | workers_.emplace_back(host, po.eltwise.alg, po.eltwise.alpha, |
60 | po.eltwise.beta, po.eltwise.scale, eu_count, scratch, |
61 | is_fwd); |
62 | } |
63 | } |
64 | |
65 | int min_scratch_regs(); |
66 | int preferred_scratch_regs(); |
67 | void set_scratch(const ngen::GRFRange &scratch); |
68 | |
69 | void compute(const ngen::GRF ®) { compute(reg - reg); } |
70 | void compute(const ngen::GRFRange ®s); |
71 | |
72 | private: |
73 | post_ops_t post_ops_; |
74 | std::vector<jit_eltwise_injector_f32<hw>> workers_; |
75 | bool is_fwd_; |
76 | ngen::GRFRange scratch_; |
77 | }; |
78 | |
79 | } // namespace jit |
80 | } // namespace gpu |
81 | } // namespace impl |
82 | } // namespace dnnl |
83 | |
84 | #endif // GPU_JIT_JIT_POST_OP_INJECTOR_HPP |
85 | |