1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "oneapi/dnnl/dnnl.h" |
18 | #include "opdesc.hpp" |
19 | #include "primitive_desc_iface.hpp" |
20 | |
21 | #include "c_types_map.hpp" |
22 | #include "utils.hpp" |
23 | |
24 | using namespace dnnl::impl; |
25 | using namespace dnnl::impl::status; |
26 | using namespace dnnl::impl::utils; |
27 | using namespace dnnl::impl::alg_kind; |
28 | |
29 | namespace dnnl { |
30 | namespace impl { |
31 | |
32 | status_t reduction_desc_init(reduction_desc_t *reduction_desc, |
33 | alg_kind_t alg_kind, const memory_desc_t *src_desc, |
34 | const memory_desc_t *dst_desc, float p, float eps) { |
35 | |
36 | bool args_ok = !any_null(src_desc, dst_desc) |
37 | && src_desc->format_kind != format_kind::any |
38 | && one_of(alg_kind, reduction_max, reduction_min, reduction_sum, |
39 | reduction_mul, reduction_mean, reduction_norm_lp_max, |
40 | reduction_norm_lp_sum, reduction_norm_lp_power_p_max, |
41 | reduction_norm_lp_power_p_sum) |
42 | && IMPLICATION(one_of(alg_kind, reduction_norm_lp_max, |
43 | reduction_norm_lp_sum, |
44 | reduction_norm_lp_power_p_max, |
45 | reduction_norm_lp_power_p_sum), |
46 | p >= 1.0f) |
47 | && IMPLICATION(one_of(alg_kind, reduction_norm_lp_max, |
48 | reduction_norm_lp_sum, |
49 | reduction_norm_lp_power_p_max, |
50 | reduction_norm_lp_power_p_sum), |
51 | one_of(src_desc->data_type, data_type::f32, data_type::bf16, |
52 | data_type::f16)); |
53 | if (!args_ok) return invalid_arguments; |
54 | |
55 | if (src_desc->ndims != dst_desc->ndims) return invalid_arguments; |
56 | |
57 | for (auto d = 0; d < src_desc->ndims; ++d) { |
58 | const auto dst_dim_d = dst_desc->dims[d]; |
59 | if (!one_of(dst_dim_d, 1, src_desc->dims[d])) return invalid_arguments; |
60 | } |
61 | |
62 | // reduction primitive doesn't support identity operation |
63 | if (array_cmp(src_desc->dims, dst_desc->dims, src_desc->ndims)) |
64 | return invalid_arguments; |
65 | |
66 | if (src_desc->format_kind != format_kind::blocked |
67 | || !one_of(dst_desc->format_kind, format_kind::blocked, |
68 | format_kind::any)) |
69 | return invalid_arguments; |
70 | |
71 | if (src_desc->extra.flags != 0 |
72 | || !IMPLICATION(dst_desc->format_kind == format_kind::blocked, |
73 | dst_desc->extra.flags == 0)) |
74 | return invalid_arguments; |
75 | |
76 | auto rd = reduction_desc_t(); |
77 | rd.primitive_kind = primitive_kind::reduction; |
78 | rd.alg_kind = alg_kind; |
79 | |
80 | rd.src_desc = *src_desc; |
81 | rd.dst_desc = *dst_desc; |
82 | rd.p = p; |
83 | rd.eps = eps; |
84 | |
85 | (*reduction_desc) = rd; |
86 | return success; |
87 | } |
88 | |
89 | } // namespace impl |
90 | } // namespace dnnl |
91 | |
92 | dnnl_status_t dnnl_reduction_primitive_desc_create( |
93 | primitive_desc_iface_t **primitive_desc_iface, engine_t *engine, |
94 | alg_kind_t alg_kind, const memory_desc_t *src_desc, |
95 | const memory_desc_t *dst_desc, float p, float eps, |
96 | const primitive_attr_t *attr) { |
97 | |
98 | auto reduction_desc = reduction_desc_t(); |
99 | CHECK(reduction_desc_init( |
100 | &reduction_desc, alg_kind, src_desc, dst_desc, p, eps)); |
101 | return primitive_desc_create(primitive_desc_iface, engine, |
102 | (const op_desc_t *)&reduction_desc, nullptr, attr); |
103 | } |
104 | |