1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/distributed_runtime/tensor_coding.h" |
17 | |
18 | #include "google/protobuf/any.pb.h" |
19 | |
20 | #include "tensorflow/core/common_runtime/device.h" |
21 | #include "tensorflow/core/framework/tensor.pb.h" |
22 | #include "tensorflow/core/framework/tensor_shape.pb.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | TensorResponse::Source::~Source() {} |
27 | |
28 | void TensorResponse::Clear() { |
29 | on_host_ = false; |
30 | device_ = nullptr; |
31 | alloc_attrs_ = AllocatorAttributes(); |
32 | allocator_ = nullptr; |
33 | already_used_ = false; |
34 | ClearTensor(); |
35 | } |
36 | |
37 | void TensorResponse::ClearTensor() { |
38 | meta_.Clear(); |
39 | tensor_ = Tensor(); |
40 | } |
41 | |
42 | void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) { |
43 | Clear(); |
44 | device_ = d; |
45 | alloc_attrs_ = aa; |
46 | const DeviceAttributes& da = d->attributes(); |
47 | if (alloc_attrs_.on_host() || da.device_type() == "CPU" ) { |
48 | on_host_ = true; |
49 | } |
50 | allocator_ = device_->GetAllocator(alloc_attrs_); |
51 | } |
52 | |
53 | Status TensorResponse::InitFrom(RecvTensorResponse* response) { |
54 | Status s; |
55 | meta_.Swap(response); |
56 | if (on_host_) { |
57 | if (!tensor_.FromProto(allocator_, meta_.tensor())) { |
58 | s = errors::InvalidArgument("Cannot parse tensor from response" ); |
59 | } |
60 | } else { |
61 | s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_); |
62 | } |
63 | { |
64 | TensorProto empty; |
65 | meta_.mutable_tensor()->Swap(&empty); |
66 | } |
67 | meta_.clear_tensor(); |
68 | return s; |
69 | } |
70 | |
71 | void TensorResponse::InitPartial(const RecvTensorResponse& response, |
72 | const AllocationAttributes& allocation_attr) { |
73 | // Everything except content is present in *response. Content will |
74 | // arrive later; allocate a Tensor with appropriate storage for that |
75 | // content. |
76 | meta_ = response; |
77 | TensorShape shape(meta_.tensor().tensor_shape()); |
78 | Tensor t(allocator_, meta_.tensor().dtype(), shape, allocation_attr); |
79 | tensor_ = std::move(t); |
80 | } |
81 | |
82 | Status TensorResponse::ParseFrom(Source* source) { |
83 | if (!on_host_) { |
84 | protobuf::io::CodedInputStream input(source->contents()); |
85 | |
86 | // Pre-parse into local storage, then delegate to device. |
87 | if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) { |
88 | return errors::InvalidArgument("Cannot parse tensor from response" ); |
89 | } |
90 | Status s = |
91 | device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_); |
92 | // Reduce memory usage for big tensors. |
93 | { |
94 | TensorProto empty; |
95 | meta_.mutable_tensor()->Swap(&empty); |
96 | } |
97 | meta_.clear_tensor(); |
98 | return s; |
99 | } |
100 | if (already_used_) { |
101 | ClearTensor(); |
102 | } |
103 | already_used_ = true; |
104 | if (ParseFast(source)) return OkStatus(); |
105 | meta_.Clear(); |
106 | if (ParseSlow(source)) return OkStatus(); |
107 | return errors::InvalidArgument("Cannot parse tensor from response" ); |
108 | } |
109 | |
110 | // Define some helper routines for decoding protocol buffer wire format data |
111 | namespace { |
112 | // We only need some of the wiretype values for this code |
113 | enum WireType { |
114 | WIRETYPE_VARINT = 0, |
115 | WIRETYPE_LENGTH_DELIMITED = 2, |
116 | }; |
117 | inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; } |
118 | inline WireType GetTagWireType(uint32 tag) { |
119 | return static_cast<WireType>(tag & 0x7); |
120 | } |
121 | |
122 | bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) { |
123 | protobuf_uint64 v; |
124 | if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) { |
125 | *result = static_cast<int>(v); |
126 | return true; |
127 | } else { |
128 | return false; |
129 | } |
130 | } |
131 | |
132 | bool ReadNestedMessage(protobuf::io::CodedInputStream* input, |
133 | protobuf::Message* value) { |
134 | int length; |
135 | if (!ReadVarintSizeAsInt(input, &length)) return false; |
136 | std::pair<protobuf::io::CodedInputStream::Limit, int> p = |
137 | input->IncrementRecursionDepthAndPushLimit(length); |
138 | if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false; |
139 | // Make sure that parsing stopped when the limit was hit, not at an endgroup |
140 | // tag. |
141 | return input->DecrementRecursionDepthAndPopLimit(p.first); |
142 | } |
143 | |
144 | } // namespace |
145 | |
146 | bool TensorResponse::ParseTensorSubmessage( |
147 | protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) { |
148 | bool seen_tensor_content = false; |
149 | while (true) { |
150 | auto p = input->ReadTagWithCutoff(127); |
151 | int tag = GetTagFieldNumber(p.first); |
152 | WireType wt = GetTagWireType(p.first); |
153 | if (!p.second) { |
154 | bool ok = (tag == 0); |
155 | if (ok && !seen_tensor_content) { |
156 | // No tensor content: could be because it's a zero-length tensor |
157 | TensorShape shape(tensor_meta->tensor_shape()); |
158 | Tensor t(allocator_, tensor_meta->dtype(), shape); |
159 | tensor_ = std::move(t); |
160 | } |
161 | return ok; |
162 | } |
163 | switch (tag) { |
164 | case TensorProto::kDtypeFieldNumber: { |
165 | uint32 v; |
166 | if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false; |
167 | if (seen_tensor_content) return false; |
168 | tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v))); |
169 | if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false; |
170 | break; |
171 | } |
172 | case TensorProto::kTensorShapeFieldNumber: { |
173 | if ((wt != WIRETYPE_LENGTH_DELIMITED) || |
174 | !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape())) |
175 | return false; |
176 | if (seen_tensor_content) return false; |
177 | break; |
178 | } |
179 | case TensorProto::kVersionNumberFieldNumber: { |
180 | uint32 v; |
181 | if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false; |
182 | if (seen_tensor_content) return false; |
183 | tensor_meta->set_version_number(static_cast<int32>(v)); |
184 | break; |
185 | } |
186 | case TensorProto::kTensorContentFieldNumber: { |
187 | // If we haven't seen the dtype and tensor_shape data first, we can't |
188 | // deal with this in the fast path. |
189 | if (seen_tensor_content) return false; |
190 | if (wt != WIRETYPE_LENGTH_DELIMITED || |
191 | !tensor_meta->has_tensor_shape()) { |
192 | return false; |
193 | } |
194 | int num_bytes; |
195 | if (!ReadVarintSizeAsInt(input, &num_bytes)) return false; |
196 | seen_tensor_content = true; |
197 | TensorShape shape(tensor_meta->tensor_shape()); |
198 | Tensor t(allocator_, tensor_meta->dtype(), shape); |
199 | StringPiece buf = t.tensor_data(); |
200 | if (static_cast<size_t>(num_bytes) != buf.size()) return false; |
201 | // TODO(jeff,sanjay): Figure out a way to avoid this copy if |
202 | // the underlying ZeroCopyInputStream data is properly aligned |
203 | // and compatible with what allocator_ wants. |
204 | if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes)) |
205 | return false; |
206 | tensor_ = std::move(t); |
207 | break; |
208 | } |
209 | default: { |
210 | // Some other tag our fast path code is not prepared to handle. |
211 | // return false. |
212 | return false; |
213 | } |
214 | } |
215 | } |
216 | } |
217 | |
218 | bool TensorResponse::ParseFast(Source* source) { |
219 | protobuf::io::CodedInputStream input(source->contents()); |
220 | while (true) { |
221 | auto p = input.ReadTagWithCutoff(127); |
222 | int tag = GetTagFieldNumber(p.first); |
223 | WireType wt = GetTagWireType(p.first); |
224 | if (!p.second) { |
225 | return (tag == 0); |
226 | } |
227 | switch (tag) { |
228 | case RecvTensorResponse::kTensorFieldNumber: { |
229 | if (wt != WIRETYPE_LENGTH_DELIMITED) return false; |
230 | |
231 | int length; |
232 | if (!ReadVarintSizeAsInt(&input, &length)) return false; |
233 | std::pair<protobuf::io::CodedInputStream::Limit, int> p = |
234 | input.IncrementRecursionDepthAndPushLimit(length); |
235 | if (p.second < 0 || |
236 | !ParseTensorSubmessage(&input, meta_.mutable_tensor())) { |
237 | return false; |
238 | } |
239 | if (!input.DecrementRecursionDepthAndPopLimit(p.first)) { |
240 | return false; |
241 | } |
242 | break; |
243 | } |
244 | case RecvTensorResponse::kIsDeadFieldNumber: { |
245 | uint32 v; |
246 | if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false; |
247 | meta_.set_is_dead(v != 0); |
248 | break; |
249 | } |
250 | case RecvTensorResponse::kSendStartMicrosFieldNumber: { |
251 | protobuf_uint64 v; |
252 | if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false; |
253 | meta_.set_send_start_micros(static_cast<int64_t>(v)); |
254 | break; |
255 | } |
256 | case RecvTensorResponse::kTransportOptionsFieldNumber: { |
257 | if ((wt != WIRETYPE_LENGTH_DELIMITED) || |
258 | !ReadNestedMessage(&input, meta_.mutable_transport_options())) |
259 | return false; |
260 | break; |
261 | } |
262 | case RecvTensorResponse::kRequireAckFieldNumber: { |
263 | uint32 v; |
264 | if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false; |
265 | meta_.set_require_ack(v != 0); |
266 | break; |
267 | } |
268 | default: { |
269 | // Unknown tag, so don't handle we can't handle on the fast path |
270 | return false; |
271 | } |
272 | } |
273 | } |
274 | |
275 | return false; |
276 | } |
277 | |
278 | bool TensorResponse::ParseSlow(Source* source) { |
279 | if (!meta_.ParseFromZeroCopyStream(source->contents())) { |
280 | return false; |
281 | } |
282 | |
283 | Tensor parsed(meta_.tensor().dtype()); |
284 | if (!parsed.FromProto(allocator_, meta_.tensor())) { |
285 | return false; |
286 | } |
287 | tensor_ = std::move(parsed); |
288 | |
289 | // Reduce memory usage for big tensors. |
290 | { |
291 | TensorProto empty; |
292 | meta_.mutable_tensor()->Swap(&empty); |
293 | } |
294 | meta_.clear_tensor(); |
295 | |
296 | return true; |
297 | } |
298 | |
299 | } // namespace tensorflow |
300 | |