1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22#include <string>
23#include <vector>
24
25namespace dnnl {
26
27// short names for brevity
28using data_type = memory::data_type;
29using tag = memory::format_tag;
30
31class weights_format_test_t : public ::testing::Test {
32protected:
33 const engine eng = get_test_engine();
34
35 struct inner_product_shape_t {
36 public:
37 inner_product_shape_t(memory::dim mb, memory::dim ic, memory::dim oc,
38 memory::dim kw = 1, memory::dim kh = 1, memory::dim kd = 1)
39 : mb_(mb), ic_(ic), oc_(oc), kw_(kw), kh_(kh), kd_(kd) {}
40
41 memory::dims src_dims() const { return maybe_spatial(mb_, ic_); }
42 memory::dims wei_dims() const { return maybe_spatial(oc_, ic_); }
43
44 memory::dims dst_dims() const { return {mb_, oc_}; }
45 memory::dims bia_dims() const { return {oc_}; }
46
47 private:
48 memory::dim mb_, ic_, oc_, kw_, kh_, kd_;
49 bool is_1d() const { return kd_ == 1 && kh_ == 1 && kw_ != 1; }
50 bool is_2d() const { return kd_ == 1 && kh_ != 1; }
51 bool is_3d() const { return kd_ != 1; }
52
53 memory::dims maybe_spatial(
54 memory::dim param1, memory::dim param2) const {
55 if (is_3d())
56 return {param1, param2, kd_, kh_, kw_};
57 else if (is_2d())
58 return {param1, param2, kh_, kw_};
59 else if (is_1d())
60 return {param1, param2, kw_};
61 else
62 return {param1, param2};
63 }
64 };
65
66 // Iterate primitive descriptor iterator till either of the following
67 // - brgemm kernel implementation is found
68 // - end of the primitive descriptor iterator is reached
69 // return `true` iff brgemm kernel is found
70 template <typename PD>
71 bool seek_brgemm_impl(PD &pd) {
72 const std::string brgemm("brgemm");
73 std::string impl_info;
74 bool brgemm_ker_found = false, seek_next_impl = true;
75 do {
76 std::string impl_info(pd.impl_info_str());
77 brgemm_ker_found = impl_info.find(brgemm) != std::string::npos;
78
79 seek_next_impl = !brgemm_ker_found && pd.next_impl();
80 } while (seek_next_impl);
81
82 return brgemm_ker_found;
83 }
84
85 std::vector<inner_product_shape_t> inner_product_shapes;
86 std::vector<data_type> inner_product_data_types;
87
88 void SetUp() override {
89 for (auto dt : {data_type::f32, data_type::bf16}) {
90 if (!unsupported_data_type(dt))
91 inner_product_data_types.push_back(dt);
92 }
93
94 // inner product shapes of zero dimension [majority case]
95 // dims format: {mb, ic, oc}
96 inner_product_shapes.insert(inner_product_shapes.end(),
97 {{2, 16, 16}, {2, 16, 32}, {2, 16, 64}, {2, 32, 16},
98 {2, 32, 32}, {2, 32, 64}, {2, 64, 16}, {2, 64, 32},
99 {2, 64, 64}, {2, 512, 16}, {2, 512, 32}, {2, 512, 64},
100 {2, 512, 512}, {2, 512, 1024}, {2, 1024, 512}});
101
102 // inner product zero dimension shapes with channel tails
103 for (auto sz : {1, 3, 15, 17, 31, 33, 63, 65, 127, 129})
104 inner_product_shapes.emplace_back(
105 inner_product_shape_t {sz, sz, sz});
106
107 // inner product zero dimensional regression shapes
108 inner_product_shapes.emplace_back(
109 inner_product_shape_t {2, 1024, 30522});
110
111 // inner product shapes of higher dimensions
112 // dims format: either of {mb, ic, oc, kw}, {mb, ic, oc, kw, kh},
113 // or {mb, ic, oc, kw, kh, kd}
114 inner_product_shapes.insert(inner_product_shapes.end(),
115 {{2, 16, 16, 2}, {2, 16, 32, 2, 3}, {2, 16, 64, 4, 3, 2},
116 {2, 32, 16, 2}, {2, 32, 32, 2, 3}, {2, 32, 64, 4, 3, 2},
117 {2, 64, 16, 2}, {2, 64, 32, 2, 3},
118 {2, 64, 64, 4, 3, 2}});
119 }
120};
121
122// Check for weights consistency in inner product, that is weights are same
123// across forward and backward pass
124// TODO: Enable similar tests for convolution once brgemm kernel's support
125// is complete
126HANDLE_EXCEPTIONS_FOR_TEST_F(weights_format_test_t, InnerProductWeightsCheck) {
127 const bool do_skip = !DNNL_X64 || (DNNL_CPU_RUNTIME == DNNL_RUNTIME_NONE)
128 || (get_test_engine_kind() != engine::kind::cpu);
129 SKIP_IF(do_skip,
130 "Inner Product weight check is applicable only for x64 CPU");
131
132 for_(const auto &input_shape : inner_product_shapes)
133 for (const auto &input_dt : inner_product_data_types) {
134 // Note: For inner product with mixed data types, e.g. with bf16 src
135 // and f32 dst, we do not require weights consistency.
136 memory::desc src_md {input_shape.src_dims(), input_dt, tag::any};
137 memory::desc wei_md {input_shape.wei_dims(), input_dt, tag::any};
138 memory::desc bia_md {input_shape.bia_dims(), input_dt, tag::any};
139 memory::desc dst_md {input_shape.dst_dims(), input_dt, tag::any};
140
141 auto fwd_pd = inner_product_forward::primitive_desc(eng,
142 prop_kind::forward_training, src_md, wei_md, bia_md, dst_md);
143
144 bool fwd_brgemm_ker_found = false, bwdd_brgemm_ker_found = false,
145 bwdw_brgemm_ker_found = false;
146 // Currently only brgemm kernel supports same weight tags
147 // for forward and backward data/weight inner product, therefore
148 // skip if the forward impl kernel is not brgemm
149 ASSERT_NO_THROW(fwd_brgemm_ker_found = seek_brgemm_impl(fwd_pd));
150 if (!fwd_brgemm_ker_found) continue;
151
152 // Since `seek_brgemm_impl` modifies the forward primitive desc above
153 // therefore bwdd_pd and bwdw_pd needs to be initialized only after
154 // fwd_pd is fixed.
155 auto bwdd_pd = inner_product_backward_data::primitive_desc(
156 eng, src_md, wei_md, dst_md, fwd_pd);
157 auto bwdw_pd = inner_product_backward_weights::primitive_desc(
158 eng, src_md, wei_md, bia_md, dst_md, fwd_pd);
159 // If the forward inner product can be handled by brgemm then so
160 // should be the backward data/weights one
161 ASSERT_NO_THROW(bwdd_brgemm_ker_found = seek_brgemm_impl(bwdd_pd));
162 ASSERT_NO_THROW(bwdw_brgemm_ker_found = seek_brgemm_impl(bwdw_pd));
163
164 ASSERT_TRUE(bwdd_brgemm_ker_found);
165 ASSERT_TRUE(bwdw_brgemm_ker_found);
166
167 // Check for weights consistency
168 const auto &fwd_wei
169 = fwd_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS);
170 const auto &bwdd_wei
171 = bwdd_pd.query_md(query::exec_arg_md, DNNL_ARG_WEIGHTS);
172 const auto &bwdw_wei
173 = bwdw_pd.query_md(query::exec_arg_md, DNNL_ARG_DIFF_WEIGHTS);
174
175 ASSERT_TRUE(fwd_wei == bwdd_wei && fwd_wei == bwdw_wei);
176 }
177}
178
179} // namespace dnnl
180