1/*******************************************************************************
2* Copyright 2021-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.hpp"
21
22#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
23#include "oneapi/dnnl/dnnl_ocl.hpp"
24#endif
25
26namespace dnnl {
27
28class persistent_cache_api_test_t : public ::testing::Test {};
29
30HANDLE_EXCEPTIONS_FOR_TEST(
31 persistent_cache_api_test_t, TestPersistentCacheAPI) {
32 engine e = get_test_engine();
33 auto pd = convolution_forward::primitive_desc {e,
34 prop_kind::forward_training, algorithm::convolution_direct,
35 {{2, 16, 16, 16}, memory::data_type::f32, memory::format_tag::nchw},
36 {{16, 16, 3, 3}, memory::data_type::f32, memory::format_tag::oihw},
37 {{2, 16, 14, 14}, memory::data_type::f32, memory::format_tag::nchw},
38 {1, 1}, {0, 0}, {0, 0}};
39 auto p = convolution_forward(pd);
40
41 std::vector<uint8_t> cache_blob_id;
42 std::vector<uint8_t> cache_blob;
43
44 ASSERT_NO_THROW(cache_blob_id = pd.get_cache_blob_id());
45 ASSERT_EQ(cache_blob_id, pd.get_cache_blob_id());
46
47 if (get_test_engine_kind() != engine::kind::gpu
48 || (get_test_engine_kind() == engine::kind::gpu
49 && DNNL_GPU_RUNTIME != DNNL_RUNTIME_OCL)) {
50 ASSERT_EQ(cache_blob_id.empty(), true);
51 EXPECT_ANY_THROW(cache_blob = p.get_cache_blob());
52 ASSERT_EQ(cache_blob.empty(), true);
53 EXPECT_ANY_THROW(convolution_forward(pd, cache_blob));
54 } else {
55 ASSERT_EQ(cache_blob_id.empty(), false);
56 ASSERT_NO_THROW(cache_blob = p.get_cache_blob());
57 ASSERT_EQ(cache_blob.empty(), false);
58 ASSERT_NO_THROW(p = convolution_forward(pd, cache_blob));
59 ASSERT_EQ(cache_blob, p.get_cache_blob());
60 }
61}
62
63#if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL
64HANDLE_EXCEPTIONS_FOR_TEST(
65 persistent_cache_api_test_t, TestPersistentCacheAPIEngine) {
66 using namespace dnnl::ocl_interop;
67 engine test_engine = get_test_engine();
68
69 if (get_test_engine_kind() != engine::kind::gpu) {
70 ASSERT_ANY_THROW(get_engine_cache_blob(test_engine));
71 return;
72 }
73
74 std::vector<uint8_t> cache_blob;
75 std::vector<uint8_t> cache_blob_id;
76
77 ASSERT_NO_THROW(cache_blob = get_engine_cache_blob(test_engine));
78 ASSERT_NO_THROW(
79 cache_blob_id = get_engine_cache_blob_id(get_device(test_engine)));
80
81 ASSERT_EQ(get_engine_cache_blob(test_engine), cache_blob);
82 ASSERT_EQ(get_engine_cache_blob_id(get_device(test_engine)), cache_blob_id);
83
84 ASSERT_TRUE(!cache_blob.empty());
85 ASSERT_TRUE(!cache_blob_id.empty());
86
87 auto eng = make_engine(
88 get_device(test_engine), get_context(test_engine), cache_blob);
89
90 ASSERT_EQ(get_engine_cache_blob(eng), cache_blob);
91 ASSERT_EQ(get_engine_cache_blob_id(get_device(eng)), cache_blob_id);
92}
93#endif
94
95} // namespace dnnl
96