1/*******************************************************************************
2* Copyright 2019-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_DW_CONV_KERNEL_UTILS_HPP
18#define CPU_X64_JIT_UNI_DW_CONV_KERNEL_UTILS_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/memory_tracking.hpp"
22#include "common/nstl.hpp"
23#include "common/type_helpers.hpp"
24#include "common/utils.hpp"
25
26#include "cpu/x64/jit_generator.hpp"
27#include "cpu/x64/jit_primitive_conf.hpp"
28
29#include "cpu/x64/jit_avx512_core_bf16_dw_conv_kernel.hpp"
30#include "cpu/x64/jit_uni_dw_conv_kernel_f32.hpp"
31
32namespace dnnl {
33namespace impl {
34namespace cpu {
35namespace x64 {
36
37template <cpu_isa_t isa, data_type_t kernel_dt>
38struct jit_uni_dw_conv_fwd_kernel {
39
40 jit_uni_dw_conv_fwd_kernel(
41 const jit_conv_conf_t &ajcp, const memory_desc_t &dst_md) {
42 ker_ = new jit_kernel_t(ajcp, dst_md);
43 }
44
45 status_t create_kernel() { return ker_->create_kernel(); }
46 ~jit_uni_dw_conv_fwd_kernel() { delete ker_; }
47
48 static status_t init_conf(jit_conv_conf_t &jcp,
49 const convolution_desc_t &cd, memory_desc_t &src_md,
50 memory_desc_t &weights_md, memory_desc_t &bias_md,
51 memory_desc_t &dst_md, primitive_attr_t &attr);
52
53 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
54 const jit_conv_conf_t &jcp);
55
56 jit_generator *ker() const { return ker_; }
57 void operator()(const jit_conv_call_s *p) const { (*ker_)(p); }
58
59private:
60 constexpr static bool ker_condition_
61 = isa == avx512_core && kernel_dt == data_type::bf16;
62 using jit_kernel_t = typename utils::conditional<ker_condition_,
63 jit_avx512_dw_conv_fwd_kernel_bf16,
64 jit_uni_dw_conv_fwd_kernel_f32<isa>>::type;
65 jit_kernel_t *ker_;
66};
67
68template <cpu_isa_t isa, data_type_t kernel_dt>
69struct jit_uni_dw_conv_bwd_data_kernel {
70
71 jit_uni_dw_conv_bwd_data_kernel(const jit_conv_conf_t &ajcp)
72 : ker_(nullptr) {
73 ker_ = new jit_kernel_t(ajcp);
74 }
75
76 status_t create_kernel() { return ker_->create_kernel(); }
77 ~jit_uni_dw_conv_bwd_data_kernel() { delete ker_; }
78
79 static status_t init_conf(jit_conv_conf_t &jcp,
80 const convolution_desc_t &cd, memory_desc_t &diff_src_md,
81 memory_desc_t &weights_md, memory_desc_t &diff_dst_md);
82
83 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
84 const jit_conv_conf_t &jcp);
85
86 void operator()(const jit_conv_call_s *p) const { (*ker_)(p); }
87
88private:
89 using jit_kernel_t = typename utils::conditional<isa == avx512_core
90 && kernel_dt == data_type::bf16,
91 jit_avx512_dw_conv_bwd_data_kernel_bf16,
92 jit_uni_dw_conv_bwd_data_kernel_f32<isa>>::type;
93 jit_kernel_t *ker_;
94
95 DNNL_DISALLOW_COPY_AND_ASSIGN(jit_uni_dw_conv_bwd_data_kernel);
96};
97
98template <cpu_isa_t isa, data_type_t kernel_dt>
99struct jit_uni_dw_conv_bwd_weights_kernel {
100
101 jit_uni_dw_conv_bwd_weights_kernel(const jit_conv_conf_t &ajcp)
102 : ker_(nullptr) {
103 ker_ = new jit_kernel_t(ajcp);
104 }
105
106 status_t create_kernel() { return ker_->create_kernel(); }
107
108 ~jit_uni_dw_conv_bwd_weights_kernel() { delete ker_; }
109
110 static status_t init_conf(jit_conv_conf_t &jcp,
111 const convolution_desc_t &cd, memory_desc_t &src_md,
112 memory_desc_t &diff_weights_md, memory_desc_t &diff_bias_md,
113 memory_desc_t &diff_dst_md, int nthreads);
114
115 static void init_scratchpad(memory_tracking::registrar_t &scratchpad,
116 const jit_conv_conf_t &jcp);
117
118 static void partition_nthr_nxc(
119 jit_conv_conf_t &jcp, int nthreads, bool prioritize_threading);
120 static void balance(jit_conv_conf_t &jcp, int nthreads);
121
122 void operator()(const jit_dw_conv_call_s *p) const { (*ker_)(p); }
123
124private:
125 using jit_kernel_t = typename utils::conditional<isa == avx512_core
126 && kernel_dt == data_type::bf16,
127 jit_avx512_dw_conv_bwd_weights_kernel_bf16,
128 jit_uni_dw_conv_bwd_weights_kernel_f32<isa>>::type;
129 jit_kernel_t *ker_;
130};
131
132} // namespace x64
133} // namespace cpu
134} // namespace impl
135} // namespace dnnl
136#endif /* CPU_X64_JIT_UNI_DW_CONV_KERNEL_UTILS_HPP */
137