1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_REDUCTION_PD_HPP
18#define COMMON_REDUCTION_PD_HPP
19
20#include "c_types_map.hpp"
21#include "primitive_desc.hpp"
22#include "utils.hpp"
23
24namespace dnnl {
25namespace impl {
26
27status_t reduction_desc_init(reduction_desc_t *reduction_desc,
28 alg_kind_t alg_kind, const memory_desc_t *src_desc,
29 const memory_desc_t *dst_desc, float p, float eps);
30
31struct reduction_pd_t : public primitive_desc_t {
32 static constexpr auto base_pkind = primitive_kind::reduction;
33
34 typedef reduction_pd_t hint_class;
35
36 const reduction_desc_t *desc() const { return &desc_; }
37 const op_desc_t *op_desc() const override {
38 return reinterpret_cast<const op_desc_t *>(this->desc());
39 }
40
41 status_t query(query_t what, int idx, void *result) const override {
42 switch (what) {
43 case query::alg_kind:
44 *(alg_kind_t *)result = desc()->alg_kind;
45 break;
46 case query::p_f32: *(float *)result = desc()->p; break;
47 case query::epsilon_f32: *(float *)result = desc()->eps; break;
48 default: return primitive_desc_t::query(what, idx, result);
49 }
50 return status::success;
51 }
52
53 arg_usage_t arg_usage(int arg) const override {
54 switch (arg) {
55 case DNNL_ARG_SRC: return arg_usage_t::input; break;
56 case DNNL_ARG_DST: return arg_usage_t::output; break;
57 default: return primitive_desc_t::arg_usage(arg);
58 }
59 }
60
61 const memory_desc_t *arg_md(int arg) const override {
62 switch (arg) {
63 case DNNL_ARG_SRC: return src_md(0); break;
64 case DNNL_ARG_DST: return dst_md(0); break;
65 default: return primitive_desc_t::arg_md(arg);
66 }
67 }
68
69 const memory_desc_t *src_md(int index = 0) const override {
70 return index == 0 ? &src_md_ : &glob_zero_md;
71 }
72 const memory_desc_t *dst_md(int index = 0) const override {
73 return index == 0 ? &dst_md_ : &glob_zero_md;
74 }
75
76 int n_inputs() const override { return 1 + n_binary_po_inputs(); }
77 int n_outputs() const override { return 1; }
78
79 static void memory_desc_reduce_dim(memory_desc_t &md, int dim) {
80 if (md.format_kind != format_kind::blocked) return;
81
82 // Update reduced dim
83 md.dims[dim] = 1;
84
85 dims_t blocks = {0};
86 memory_desc_wrapper(md).compute_blocks(blocks);
87
88 // Reduced dim should be padded in case of inner blocks to preserve
89 // layout
90 md.padded_dims[dim] = blocks[dim];
91
92 // Update strides of dimensions which depend on reduced dim
93 int perm[DNNL_MAX_NDIMS];
94 for (int i = 0; i < md.ndims; ++i)
95 perm[i] = i;
96
97 auto &blk_d = md.format_desc.blocking;
98
99 dims_t strides;
100 utils::array_copy(strides, blk_d.strides, md.ndims);
101
102 // compute ou_dims. It is required to get correct perm
103 dims_t ou_dims;
104 for (int i = 0; i < md.ndims; ++i)
105 ou_dims[i] = md.padded_dims[i] / blocks[i];
106
107 utils::simultaneous_sort(strides, ou_dims, perm, md.ndims,
108 [](stride_t a, stride_t b) { return a - b; });
109
110 auto stride = md.padded_dims[dim] / blocks[dim] * blk_d.strides[dim];
111 for (int _d = 0; _d < md.ndims; ++_d) {
112 const auto d = perm[_d];
113 if (strides[_d] > blk_d.strides[dim]) {
114 blk_d.strides[d] = stride;
115 stride *= md.padded_dims[d] / blocks[d];
116 }
117 }
118 }
119
120protected:
121 reduction_desc_t desc_;
122
123 memory_desc_t src_md_;
124 memory_desc_t dst_md_;
125
126 reduction_pd_t(const reduction_desc_t *adesc, const primitive_attr_t *attr,
127 const hint_class *hint_fwd)
128 : primitive_desc_t(attr, base_pkind)
129 , desc_(*adesc)
130 , src_md_(desc_.src_desc)
131 , dst_md_(desc_.dst_desc) {}
132
133 status_t set_default_params() {
134 if (dst_md_.format_kind != format_kind::any) return status::success;
135
136 return set_dst_format();
137 }
138
139 status_t set_dst_format() {
140 memory_desc_t new_dst_md = src_md_;
141 new_dst_md.data_type = dst_md_.data_type;
142 for (int d = 0; d < src_md_.ndims; d++)
143 if (src_md_.dims[d] != dst_md_.dims[d])
144 memory_desc_reduce_dim(new_dst_md, d);
145 dst_md_ = new_dst_md;
146
147 return status::success;
148 }
149};
150
151} // namespace impl
152} // namespace dnnl
153
154#endif
155