1/*******************************************************************************
2* Copyright 2021 Intel Corporation
3*
4* Licensed under the Apache License, Version 2.0 (the "License");
5* you may not use this file except in compliance with the License.
6* You may obtain a copy of the License at
7*
8* http://www.apache.org/licenses/LICENSE-2.0
9*
10* Unless required by applicable law or agreed to in writing, software
11* distributed under the License is distributed on an "AS IS" BASIS,
12* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13* See the License for the specific language governing permissions and
14* limitations under the License.
15*******************************************************************************/
16
17#ifndef COMMON_CACHE_BLOB_HPP
18#define COMMON_CACHE_BLOB_HPP
19
20#include <cstdint>
21#include <cstring>
22#include <memory>
23
24#include "c_types_map.hpp"
25
26namespace dnnl {
27namespace impl {
28
29struct cache_blob_impl_t {
30 cache_blob_impl_t() = delete;
31 cache_blob_impl_t(uint8_t *data, size_t size)
32 : pos_(0), data_(data), size_(size) {}
33
34 status_t add_binary(const uint8_t *binary, size_t binary_size) {
35 if (!binary || binary_size == 0) { return status::invalid_arguments; }
36 if (pos_ + sizeof(binary_size) + binary_size > size_) {
37 return status::invalid_arguments;
38 }
39
40 std::memcpy(data_ + pos_, &binary_size, sizeof(binary_size));
41 pos_ += sizeof(binary_size);
42 std::memcpy(data_ + pos_, binary, binary_size);
43 pos_ += binary_size;
44 return status::success;
45 }
46
47 status_t get_binary(const uint8_t **binary, size_t *binary_size) {
48 if (!binary || !binary_size) { return status::invalid_arguments; }
49 if (pos_ >= size_) { return status::invalid_arguments; }
50 (*binary_size) = *reinterpret_cast<size_t *>(data_ + pos_);
51 pos_ += sizeof(*binary_size);
52 (*binary) = data_ + pos_;
53 pos_ += *binary_size;
54 return status::success;
55 }
56
57private:
58 size_t pos_;
59 uint8_t *data_;
60 size_t size_;
61};
62
63struct cache_blob_t {
64 cache_blob_t() = default;
65 cache_blob_t(uint8_t *data, size_t size)
66 : impl_(std::make_shared<cache_blob_impl_t>(data, size)) {}
67
68 status_t add_binary(const uint8_t *binary, size_t binary_size) {
69 if (!impl_) return status::runtime_error;
70 return impl_->add_binary(binary, binary_size);
71 }
72
73 status_t get_binary(const uint8_t **binary, size_t *binary_size) {
74 if (!impl_) return status::runtime_error;
75 return impl_->get_binary(binary, binary_size);
76 }
77
78 explicit operator bool() const { return bool(impl_); }
79
80private:
81 std::shared_ptr<cache_blob_impl_t> impl_;
82};
83
84} // namespace impl
85} // namespace dnnl
86
87#endif
88