1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef UTILS_COMPARE_HPP
18#define UTILS_COMPARE_HPP
19
20#include <functional>
21
22#include "dnn_types.hpp"
23#include "dnnl_memory.hpp"
24
25namespace compare {
26
27bool compare_extreme_values(float a, float b);
28
29struct compare_t {
30 struct driver_check_func_args_t {
31 driver_check_func_args_t(const dnn_mem_t &exp_mem,
32 const dnn_mem_t &got_f32, const int64_t i,
33 const dnnl_data_type_t data_type, const float trh);
34
35 const dnnl_data_type_t dt = dnnl_data_type_undef;
36 const int64_t idx = 0;
37 const float exp_f32 = 0.f;
38 const float exp = 0.f;
39 const float got = 0.f;
40 const float diff = 0.f;
41 const float rel_diff = 0.f;
42 const float trh = 0.f;
43 };
44
45 compare_t() = default;
46
47 void set_norm_validation_mode(bool un) { use_norm_ = un; }
48 void set_threshold(float trh) { trh_ = trh; }
49 void set_zero_trust_percent(float ztp) { zero_trust_percent_ = ztp; }
50 void set_data_kind(data_kind_t dk) { kind_ = dk; }
51 void set_op_output_has_nans(bool ohn) { op_output_has_nans_ = ohn; }
52
53 // @param idx The index of compared element. Helps to obtain any element
54 // from any reference memory since it's in abx format.
55 // @param got The value of library memory for index `idx`. Can't be obtained
56 // by `idx` directly since could have different memory formats.
57 // @param diff The absolute difference between expected and got values.
58 // @returns true if checks pass and false otherwise.
59 using driver_check_func_t
60 = std::function<bool(const driver_check_func_args_t &)>;
61 void set_driver_check_function(const driver_check_func_t &dcf) {
62 driver_check_func_ = dcf;
63 }
64
65 int compare(const dnn_mem_t &exp_mem, const dnn_mem_t &got_mem,
66 const attr_t &attr, res_t *res) const;
67
68private:
69 // Switch between point-to-point and norm comparison.
70 bool use_norm_ = false;
71 // Threshold for a point-to-point comparison.
72 float trh_ = 0.f;
73 // The default percent value of zeros allowed in the output.
74 float zero_trust_percent_ = 30.f;
75 // Kind specifies what tensor is checked. Not printed if default one.
76 data_kind_t kind_ = DAT_TOTAL;
77 // Driver-specific function that adds additional criteria for a test case to
78 // pass.
79 driver_check_func_t driver_check_func_;
80 // Some operators may legally return NaNs. This member allows to work around
81 // issues involving comparison with NaNs in the op output if additional
82 // post-ops are involved.
83 bool op_output_has_nans_ = false;
84
85 // Internal validation methods under `compare` interface.
86 int compare_p2p(const dnn_mem_t &exp_mem, const dnn_mem_t &got_mem,
87 const attr_t &attr, res_t *res) const;
88 int compare_norm(const dnn_mem_t &exp_mem, const dnn_mem_t &got_mem,
89 const attr_t &attr, res_t *res) const;
90};
91
92} // namespace compare
93
94#endif
95