1/*******************************************************************************
2* Copyright 2018-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
18#define CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/primitive.hpp"
23#include "common/primitive_desc_iterator.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/cpu_deconvolution_pd.hpp"
29#include "cpu/zero_point_utils.hpp"
30
31#include "cpu/x64/jit_avx512_core_x8s8s32x_1x1_convolution.hpp"
32#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39struct jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t : public primitive_t {
40 struct pd_t : public cpu_deconvolution_fwd_pd_t {
41 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
42 const deconvolution_fwd_pd_t *hint_fwd_pd)
43 : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
44
45 pd_t(const pd_t &other)
46 : cpu_deconvolution_fwd_pd_t(other)
47 , conv_pd_(other.conv_pd_->clone()) {}
48
49 ~pd_t() = default;
50
51 DECLARE_COMMON_PD_T(conv_pd_->name(),
52 jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t);
53
54 status_t init_convolution(engine_t *engine) {
55 convolution_desc_t cd;
56
57 auto dd = desc();
58 CHECK(conv_desc_init(&cd, prop_kind::forward_training,
59 alg_kind::convolution_direct, &(dd->src_desc),
60 &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc),
61 dd->strides, dd->dilates, dd->padding[0], dd->padding[1]));
62
63 primitive_attr_t conv_attr(*attr());
64 if (!conv_attr.is_initialized()) return status::out_of_memory;
65 primitive_desc_iterator_t it(
66 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
67 if (!it.is_initialized()) return status::out_of_memory;
68
69 while (++it != it.end()) {
70 conv_pd_ = *it;
71 // XXX: find another way to create required implementation.
72 if (dynamic_cast<conv_pd_t *>(conv_pd_.get()))
73 return set_default_params();
74 }
75
76 return status::unimplemented;
77 };
78
79 status_t init(engine_t *engine) {
80 using namespace data_type;
81 using skip_mask_t = primitive_attr_t::skip_mask_t;
82 bool ok = is_fwd()
83 && desc()->alg_kind == alg_kind::deconvolution_direct
84 && !has_zero_dim_memory()
85 && utils::one_of(src_md(0)->data_type, s8, u8)
86 && weights_md(0)->data_type == s8
87 && IMPLICATION(with_bias(),
88 utils::one_of(
89 weights_md(1)->data_type, f32, s32, s8, u8))
90 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
91 && desc()->accum_data_type == s32
92 && attr()->has_default_values(skip_mask_t::scales_runtime
93 | skip_mask_t::post_ops
94 | skip_mask_t::zero_points_runtime)
95 && zero_points_valid(
96 attr(), true /*per_oc_bcast_accepted*/);
97
98 if (!ok) return status::unimplemented;
99
100 CHECK(init_convolution(engine));
101 CHECK(attr_.set_default_formats(dst_md(0)));
102 init_scratchpad();
103
104 return status::success;
105 }
106
107 protected:
108 status_t set_default_params() {
109 auto conv_1x1_pd_ = static_cast<conv_pd_t *>(conv_pd_.get());
110 src_md_ = *conv_1x1_pd_->src_md();
111 dst_md_ = *conv_1x1_pd_->dst_md();
112 weights_md_ = *conv_1x1_pd_->weights_md();
113 if (with_bias()) bias_md_ = *conv_1x1_pd_->weights_md(1);
114 return status::success;
115 }
116
117 using conv_pd_t =
118 typename jit_avx512_core_x8s8s32x_1x1_convolution_fwd_t::pd_t;
119 friend jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t;
120
121 std::shared_ptr<primitive_desc_t> conv_pd_;
122
123 private:
124 void init_scratchpad() {
125 auto scratchpad = scratchpad_registry().registrar();
126 scratchpad.book(memory_tracking::names::key_nested,
127 conv_pd_->scratchpad_registry());
128 }
129 };
130
131 jit_avx512_core_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd)
132 : primitive_t(apd) {}
133
134 status_t init(engine_t *engine) override {
135 pd()->conv_pd_->create_primitive(conv_p_, engine);
136 return status::success;
137 }
138
139 status_t execute(const exec_ctx_t &ctx) const override {
140 nested_scratchpad_t ns(
141 ctx, memory_tracking::names::key_nested, conv_p_);
142 // XXX: create a new ctx for convolution?
143 auto &tmp_ctx = const_cast<exec_ctx_t &>(ctx);
144 tmp_ctx.set_scratchpad_grantor(ns.grantor());
145 return conv_p_->execute(tmp_ctx);
146 }
147
148private:
149 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
150 std::shared_ptr<primitive_t> conv_p_;
151};
152
153} // namespace x64
154} // namespace cpu
155} // namespace impl
156} // namespace dnnl
157
158#endif /* CPU_X64_JIT_AVX512_CORE_X8S8S32X_1X1_DECONVOLUTION_HPP */
159