1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_SSE41_CONVOLUTION_HPP
18#define CPU_X64_JIT_SSE41_CONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/primitive.hpp"
23#include "common/utils.hpp"
24
25#include "cpu/cpu_convolution_pd.hpp"
26
27#include "cpu/x64/jit_primitive_conf.hpp"
28#include "cpu/x64/jit_sse41_conv_kernel_f32.hpp"
29
30namespace dnnl {
31namespace impl {
32namespace cpu {
33namespace x64 {
34
35struct jit_sse41_convolution_fwd_t : public primitive_t {
36 struct pd_t : public cpu_convolution_fwd_pd_t {
37 pd_t(const convolution_desc_t *adesc, const primitive_attr_t *attr,
38 const typename pd_t::base_class *hint_fwd_pd)
39 : cpu_convolution_fwd_pd_t(adesc, attr, hint_fwd_pd), jcp_() {}
40
41 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit:", sse41, ""),
42 jit_sse41_convolution_fwd_t);
43
44 status_t init(engine_t *engine) {
45 using namespace data_type;
46 bool ok = is_fwd()
47 && set_default_alg_kind(alg_kind::convolution_direct)
48 && expect_data_types(f32, f32, f32, f32, f32)
49 && attr()->has_default_values(
50 primitive_attr_t::skip_mask_t::post_ops, f32)
51 && !has_zero_dim_memory() && set_default_formats()
52 && attr_.set_default_formats(dst_md(0)) == status::success;
53 if (!ok) return status::unimplemented;
54
55 CHECK(jit_sse41_conv_fwd_kernel_f32::init_conf(jcp_, *desc(),
56 *src_md(), *weights_md(), *dst_md(), *attr(),
57 dnnl_get_max_threads()));
58
59 return status::success;
60 }
61
62 jit_conv_conf_t jcp_;
63
64 protected:
65 bool set_default_formats() {
66 using namespace format_tag;
67
68 const memory_desc_wrapper src_d(&src_md_);
69 const memory_desc_wrapper dst_d(&dst_md_);
70
71 const auto dat_tag_nxc = utils::pick(ndims() - 3, nwc, nhwc, ndhwc);
72 const auto dat_tag_ncx = utils::pick(ndims() - 3, ncw, nchw, ncdhw);
73 const auto dat_tag_nCx8c
74 = utils::pick(ndims() - 3, nCw8c, nChw8c, nCdhw8c);
75 const auto curr_src_tag = src_d.matches_one_of_tag(
76 dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
77 const auto curr_dst_tag = dst_d.matches_one_of_tag(
78 dat_tag_nxc, dat_tag_ncx, dat_tag_nCx8c);
79 const auto is_data_layout_nxc
80 = IMPLICATION(curr_src_tag != dat_tag_nxc,
81 src_d.format_kind() == format_kind::any)
82 && IMPLICATION(curr_dst_tag != dat_tag_nxc,
83 dst_d.format_kind() == format_kind::any)
84 && utils::one_of(dat_tag_nxc, curr_src_tag, curr_dst_tag);
85 const bool flat = IC() == 3;
86 auto src_tag = is_data_layout_nxc
87 ? dat_tag_nxc
88 : flat ? dat_tag_ncx : dat_tag_nCx8c;
89 auto dst_tag = is_data_layout_nxc ? dat_tag_nxc : dat_tag_nCx8c;
90 auto wei_tag = with_groups()
91 ? utils::pick(2 * ndims() - 6 + flat, gOIw8i8o, gOwi8o,
92 gOIhw8i8o, gOhwi8o, gOIdhw8i8o, gOdhwi8o)
93 : utils::pick(2 * ndims() - 6 + flat, OIw8i8o, Owi8o,
94 OIhw8i8o, Ohwi8o, OIdhw8i8o, Odhwi8o);
95
96 return set_default_formats_common(src_tag, wei_tag, dst_tag);
97 }
98 };
99
100 jit_sse41_convolution_fwd_t(const pd_t *apd) : primitive_t(apd) {}
101
102 typedef typename prec_traits<data_type::f32>::type data_t;
103
104 status_t init(engine_t *engine) override {
105 CHECK(safe_ptr_assign(kernel_,
106 new jit_sse41_conv_fwd_kernel_f32(
107 pd()->jcp_, *pd()->attr(), *pd()->dst_md(0))));
108 return kernel_->create_kernel();
109 }
110
111 status_t execute(const exec_ctx_t &ctx) const override {
112 execute_forward(ctx);
113 return status::success;
114 }
115
116private:
117 void execute_forward(const exec_ctx_t &ctx) const;
118 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
119 std::unique_ptr<jit_sse41_conv_fwd_kernel_f32> kernel_;
120};
121
122} // namespace x64
123} // namespace cpu
124} // namespace impl
125} // namespace dnnl
126
127#endif
128