1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef UTILS_COMPARE_HPP |
18 | #define UTILS_COMPARE_HPP |
19 | |
20 | #include <functional> |
21 | |
22 | #include "dnn_types.hpp" |
23 | #include "dnnl_memory.hpp" |
24 | |
25 | namespace compare { |
26 | |
27 | bool compare_extreme_values(float a, float b); |
28 | |
29 | struct compare_t { |
30 | struct driver_check_func_args_t { |
31 | driver_check_func_args_t(const dnn_mem_t &exp_mem, |
32 | const dnn_mem_t &got_f32, const int64_t i, |
33 | const dnnl_data_type_t data_type, const float trh); |
34 | |
35 | const dnnl_data_type_t dt = dnnl_data_type_undef; |
36 | const int64_t idx = 0; |
37 | const float exp_f32 = 0.f; |
38 | const float exp = 0.f; |
39 | const float got = 0.f; |
40 | const float diff = 0.f; |
41 | const float rel_diff = 0.f; |
42 | const float trh = 0.f; |
43 | }; |
44 | |
45 | compare_t() = default; |
46 | |
47 | void set_norm_validation_mode(bool un) { use_norm_ = un; } |
48 | void set_threshold(float trh) { trh_ = trh; } |
49 | void set_zero_trust_percent(float ztp) { zero_trust_percent_ = ztp; } |
50 | void set_data_kind(data_kind_t dk) { kind_ = dk; } |
51 | void set_op_output_has_nans(bool ohn) { op_output_has_nans_ = ohn; } |
52 | |
53 | // @param idx The index of compared element. Helps to obtain any element |
54 | // from any reference memory since it's in abx format. |
55 | // @param got The value of library memory for index `idx`. Can't be obtained |
56 | // by `idx` directly since could have different memory formats. |
57 | // @param diff The absolute difference between expected and got values. |
58 | // @returns true if checks pass and false otherwise. |
59 | using driver_check_func_t |
60 | = std::function<bool(const driver_check_func_args_t &)>; |
61 | void set_driver_check_function(const driver_check_func_t &dcf) { |
62 | driver_check_func_ = dcf; |
63 | } |
64 | |
65 | int compare(const dnn_mem_t &exp_mem, const dnn_mem_t &got_mem, |
66 | const attr_t &attr, res_t *res) const; |
67 | |
68 | private: |
69 | // Switch between point-to-point and norm comparison. |
70 | bool use_norm_ = false; |
71 | // Threshold for a point-to-point comparison. |
72 | float trh_ = 0.f; |
73 | // The default percent value of zeros allowed in the output. |
74 | float zero_trust_percent_ = 30.f; |
75 | // Kind specifies what tensor is checked. Not printed if default one. |
76 | data_kind_t kind_ = DAT_TOTAL; |
77 | // Driver-specific function that adds additional criteria for a test case to |
78 | // pass. |
79 | driver_check_func_t driver_check_func_; |
80 | // Some operators may legally return NaNs. This member allows to work around |
81 | // issues involving comparison with NaNs in the op output if additional |
82 | // post-ops are involved. |
83 | bool op_output_has_nans_ = false; |
84 | |
85 | // Internal validation methods under `compare` interface. |
86 | int compare_p2p(const dnn_mem_t &exp_mem, const dnn_mem_t &got_mem, |
87 | const attr_t &attr, res_t *res) const; |
88 | int compare_norm(const dnn_mem_t &exp_mem, const dnn_mem_t &got_mem, |
89 | const attr_t &attr, res_t *res) const; |
90 | }; |
91 | |
92 | } // namespace compare |
93 | |
94 | #endif |
95 | |