1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include <limits>
18#include <math.h>
19
20#include "utils/parallel.hpp"
21
22#include "common.hpp"
23#include "dnnl_memory.hpp"
24
25#include "reduction.hpp"
26
27namespace reduction {
28
29void init_acc(float &acc, alg_t alg) {
30 switch (alg) {
31 case max: acc = std::numeric_limits<float>::lowest(); break;
32 case min: acc = std::numeric_limits<float>::max(); break;
33 case sum: acc = 0.0f; break;
34 case mul: acc = 1.0f; break;
35 case mean:
36 case norm_lp_max:
37 case norm_lp_sum:
38 case norm_lp_power_p_max:
39 case norm_lp_power_p_sum: acc = 0.0f; break;
40 default: assert(!"unknown algorithm");
41 }
42}
43
44void accumulate(float &dst, const float src, alg_t alg, float p, float eps) {
45 switch (alg) {
46 case max: dst = MAX2(dst, src); break;
47 case min: dst = MIN2(dst, src); break;
48 case mean:
49 case sum: dst += src; break;
50 case mul: dst *= src; break;
51 case norm_lp_max:
52 case norm_lp_sum:
53 case norm_lp_power_p_max:
54 case norm_lp_power_p_sum: dst += pow(fabs(src), p); break;
55 default: assert(!"unknown algorithm");
56 }
57}
58
59void finalize(float &dst, alg_t alg, float p, float eps, dnnl_dim_t n) {
60 switch (alg) {
61 case mean: dst /= n; break;
62 case norm_lp_max:
63 dst = MAX2(dst, eps);
64 dst = pow(dst, 1.0f / p);
65 break;
66 case norm_lp_sum:
67 dst += eps;
68 dst = pow(dst, 1.0f / p);
69 break;
70 case norm_lp_power_p_max: dst = MAX2(dst, eps); break;
71 case norm_lp_power_p_sum: dst += eps; break;
72 default: break;
73 }
74}
75
76void compute_ref(
77 const prb_t *prb, const args_t &args, dnnl_primitive_t prim_ref) {
78 const dnn_mem_t &src = args.find(DNNL_ARG_SRC);
79 const dnn_mem_t &dst = args.find(DNNL_ARG_DST);
80
81 float *dst_ptr = (float *)dst;
82
83 const auto &ndims = prb->ndims;
84 const auto &src_dims = prb->vdims[0];
85 const auto &dst_dims = prb->vdims[1];
86
87 const auto alg = prb->alg;
88 const auto p = prb->p;
89 const auto eps = prb->eps;
90
91 dims_t reduce_dims(ndims, 1);
92 int64_t reduce_size {1}, idle_size {1};
93
94 for (int d = 0; d < ndims; ++d) {
95 const bool is_reduction_dim = src_dims[d] != dst_dims[d];
96 if (is_reduction_dim) {
97 reduce_dims[d] = src_dims[d];
98 reduce_size *= reduce_dims[d];
99 } else {
100 idle_size *= dst_dims[d];
101 }
102 }
103
104 if (reduce_size == 1) return;
105
106 auto v_po_masks = prb->attr.post_ops.get_po_masks();
107 benchdnn_parallel_nd(idle_size, [&](int64_t f) {
108 dims_t idle_pos = off2dims_idx(dst_dims, f);
109 const int64_t dst_off = md_off_v(dst, idle_pos.data());
110 const int64_t src_idle_off = md_off_v(src, idle_pos.data());
111 float acc {0.0f};
112 init_acc(acc, alg);
113 for (int64_t r = 0; r < reduce_size; ++r) {
114 dims_t reduce_pos = off2dims_idx(reduce_dims, r);
115 const int64_t src_reduce_off = md_off_v(src, reduce_pos.data());
116 const int64_t src_off = src_idle_off + src_reduce_off;
117 accumulate(acc, src.get_elem(src_off), alg, p, eps);
118 }
119 finalize(acc, alg, p, eps, reduce_size);
120
121 const auto v_po_vals = prepare_po_vals(dst, args, v_po_masks, dst_off);
122
123 maybe_post_ops(prb->attr, acc, dst_ptr[dst_off], v_po_vals);
124 dst_ptr[dst_off] = acc;
125 });
126}
127
128} // namespace reduction
129