1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include "tensorflow/core/framework/kernel_def_builder.h"
17
18#include "tensorflow/core/framework/attr_value.pb.h"
19#include "tensorflow/core/framework/kernel_def.pb.h"
20
21namespace tensorflow {
22
23KernelDefBuilder::KernelDefBuilder(const char* op_name) {
24 kernel_def_ = new KernelDef;
25 kernel_def_->set_op(op_name);
26}
27
28KernelDefBuilder::~KernelDefBuilder() {
29 DCHECK(kernel_def_ == nullptr) << "Did not call Build()";
30}
31
32KernelDefBuilder& KernelDefBuilder::Device(const char* device_type) {
33 kernel_def_->set_device_type(device_type);
34 return *this;
35}
36
37template <>
38KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64_t>(
39 const char* attr_name, gtl::ArraySlice<int64_t> allowed) {
40 auto* constraint = kernel_def_->add_constraint();
41 constraint->set_name(attr_name);
42 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
43 for (const int64_t integer : allowed) {
44 allowed_values->add_i(integer);
45 }
46 return *this;
47}
48
49template <>
50KernelDefBuilder& KernelDefBuilder::AttrConstraint<int64_t>(
51 const char* attr_name, int64_t allowed) {
52 return AttrConstraint(
53 attr_name,
54 gtl::ArraySlice<int64_t>(std::initializer_list<int64_t>({allowed})));
55}
56
57template <>
58KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
59 const char* attr_name, gtl::ArraySlice<string> allowed) {
60 auto* constraint = kernel_def_->add_constraint();
61 constraint->set_name(attr_name);
62 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
63 for (const auto& str : allowed) {
64 allowed_values->add_s(str);
65 }
66 return *this;
67}
68
69template <>
70KernelDefBuilder& KernelDefBuilder::AttrConstraint<string>(
71 const char* attr_name, string allowed) {
72 return AttrConstraint(
73 attr_name,
74 gtl::ArraySlice<string>(std::initializer_list<string>({allowed})));
75}
76
77template <>
78KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
79 const char* attr_name, gtl::ArraySlice<const char*> allowed) {
80 auto* constraint = kernel_def_->add_constraint();
81 constraint->set_name(attr_name);
82 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
83 for (const auto& str : allowed) {
84 allowed_values->add_s(str);
85 }
86 return *this;
87}
88
89template <>
90KernelDefBuilder& KernelDefBuilder::AttrConstraint<const char*>(
91 const char* attr_name, const char* allowed) {
92 return AttrConstraint(attr_name,
93 gtl::ArraySlice<const char*>(
94 std::initializer_list<const char*>({allowed})));
95}
96
97template <>
98KernelDefBuilder& KernelDefBuilder::AttrConstraint<bool>(const char* attr_name,
99 bool allowed) {
100 auto* constraint = kernel_def_->add_constraint();
101 constraint->set_name(attr_name);
102 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
103 allowed_values->add_b(allowed);
104 return *this;
105}
106
107KernelDefBuilder& KernelDefBuilder::TypeConstraint(
108 const char* attr_name, gtl::ArraySlice<DataType> allowed) {
109 auto* constraint = kernel_def_->add_constraint();
110 constraint->set_name(attr_name);
111 auto* allowed_values = constraint->mutable_allowed_values()->mutable_list();
112 for (DataType dt : allowed) {
113 allowed_values->add_type(dt);
114 }
115 return *this;
116}
117
118KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name,
119 DataType allowed) {
120 auto* constraint = kernel_def_->add_constraint();
121 constraint->set_name(attr_name);
122 constraint->mutable_allowed_values()->mutable_list()->add_type(allowed);
123 return *this;
124}
125
126KernelDefBuilder& KernelDefBuilder::HostMemory(const char* arg_name) {
127 kernel_def_->add_host_memory_arg(arg_name);
128 return *this;
129}
130
131KernelDefBuilder& KernelDefBuilder::Label(const char* label) {
132 CHECK_EQ(kernel_def_->label(), "")
133 << "Trying to set a kernel's label a second time: '" << label
134 << "' in: " << kernel_def_->DebugString();
135 kernel_def_->set_label(label);
136 return *this;
137}
138
139KernelDefBuilder& KernelDefBuilder::Priority(int32_t priority) {
140 kernel_def_->set_priority(priority);
141 return *this;
142}
143
144const KernelDef* KernelDefBuilder::Build() {
145 KernelDef* r = kernel_def_;
146 kernel_def_ = nullptr;
147 return r;
148}
149
150} // namespace tensorflow
151