1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/framework/kernel_def_builder.h" |
17 | |
18 | #include "tensorflow/core/framework/attr_value.pb.h" |
19 | #include "tensorflow/core/framework/kernel_def.pb.h" |
20 | |
21 | namespace tensorflow { |
22 | |
23 | KernelDefBuilder::KernelDefBuilder(const char* op_name) { |
24 | kernel_def_ = new KernelDef; |
25 | kernel_def_->set_op(op_name); |
26 | } |
27 | |
28 | KernelDefBuilder::~KernelDefBuilder() { |
29 | DCHECK(kernel_def_ == nullptr) << "Did not call Build()" ; |
30 | } |
31 | |
32 | KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) { |
33 | kernel_def_->set_device_type(device_type); |
34 | return *this; |
35 | } |
36 | |
37 | template <> |
38 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64_t>( |
39 | const char* attr_name, gtl::ArraySlice<int64_t> allowed) { |
40 | auto* constraint = kernel_def_->add_constraint(); |
41 | constraint->set_name(attr_name); |
42 | auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); |
43 | for (const int64_t integer : allowed) { |
44 | allowed_values->add_i(integer); |
45 | } |
46 | return *this; |
47 | } |
48 | |
49 | template <> |
50 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64_t>( |
51 | const char* attr_name, int64_t allowed) { |
52 | return AttrConstraint( |
53 | attr_name, |
54 | gtl::ArraySlice<int64_t>(std::initializer_list<int64_t>({allowed}))); |
55 | } |
56 | |
57 | template <> |
58 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>( |
59 | const char* attr_name, gtl::ArraySlice<string> allowed) { |
60 | auto* constraint = kernel_def_->add_constraint(); |
61 | constraint->set_name(attr_name); |
62 | auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); |
63 | for (const auto& str : allowed) { |
64 | allowed_values->add_s(str); |
65 | } |
66 | return *this; |
67 | } |
68 | |
69 | template <> |
70 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>( |
71 | const char* attr_name, string allowed) { |
72 | return AttrConstraint( |
73 | attr_name, |
74 | gtl::ArraySlice<string>(std::initializer_list<string>({allowed}))); |
75 | } |
76 | |
77 | template <> |
78 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>( |
79 | const char* attr_name, gtl::ArraySlice<const char*> allowed) { |
80 | auto* constraint = kernel_def_->add_constraint(); |
81 | constraint->set_name(attr_name); |
82 | auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); |
83 | for (const auto& str : allowed) { |
84 | allowed_values->add_s(str); |
85 | } |
86 | return *this; |
87 | } |
88 | |
89 | template <> |
90 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>( |
91 | const char* attr_name, const char* allowed) { |
92 | return AttrConstraint(attr_name, |
93 | gtl::ArraySlice<const char*>( |
94 | std::initializer_list<const char*>({allowed}))); |
95 | } |
96 | |
97 | template <> |
98 | KernelDefBuilder& KernelDefBuilder::AttrConstraint<bool>(const char* attr_name, |
99 | bool allowed) { |
100 | auto* constraint = kernel_def_->add_constraint(); |
101 | constraint->set_name(attr_name); |
102 | auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); |
103 | allowed_values->add_b(allowed); |
104 | return *this; |
105 | } |
106 | |
107 | KernelDefBuilder& KernelDefBuilder::TypeConstraint( |
108 | const char* attr_name, gtl::ArraySlice<DataType> allowed) { |
109 | auto* constraint = kernel_def_->add_constraint(); |
110 | constraint->set_name(attr_name); |
111 | auto* allowed_values = constraint->mutable_allowed_values()->mutable_list(); |
112 | for (DataType dt : allowed) { |
113 | allowed_values->add_type(dt); |
114 | } |
115 | return *this; |
116 | } |
117 | |
118 | KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name, |
119 | DataType allowed) { |
120 | auto* constraint = kernel_def_->add_constraint(); |
121 | constraint->set_name(attr_name); |
122 | constraint->mutable_allowed_values()->mutable_list()->add_type(allowed); |
123 | return *this; |
124 | } |
125 | |
126 | KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) { |
127 | kernel_def_->add_host_memory_arg(arg_name); |
128 | return *this; |
129 | } |
130 | |
131 | KernelDefBuilder& KernelDefBuilder::Label(const char* label) { |
132 | CHECK_EQ(kernel_def_->label(), "" ) |
133 | << "Trying to set a kernel's label a second time: '" << label |
134 | << "' in: " << kernel_def_->DebugString(); |
135 | kernel_def_->set_label(label); |
136 | return *this; |
137 | } |
138 | |
139 | KernelDefBuilder& KernelDefBuilder::Priority(int32_t priority) { |
140 | kernel_def_->set_priority(priority); |
141 | return *this; |
142 | } |
143 | |
144 | const KernelDef* KernelDefBuilder::Build() { |
145 | KernelDef* r = kernel_def_; |
146 | kernel_def_ = nullptr; |
147 | return r; |
148 | } |
149 | |
150 | } // namespace tensorflow |
151 | |