1/*******************************************************************************
2* Copyright 2019-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_SUM_PD_HPP
18#define COMMON_SUM_PD_HPP
19
20#include <assert.h>
21#include "oneapi/dnnl/dnnl.h"
22
23#include "c_types_map.hpp"
24#include "primitive_desc.hpp"
25#include "type_helpers.hpp"
26
27#include "utils.hpp"
28
29#include "primitive_hashing.hpp"
30
31namespace dnnl {
32namespace impl {
33
34struct sum_pd_t : public primitive_desc_t {
35 const sum_desc_t *desc() const { return &desc_; }
36 const op_desc_t *op_desc() const override {
37 return reinterpret_cast<const op_desc_t *>(this->desc());
38 }
39
40 arg_usage_t arg_usage(int arg) const override {
41 if (arg >= DNNL_ARG_MULTIPLE_SRC
42 && arg < DNNL_ARG_MULTIPLE_SRC + n_inputs())
43 return arg_usage_t::input;
44
45 if (arg == DNNL_ARG_DST) return arg_usage_t::output;
46
47 return primitive_desc_t::arg_usage(arg);
48 }
49
50 const memory_desc_t *arg_md(int arg) const override {
51 int src_index = arg - DNNL_ARG_MULTIPLE_SRC;
52 if (src_index >= 0 && src_index < n_inputs()) return src_md(src_index);
53 if (arg == DNNL_ARG_DST) return dst_md(0);
54 return primitive_desc_t::arg_md(arg);
55 }
56
57 const memory_desc_t *src_md(int index = 0) const override {
58 return index < n_inputs() ? &src_mds_[index] : &glob_zero_md;
59 }
60 const memory_desc_t *dst_md(int index = 0) const override {
61 return index == 0 ? &dst_md_ : &glob_zero_md;
62 }
63 const memory_desc_t *dst_acc_md() const {
64 return need_output_reorder() ? &dst_acc_md_ : &dst_md_;
65 }
66
67 int n_inputs() const override { return n_; }
68 int n_outputs() const override { return 1; }
69
70 const float *scales() const { return &scales_[0]; }
71
72 bool need_output_reorder() const { return dst_md()->data_type != dnnl_f32; }
73
74 bool has_zero_dim_memory() const {
75 return memory_desc_wrapper(dst_md()).has_zero_dim();
76 }
77
78protected:
79 int n_;
80 std::vector<float> scales_;
81 memory_desc_t dst_md_, dst_acc_md_;
82 std::vector<memory_desc_t> src_mds_;
83 memory_desc_t original_dst_md_;
84
85 sum_desc_t desc_;
86
87 sum_pd_t(const primitive_attr_t *attr, const memory_desc_t *dst_md, int n,
88 const float *scales, const memory_desc_t *const *src_mds)
89 : primitive_desc_t(attr, primitive_kind::sum)
90 , n_(n)
91 , dst_md_(*dst_md)
92 , original_dst_md_(*dst_md) {
93 scales_.reserve(n_);
94 for (int i = 0; i < n_; ++i)
95 scales_.push_back(scales[i]);
96 src_mds_.reserve(n_);
97 for (int i = 0; i < n_; ++i)
98 src_mds_.push_back(*src_mds[i]);
99
100 init_desc();
101 }
102
103 sum_pd_t(const sum_pd_t &other) : primitive_desc_t(other) {
104 n_ = other.n_;
105 scales_ = other.scales_;
106 dst_md_ = other.dst_md_;
107 dst_acc_md_ = other.dst_acc_md_;
108 src_mds_ = other.src_mds_;
109 original_dst_md_ = other.original_dst_md_;
110
111 init_desc();
112 }
113
114 // backends could redefine the accumulation tensor if required
115 virtual void define_dst_acc_md() {
116 dst_acc_md_ = dst_md_;
117 dst_acc_md_.data_type = dnnl_f32;
118 }
119 /* inits dst_md_ in simple cases. The call may fail. */
120 status_t init(engine_t *engine) {
121 for (int i = 0; i < n_; ++i) {
122 const memory_desc_wrapper src_d(&src_mds_[i]);
123 if (!src_d.is_blocking_desc() || src_d.is_additional_buffer())
124 return status::unimplemented;
125 }
126 bool ok = true && set_default_params() == status::success
127 && attr()->has_default_values();
128 if (!ok) return status::unimplemented;
129
130 // use f32 accumulator to handle float scales w/o accuracy loss
131 if (need_output_reorder()) { define_dst_acc_md(); }
132
133 return status::success;
134 }
135
136 status_t set_default_params() {
137 if (dst_md_.format_kind != format_kind::any) return status::success;
138
139 /* The stupidest ever heuristics (but not the same as we had before):
140 * - Pick the first non-plain format;
141 * - If all formats are plain, pick the format of the first input
142 */
143 for (int i = 0; i < n_; ++i) {
144 const memory_desc_wrapper src_d(src_mds_[i]);
145 if (!src_d.is_plain() && src_d.is_blocking_desc()) {
146 return memory_desc_init_by_blocking_desc(
147 dst_md_, src_d.blocking_desc());
148 }
149 }
150
151 if (src_mds_[0].format_kind != format_kind::blocked)
152 return status::unimplemented;
153
154 memory_desc_init_by_md_and_dt(dst_md_, src_mds_[0], dst_md_.data_type);
155
156 return status::success;
157 }
158
159private:
160 void init_desc() {
161 desc_ = sum_desc_t();
162 desc_.primitive_kind = primitive_kind::sum;
163 desc_.dst_md = &original_dst_md_;
164 desc_.n = n_;
165 desc_.scales = scales_.data();
166 for (const auto &md : src_mds_)
167 desc_.src_mds.push_back(&md);
168 }
169};
170
171#define DECLARE_SUM_PD_t(impl_name, ...) \
172 static status_t create(sum_pd_t **sum_pd, engine_t *engine, \
173 const primitive_attr_t *attr, const memory_desc_t *dst_md, int n, \
174 const float *scales, const memory_desc_t *const *src_mds) { \
175 using namespace status; \
176 auto _pd = new pd_t(attr, dst_md, n, scales, src_mds); \
177 if (_pd == nullptr) return out_of_memory; \
178 if (_pd->init(engine) != success) { \
179 delete _pd; \
180 return unimplemented; \
181 } \
182 _pd->init_scratchpad_md(); \
183 return safe_ptr_assign(*sum_pd, _pd); \
184 } \
185 status_t create_primitive( \
186 std::pair<std::shared_ptr<primitive_t>, bool> &primitive, \
187 engine_t *engine, const cache_blob_t &cache_blob) const override { \
188 return primitive_t::create_primitive_common<__VA_ARGS__, pd_t>( \
189 primitive, this, engine, false, cache_blob); \
190 } \
191 pd_t *clone() const override { \
192 auto new_pd = utils::make_unique<pd_t>(*this); \
193 if (!new_pd->is_initialized()) return nullptr; \
194 return new_pd.release(); \
195 } \
196 const char *name() const override { return impl_name; }
197
198#define DECLARE_SUM_PD_T(impl_name, ...) \
199 DECLARE_SUM_PD_t(impl_name, __VA_ARGS__)
200
201} // namespace impl
202} // namespace dnnl
203
204#endif
205