1/*******************************************************************************
2* Copyright 2017-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_NORM_HPP
18#define UTILS_NORM_HPP
19
20#include <limits>
21
22#include "common.hpp"
23
24struct norm_t {
25 /* strictly speaking L0 is not a norm... it stands for the biggest
26 * absolute element-wise difference and is used in diff_norm_t only */
27 enum { L0, L1, L2, LINF, L8 = LINF, L_LAST };
28
29 norm_t() : num_(0) {
30 for (int i = 0; i < L_LAST; ++i)
31 norm_[i] = 0;
32 }
33
34 void update(float v) {
35 norm_[L1] += ABS(v);
36 norm_[L2] += v * v;
37 norm_[L8] = MAX2(norm_[L8], ABS(v));
38 num_++;
39 }
40
41 void done() { norm_[L2] = sqrt(norm_[L2]); }
42
43 float operator[](int type) const { return norm_[type]; }
44
45 double norm_[L_LAST];
46 size_t num_;
47};
48
49struct diff_norm_t {
50 void update(float a, float b) {
51 float diff = a - b;
52 a_.update(a);
53 b_.update(b);
54 diff_.update(diff);
55 diff_.norm_[norm_t::L0] = MAX2(diff_.norm_[norm_t::L0],
56 ABS(diff) / (ABS(a) > FLT_MIN ? ABS(a) : 1.));
57 }
58 void done() {
59 a_.done();
60 b_.done();
61 diff_.done();
62 }
63
64 float rel_diff(int type) const {
65 if (type == norm_t::L0) return diff_.norm_[type];
66 if (a_.norm_[type] == 0)
67 return diff_.norm_[type] == 0
68 ? 0
69 : std::numeric_limits<float>::infinity();
70 assert(a_.norm_[type] != 0);
71 return diff_.norm_[type] / a_.norm_[type];
72 }
73
74 norm_t a_, b_, diff_;
75};
76
77#endif
78