1 | /* Copyright 2017 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/common_shape_fns.h" |
17 | #include "tensorflow/core/framework/op.h" |
18 | #include "tensorflow/core/framework/shape_inference.h" |
19 | |
20 | namespace tensorflow { |
21 | |
22 | REGISTER_OP("BatchFunction" ) |
23 | .Input("in_tensors: Tin" ) |
24 | .Input("captured_tensors: Tcaptured" ) |
25 | .Output("out_tensors: Tout" ) |
26 | .Attr("f: func" ) |
27 | .Attr("num_batch_threads: int" ) |
28 | // 'max_batch_size' denotes the maximum batch size acceptable, i.e., inputs |
29 | // with larger batch size are simply invalidated. |
30 | // By default, 'max_batch_size' must be equal to max value of |
31 | // 'allowed_batch_sizes'. |
32 | // By setting 'enable_large_batch_splitting' (attribute below) to true, |
33 | // 'max_batch_size' can be greater than or equal to max value of |
34 | // 'allowed_batch_sizes', in other words, |
35 | // 1) input with size > 'max_batch_size' is still invalidated. |
36 | // 2) input with |
37 | // a) size <= 'max_batch_size' |
38 | // b) size > max value of 'allowed_batch_sizes' |
39 | // will automatically be split into multiple batches (with batch size in |
40 | // 'allowed_batch_sizes'), executed, and re-composed (as final output). |
41 | .Attr("max_batch_size: int" ) |
42 | .Attr("batch_timeout_micros: int" ) |
43 | .Attr("max_enqueued_batches: int = 10" ) |
44 | .Attr("allowed_batch_sizes: list(int) = []" ) |
45 | .Attr("container: string = ''" ) |
46 | .Attr("shared_name: string = ''" ) |
47 | .Attr("batching_queue: string = ''" ) |
48 | .Attr("Tin: list(type)" ) |
49 | .Attr("Tcaptured: list(type) >= 0" ) |
50 | .Attr("Tout: list(type)" ) |
51 | // If 'enable_large_batch_splitting' is true, for input batches exceeding |
52 | // the largest value in "allowed_batch_sizes", allow the batch to be split |
53 | // into multiple batches with batch size within "allowed_batch_sizes". |
54 | // NOTE: Support for `enable_large_batch_splitting == true` is still |
55 | // developed in progress. |
56 | .Attr("enable_large_batch_splitting: bool = false" ) |
57 | // TODO(apassos): Fix this shape inference function. It requires shape |
58 | // inference of function calls. |
59 | .SetShapeFn(shape_inference::UnknownShape) |
60 | .SetIsDistributedCommunication(); |
61 | |
62 | REGISTER_OP("Batch" ) |
63 | .Input("in_tensors: T" ) |
64 | .Output("batched_tensors: T" ) |
65 | .Output("batch_index: int64" ) |
66 | .Output("id: int64" ) |
67 | .Attr("num_batch_threads: int" ) |
68 | .Attr("max_batch_size: int" ) |
69 | .Attr("max_enqueued_batches: int = 10" ) |
70 | .Attr("batch_timeout_micros: int" ) |
71 | .Attr("allowed_batch_sizes: list(int) = []" ) |
72 | .Attr("grad_timeout_micros: int" ) |
73 | .Attr("container: string = ''" ) |
74 | .Attr("shared_name: string = ''" ) |
75 | .Attr("batching_queue: string = ''" ) |
76 | .Attr("T: list(type)" ) |
77 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
78 | std::vector<shape_inference::ShapeHandle> in_shapes; |
79 | TF_RETURN_IF_ERROR(c->input("in_tensors" , &in_shapes)); |
80 | std::vector<shape_inference::ShapeHandle> out_shapes(in_shapes.size()); |
81 | for (int i = 0; i < in_shapes.size(); ++i) { |
82 | TF_RETURN_IF_ERROR( |
83 | c->ReplaceDim(in_shapes[i], 0, c->UnknownDim(), &out_shapes[i])); |
84 | } |
85 | TF_RETURN_IF_ERROR(c->set_output("batched_tensors" , out_shapes)); |
86 | TF_RETURN_IF_ERROR(c->set_output("id" , {c->Scalar()})); |
87 | TF_RETURN_IF_ERROR(c->set_output( |
88 | "batch_index" , |
89 | {c->MakeShape({shape_inference::DimensionOrConstant(c->UnknownDim()), |
90 | shape_inference::DimensionOrConstant(3)})})); |
91 | return OkStatus(); |
92 | }) |
93 | .SetIsDistributedCommunication(); |
94 | |
95 | REGISTER_OP("Unbatch" ) |
96 | .Input("batched_tensor: T" ) |
97 | .Input("batch_index: int64" ) |
98 | .Input("id: int64" ) |
99 | .Output("unbatched_tensor: T" ) |
100 | .Attr("timeout_micros: int" ) |
101 | .Attr("container: string = ''" ) |
102 | .Attr("shared_name: string = ''" ) |
103 | .Attr("T: type" ) |
104 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
105 | shape_inference::ShapeHandle out_shape; |
106 | TF_RETURN_IF_ERROR( |
107 | c->ReplaceDim(c->input(0), 0, c->UnknownDim(), &out_shape)); |
108 | c->set_output(0, out_shape); |
109 | return OkStatus(); |
110 | }); |
111 | |
112 | REGISTER_OP("UnbatchGrad" ) |
113 | .Input("original_input: T" ) |
114 | .Input("batch_index: int64" ) |
115 | .Input("grad: T" ) |
116 | .Input("id: int64" ) |
117 | .Output("batched_grad: T" ) |
118 | .Attr("container: string = ''" ) |
119 | .Attr("shared_name: string = ''" ) |
120 | .Attr("T: type" ) |
121 | .SetShapeFn([](shape_inference::InferenceContext* c) { |
122 | c->set_output(0, c->UnknownShapeOfRank(c->Rank(c->input(2)))); |
123 | return OkStatus(); |
124 | }); |
125 | |
126 | } // namespace tensorflow |
127 | |