1/*******************************************************************************
2* Copyright 2020-2022 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#include "dnnl_test_common.hpp"
18#include "gtest/gtest.h"
19
20#include "oneapi/dnnl/dnnl.h"
21#include "oneapi/dnnl/dnnl.hpp"
22
23#include <limits>
24#include <new>
25
26#ifdef DNNL_WITH_SYCL
27#include "oneapi/dnnl/dnnl_sycl.hpp"
28#endif
29
30namespace dnnl {
31
32class memory_test_c_t : public ::testing::TestWithParam<dnnl_engine_kind_t> {
33protected:
34 void SetUp() override {
35 eng_kind = GetParam();
36
37 if (dnnl_engine_get_count(eng_kind) == 0) return;
38
39 DNNL_CHECK(dnnl_engine_create(&engine, eng_kind, 0));
40 }
41
42 void TearDown() override {
43 if (engine) { DNNL_CHECK(dnnl_engine_destroy(engine)); }
44 }
45
46 dnnl_engine_kind_t eng_kind;
47 dnnl_engine_t engine = nullptr;
48};
49
50class memory_test_cpp_t : public ::testing::TestWithParam<dnnl_engine_kind_t> {
51};
52
53TEST_P(memory_test_c_t, OutOfMemory) {
54 SKIP_IF(!engine, "Engine is not found.");
55 SKIP_IF(is_sycl_engine(static_cast<engine::kind>(eng_kind)),
56 "Do not test C API with SYCL.");
57
58 dnnl_dim_t sz = std::numeric_limits<memory::dim>::max();
59 dnnl_dims_t dims = {sz};
60 dnnl_memory_desc_t md;
61 DNNL_CHECK(dnnl_memory_desc_create_with_tag(&md, 1, dims, dnnl_u8, dnnl_x));
62
63 dnnl_data_type_t data_type;
64 DNNL_CHECK(dnnl_memory_desc_query(md, dnnl_query_data_type, &data_type));
65 ASSERT_EQ(dnnl_data_type_size(data_type), sizeof(uint8_t));
66
67 dnnl_memory_t mem;
68 dnnl_status_t s
69 = dnnl_memory_create(&mem, md, engine, DNNL_MEMORY_ALLOCATE);
70 ASSERT_EQ(s, dnnl_out_of_memory);
71
72 DNNL_CHECK(dnnl_memory_desc_destroy(md));
73}
74
75TEST_P(memory_test_cpp_t, OutOfMemory) {
76 dnnl_engine_kind_t eng_kind_c = GetParam();
77 engine::kind eng_kind = static_cast<engine::kind>(eng_kind_c);
78 SKIP_IF(engine::get_count(eng_kind) == 0, "Engine is not found.");
79
80 engine eng(eng_kind, 0);
81
82 bool is_sycl = is_sycl_engine(eng_kind);
83
84 auto sz = std::numeric_limits<memory::dim>::max();
85#ifdef DNNL_WITH_SYCL
86 if (is_sycl) {
87 auto dev = sycl_interop::get_device(eng);
88 const memory::dim max_alloc_size
89 = dev.get_info<::sycl::info::device::max_mem_alloc_size>();
90 sz = (max_alloc_size < sz) ? max_alloc_size + 1 : sz;
91 }
92#endif
93
94 auto dt = memory::data_type::u8;
95 auto tag = memory::format_tag::x;
96 memory::desc md({sz}, dt, tag);
97 ASSERT_EQ(memory::data_type_size(dt), sizeof(uint8_t));
98 try {
99 auto mem = test::make_memory(md, eng);
100 ASSERT_NE(mem.get_data_handle(), nullptr);
101 } catch (const dnnl::error &e) {
102 ASSERT_EQ(e.status, dnnl_out_of_memory);
103 return;
104 } catch (const std::bad_alloc &) {
105 // Expect bad_alloc only with SYCL.
106 if (is_sycl) return;
107 throw;
108 }
109
110 // XXX: SYCL does not always throw, even when allocating
111 // > max_mem_alloc_size bytes.
112 if (!is_sycl) FAIL() << "Expected exception.";
113}
114
115namespace {
116struct print_to_string_param_name_t {
117 template <class ParamType>
118 std::string operator()(
119 const ::testing::TestParamInfo<ParamType> &info) const {
120 return to_string(info.param);
121 }
122};
123
124auto all_engine_kinds = ::testing::Values(dnnl_cpu, dnnl_gpu);
125
126} // namespace
127
128INSTANTIATE_TEST_SUITE_P(AllEngineKinds, memory_test_c_t, all_engine_kinds,
129 print_to_string_param_name_t());
130INSTANTIATE_TEST_SUITE_P(AllEngineKinds, memory_test_cpp_t, all_engine_kinds,
131 print_to_string_param_name_t());
132
133} // namespace dnnl
134