1/*******************************************************************************
2* Copyright 2020-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_PRELU_JIT_PRELU_BACKWARD_HPP
18#define CPU_X64_PRELU_JIT_PRELU_BACKWARD_HPP
19
20#include <memory>
21#include <set>
22
23#include "common/primitive.hpp"
24#include "cpu/cpu_prelu_pd.hpp"
25
26#include "cpu/x64/prelu/jit_prelu_utils.hpp"
27
28namespace dnnl {
29namespace impl {
30namespace cpu {
31namespace x64 {
32
33class jit_prelu_backward_kernel_t;
34class jit_prelu_reduction_kernel_t;
35
36class jit_prelu_bwd_t : public primitive_t {
37public:
38 struct pd_t : public cpu_prelu_bwd_pd_t {
39 public:
40 using cpu_prelu_bwd_pd_t::cpu_prelu_bwd_pd_t;
41 DECLARE_COMMON_PD_T("jit_uni", jit_prelu_bwd_t);
42 status_t init(engine_t *engine);
43 int nthr_; // To not exceed the limit in execute used for set up.
44
45 private:
46 bool dt_supported(const std::set<data_type_t> &tensor_data_types) const
47 noexcept;
48 bool bcast_supported(const prelu::bcast &bcast,
49 const memory_desc_wrapper &src_diff_d,
50 const memory_desc_wrapper &weights_diff_d, int simd_w) const;
51 };
52
53 jit_prelu_bwd_t(const pd_t *apd);
54 ~jit_prelu_bwd_t();
55 status_t init(engine_t *engine) override;
56 status_t execute(const exec_ctx_t &ctx) const override;
57
58private:
59 using byte = unsigned char;
60 void fill_scratchpad_zeros(float *const scratchpad,
61 size_t thread_scratchpad_size, int nthr) const;
62 void scratchpad_to_diff_weights_reduction(float *scratchpad,
63 byte *weights_diff, size_t weights_diff_dt, dim_t C,
64 size_t reduction_blocks) const;
65 const pd_t *pd() const;
66 std::unique_ptr<jit_prelu_backward_kernel_t> kernel_;
67 std::unique_ptr<jit_prelu_reduction_kernel_t> reduction_kernel_;
68};
69
70} // namespace x64
71} // namespace cpu
72} // namespace impl
73} // namespace dnnl
74
75#endif
76