1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/c/tf_buffer.h" |
17 | |
18 | #include "tensorflow/core/platform/errors.h" |
19 | #include "tensorflow/core/platform/mem.h" |
20 | #include "tensorflow/core/platform/protobuf.h" |
21 | #include "tensorflow/core/platform/status.h" |
22 | |
23 | extern "C" { |
24 | |
25 | TF_Buffer* TF_NewBuffer() { return new TF_Buffer{nullptr, 0, nullptr}; } |
26 | |
27 | TF_Buffer* TF_NewBufferFromString(const void* proto, size_t proto_len) { |
28 | void* copy = tensorflow::port::Malloc(proto_len); |
29 | memcpy(copy, proto, proto_len); |
30 | |
31 | TF_Buffer* buf = new TF_Buffer; |
32 | buf->data = copy; |
33 | buf->length = proto_len; |
34 | buf->data_deallocator = [](void* data, size_t length) { |
35 | tensorflow::port::Free(data); |
36 | }; |
37 | return buf; |
38 | } |
39 | |
40 | void TF_DeleteBuffer(TF_Buffer* buffer) { |
41 | if (buffer == nullptr) return; |
42 | if (buffer->data_deallocator != nullptr) { |
43 | (*buffer->data_deallocator)(const_cast<void*>(buffer->data), |
44 | buffer->length); |
45 | } |
46 | delete buffer; |
47 | } |
48 | |
49 | TF_Buffer TF_GetBuffer(TF_Buffer* buffer) { return *buffer; } |
50 | |
51 | } // end extern "C" |
52 | |
53 | namespace tensorflow { |
54 | |
55 | Status MessageToBuffer(const tensorflow::protobuf::MessageLite& in, |
56 | TF_Buffer* out) { |
57 | if (out->data != nullptr) { |
58 | return errors::InvalidArgument("Passing non-empty TF_Buffer is invalid." ); |
59 | } |
60 | const size_t proto_size = in.ByteSizeLong(); |
61 | void* buf = port::Malloc(proto_size); |
62 | if (buf == nullptr) { |
63 | return tensorflow::errors::ResourceExhausted( |
64 | "Failed to allocate memory to serialize message of type '" , |
65 | in.GetTypeName(), "' and size " , proto_size); |
66 | } |
67 | if (!in.SerializeWithCachedSizesToArray(static_cast<uint8*>(buf))) { |
68 | port::Free(buf); |
69 | return errors::InvalidArgument( |
70 | "Unable to serialize " , in.GetTypeName(), |
71 | " protocol buffer, perhaps the serialized size (" , proto_size, |
72 | " bytes) is too large?" ); |
73 | } |
74 | out->data = buf; |
75 | out->length = proto_size; |
76 | out->data_deallocator = [](void* data, size_t length) { port::Free(data); }; |
77 | return OkStatus(); |
78 | } |
79 | |
80 | Status BufferToMessage(const TF_Buffer* in, |
81 | tensorflow::protobuf::MessageLite* out) { |
82 | if (in == nullptr || !out->ParseFromArray(in->data, in->length)) { |
83 | return errors::InvalidArgument("Unparseable " , out->GetTypeName(), |
84 | " proto" ); |
85 | } |
86 | return OkStatus(); |
87 | } |
88 | |
89 | } // namespace tensorflow |
90 | |