1/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
17#define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
18
19#include "tensorflow/core/framework/types.h"
20#include "tensorflow/core/lib/gtl/array_slice.h"
21#include "tensorflow/core/platform/macros.h"
22#include "tensorflow/core/platform/types.h"
23
24namespace tensorflow {
25
26// Forward declare proto so that kernels don't need to depend on it
27class KernelDef;
28
29// Builder class passed to the REGISTER_KERNEL_BUILDER() macro.
30class KernelDefBuilder {
31 public:
32 // Starts with just the name field set.
33 // Caller MUST call Build() and take ownership of the result.
34 explicit KernelDefBuilder(const char* op_name);
35 ~KernelDefBuilder();
36
37 // Required: specify the type of device this kernel supports.
38 // Returns *this.
39 KernelDefBuilder& Device(const char* device_type);
40
41 // Specify that this kernel supports a limited set of values for a
42 // particular type or list(type) attr (a further restriction than
43 // what the Op allows).
44 // Returns *this.
45 template <typename T>
46 KernelDefBuilder& AttrConstraint(const char* attr_name,
47 gtl::ArraySlice<T> allowed);
48
49 // Like AttrConstraint above but supports just a single value.
50 template <typename T>
51 KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed);
52
53 // Specify that this kernel supports a limited set of values for a
54 // particular type or list(type) attr (a further restriction than
55 // what the Op allows).
56 // Returns *this.
57 KernelDefBuilder& TypeConstraint(const char* attr_name,
58 gtl::ArraySlice<DataType> allowed);
59
60 // Like TypeConstraint but supports just a single type.
61 KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed);
62
63 // Like TypeConstraint, but (a) gets the type from a template parameter
64 // and (b) only supports a constraint to a single type.
65 template <class T>
66 KernelDefBuilder& TypeConstraint(const char* attr_name) TF_ATTRIBUTE_NOINLINE;
67 // TODO(josh11b): Support other types of attr constraints as needed.
68
69 // Specify that this kernel requires/provides an input/output arg
70 // in host memory (instead of the default, device memory).
71 // Returns *this.
72 KernelDefBuilder& HostMemory(const char* arg_name);
73
74 // Specify that this kernel requires a particular value for the
75 // "_kernel" attr. May only be specified once. Returns *this.
76 KernelDefBuilder& Label(const char* label);
77
78 // Specify a priority number for this kernel.
79 KernelDefBuilder& Priority(int32_t priority);
80
81 // Returns a pointer to a KernelDef with fields set based on the
82 // above calls to this instance.
83 // Caller takes ownership of the result.
84 const KernelDef* Build();
85
86 private:
87 KernelDef* kernel_def_;
88
89 TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder);
90};
91
92// IMPLEMENTATION
93
94template <class T>
95KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) {
96 return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v());
97}
98
99} // namespace tensorflow
100
101#endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_
102