1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/attr_value.pb.h"
17#include "tensorflow/core/framework/common_shape_fns.h"
18#include "tensorflow/core/framework/op.h"
19#include "tensorflow/core/framework/shape_inference.h"
20#include "tensorflow/core/framework/tensor_shape.pb.h"
21#include "tensorflow/core/lib/core/errors.h"
22
23namespace tensorflow {
24
25REGISTER_SYSTEM_OP("_Arg")
26 .Output("output: T")
27 .Attr("T: type")
28 .Attr("index: int >= 0")
29 .SetIsStateful()
30 .SetShapeFn([](shape_inference::InferenceContext* context) {
31 const AttrValue* dtype_attr = context->GetAttr("T");
32 if (!dtype_attr) {
33 return errors::InvalidArgument(
34 "_Arg node does not have attribute \"T\"");
35 }
36
37 const AttrValue* shape_attr = context->GetAttr("_output_shapes");
38 if (shape_attr && shape_attr->has_list()) {
39 if (shape_attr->list().shape().empty()) {
40 return errors::InvalidArgument(
41 "Invalid \"_output_shapes\" attribute value for _Arg node: ",
42 shape_attr->DebugString());
43 }
44 const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
45 shape_inference::ShapeHandle shape_handle;
46 TF_RETURN_IF_ERROR(
47 context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
48 context->set_output(0, shape_handle);
49 } else {
50 context->set_output(0, context->UnknownShape());
51 }
52
53 if (dtype_attr->type() != DT_RESOURCE) {
54 return OkStatus();
55 }
56
57 // If the argument is for a resource type, then also try to infer the
58 // type of the tensor store in the resource type.
59 dtype_attr = context->GetAttr("_handle_dtypes");
60 shape_attr = context->GetAttr("_handle_shapes");
61 // If either the shape or type attribute is not set then simply return
62 // with unknown output set above.
63 if (!dtype_attr || !shape_attr) {
64 return OkStatus();
65 }
66
67 if (dtype_attr->list().type().empty()) {
68 return errors::InvalidArgument(
69 "Invalid \"_handle_dtypes\" attribute value for _Arg node: ",
70 dtype_attr->DebugString());
71 }
72 if (shape_attr->list().shape().empty()) {
73 return errors::InvalidArgument(
74 "Invalid \"_handle_shapes\" attribute value for _Arg node: ",
75 shape_attr->DebugString());
76 }
77 DataType dtype = dtype_attr->list().type(0);
78 const TensorShapeProto& shape_proto = shape_attr->list().shape(0);
79 shape_inference::ShapeHandle shape_handle;
80 TF_RETURN_IF_ERROR(
81 context->MakeShapeFromShapeProto(shape_proto, &shape_handle));
82 context->set_output_handle_shapes_and_types(
83 0, std::vector<shape_inference::ShapeAndType>{{shape_handle, dtype}});
84 return OkStatus();
85 })
86 .Doc(R"doc(
87A graph node which represents an argument to a function.
88
89output: The argument.
90index: This argument is the index-th argument of the function.
91
92Attributes for shape inference:
931. _output_shapes: this attribute should contain a list of TensorShapeProto
94 describing the shape(s) of the tensor(s) this _Arg node will produce. If set,
95 _Arg node's shape inference function will use it as the node's output shapes.
962. _handle_dtypes and _handle_shapes: these attributes can be set on an _Arg
97 node producing resource output(s). If set, value of _handle_dtypes should
98 contain the dtype(s) of the resource(s) and value of _handle_shapes should
99 contain the shape(s) of the resource(s). If both attributes are set, _Arg
100 node's shape inference function will use their values as the node's output
101 handle's type(s) and shape(s).
102)doc");
103
104REGISTER_SYSTEM_OP("_DeviceArg")
105 .Output("output: T")
106 .Attr("T: type")
107 .Attr("index: int >= 0")
108 .SetIsStateful()
109 .SetShapeFn([](shape_inference::InferenceContext* context) {
110 context->set_output(0, context->UnknownShape());
111 return OkStatus();
112 })
113 .Doc(R"doc(
114A graph node which represents an argument to a function.
115
116output: The argument.
117index: This argument is the index-th argument of the function.
118)doc");
119
120REGISTER_SYSTEM_OP("_Retval")
121 .Input("input: T")
122 .Attr("T: type")
123 .Attr("index: int >= 0")
124 .SetIsStateful()
125 .SetShapeFn([](shape_inference::InferenceContext* context) {
126 return OkStatus();
127 })
128 .Doc(R"doc(
129A graph node which represents a return value of a function.
130
131input: The return value.
132index: This return value is the index-th return value of the function.
133)doc");
134
135REGISTER_SYSTEM_OP("_DeviceRetval")
136 .Input("input: T")
137 .Attr("T: type")
138 .Attr("index: int >= 0")
139 .SetIsStateful()
140 .SetShapeFn([](shape_inference::InferenceContext* context) {
141 return OkStatus();
142 })
143 .Doc(R"doc(
144A graph node which represents a return value of a function.
145
146input: The return value.
147index: This return value is the index-th return value of the function.
148)doc");
149
150REGISTER_SYSTEM_OP("_ListToArray")
151 .Input("input: Tin")
152 .Output("output: N * T")
153 .Attr("Tin: list(type)")
154 .Attr("T: type")
155 .Attr("N: int >= 1")
156 .SetShapeFn(shape_inference::UnknownShape)
157 .Doc(R"doc(
158Converts a list of tensors to an array of tensors.
159)doc");
160
161REGISTER_SYSTEM_OP("_ArrayToList")
162 .Input("input: N * T")
163 .Output("output: out_types")
164 .Attr("T: type")
165 .Attr("N: int >= 1")
166 .Attr("out_types: list(type)")
167 .SetShapeFn(shape_inference::UnknownShape)
168 .Doc(R"doc(
169Converts an array of tensors to a list of tensors.
170)doc");
171
172} // namespace tensorflow
173