1/*******************************************************************************
2* Copyright 2016-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22#include "tests/test_isa_common.hpp"
23
24namespace dnnl {
25
26using tag = memory::format_tag;
27
28enum class data_fmt_t { flat, blocked_cX };
29
30#define FLT data_fmt_t::flat
31#define BLK data_fmt_t::blocked_cX
32
33struct conv_any_fmt_test_params_t {
34 prop_kind aprop_kind;
35 algorithm aalgorithm;
36 data_fmt_t expected_src_fmt;
37 data_fmt_t expected_dst_fmt;
38 test_convolution_sizes_t test_cd;
39};
40
41template <typename data_t>
42class convolution_any_fmt_test_t
43 : public ::testing::TestWithParam<conv_any_fmt_test_params_t> {
44protected:
45 void SetUp() override {
46#if DNNL_X64
47 // Skip this test if the library cannot select blocked format a priori.
48 // Currently blocking is supported only for sse41 and later CPUs.
49 bool implementation_supports_blocking = dnnl::mayiuse(cpu_isa::sse41);
50 if (!implementation_supports_blocking) return;
51#else
52 return;
53#endif
54
55 auto p = ::testing::TestWithParam<
56 conv_any_fmt_test_params_t>::GetParam();
57
58 ASSERT_EQ(p.aprop_kind, prop_kind::forward);
59 ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct);
60 auto eng = get_test_engine();
61 memory::data_type data_type = data_traits<data_t>::data_type;
62 SKIP_IF_CUDA((p.expected_src_fmt == BLK || p.expected_dst_fmt == BLK),
63 "unsupported format");
64 ASSERT_EQ(data_type, dnnl::memory::data_type::f32);
65
66 test_convolution_sizes_t cd = p.test_cd;
67
68 auto c_src_desc
69 = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type, tag::any);
70 auto c_weights_desc = cd.ng > 1
71 ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw},
72 data_type, tag::any)
73 : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type, tag::any);
74 auto c_dst_desc
75 = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, data_type, tag::any);
76
77 auto conv_prim_desc = convolution_forward::primitive_desc(eng,
78 p.aprop_kind, p.aalgorithm, c_src_desc, c_weights_desc,
79 c_dst_desc, {cd.strh, cd.strw}, {cd.padh, cd.padw},
80 {cd.padh, cd.padw});
81
82 auto check_fmt = [&](const memory::desc &md, data_fmt_t expected) {
83 bool ok = false;
84 if (expected == FLT) {
85 ok = true
86 && md.get_format_kind() == memory::format_kind::blocked
87 && md.get_inner_nblks() == 0;
88 } else if (expected == BLK) {
89 ok = true
90 && md.get_format_kind() == memory::format_kind::blocked
91 && md.get_inner_nblks() == 1
92 && md.get_inner_idxs()[0] == 1
93 && (false || md.get_inner_blks()[0] == 8
94 || md.get_inner_blks()[0] == 16);
95 }
96 return ok;
97 };
98
99 ASSERT_TRUE(check_fmt(conv_prim_desc.src_desc(), p.expected_src_fmt));
100 ASSERT_TRUE(check_fmt(conv_prim_desc.dst_desc(), p.expected_dst_fmt));
101 }
102};
103
104using conv_any_fmt_test_float = convolution_any_fmt_test_t<float>;
105
106#define CPARAMS prop_kind::forward, algorithm::convolution_direct
107
108using tf32 = conv_any_fmt_test_params_t;
109
110#define ALEXNET_SUITE(EFMT) \
111 tf32 {CPARAMS, FLT, EFMT, \
112 {2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4}}, \
113 tf32 {CPARAMS, EFMT, EFMT, \
114 {2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1}}, \
115 tf32 {CPARAMS, EFMT, EFMT, \
116 {2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1}}, \
117 tf32 {CPARAMS, EFMT, EFMT, \
118 {2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1}}, \
119 tf32 { \
120 CPARAMS, EFMT, EFMT, { \
121 2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 \
122 } \
123 }
124
125#if DNNL_X64
126TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt) {}
127
128CPU_INSTANTIATE_TEST_SUITE_P(TestConvolutionAlexnetAnyFmtForward,
129 conv_any_fmt_test_float, ::testing::Values(ALEXNET_SUITE(BLK)));
130#endif
131} // namespace dnnl
132