1 | /******************************************************************************* |
2 | * Copyright 2021 Intel Corporation |
3 | * |
4 | * Licensed under the Apache License, Version 2.0 (the "License"); |
5 | * you may not use this file except in compliance with the License. |
6 | * You may obtain a copy of the License at |
7 | * |
8 | * http://www.apache.org/licenses/LICENSE-2.0 |
9 | * |
10 | * Unless required by applicable law or agreed to in writing, software |
11 | * distributed under the License is distributed on an "AS IS" BASIS, |
12 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
13 | * See the License for the specific language governing permissions and |
14 | * limitations under the License. |
15 | *******************************************************************************/ |
16 | |
17 | #ifndef COMMON_CACHE_BLOB_HPP |
18 | #define COMMON_CACHE_BLOB_HPP |
19 | |
20 | #include <cstdint> |
21 | #include <cstring> |
22 | #include <memory> |
23 | |
24 | #include "c_types_map.hpp" |
25 | |
26 | namespace dnnl { |
27 | namespace impl { |
28 | |
29 | struct cache_blob_impl_t { |
30 | cache_blob_impl_t() = delete; |
31 | cache_blob_impl_t(uint8_t *data, size_t size) |
32 | : pos_(0), data_(data), size_(size) {} |
33 | |
34 | status_t add_binary(const uint8_t *binary, size_t binary_size) { |
35 | if (!binary || binary_size == 0) { return status::invalid_arguments; } |
36 | if (pos_ + sizeof(binary_size) + binary_size > size_) { |
37 | return status::invalid_arguments; |
38 | } |
39 | |
40 | std::memcpy(data_ + pos_, &binary_size, sizeof(binary_size)); |
41 | pos_ += sizeof(binary_size); |
42 | std::memcpy(data_ + pos_, binary, binary_size); |
43 | pos_ += binary_size; |
44 | return status::success; |
45 | } |
46 | |
47 | status_t get_binary(const uint8_t **binary, size_t *binary_size) { |
48 | if (!binary || !binary_size) { return status::invalid_arguments; } |
49 | if (pos_ >= size_) { return status::invalid_arguments; } |
50 | (*binary_size) = *reinterpret_cast<size_t *>(data_ + pos_); |
51 | pos_ += sizeof(*binary_size); |
52 | (*binary) = data_ + pos_; |
53 | pos_ += *binary_size; |
54 | return status::success; |
55 | } |
56 | |
57 | private: |
58 | size_t pos_; |
59 | uint8_t *data_; |
60 | size_t size_; |
61 | }; |
62 | |
63 | struct cache_blob_t { |
64 | cache_blob_t() = default; |
65 | cache_blob_t(uint8_t *data, size_t size) |
66 | : impl_(std::make_shared<cache_blob_impl_t>(data, size)) {} |
67 | |
68 | status_t add_binary(const uint8_t *binary, size_t binary_size) { |
69 | if (!impl_) return status::runtime_error; |
70 | return impl_->add_binary(binary, binary_size); |
71 | } |
72 | |
73 | status_t get_binary(const uint8_t **binary, size_t *binary_size) { |
74 | if (!impl_) return status::runtime_error; |
75 | return impl_->get_binary(binary, binary_size); |
76 | } |
77 | |
78 | explicit operator bool() const { return bool(impl_); } |
79 | |
80 | private: |
81 | std::shared_ptr<cache_blob_impl_t> impl_; |
82 | }; |
83 | |
84 | } // namespace impl |
85 | } // namespace dnnl |
86 | |
87 | #endif |
88 | |