1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "tests/gtests/dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "dnnl.hpp" |
21 | |
22 | #include "common/primitive_attr.hpp" |
23 | #include "common/type_helpers.hpp" |
24 | |
25 | namespace dnnl { |
26 | |
27 | namespace { |
28 | bool compare(const dnnl::primitive_attr &lhs, const dnnl::primitive_attr &rhs) { |
29 | return *lhs.get() == *rhs.get(); |
30 | } |
31 | |
32 | bool self_compare(const dnnl::primitive_attr &attr) { |
33 | return *attr.get() == *attr.get(); |
34 | } |
35 | |
36 | template <typename T> |
37 | bool self_compare(const T &desc) { |
38 | return dnnl::impl::operator==(desc, desc); |
39 | } |
40 | |
41 | } // namespace |
42 | |
43 | #define TEST_SELF_COMPARISON(v) ASSERT_EQ(true, self_compare(v)) |
44 | |
45 | class comparison_operators_t : public ::testing::Test {}; |
46 | |
47 | TEST(comparison_operators_t, TestAttrScales) { |
48 | dnnl::primitive_attr default_attr; |
49 | |
50 | const std::vector<int> supported_args |
51 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
52 | for (auto arg : supported_args) { |
53 | dnnl::primitive_attr attr; |
54 | attr.set_scales_mask(arg, 0); |
55 | ASSERT_EQ(compare(default_attr, attr), false); |
56 | } |
57 | } |
58 | |
59 | TEST(comparison_operators_t, TestAttrZeroPoints) { |
60 | dnnl::primitive_attr default_attr; |
61 | |
62 | const std::vector<int> supported_args |
63 | = {DNNL_ARG_SRC, DNNL_ARG_WEIGHTS, DNNL_ARG_DST}; |
64 | for (auto arg : supported_args) { |
65 | dnnl::primitive_attr attr; |
66 | attr.set_zero_points_mask(arg, 0); |
67 | ASSERT_EQ(compare(default_attr, attr), false); |
68 | } |
69 | } |
70 | |
71 | TEST(comparison_operators_t, TestAttrDataQparams) { |
72 | dnnl::primitive_attr attr; |
73 | |
74 | attr.set_rnn_data_qparams(1.5f, NAN); |
75 | TEST_SELF_COMPARISON(attr); |
76 | } |
77 | |
78 | HANDLE_EXCEPTIONS_FOR_TEST(comparison_operators_t, TestAttrWeightsQparams) { |
79 | dnnl::primitive_attr attr; |
80 | |
81 | attr.set_rnn_weights_qparams(0, {NAN}); |
82 | TEST_SELF_COMPARISON(attr); |
83 | |
84 | attr.set_rnn_weights_qparams(1 << 1, {1.5f, NAN, 3.5f}); |
85 | TEST_SELF_COMPARISON(attr); |
86 | } |
87 | |
88 | HANDLE_EXCEPTIONS_FOR_TEST( |
89 | comparison_operators_t, TestAttrWeightsProjectionQparams) { |
90 | dnnl::primitive_attr attr; |
91 | |
92 | attr.set_rnn_weights_projection_qparams(0, {NAN}); |
93 | TEST_SELF_COMPARISON(attr); |
94 | |
95 | attr.set_rnn_weights_projection_qparams(1 << 1, {1.5f, NAN, 3.5f}); |
96 | TEST_SELF_COMPARISON(attr); |
97 | } |
98 | |
99 | TEST(comparison_operators_t, TestSumPostOp) { |
100 | dnnl::primitive_attr attr; |
101 | dnnl::post_ops ops; |
102 | |
103 | ops.append_sum(NAN); |
104 | attr.set_post_ops(ops); |
105 | TEST_SELF_COMPARISON(attr); |
106 | } |
107 | |
108 | HANDLE_EXCEPTIONS_FOR_TEST(comparison_operators_t, TestDepthwisePostOp) { |
109 | dnnl::primitive_attr attr; |
110 | dnnl::post_ops ops; |
111 | |
112 | ops.append_dw(memory::data_type::s8, memory::data_type::f32, |
113 | memory::data_type::u8, 3, 1, 1); |
114 | attr.set_post_ops(ops); |
115 | TEST_SELF_COMPARISON(attr); |
116 | } |
117 | |
118 | TEST(comparison_operators_t, TestBatchNormDesc) { |
119 | auto bnorm_desc = dnnl::impl::batch_normalization_desc_t(); |
120 | bnorm_desc.batch_norm_epsilon = NAN; |
121 | TEST_SELF_COMPARISON(bnorm_desc); |
122 | } |
123 | |
124 | TEST(comparison_operators_t, TestEltwiseDesc) { |
125 | auto eltwise_desc = dnnl::impl::eltwise_desc_t(); |
126 | eltwise_desc.alpha = NAN; |
127 | TEST_SELF_COMPARISON(eltwise_desc); |
128 | } |
129 | |
130 | TEST(comparison_operators_t, TestLayerNormDesc) { |
131 | auto lnorm_desc = dnnl::impl::layer_normalization_desc_t(); |
132 | lnorm_desc.layer_norm_epsilon = NAN; |
133 | TEST_SELF_COMPARISON(lnorm_desc); |
134 | } |
135 | |
136 | TEST(comparison_operators_t, TestLRNDesc) { |
137 | auto lrn_desc = dnnl::impl::lrn_desc_t(); |
138 | lrn_desc.lrn_alpha = NAN; |
139 | TEST_SELF_COMPARISON(lrn_desc); |
140 | } |
141 | |
142 | TEST(comparison_operators_t, TestReductionDesc) { |
143 | auto reduction_desc = dnnl::impl::reduction_desc_t(); |
144 | reduction_desc.p = NAN; |
145 | TEST_SELF_COMPARISON(reduction_desc); |
146 | } |
147 | |
148 | TEST(comparison_operators_t, TestResamplingDesc) { |
149 | auto resampling_desc = dnnl::impl::resampling_desc_t(); |
150 | resampling_desc.factors[0] = NAN; |
151 | TEST_SELF_COMPARISON(resampling_desc); |
152 | } |
153 | |
154 | TEST(comparison_operators_t, TestRNNDesc) { |
155 | auto rnn_desc = dnnl::impl::rnn_desc_t(); |
156 | rnn_desc.alpha = NAN; |
157 | TEST_SELF_COMPARISON(rnn_desc); |
158 | } |
159 | |
160 | TEST(comparison_operators_t, TestSumDesc) { |
161 | float scales[2] = {NAN, 2.5f}; |
162 | dnnl::impl::memory_desc_t md {}; |
163 | dnnl_memory_desc_t mds[2] = {&md, &md}; |
164 | |
165 | dnnl::impl::sum_desc_t sum_desc( |
166 | dnnl::impl::primitive_kind::sum, &md, 2, scales, mds); |
167 | TEST_SELF_COMPARISON(sum_desc); |
168 | } |
169 | |
170 | } // namespace dnnl |
171 | |