1/*******************************************************************************
2 * Copyright 2021-2022 Intel Corporation
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 *******************************************************************************/
16
17#ifndef GPU_JIT_JIT_POST_OP_INJECTOR_HPP
18#define GPU_JIT_JIT_POST_OP_INJECTOR_HPP
19
20#include "common/primitive_attr.hpp"
21#include "gpu/jit/jit_eltwise_injector.hpp"
22#include "gpu/jit/jit_generator.hpp"
23
24namespace dnnl {
25namespace impl {
26namespace gpu {
27namespace jit {
28
29inline bool jit_post_op_injector_is_supported(
30 const post_ops_t &post_ops, bool skip_sum) {
31 bool is_supported = true;
32 for (int idx = 0; idx < post_ops.len(); ++idx) {
33 const auto &po = post_ops.entry_[idx];
34 if (po.is_binary())
35 is_supported &= false;
36 else if (po.is_convolution())
37 is_supported &= false;
38 else if (po.is_eltwise())
39 is_supported
40 &= jit_eltwise_injector_f32_is_supported(po.eltwise.alg);
41 else if (po.is_sum(false, false))
42 is_supported &= skip_sum;
43 }
44 return is_supported;
45}
46
47template <gpu_gen_t hw>
48struct jit_post_op_injector {
49 jit_post_op_injector(jit_generator<hw> *host, data_type_t accumulator_type,
50 const post_ops_t &post_ops, int eu_count,
51 const ngen::GRFRange &scratch = ngen::GRFRange(),
52 bool is_fwd = true)
53 : post_ops_(post_ops), is_fwd_(is_fwd), scratch_(scratch) {
54 assert(accumulator_type == data_type_t::dnnl_f32);
55 workers_.reserve(post_ops.len());
56 for (int idx = 0; idx < post_ops.len(); ++idx) {
57 const auto &po = post_ops.entry_[idx];
58 if (po.is_eltwise())
59 workers_.emplace_back(host, po.eltwise.alg, po.eltwise.alpha,
60 po.eltwise.beta, po.eltwise.scale, eu_count, scratch,
61 is_fwd);
62 }
63 }
64
65 int min_scratch_regs();
66 int preferred_scratch_regs();
67 void set_scratch(const ngen::GRFRange &scratch);
68
69 void compute(const ngen::GRF &reg) { compute(reg - reg); }
70 void compute(const ngen::GRFRange &regs);
71
72private:
73 post_ops_t post_ops_;
74 std::vector<jit_eltwise_injector_f32<hw>> workers_;
75 bool is_fwd_;
76 ngen::GRFRange scratch_;
77};
78
79} // namespace jit
80} // namespace gpu
81} // namespace impl
82} // namespace dnnl
83
84#endif // GPU_JIT_JIT_POST_OP_INJECTOR_HPP
85