1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "tests/gtests/dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "dnnl.hpp"
21
22#include "common/primitive_attr.hpp"
23#include "common/type_helpers.hpp"
24
25namespace dnnl {
26
27namespace {
28bool compare(const dnnl::primitive_attr &lhs, const dnnl::primitive_attr &rhs) {
29 return *lhs.get() == *rhs.get();
30}
31
32bool self_compare(const dnnl::primitive_attr &attr) {
33 return *attr.get() == *attr.get();
34}
35
36template <typename T>
37bool self_compare(const T &desc) {
38 return dnnl::impl::operator==(desc, desc);
39}
40
41} // namespace
42
43#define TEST_SELF_COMPARISON(v) ASSERT_EQ(true, self_compare(v))
44
45class comparison_operators_t : public ::testing::Test {};
46
47TEST(comparison_operators_t, TestAttrScales) {
48 dnnl::primitive_attr default_attr;
49
50 const std::vector<int> supported_args
51 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
52 for (auto arg : supported_args) {
53 dnnl::primitive_attr attr;
54 attr.set_scales_mask(arg, 0);
55 ASSERT_EQ(compare(default_attr, attr), false);
56 }
57}
58
59TEST(comparison_operators_t, TestAttrZeroPoints) {
60 dnnl::primitive_attr default_attr;
61
62 const std::vector<int> supported_args
63 = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST};
64 for (auto arg : supported_args) {
65 dnnl::primitive_attr attr;
66 attr.set_zero_points_mask(arg, 0);
67 ASSERT_EQ(compare(default_attr, attr), false);
68 }
69}
70
71TEST(comparison_operators_t, TestAttrDataQparams) {
72 dnnl::primitive_attr attr;
73
74 attr.set_rnn_data_qparams(1.5f, NAN);
75 TEST_SELF_COMPARISON(attr);
76}
77
78HANDLE_EXCEPTIONS_FOR_TEST(comparison_operators_t, TestAttrWeightsQparams) {
79 dnnl::primitive_attr attr;
80
81 attr.set_rnn_weights_qparams(0, {NAN});
82 TEST_SELF_COMPARISON(attr);
83
84 attr.set_rnn_weights_qparams(1 << 1, {1.5f, NAN, 3.5f});
85 TEST_SELF_COMPARISON(attr);
86}
87
88HANDLE_EXCEPTIONS_FOR_TEST(
89 comparison_operators_t, TestAttrWeightsProjectionQparams) {
90 dnnl::primitive_attr attr;
91
92 attr.set_rnn_weights_projection_qparams(0, {NAN});
93 TEST_SELF_COMPARISON(attr);
94
95 attr.set_rnn_weights_projection_qparams(1 << 1, {1.5f, NAN, 3.5f});
96 TEST_SELF_COMPARISON(attr);
97}
98
99TEST(comparison_operators_t, TestSumPostOp) {
100 dnnl::primitive_attr attr;
101 dnnl::post_ops ops;
102
103 ops.append_sum(NAN);
104 attr.set_post_ops(ops);
105 TEST_SELF_COMPARISON(attr);
106}
107
108HANDLE_EXCEPTIONS_FOR_TEST(comparison_operators_t, TestDepthwisePostOp) {
109 dnnl::primitive_attr attr;
110 dnnl::post_ops ops;
111
112 ops.append_dw(memory::data_type::s8, memory::data_type::f32,
113 memory::data_type::u8, 3, 1, 1);
114 attr.set_post_ops(ops);
115 TEST_SELF_COMPARISON(attr);
116}
117
118TEST(comparison_operators_t, TestBatchNormDesc) {
119 auto bnorm_desc = dnnl::impl::batch_normalization_desc_t();
120 bnorm_desc.batch_norm_epsilon = NAN;
121 TEST_SELF_COMPARISON(bnorm_desc);
122}
123
124TEST(comparison_operators_t, TestEltwiseDesc) {
125 auto eltwise_desc = dnnl::impl::eltwise_desc_t();
126 eltwise_desc.alpha = NAN;
127 TEST_SELF_COMPARISON(eltwise_desc);
128}
129
130TEST(comparison_operators_t, TestLayerNormDesc) {
131 auto lnorm_desc = dnnl::impl::layer_normalization_desc_t();
132 lnorm_desc.layer_norm_epsilon = NAN;
133 TEST_SELF_COMPARISON(lnorm_desc);
134}
135
136TEST(comparison_operators_t, TestLRNDesc) {
137 auto lrn_desc = dnnl::impl::lrn_desc_t();
138 lrn_desc.lrn_alpha = NAN;
139 TEST_SELF_COMPARISON(lrn_desc);
140}
141
142TEST(comparison_operators_t, TestReductionDesc) {
143 auto reduction_desc = dnnl::impl::reduction_desc_t();
144 reduction_desc.p = NAN;
145 TEST_SELF_COMPARISON(reduction_desc);
146}
147
148TEST(comparison_operators_t, TestResamplingDesc) {
149 auto resampling_desc = dnnl::impl::resampling_desc_t();
150 resampling_desc.factors[0] = NAN;
151 TEST_SELF_COMPARISON(resampling_desc);
152}
153
154TEST(comparison_operators_t, TestRNNDesc) {
155 auto rnn_desc = dnnl::impl::rnn_desc_t();
156 rnn_desc.alpha = NAN;
157 TEST_SELF_COMPARISON(rnn_desc);
158}
159
160TEST(comparison_operators_t, TestSumDesc) {
161 float scales[2] = {NAN, 2.5f};
162 dnnl::impl::memory_desc_t md {};
163 dnnl_memory_desc_t mds[2] = {&md, &md};
164
165 dnnl::impl::sum_desc_t sum_desc(
166 dnnl::impl::primitive_kind::sum, &md, 2, scales, mds);
167 TEST_SELF_COMPARISON(sum_desc);
168}
169
170} // namespace dnnl
171