1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #include <string> |
23 | #include <vector> |
24 | |
25 | namespace dnnl { |
26 | |
27 | // short names for brevity |
28 | using data_type = memory::data_type; |
29 | using tag = memory::format_tag; |
30 | |
31 | class weights_format_test_t : public ::testing::Test { |
32 | protected: |
33 | const engine eng = get_test_engine(); |
34 | |
35 | struct inner_product_shape_t { |
36 | public: |
37 | inner_product_shape_t(memory::dim mb, memory::dim ic, memory::dim oc, |
38 | memory::dim kw = 1, memory::dim kh = 1, memory::dim kd = 1) |
39 | : mb_(mb), ic_(ic), oc_(oc), kw_(kw), kh_(kh), kd_(kd) {} |
40 | |
41 | memory::dims src_dims() const { return maybe_spatial(mb_, ic_); } |
42 | memory::dims wei_dims() const { return maybe_spatial(oc_, ic_); } |
43 | |
44 | memory::dims dst_dims() const { return {mb_, oc_}; } |
45 | memory::dims bia_dims() const { return {oc_}; } |
46 | |
47 | private: |
48 | memory::dim mb_, ic_, oc_, kw_, kh_, kd_; |
49 | bool is_1d() const { return kd_ == 1 && kh_ == 1 && kw_ != 1; } |
50 | bool is_2d() const { return kd_ == 1 && kh_ != 1; } |
51 | bool is_3d() const { return kd_ != 1; } |
52 | |
53 | memory::dims maybe_spatial( |
54 | memory::dim param1, memory::dim param2) const { |
55 | if (is_3d()) |
56 | return {param1, param2, kd_, kh_, kw_}; |
57 | else if (is_2d()) |
58 | return {param1, param2, kh_, kw_}; |
59 | else if (is_1d()) |
60 | return {param1, param2, kw_}; |
61 | else |
62 | return {param1, param2}; |
63 | } |
64 | }; |
65 | |
66 | // Iterate primitive descriptor iterator till either of the following |
67 | // - brgemm kernel implementation is found |
68 | // - end of the primitive descriptor iterator is reached |
69 | // return `true` iff brgemm kernel is found |
70 | template <typename PD> |
71 | bool seek_brgemm_impl(PD &pd) { |
72 | const std::string brgemm("brgemm" ); |
73 | std::string impl_info; |
74 | bool brgemm_ker_found = false, seek_next_impl = true; |
75 | do { |
76 | std::string impl_info(pd.impl_info_str()); |
77 | brgemm_ker_found = impl_info.find(brgemm) != std::string::npos; |
78 | |
79 | seek_next_impl = !brgemm_ker_found && pd.next_impl(); |
80 | } while (seek_next_impl); |
81 | |
82 | return brgemm_ker_found; |
83 | } |
84 | |
85 | std::vector<inner_product_shape_t> inner_product_shapes; |
86 | std::vector<data_type> inner_product_data_types; |
87 | |
88 | void SetUp() override { |
89 | for (auto dt : {data_type::f32, data_type::bf16}) { |
90 | if (!unsupported_data_type(dt)) |
91 | inner_product_data_types.push_back(dt); |
92 | } |
93 | |
94 | // inner product shapes of zero dimension [majority case] |
95 | // dims format: {mb, ic, oc} |
96 | inner_product_shapes.insert(inner_product_shapes.end(), |
97 | {{2, 16, 16}, {2, 16, 32}, {2, 16, 64}, {2, 32, 16}, |
98 | {2, 32, 32}, {2, 32, 64}, {2, 64, 16}, {2, 64, 32}, |
99 | {2, 64, 64}, {2, 512, 16}, {2, 512, 32}, {2, 512, 64}, |
100 | {2, 512, 512}, {2, 512, 1024}, {2, 1024, 512}}); |
101 | |
102 | // inner product zero dimension shapes with channel tails |
103 | for (auto sz : {1, 3, 15, 17, 31, 33, 63, 65, 127, 129}) |
104 | inner_product_shapes.emplace_back( |
105 | inner_product_shape_t {sz, sz, sz}); |
106 | |
107 | // inner product zero dimensional regression shapes |
108 | inner_product_shapes.emplace_back( |
109 | inner_product_shape_t {2, 1024, 30522}); |
110 | |
111 | // inner product shapes of higher dimensions |
112 | // dims format: either of {mb, ic, oc, kw}, {mb, ic, oc, kw, kh}, |
113 | // or {mb, ic, oc, kw, kh, kd} |
114 | inner_product_shapes.insert(inner_product_shapes.end(), |
115 | {{2, 16, 16, 2}, {2, 16, 32, 2, 3}, {2, 16, 64, 4, 3, 2}, |
116 | {2, 32, 16, 2}, {2, 32, 32, 2, 3}, {2, 32, 64, 4, 3, 2}, |
117 | {2, 64, 16, 2}, {2, 64, 32, 2, 3}, |
118 | {2, 64, 64, 4, 3, 2}}); |
119 | } |
120 | }; |
121 | |
122 | // Check for weights consistency in inner product, that is weights are same |
123 | // across forward and backward pass |
124 | // TODO: Enable similar tests for convolution once brgemm kernel's support |
125 | // is complete |
126 | HANDLE_EXCEPTIONS_FOR_TEST_F(weights_format_test_t, InnerProductWeightsCheck) { |
127 | const bool do_skip = !DNNL_X64 || (DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE) |
128 | || (get_test_engine_kind() != engine::kind::cpu); |
129 | SKIP_IF(do_skip, |
130 | "Inner Product weight check is applicable only for x64 CPU" ); |
131 | |
132 | for_(const auto &input_shape : inner_product_shapes) |
133 | for (const auto &input_dt : inner_product_data_types) { |
134 | // Note: For inner product with mixed data types, e.g. with bf16 src |
135 | // and f32 dst, we do not require weights consistency. |
136 | memory::desc src_md {input_shape.src_dims(), input_dt, tag::any}; |
137 | memory::desc wei_md {input_shape.wei_dims(), input_dt, tag::any}; |
138 | memory::desc bia_md {input_shape.bia_dims(), input_dt, tag::any}; |
139 | memory::desc dst_md {input_shape.dst_dims(), input_dt, tag::any}; |
140 | |
141 | auto fwd_pd = inner_product_forward::primitive_desc(eng, |
142 | prop_kind::forward_training, src_md, wei_md, bia_md, dst_md); |
143 | |
144 | bool fwd_brgemm_ker_found = false, bwdd_brgemm_ker_found = false, |
145 | bwdw_brgemm_ker_found = false; |
146 | // Currently only brgemm kernel supports same weight tags |
147 | // for forward and backward data/weight inner product, therefore |
148 | // skip if the forward impl kernel is not brgemm |
149 | ASSERT_NO_THROW(fwd_brgemm_ker_found = seek_brgemm_impl(fwd_pd)); |
150 | if (!fwd_brgemm_ker_found) continue; |
151 | |
152 | // Since `seek_brgemm_impl` modifies the forward primitive desc above |
153 | // therefore bwdd_pd and bwdw_pd needs to be initialized only after |
154 | // fwd_pd is fixed. |
155 | auto bwdd_pd = inner_product_backward_data::primitive_desc( |
156 | eng, src_md, wei_md, dst_md, fwd_pd); |
157 | auto bwdw_pd = inner_product_backward_weights::primitive_desc( |
158 | eng, src_md, wei_md, bia_md, dst_md, fwd_pd); |
159 | // If the forward inner product can be handled by brgemm then so |
160 | // should be the backward data/weights one |
161 | ASSERT_NO_THROW(bwdd_brgemm_ker_found = seek_brgemm_impl(bwdd_pd)); |
162 | ASSERT_NO_THROW(bwdw_brgemm_ker_found = seek_brgemm_impl(bwdw_pd)); |
163 | |
164 | ASSERT_TRUE(bwdd_brgemm_ker_found); |
165 | ASSERT_TRUE(bwdw_brgemm_ker_found); |
166 | |
167 | // Check for weights consistency |
168 | const auto &fwd_wei |
169 | = fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS); |
170 | const auto &bwdd_wei |
171 | = bwdd_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS); |
172 | const auto &bwdw_wei |
173 | = bwdw_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS); |
174 | |
175 | ASSERT_TRUE(fwd_wei == bwdd_wei && fwd_wei == bwdw_wei); |
176 | } |
177 | } |
178 | |
179 | } // namespace dnnl |
180 | |