1 | /* Copyright 2016 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/cc/ops/const_op.h" |
17 | #include "tensorflow/core/framework/types.h" |
18 | |
19 | namespace tensorflow { |
20 | namespace ops { |
21 | |
22 | namespace { |
23 | template <typename T> |
24 | Output ConstHelper(const Scope& scope, const T& value, DataType dtype) { |
25 | if (!scope.ok()) return Output(); |
26 | |
27 | Node* ret; |
28 | Graph* graph = scope.graph(); |
29 | const string unique_name = scope.GetUniqueNameForOp("Const" ); |
30 | auto builder = NodeBuilder(unique_name, "Const" ) |
31 | .Attr("value" , value) |
32 | .Attr("dtype" , dtype); |
33 | scope.UpdateBuilder(&builder); |
34 | scope.UpdateStatus(builder.Finalize(graph, &ret)); |
35 | if (!scope.ok()) return Output(); |
36 | |
37 | scope.UpdateStatus(scope.DoShapeInference(ret)); |
38 | if (!scope.ok()) return Output(); |
39 | |
40 | return Output(ret); |
41 | } |
42 | } // namespace |
43 | |
44 | Output Const(const Scope& scope, const Input::Initializer& val) { |
45 | if (!val.status.ok()) { |
46 | scope.UpdateStatus(val.status); |
47 | return Output(); |
48 | } |
49 | return ConstHelper(scope, val.tensor, val.tensor.dtype()); |
50 | } |
51 | |
52 | Output ConstFromProto(const Scope& scope, const TensorProto& proto) { |
53 | return ConstHelper(scope, proto, proto.dtype()); |
54 | } |
55 | |
56 | NodeBuilder::NodeOut AsNodeOut(const Scope& scope, const Input& inp) { |
57 | if (!inp.status().ok()) { |
58 | scope.UpdateStatus(inp.status()); |
59 | return NodeBuilder::NodeOut(inp.node(), inp.index()); |
60 | } |
61 | if (inp.node()) { |
62 | return NodeBuilder::NodeOut(inp.node(), inp.index()); |
63 | } |
64 | if (!inp.node_name().empty()) { |
65 | return NodeBuilder::NodeOut(inp.node_name(), inp.index(), inp.data_type()); |
66 | } |
67 | auto transformed = Input{ |
68 | Const(scope.NewSubScope("Const" ), Input::Initializer(inp.tensor()))}; |
69 | return NodeBuilder::NodeOut{transformed.node(), transformed.index()}; |
70 | } |
71 | |
72 | std::vector<NodeBuilder::NodeOut> AsNodeOutList(const Scope& scope, |
73 | const InputList& inp) { |
74 | std::vector<NodeBuilder::NodeOut> out; |
75 | for (const auto& i : inp) { |
76 | const auto node_out = AsNodeOut(scope, i); |
77 | if (!scope.ok()) { |
78 | return {}; |
79 | } |
80 | out.push_back(node_out); |
81 | } |
82 | return out; |
83 | } |
84 | |
85 | } // namespace ops |
86 | } // namespace tensorflow |
87 | |