1 | /******************************************************************************* |
2 | * Copyright 2016-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #include "tests/test_isa_common.hpp" |
23 | |
24 | namespace dnnl { |
25 | |
26 | using tag = memory::format_tag; |
27 | |
28 | enum class data_fmt_t { flat, blocked_cX }; |
29 | |
30 | #define FLT data_fmt_t::flat |
31 | #define BLK data_fmt_t::blocked_cX |
32 | |
33 | struct conv_any_fmt_test_params_t { |
34 | prop_kind aprop_kind; |
35 | algorithm aalgorithm; |
36 | data_fmt_t expected_src_fmt; |
37 | data_fmt_t expected_dst_fmt; |
38 | test_convolution_sizes_t test_cd; |
39 | }; |
40 | |
41 | template <typename data_t> |
42 | class convolution_any_fmt_test_t |
43 | : public ::testing::TestWithParam<conv_any_fmt_test_params_t> { |
44 | protected: |
45 | void SetUp() override { |
46 | #if DNNL_X64 |
47 | // Skip this test if the library cannot select blocked format a priori. |
48 | // Currently blocking is supported only for sse41 and later CPUs. |
49 | bool implementation_supports_blocking = dnnl::mayiuse(cpu_isa::sse41); |
50 | if (!implementation_supports_blocking) return; |
51 | #else |
52 | return; |
53 | #endif |
54 | |
55 | auto p = ::testing::TestWithParam< |
56 | conv_any_fmt_test_params_t>::GetParam(); |
57 | |
58 | ASSERT_EQ(p.aprop_kind, prop_kind::forward); |
59 | ASSERT_EQ(p.aalgorithm, algorithm::convolution_direct); |
60 | auto eng = get_test_engine(); |
61 | memory::data_type data_type = data_traits<data_t>::data_type; |
62 | SKIP_IF_CUDA((p.expected_src_fmt == BLK || p.expected_dst_fmt == BLK), |
63 | "unsupported format" ); |
64 | ASSERT_EQ(data_type, dnnl::memory::data_type::f32); |
65 | |
66 | test_convolution_sizes_t cd = p.test_cd; |
67 | |
68 | auto c_src_desc |
69 | = create_md({cd.mb, cd.ic, cd.ih, cd.iw}, data_type, tag::any); |
70 | auto c_weights_desc = cd.ng > 1 |
71 | ? create_md({cd.ng, cd.oc / cd.ng, cd.ic / cd.ng, cd.kh, cd.kw}, |
72 | data_type, tag::any) |
73 | : create_md({cd.oc, cd.ic, cd.kh, cd.kw}, data_type, tag::any); |
74 | auto c_dst_desc |
75 | = create_md({cd.mb, cd.oc, cd.oh, cd.ow}, data_type, tag::any); |
76 | |
77 | auto conv_prim_desc = convolution_forward::primitive_desc(eng, |
78 | p.aprop_kind, p.aalgorithm, c_src_desc, c_weights_desc, |
79 | c_dst_desc, {cd.strh, cd.strw}, {cd.padh, cd.padw}, |
80 | {cd.padh, cd.padw}); |
81 | |
82 | auto check_fmt = [&](const memory::desc &md, data_fmt_t expected) { |
83 | bool ok = false; |
84 | if (expected == FLT) { |
85 | ok = true |
86 | && md.get_format_kind() == memory::format_kind::blocked |
87 | && md.get_inner_nblks() == 0; |
88 | } else if (expected == BLK) { |
89 | ok = true |
90 | && md.get_format_kind() == memory::format_kind::blocked |
91 | && md.get_inner_nblks() == 1 |
92 | && md.get_inner_idxs()[0] == 1 |
93 | && (false || md.get_inner_blks()[0] == 8 |
94 | || md.get_inner_blks()[0] == 16); |
95 | } |
96 | return ok; |
97 | }; |
98 | |
99 | ASSERT_TRUE(check_fmt(conv_prim_desc.src_desc(), p.expected_src_fmt)); |
100 | ASSERT_TRUE(check_fmt(conv_prim_desc.dst_desc(), p.expected_dst_fmt)); |
101 | } |
102 | }; |
103 | |
104 | using conv_any_fmt_test_float = convolution_any_fmt_test_t<float>; |
105 | |
106 | #define CPARAMS prop_kind::forward, algorithm::convolution_direct |
107 | |
108 | using tf32 = conv_any_fmt_test_params_t; |
109 | |
110 | #define ALEXNET_SUITE(EFMT) \ |
111 | tf32 {CPARAMS, FLT, EFMT, \ |
112 | {2, 1, 3, 227, 227, 96, 55, 55, 11, 11, 0, 0, 4, 4}}, \ |
113 | tf32 {CPARAMS, EFMT, EFMT, \ |
114 | {2, 2, 96, 27, 27, 256, 27, 27, 5, 5, 2, 2, 1, 1}}, \ |
115 | tf32 {CPARAMS, EFMT, EFMT, \ |
116 | {2, 1, 256, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1}}, \ |
117 | tf32 {CPARAMS, EFMT, EFMT, \ |
118 | {2, 2, 384, 13, 13, 384, 13, 13, 3, 3, 1, 1, 1, 1}}, \ |
119 | tf32 { \ |
120 | CPARAMS, EFMT, EFMT, { \ |
121 | 2, 2, 384, 13, 13, 256, 13, 13, 3, 3, 1, 1, 1, 1 \ |
122 | } \ |
123 | } |
124 | |
125 | #if DNNL_X64 |
126 | TEST_P(conv_any_fmt_test_float, TestsConvolutionAnyFmt) {} |
127 | |
128 | CPU_INSTANTIATE_TEST_SUITE_P(TestConvolutionAlexnetAnyFmtForward, |
129 | conv_any_fmt_test_float, ::testing::Values(ALEXNET_SUITE(BLK))); |
130 | #endif |
131 | } // namespace dnnl |
132 | |