1/*******************************************************************************
2* Copyright 2017-2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_I8I8_POOLING_HPP
18#define CPU_X64_JIT_UNI_I8I8_POOLING_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/primitive.hpp"
22#include "common/type_helpers.hpp"
23
24#include "cpu/cpu_pooling_pd.hpp"
25
26#include "cpu/x64/cpu_isa_traits.hpp"
27#include "cpu/x64/jit_primitive_conf.hpp"
28
29namespace dnnl {
30namespace impl {
31namespace cpu {
32namespace x64 {
33
34template <cpu_isa_t isa>
35struct jit_uni_i8i8_pooling_fwd_ker_t;
36
37template <cpu_isa_t isa>
38struct jit_uni_i8i8_pooling_fwd_t : public primitive_t {
39 struct pd_t : public cpu_pooling_fwd_pd_t {
40 using cpu_pooling_fwd_pd_t::cpu_pooling_fwd_pd_t;
41
42 DECLARE_COMMON_PD_T(JIT_IMPL_NAME_HELPER("jit_int:", isa, ""),
43 jit_uni_i8i8_pooling_fwd_t);
44
45 status_t init(engine_t *engine) {
46 using namespace format_tag;
47 bool ok = mayiuse(isa) && utils::one_of(ndims(), 3, 4, 5)
48 && desc()->prop_kind == prop_kind::forward_inference
49 && utils::one_of(desc()->alg_kind, alg_kind::pooling_max,
50 alg_kind::pooling_avg_include_padding,
51 alg_kind::pooling_avg_exclude_padding)
52 && utils::one_of(src_md()->data_type, data_type::s32,
53 data_type::s8, data_type::u8)
54 && src_md()->data_type == dst_md()->data_type
55 && !is_dilated()
56 && attr()->has_default_values(
57 primitive_attr_t::skip_mask_t::post_ops)
58 && set_default_params() == status::success
59 && memory_desc_matches_one_of_tag(
60 *src_md(), nwc, nhwc, ndhwc)
61 != format_tag::undef
62 && memory_desc_matches_one_of_tag(
63 *dst_md(), nwc, nhwc, ndhwc)
64 != format_tag::undef
65 && attr_.set_default_formats(dst_md(0)) == status::success;
66 if (!ok) return status::unimplemented;
67
68 CHECK(jit_conf());
69
70 return status::success;
71 }
72
73 jit_pool_conf_t jpp_;
74
75 protected:
76 status_t jit_conf();
77 };
78
79 jit_uni_i8i8_pooling_fwd_t(const pd_t *apd);
80 ~jit_uni_i8i8_pooling_fwd_t();
81
82 status_t init(engine_t *engine) override;
83
84 status_t execute(const exec_ctx_t &ctx) const override {
85 return execute_forward(ctx);
86 }
87
88private:
89 status_t execute_forward(const exec_ctx_t &ctx) const;
90 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
91
92 std::unique_ptr<jit_uni_i8i8_pooling_fwd_ker_t<isa>> ker_;
93};
94
95} // namespace x64
96} // namespace cpu
97} // namespace impl
98} // namespace dnnl
99
100#endif
101