1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15* *******************************************************************************/
16
17#ifndef CPU_X64_JIT_UNI_X8S8S32X_1X1_DECONVOLUTION_HPP
18#define CPU_X64_JIT_UNI_X8S8S32X_1X1_DECONVOLUTION_HPP
19
20#include "common/c_types_map.hpp"
21#include "common/dnnl_thread.hpp"
22#include "common/primitive.hpp"
23#include "common/primitive_desc_iterator.hpp"
24#include "common/type_helpers.hpp"
25#include "common/utils.hpp"
26
27#include "cpu/cpu_convolution_pd.hpp"
28#include "cpu/cpu_deconvolution_pd.hpp"
29
30#include "cpu/x64/jit_uni_1x1_conv_utils.hpp"
31#include "cpu/x64/jit_uni_x8s8s32x_1x1_convolution.hpp"
32#include "cpu/zero_point_utils.hpp"
33
34namespace dnnl {
35namespace impl {
36namespace cpu {
37namespace x64 {
38
39template <cpu_isa_t isa>
40struct jit_uni_x8s8s32x_1x1_deconvolution_fwd_t : public primitive_t {
41 struct pd_t : public cpu_deconvolution_fwd_pd_t {
42 pd_t(const deconvolution_desc_t *adesc, const primitive_attr_t *attr,
43 const deconvolution_fwd_pd_t *hint_fwd_pd)
44 : cpu_deconvolution_fwd_pd_t(adesc, attr, hint_fwd_pd) {}
45
46 pd_t(const pd_t &other)
47 : cpu_deconvolution_fwd_pd_t(other)
48 , conv_pd_(other.conv_pd_->clone()) {}
49
50 ~pd_t() = default;
51
52 DECLARE_COMMON_PD_T(
53 conv_pd_->name(), jit_uni_x8s8s32x_1x1_deconvolution_fwd_t);
54
55 status_t init_convolution(engine_t *engine) {
56 convolution_desc_t cd;
57
58 auto dd = desc();
59 CHECK(conv_desc_init(&cd, prop_kind::forward_training,
60 alg_kind::convolution_direct, &(dd->src_desc),
61 &(dd->weights_desc), &(dd->bias_desc), &(dd->dst_desc),
62 dd->strides, dd->dilates, dd->padding[0], dd->padding[1]));
63
64 primitive_attr_t conv_attr(*attr());
65 if (!conv_attr.is_initialized()) return status::out_of_memory;
66 primitive_desc_iterator_t it(
67 engine, (op_desc_t *)&cd, &conv_attr, nullptr);
68 if (!it.is_initialized()) return status::out_of_memory;
69
70 while (++it != it.end()) {
71 conv_pd_ = *it;
72 // XXX: find another way to create required implementation.
73 if (dynamic_cast<conv_pd_t *>(conv_pd_.get()))
74 return set_default_params();
75 }
76
77 return status::unimplemented;
78 };
79
80 status_t init(engine_t *engine) {
81 using namespace data_type;
82 using skip_mask_t = primitive_attr_t::skip_mask_t;
83 bool ok = is_fwd()
84 && desc()->alg_kind == alg_kind::deconvolution_direct
85 && !has_zero_dim_memory()
86 && utils::one_of(src_md(0)->data_type, s8, u8)
87 && weights_md(0)->data_type == s8
88 && IMPLICATION(with_bias(),
89 utils::one_of(
90 weights_md(1)->data_type, f32, s32, s8, u8))
91 && utils::one_of(dst_md(0)->data_type, f32, s32, s8, u8)
92 && desc()->accum_data_type == s32
93 && attr()->has_default_values(skip_mask_t::scales_runtime
94 | skip_mask_t::post_ops
95 | skip_mask_t::zero_points_runtime)
96 && zero_points_valid(
97 attr(), true /*per_oc_bcast_accepted*/);
98 if (!ok) return status::unimplemented;
99
100 CHECK(init_convolution(engine));
101 CHECK(attr_.set_default_formats(dst_md(0)));
102 init_scratchpad();
103
104 return status::success;
105 }
106
107 protected:
108 status_t set_default_params() {
109 auto conv_1x1_pd_ = static_cast<conv_pd_t *>(conv_pd_.get());
110 src_md_ = *conv_1x1_pd_->src_md();
111 dst_md_ = *conv_1x1_pd_->dst_md();
112 weights_md_ = *conv_1x1_pd_->weights_md();
113 if (with_bias()) bias_md_ = *conv_1x1_pd_->weights_md(1);
114 return status::success;
115 }
116
117 using conv_pd_t =
118 typename jit_uni_x8s8s32x_1x1_convolution_fwd_t<isa>::pd_t;
119 friend jit_uni_x8s8s32x_1x1_deconvolution_fwd_t;
120
121 std::shared_ptr<primitive_desc_t> conv_pd_;
122
123 private:
124 void init_scratchpad() {
125 auto scratchpad = scratchpad_registry().registrar();
126 scratchpad.book(memory_tracking::names::key_nested,
127 conv_pd_->scratchpad_registry());
128 }
129 };
130
131 jit_uni_x8s8s32x_1x1_deconvolution_fwd_t(const pd_t *apd)
132 : primitive_t(apd) {}
133
134 status_t init(engine_t *engine) override {
135 pd()->conv_pd_->create_primitive(conv_p_, engine);
136 return status::success;
137 }
138
139 status_t execute(const exec_ctx_t &ctx) const override {
140 nested_scratchpad_t ns(
141 ctx, memory_tracking::names::key_nested, conv_p_);
142 // XXX: create a new ctx for convolution?
143 auto &tmp_ctx = const_cast<exec_ctx_t &>(ctx);
144 tmp_ctx.set_scratchpad_grantor(ns.grantor());
145 return conv_p_->execute(tmp_ctx);
146 }
147
148private:
149 const pd_t *pd() const { return (const pd_t *)primitive_t::pd().get(); }
150 std::shared_ptr<primitive_t> conv_p_;
151};
152
153} // namespace x64
154} // namespace cpu
155} // namespace impl
156} // namespace dnnl
157
158#endif /* CPU_X64_JIT_UNI_X8S8S32X_1X1_DECONVOLUTION_HPP */
159