1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | /// \file |
16 | /// Memory management for TF Lite. |
17 | #ifndef TENSORFLOW_LITE_ALLOCATION_H_ |
18 | #define TENSORFLOW_LITE_ALLOCATION_H_ |
19 | |
20 | #include <stddef.h> |
21 | |
22 | #include <cstdio> |
23 | #include <cstdlib> |
24 | #include <memory> |
25 | |
26 | #include "tensorflow/lite/core/api/error_reporter.h" |
27 | |
28 | namespace tflite { |
29 | |
30 | /// A memory allocation handle. This could be a mmap or shared memory. |
31 | class Allocation { |
32 | public: |
33 | virtual ~Allocation() {} |
34 | |
35 | enum class Type { |
36 | kMMap, |
37 | kFileCopy, |
38 | kMemory, |
39 | }; |
40 | |
41 | /// Base pointer of this allocation |
42 | virtual const void* base() const = 0; |
43 | /// Size in bytes of the allocation |
44 | virtual size_t bytes() const = 0; |
45 | /// Whether the allocation is valid |
46 | virtual bool valid() const = 0; |
47 | /// Return the type of the Allocation. |
48 | Type type() const { return type_; } |
49 | |
50 | protected: |
51 | Allocation(ErrorReporter* error_reporter, Type type) |
52 | : error_reporter_(error_reporter), type_(type) {} |
53 | ErrorReporter* error_reporter_; |
54 | |
55 | private: |
56 | const Type type_; |
57 | }; |
58 | |
59 | /// Note that not all platforms support MMAP-based allocation. |
60 | /// Use `IsSupported()` to check. |
61 | class MMAPAllocation : public Allocation { |
62 | public: |
63 | /// Loads and maps the provided file to a memory region. |
64 | MMAPAllocation(const char* filename, ErrorReporter* error_reporter); |
65 | |
66 | /// Maps the provided file descriptor to a memory region. |
67 | /// Note: The provided file descriptor will be dup'ed for usage; the caller |
68 | /// retains ownership of the provided descriptor and should close accordingly. |
69 | MMAPAllocation(int fd, ErrorReporter* error_reporter); |
70 | |
71 | /// Maps the provided file descriptor, with the given offset and length (both |
72 | /// in bytes), to a memory region. |
73 | /// Note: The provided file descriptor will be dup'ed for usage; the caller |
74 | /// retains ownership of the provided descriptor and should close accordingly. |
75 | MMAPAllocation(int fd, size_t offset, size_t length, |
76 | ErrorReporter* error_reporter); |
77 | |
78 | ~MMAPAllocation() override; |
79 | const void* base() const override; |
80 | size_t bytes() const override; |
81 | bool valid() const override; |
82 | |
83 | int fd() const { return mmap_fd_; } |
84 | |
85 | static bool IsSupported(); |
86 | |
87 | protected: |
88 | // Data required for mmap. |
89 | int mmap_fd_ = -1; // mmap file descriptor |
90 | const void* mmapped_buffer_; |
91 | size_t buffer_size_bytes_ = 0; |
92 | // Used when the address to mmap is not page-aligned. |
93 | size_t offset_in_buffer_ = 0; |
94 | |
95 | private: |
96 | // Assumes ownership of the provided `owned_fd` instance. |
97 | MMAPAllocation(ErrorReporter* error_reporter, int owned_fd); |
98 | |
99 | // Assumes ownership of the provided `owned_fd` instance, and uses the given |
100 | // offset and length (both in bytes) for memory mapping. |
101 | MMAPAllocation(ErrorReporter* error_reporter, int owned_fd, size_t offset, |
102 | size_t length); |
103 | }; |
104 | |
105 | class FileCopyAllocation : public Allocation { |
106 | public: |
107 | /// Loads the provided file into a heap memory region. |
108 | FileCopyAllocation(const char* filename, ErrorReporter* error_reporter); |
109 | ~FileCopyAllocation() override; |
110 | const void* base() const override; |
111 | size_t bytes() const override; |
112 | bool valid() const override; |
113 | |
114 | private: |
115 | std::unique_ptr<const char[]> copied_buffer_; |
116 | size_t buffer_size_bytes_ = 0; |
117 | }; |
118 | |
119 | class MemoryAllocation : public Allocation { |
120 | public: |
121 | /// Provides a (read-only) view of the provided buffer region as an |
122 | /// allocation. |
123 | /// Note: The caller retains ownership of `ptr`, and must ensure it remains |
124 | /// valid for the lifetime of the class instance. |
125 | MemoryAllocation(const void* ptr, size_t num_bytes, |
126 | ErrorReporter* error_reporter); |
127 | ~MemoryAllocation() override; |
128 | const void* base() const override; |
129 | size_t bytes() const override; |
130 | bool valid() const override; |
131 | |
132 | private: |
133 | const void* buffer_; |
134 | size_t buffer_size_bytes_ = 0; |
135 | }; |
136 | |
137 | } // namespace tflite |
138 | |
139 | #endif // TENSORFLOW_LITE_ALLOCATION_H_ |
140 | |