1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "dnnl.hpp"
21
22namespace dnnl {
23
24TEST(primitive_cache_mt_test, TestGeneralCase) {
25 using tag = memory::format_tag;
26 using dt = memory::data_type;
27
28 engine eng(get_test_engine_kind(), 0);
29
30 // Flush the cache
31 set_primitive_cache_capacity(0);
32 set_primitive_cache_capacity(1024);
33
34 memory::dim n_primitives = 12;
35
36 dnnl::impl::parallel_nd(n_primitives, [&](memory::dim np) {
37 auto md = memory::desc({{np, 1, 1, 1}, dt::f32, tag::nchw});
38 auto relu_pd = eltwise_forward::primitive_desc(eng,
39 prop_kind::forward_inference, algorithm::eltwise_relu, md, md,
40 0.f);
41 auto relu = eltwise_forward(relu_pd);
42 });
43
44 ASSERT_EQ(get_primitive_cache_size(), n_primitives);
45}
46
47TEST(primitive_cache_mt_test, TestNestedCase) {
48 using tag = memory::format_tag;
49 using dt = memory::data_type;
50
51 engine eng(get_test_engine_kind(), 0);
52
53 // Flush the cache
54 set_primitive_cache_capacity(0);
55 set_primitive_cache_capacity(1024);
56
57 memory::dim n_primitives = 12;
58 memory::dim n_srcs = 32;
59
60 dnnl::impl::parallel_nd(n_primitives, [&](memory::dim np) {
61 std::vector<memory::desc> src_mds(n_srcs);
62 std::vector<float> scales(n_srcs, 1.0);
63
64 for (memory::dim ns = 0; ns < n_srcs; ++ns) {
65 src_mds[ns] = memory::desc({{128, 128}, dt::f32, tag::nc});
66 }
67 auto sum_pd = sum::primitive_desc(eng, scales, src_mds);
68 auto sum_prim = sum(sum_pd);
69 });
70}
71
72TEST(primitive_cache_mt_test, TestMTCacheHit) {
73 using tag = memory::format_tag;
74 using dt = memory::data_type;
75
76 engine eng(get_test_engine_kind(), 0);
77
78 // Flush the cache
79 dnnl::set_primitive_cache_capacity(0);
80 dnnl::set_primitive_cache_capacity(1024);
81
82 int n_primitives = 10;
83
84 auto create_eltwise_primitive = [&](int np) {
85 auto md = memory::desc({{np, 1, 1, 1}, dt::f32, tag::nchw});
86 auto relu_pd = eltwise_forward::primitive_desc(eng,
87 prop_kind::forward_inference, algorithm::eltwise_relu, md, md,
88 0.f);
89 auto relu = eltwise_forward(relu_pd);
90 };
91
92 // Fill the cache with n_primitives (cache_miss)
93 for (int i = 0; i < n_primitives; i++)
94 create_eltwise_primitive(i);
95
96 // This section should only perform cache_hits
97 dnnl::impl::parallel(0, [&](int ithr, int nthr) {
98 for (int i = 0; i < n_primitives; i++)
99 create_eltwise_primitive(i);
100 });
101
102 ASSERT_EQ(get_primitive_cache_size(), n_primitives);
103}
104
105} // namespace dnnl
106