1 | /******************************************************************************* |
2 | * Copyright 2021-2022 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #include "dnnl_test_common.hpp" |
18 | #include "gtest/gtest.h" |
19 | |
20 | #include "oneapi/dnnl/dnnl.hpp" |
21 | |
22 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
23 | #include "oneapi/dnnl/dnnl_ocl.hpp" |
24 | #endif |
25 | |
26 | namespace dnnl { |
27 | |
28 | class persistent_cache_api_test_t : public ::testing::Test {}; |
29 | |
30 | HANDLE_EXCEPTIONS_FOR_TEST( |
31 | persistent_cache_api_test_t, TestPersistentCacheAPI) { |
32 | engine e = get_test_engine(); |
33 | auto pd = convolution_forward::primitive_desc {e, |
34 | prop_kind::forward_training, algorithm::convolution_direct, |
35 | {{2, 16, 16, 16}, memory::data_type::f32, memory::format_tag::nchw}, |
36 | {{16, 16, 3, 3}, memory::data_type::f32, memory::format_tag::oihw}, |
37 | {{2, 16, 14, 14}, memory::data_type::f32, memory::format_tag::nchw}, |
38 | {1, 1}, {0, 0}, {0, 0}}; |
39 | auto p = convolution_forward(pd); |
40 | |
41 | std::vector<uint8_t> cache_blob_id; |
42 | std::vector<uint8_t> cache_blob; |
43 | |
44 | ASSERT_NO_THROW(cache_blob_id = pd.get_cache_blob_id()); |
45 | ASSERT_EQ(cache_blob_id, pd.get_cache_blob_id()); |
46 | |
47 | if (get_test_engine_kind() != engine::kind::gpu |
48 | || (get_test_engine_kind() == engine::kind::gpu |
49 | && DNNL_GPU_RUNTIME != DNNL_RUNTIME_OCL)) { |
50 | ASSERT_EQ(cache_blob_id.empty(), true); |
51 | EXPECT_ANY_THROW(cache_blob = p.get_cache_blob()); |
52 | ASSERT_EQ(cache_blob.empty(), true); |
53 | EXPECT_ANY_THROW(convolution_forward(pd, cache_blob)); |
54 | } else { |
55 | ASSERT_EQ(cache_blob_id.empty(), false); |
56 | ASSERT_NO_THROW(cache_blob = p.get_cache_blob()); |
57 | ASSERT_EQ(cache_blob.empty(), false); |
58 | ASSERT_NO_THROW(p = convolution_forward(pd, cache_blob)); |
59 | ASSERT_EQ(cache_blob, p.get_cache_blob()); |
60 | } |
61 | } |
62 | |
63 | #if DNNL_GPU_RUNTIME == DNNL_RUNTIME_OCL |
64 | HANDLE_EXCEPTIONS_FOR_TEST( |
65 | persistent_cache_api_test_t, TestPersistentCacheAPIEngine) { |
66 | using namespace dnnl::ocl_interop; |
67 | engine test_engine = get_test_engine(); |
68 | |
69 | if (get_test_engine_kind() != engine::kind::gpu) { |
70 | ASSERT_ANY_THROW(get_engine_cache_blob(test_engine)); |
71 | return; |
72 | } |
73 | |
74 | std::vector<uint8_t> cache_blob; |
75 | std::vector<uint8_t> cache_blob_id; |
76 | |
77 | ASSERT_NO_THROW(cache_blob = get_engine_cache_blob(test_engine)); |
78 | ASSERT_NO_THROW( |
79 | cache_blob_id = get_engine_cache_blob_id(get_device(test_engine))); |
80 | |
81 | ASSERT_EQ(get_engine_cache_blob(test_engine), cache_blob); |
82 | ASSERT_EQ(get_engine_cache_blob_id(get_device(test_engine)), cache_blob_id); |
83 | |
84 | ASSERT_TRUE(!cache_blob.empty()); |
85 | ASSERT_TRUE(!cache_blob_id.empty()); |
86 | |
87 | auto eng = make_engine( |
88 | get_device(test_engine), get_context(test_engine), cache_blob); |
89 | |
90 | ASSERT_EQ(get_engine_cache_blob(eng), cache_blob); |
91 | ASSERT_EQ(get_engine_cache_blob_id(get_device(eng)), cache_blob_id); |
92 | } |
93 | #endif |
94 | |
95 | } // namespace dnnl |
96 | |