1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_PRELU_JIT_PRELU_FORWARD_HPP
18#define CPU_X64_PRELU_JIT_PRELU_FORWARD_HPP
19
20#include <memory>
21#include <set>
22
23#include "common/primitive.hpp"
24#include "cpu/cpu_prelu_pd.hpp"
25
26namespace dnnl {
27namespace impl {
28namespace cpu {
29namespace x64 {
30
31class jit_prelu_forward_kernel_t;
32
33class jit_prelu_fwd_t : public primitive_t {
34public:
35 struct pd_t : public cpu_prelu_fwd_pd_t {
36 public:
37 using cpu_prelu_fwd_pd_t::cpu_prelu_fwd_pd_t;
38 DECLARE_COMMON_PD_T("jit_uni", jit_prelu_fwd_t);
39 status_t init(engine_t *engine);
40
41 private:
42 bool bcast_supported(const memory_desc_wrapper &src_d,
43 const memory_desc_wrapper &weights_d,
44 const memory_desc_wrapper &dst_d) const;
45 };
46
47 jit_prelu_fwd_t(const pd_t *apd);
48 ~jit_prelu_fwd_t();
49 status_t init(engine_t *engine) override;
50 status_t execute(const exec_ctx_t &ctx) const override;
51
52private:
53 const pd_t *pd() const;
54 std::unique_ptr<jit_prelu_forward_kernel_t> kernel_;
55};
56
57} // namespace x64
58} // namespace cpu
59} // namespace impl
60} // namespace dnnl
61
62#endif
63