1/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/distributed_runtime/tensor_coding.h"
17
18#include "google/protobuf/any.pb.h"
19
20#include "tensorflow/core/common_runtime/device.h"
21#include "tensorflow/core/framework/tensor.pb.h"
22#include "tensorflow/core/framework/tensor_shape.pb.h"
23
24namespace tensorflow {
25
26TensorResponse::Source::~Source() {}
27
28void TensorResponse::Clear() {
29 on_host_ = false;
30 device_ = nullptr;
31 alloc_attrs_ = AllocatorAttributes();
32 allocator_ = nullptr;
33 already_used_ = false;
34 ClearTensor();
35}
36
37void TensorResponse::ClearTensor() {
38 meta_.Clear();
39 tensor_ = Tensor();
40}
41
42void TensorResponse::InitAlloc(DeviceBase* d, const AllocatorAttributes& aa) {
43 Clear();
44 device_ = d;
45 alloc_attrs_ = aa;
46 const DeviceAttributes& da = d->attributes();
47 if (alloc_attrs_.on_host() || da.device_type() == "CPU") {
48 on_host_ = true;
49 }
50 allocator_ = device_->GetAllocator(alloc_attrs_);
51}
52
53Status TensorResponse::InitFrom(RecvTensorResponse* response) {
54 Status s;
55 meta_.Swap(response);
56 if (on_host_) {
57 if (!tensor_.FromProto(allocator_, meta_.tensor())) {
58 s = errors::InvalidArgument("Cannot parse tensor from response");
59 }
60 } else {
61 s = device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
62 }
63 {
64 TensorProto empty;
65 meta_.mutable_tensor()->Swap(&empty);
66 }
67 meta_.clear_tensor();
68 return s;
69}
70
71void TensorResponse::InitPartial(const RecvTensorResponse& response,
72 const AllocationAttributes& allocation_attr) {
73 // Everything except content is present in *response. Content will
74 // arrive later; allocate a Tensor with appropriate storage for that
75 // content.
76 meta_ = response;
77 TensorShape shape(meta_.tensor().tensor_shape());
78 Tensor t(allocator_, meta_.tensor().dtype(), shape, allocation_attr);
79 tensor_ = std::move(t);
80}
81
82Status TensorResponse::ParseFrom(Source* source) {
83 if (!on_host_) {
84 protobuf::io::CodedInputStream input(source->contents());
85
86 // Pre-parse into local storage, then delegate to device.
87 if (!meta_.ParseFromCodedStream(&input) || !input.ConsumedEntireMessage()) {
88 return errors::InvalidArgument("Cannot parse tensor from response");
89 }
90 Status s =
91 device_->MakeTensorFromProto(meta_.tensor(), alloc_attrs_, &tensor_);
92 // Reduce memory usage for big tensors.
93 {
94 TensorProto empty;
95 meta_.mutable_tensor()->Swap(&empty);
96 }
97 meta_.clear_tensor();
98 return s;
99 }
100 if (already_used_) {
101 ClearTensor();
102 }
103 already_used_ = true;
104 if (ParseFast(source)) return OkStatus();
105 meta_.Clear();
106 if (ParseSlow(source)) return OkStatus();
107 return errors::InvalidArgument("Cannot parse tensor from response");
108}
109
110// Define some helper routines for decoding protocol buffer wire format data
111namespace {
112// We only need some of the wiretype values for this code
113enum WireType {
114 WIRETYPE_VARINT = 0,
115 WIRETYPE_LENGTH_DELIMITED = 2,
116};
117inline int GetTagFieldNumber(uint32 tag) { return tag >> 3; }
118inline WireType GetTagWireType(uint32 tag) {
119 return static_cast<WireType>(tag & 0x7);
120}
121
122bool ReadVarintSizeAsInt(protobuf::io::CodedInputStream* input, int* result) {
123 protobuf_uint64 v;
124 if (input->ReadVarint64(&v) && v <= static_cast<uint64>(INT_MAX)) {
125 *result = static_cast<int>(v);
126 return true;
127 } else {
128 return false;
129 }
130}
131
132bool ReadNestedMessage(protobuf::io::CodedInputStream* input,
133 protobuf::Message* value) {
134 int length;
135 if (!ReadVarintSizeAsInt(input, &length)) return false;
136 std::pair<protobuf::io::CodedInputStream::Limit, int> p =
137 input->IncrementRecursionDepthAndPushLimit(length);
138 if (p.second < 0 || !value->MergePartialFromCodedStream(input)) return false;
139 // Make sure that parsing stopped when the limit was hit, not at an endgroup
140 // tag.
141 return input->DecrementRecursionDepthAndPopLimit(p.first);
142}
143
144} // namespace
145
146bool TensorResponse::ParseTensorSubmessage(
147 protobuf::io::CodedInputStream* input, TensorProto* tensor_meta) {
148 bool seen_tensor_content = false;
149 while (true) {
150 auto p = input->ReadTagWithCutoff(127);
151 int tag = GetTagFieldNumber(p.first);
152 WireType wt = GetTagWireType(p.first);
153 if (!p.second) {
154 bool ok = (tag == 0);
155 if (ok && !seen_tensor_content) {
156 // No tensor content: could be because it's a zero-length tensor
157 TensorShape shape(tensor_meta->tensor_shape());
158 Tensor t(allocator_, tensor_meta->dtype(), shape);
159 tensor_ = std::move(t);
160 }
161 return ok;
162 }
163 switch (tag) {
164 case TensorProto::kDtypeFieldNumber: {
165 uint32 v;
166 if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
167 if (seen_tensor_content) return false;
168 tensor_meta->set_dtype(static_cast<DataType>(static_cast<int>(v)));
169 if (!DataTypeCanUseMemcpy(tensor_meta->dtype())) return false;
170 break;
171 }
172 case TensorProto::kTensorShapeFieldNumber: {
173 if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
174 !ReadNestedMessage(input, tensor_meta->mutable_tensor_shape()))
175 return false;
176 if (seen_tensor_content) return false;
177 break;
178 }
179 case TensorProto::kVersionNumberFieldNumber: {
180 uint32 v;
181 if ((wt != WIRETYPE_VARINT) || !input->ReadVarint32(&v)) return false;
182 if (seen_tensor_content) return false;
183 tensor_meta->set_version_number(static_cast<int32>(v));
184 break;
185 }
186 case TensorProto::kTensorContentFieldNumber: {
187 // If we haven't seen the dtype and tensor_shape data first, we can't
188 // deal with this in the fast path.
189 if (seen_tensor_content) return false;
190 if (wt != WIRETYPE_LENGTH_DELIMITED ||
191 !tensor_meta->has_tensor_shape()) {
192 return false;
193 }
194 int num_bytes;
195 if (!ReadVarintSizeAsInt(input, &num_bytes)) return false;
196 seen_tensor_content = true;
197 TensorShape shape(tensor_meta->tensor_shape());
198 Tensor t(allocator_, tensor_meta->dtype(), shape);
199 StringPiece buf = t.tensor_data();
200 if (static_cast<size_t>(num_bytes) != buf.size()) return false;
201 // TODO(jeff,sanjay): Figure out a way to avoid this copy if
202 // the underlying ZeroCopyInputStream data is properly aligned
203 // and compatible with what allocator_ wants.
204 if (!input->ReadRaw(const_cast<char*>(buf.data()), num_bytes))
205 return false;
206 tensor_ = std::move(t);
207 break;
208 }
209 default: {
210 // Some other tag our fast path code is not prepared to handle.
211 // return false.
212 return false;
213 }
214 }
215 }
216}
217
218bool TensorResponse::ParseFast(Source* source) {
219 protobuf::io::CodedInputStream input(source->contents());
220 while (true) {
221 auto p = input.ReadTagWithCutoff(127);
222 int tag = GetTagFieldNumber(p.first);
223 WireType wt = GetTagWireType(p.first);
224 if (!p.second) {
225 return (tag == 0);
226 }
227 switch (tag) {
228 case RecvTensorResponse::kTensorFieldNumber: {
229 if (wt != WIRETYPE_LENGTH_DELIMITED) return false;
230
231 int length;
232 if (!ReadVarintSizeAsInt(&input, &length)) return false;
233 std::pair<protobuf::io::CodedInputStream::Limit, int> p =
234 input.IncrementRecursionDepthAndPushLimit(length);
235 if (p.second < 0 ||
236 !ParseTensorSubmessage(&input, meta_.mutable_tensor())) {
237 return false;
238 }
239 if (!input.DecrementRecursionDepthAndPopLimit(p.first)) {
240 return false;
241 }
242 break;
243 }
244 case RecvTensorResponse::kIsDeadFieldNumber: {
245 uint32 v;
246 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
247 meta_.set_is_dead(v != 0);
248 break;
249 }
250 case RecvTensorResponse::kSendStartMicrosFieldNumber: {
251 protobuf_uint64 v;
252 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint64(&v)) return false;
253 meta_.set_send_start_micros(static_cast<int64_t>(v));
254 break;
255 }
256 case RecvTensorResponse::kTransportOptionsFieldNumber: {
257 if ((wt != WIRETYPE_LENGTH_DELIMITED) ||
258 !ReadNestedMessage(&input, meta_.mutable_transport_options()))
259 return false;
260 break;
261 }
262 case RecvTensorResponse::kRequireAckFieldNumber: {
263 uint32 v;
264 if ((wt != WIRETYPE_VARINT) || !input.ReadVarint32(&v)) return false;
265 meta_.set_require_ack(v != 0);
266 break;
267 }
268 default: {
269 // Unknown tag, so don't handle we can't handle on the fast path
270 return false;
271 }
272 }
273 }
274
275 return false;
276}
277
278bool TensorResponse::ParseSlow(Source* source) {
279 if (!meta_.ParseFromZeroCopyStream(source->contents())) {
280 return false;
281 }
282
283 Tensor parsed(meta_.tensor().dtype());
284 if (!parsed.FromProto(allocator_, meta_.tensor())) {
285 return false;
286 }
287 tensor_ = std::move(parsed);
288
289 // Reduce memory usage for big tensors.
290 {
291 TensorProto empty;
292 meta_.mutable_tensor()->Swap(&empty);
293 }
294 meta_.clear_tensor();
295
296 return true;
297}
298
299} // namespace tensorflow
300