1 | /******************************************************************************* |
2 | * Copyright 2020-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
23 | #include "tests/test_isa_common.hpp" |
24 | #endif |
25 | namespace dnnl { |
26 | |
27 | // short names for brevity |
28 | using data_type = memory::data_type; |
29 | using tag = memory::format_tag; |
30 | |
31 | class wino_conv_test_t : public ::testing::Test { |
32 | protected: |
33 | engine eng = get_test_engine(); |
34 | struct input_data_t { |
35 | data_type dat_dt; |
36 | data_type wei_dt; |
37 | bool wino_supported = false; |
38 | bool backward_supported = false; |
39 | } input_f32, input_f16, input_int8; |
40 | |
41 | void SetUp() override { |
42 | input_f32.dat_dt = data_type::f32; |
43 | input_f32.wei_dt = data_type::f32; |
44 | |
45 | input_f16.dat_dt = data_type::f16; |
46 | input_f16.wei_dt = data_type::f16; |
47 | |
48 | input_int8.dat_dt = data_type::u8; |
49 | input_int8.wei_dt = data_type::s8; |
50 | |
51 | #if DNNL_CPU_RUNTIME != DNNL_RUNTIME_NONE |
52 | #if DNNL_X64 |
53 | const bool is_cpu = get_test_engine_kind() == engine::kind::cpu; |
54 | const bool is_gpu = get_test_engine_kind() == engine::kind::gpu; |
55 | static const auto isa = get_effective_cpu_isa(); |
56 | static const bool has_avx512_core |
57 | = dnnl::is_superset(isa, cpu_isa::avx512_core); |
58 | input_f32.wino_supported = is_gpu || (is_cpu && has_avx512_core); |
59 | input_f16.wino_supported = is_gpu; |
60 | input_f32.backward_supported = is_cpu && impl::dnnl_thr_syncable(); |
61 | #elif DNNL_AARCH64 && DNNL_AARCH64_USE_ACL |
62 | const bool is_cpu = get_test_engine_kind() == engine::kind::cpu; |
63 | input_f32.wino_supported = is_cpu; |
64 | input_f16.wino_supported = is_cpu; |
65 | #endif |
66 | |
67 | #else |
68 | const bool is_gpu = get_test_engine_kind() == engine::kind::gpu; |
69 | input_f32.wino_supported = is_gpu; |
70 | input_f16.wino_supported = is_gpu; |
71 | #endif |
72 | } |
73 | }; |
74 | |
75 | TEST_F(wino_conv_test_t, TestSmallPadding) { |
76 | for (const auto &input : {input_f32, input_f16, input_int8}) { |
77 | if (unsupported_data_type(input.dat_dt) |
78 | || unsupported_data_type(input.wei_dt)) |
79 | continue; |
80 | |
81 | memory::desc src_md {{1, 16, 7, 7}, input.dat_dt, tag::any}; |
82 | memory::desc wei_md {{32, 16, 3, 3}, input.wei_dt, tag::any}; |
83 | memory::desc dst_md {{1, 32, 7, 7}, input.dat_dt, tag::any}; |
84 | |
85 | if (input.wino_supported) { |
86 | convolution_forward::primitive_desc fwd_hint; |
87 | EXPECT_NO_THROW( |
88 | fwd_hint = convolution_forward::primitive_desc(eng, |
89 | prop_kind::forward, algorithm::convolution_winograd, |
90 | src_md, wei_md, dst_md, {1, 1}, {1, 1}, {1, 1})); |
91 | #if DNNL_X64 |
92 | if (input.wei_dt == data_type::s8) { |
93 | EXPECT_EQ(fwd_hint.weights_desc().get_format_kind(), |
94 | memory::format_kind::opaque); |
95 | } |
96 | #endif |
97 | if (input.backward_supported) { |
98 | EXPECT_NO_THROW(convolution_backward_data::primitive_desc(eng, |
99 | algorithm::convolution_winograd, src_md, wei_md, dst_md, |
100 | {1, 1}, {1, 1}, {1, 1}, fwd_hint)); |
101 | |
102 | EXPECT_NO_THROW(convolution_backward_weights::primitive_desc( |
103 | eng, algorithm::convolution_winograd, src_md, wei_md, |
104 | dst_md, {1, 1}, {1, 1}, {1, 1}, fwd_hint)); |
105 | } |
106 | } else { |
107 | EXPECT_ANY_THROW(convolution_forward::primitive_desc(eng, |
108 | prop_kind::forward, algorithm::convolution_winograd, src_md, |
109 | wei_md, dst_md, {1, 1}, {1, 1}, {1, 1})); |
110 | } |
111 | } |
112 | } |
113 | |
114 | TEST_F(wino_conv_test_t, TestLargePadding) { |
115 | for (const auto &input : {input_f32, input_f16, input_int8}) { |
116 | if (unsupported_data_type(input.dat_dt) |
117 | || unsupported_data_type(input.wei_dt)) |
118 | continue; |
119 | |
120 | memory::desc src_md {{1, 16, 7, 7}, input.dat_dt, tag::any}; |
121 | memory::desc wei_md {{32, 16, 3, 3}, input.wei_dt, tag::any}; |
122 | memory::desc dst_md {{1, 32, 9, 9}, input.dat_dt, tag::any}; |
123 | |
124 | bool large_pad_is_supported |
125 | = (get_test_engine_kind() == engine::kind::gpu); |
126 | if (input.wino_supported && large_pad_is_supported) { |
127 | EXPECT_NO_THROW(convolution_forward::primitive_desc(eng, |
128 | prop_kind::forward, algorithm::convolution_winograd, src_md, |
129 | wei_md, dst_md, {1, 1}, {2, 2}, {2, 2})); |
130 | } else { |
131 | EXPECT_ANY_THROW(convolution_forward::primitive_desc(eng, |
132 | prop_kind::forward, algorithm::convolution_winograd, src_md, |
133 | wei_md, dst_md, {1, 1}, {2, 2}, {2, 2})); |
134 | } |
135 | } |
136 | } |
137 | |
138 | TEST_F(wino_conv_test_t, TestUnsupportedKernel) { |
139 | for (const auto &input : {input_f32, input_f16, input_int8}) { |
140 | if (unsupported_data_type(input.dat_dt) |
141 | || unsupported_data_type(input.wei_dt)) |
142 | continue; |
143 | |
144 | memory::desc src_md {{1, 16, 5, 5}, input.dat_dt, tag::any}; |
145 | memory::desc wei_md {{32, 16, 2, 2}, input.wei_dt, tag::any}; |
146 | memory::desc dst_md {{1, 32, 6, 6}, input.dat_dt, tag::any}; |
147 | |
148 | EXPECT_ANY_THROW(convolution_forward::primitive_desc(eng, |
149 | prop_kind::forward, algorithm::convolution_winograd, src_md, |
150 | wei_md, dst_md, {1, 1}, {1, 1}, {1, 1})); |
151 | } |
152 | } |
153 | |
154 | } // namespace dnnl |
155 | |