1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CC_OPS_CONST_OP_H_ |
17 | #define TENSORFLOW_CC_OPS_CONST_OP_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "tensorflow/cc/framework/ops.h" |
22 | #include "tensorflow/cc/framework/scope.h" |
23 | #include "tensorflow/core/graph/node_builder.h" |
24 | |
25 | namespace tensorflow { |
26 | namespace ops { |
27 | |
28 | /// @defgroup const_op Const Op |
29 | /// @{ |
30 | |
31 | Output Const(const Scope& scope, const Input::Initializer& val); |
32 | |
33 | Output ConstFromProto(const Scope& scope, const TensorProto& proto); |
34 | |
35 | NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp); |
36 | |
37 | template <typename T> |
38 | Output Const(const Scope& scope, const Input::Initializer& val) { |
39 | auto orig_const_output = Const(scope, val); |
40 | if (!scope.ok()) return Output(); |
41 | |
42 | typedef typename Input::Initializer::RealType<T>::type DstT; |
43 | |
44 | if (val.tensor.dtype() == DataTypeToEnum<DstT>::v()) { |
45 | return orig_const_output; |
46 | } |
47 | if (val.tensor.NumElements() == 0) { |
48 | Tensor t(DataTypeToEnum<DstT>::v(), val.tensor.shape()); |
49 | return Const(scope, Input::Initializer(t)); |
50 | } |
51 | |
52 | // TODO(keveman): Refactor Cast op's kernel implementation such that the code |
53 | // can be directly called here instead of adding the Cast op to the graph. |
54 | auto orig_const = AsNodeOut(scope, orig_const_output); |
55 | const auto cast_op_name = scope.GetUniqueNameForOp("Cast" ); |
56 | |
57 | auto cast_builder = NodeBuilder(cast_op_name, "Cast" ) |
58 | .Input(orig_const) |
59 | .Attr("DstT" , DataTypeToEnum<DstT>::v()); |
60 | scope.UpdateBuilder(&cast_builder); |
61 | Node* ret; |
62 | scope.UpdateStatus(cast_builder.Finalize(scope.graph(), &ret)); |
63 | if (!scope.ok()) return Output(); |
64 | scope.UpdateStatus(scope.DoShapeInference(ret)); |
65 | return Output(ret, 0); |
66 | } |
67 | |
68 | template <typename T> |
69 | Output Const(const Scope& scope, const T& v, const TensorShape shape) { |
70 | return Const(scope, Input::Initializer(v, shape)); |
71 | } |
72 | |
73 | template <typename T> |
74 | Output Const(const Scope& scope, const std::initializer_list<T>& v, |
75 | const TensorShape shape) { |
76 | return Const(scope, Input::Initializer(v, shape)); |
77 | } |
78 | |
79 | std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope, |
80 | const InputList& inp); |
81 | |
82 | /// }@ |
83 | |
84 | } // namespace ops |
85 | } // namespace tensorflow |
86 | |
87 | #endif // TENSORFLOW_CC_OPS_CONST_OP_H_ |
88 | |