1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "oneapi/dnnl/dnnl.h"
18#include "opdesc.hpp"
19#include "primitive_desc_iface.hpp"
20
21#include "c_types_map.hpp"
22#include "utils.hpp"
23
24using namespace dnnl::impl;
25using namespace dnnl::impl::status;
26using namespace dnnl::impl::utils;
27using namespace dnnl::impl::alg_kind;
28
29namespace dnnl {
30namespace impl {
31
32status_t reduction_desc_init(reduction_desc_t *reduction_desc,
33 alg_kind_t alg_kind, const memory_desc_t *src_desc,
34 const memory_desc_t *dst_desc, float p, float eps) {
35
36 bool args_ok = !any_null(src_desc, dst_desc)
37 && src_desc->format_kind != format_kind::any
38 && one_of(alg_kind, reduction_max, reduction_min, reduction_sum,
39 reduction_mul, reduction_mean, reduction_norm_lp_max,
40 reduction_norm_lp_sum, reduction_norm_lp_power_p_max,
41 reduction_norm_lp_power_p_sum)
42 && IMPLICATION(one_of(alg_kind, reduction_norm_lp_max,
43 reduction_norm_lp_sum,
44 reduction_norm_lp_power_p_max,
45 reduction_norm_lp_power_p_sum),
46 p >= 1.0f)
47 && IMPLICATION(one_of(alg_kind, reduction_norm_lp_max,
48 reduction_norm_lp_sum,
49 reduction_norm_lp_power_p_max,
50 reduction_norm_lp_power_p_sum),
51 one_of(src_desc->data_type, data_type::f32, data_type::bf16,
52 data_type::f16));
53 if (!args_ok) return invalid_arguments;
54
55 if (src_desc->ndims != dst_desc->ndims) return invalid_arguments;
56
57 for (auto d = 0; d < src_desc->ndims; ++d) {
58 const auto dst_dim_d = dst_desc->dims[d];
59 if (!one_of(dst_dim_d, 1, src_desc->dims[d])) return invalid_arguments;
60 }
61
62 // reduction primitive doesn't support identity operation
63 if (array_cmp(src_desc->dims, dst_desc->dims, src_desc->ndims))
64 return invalid_arguments;
65
66 if (src_desc->format_kind != format_kind::blocked
67 || !one_of(dst_desc->format_kind, format_kind::blocked,
68 format_kind::any))
69 return invalid_arguments;
70
71 if (src_desc->extra.flags != 0
72 || !IMPLICATION(dst_desc->format_kind == format_kind::blocked,
73 dst_desc->extra.flags == 0))
74 return invalid_arguments;
75
76 auto rd = reduction_desc_t();
77 rd.primitive_kind = primitive_kind::reduction;
78 rd.alg_kind = alg_kind;
79
80 rd.src_desc = *src_desc;
81 rd.dst_desc = *dst_desc;
82 rd.p = p;
83 rd.eps = eps;
84
85 (*reduction_desc) = rd;
86 return success;
87}
88
89} // namespace impl
90} // namespace dnnl
91
92dnnl_status_t dnnl_reduction_primitive_desc_create(
93 primitive_desc_iface_t **primitive_desc_iface, engine_t *engine,
94 alg_kind_t alg_kind, const memory_desc_t *src_desc,
95 const memory_desc_t *dst_desc, float p, float eps,
96 const primitive_attr_t *attr) {
97
98 auto reduction_desc = reduction_desc_t();
99 CHECK(reduction_desc_init(
100 &reduction_desc, alg_kind, src_desc, dst_desc, p, eps));
101 return primitive_desc_create(primitive_desc_iface, engine,
102 (const op_desc_t *)&reduction_desc, nullptr, attr);
103}
104