1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_ |
17 | #define TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_ |
18 | |
19 | #include "tensorflow/core/framework/types.h" |
20 | #include "tensorflow/core/lib/gtl/array_slice.h" |
21 | #include "tensorflow/core/platform/macros.h" |
22 | #include "tensorflow/core/platform/types.h" |
23 | |
24 | namespace tensorflow { |
25 | |
26 | // Forward declare proto so that kernels don't need to depend on it |
27 | class KernelDef; |
28 | |
29 | // Builder class passed to the REGISTER_KERNEL_BUILDER() macro. |
30 | class KernelDefBuilder { |
31 | public: |
32 | // Starts with just the name field set. |
33 | // Caller MUST call Build() and take ownership of the result. |
34 | explicit KernelDefBuilder(const char* op_name); |
35 | ~KernelDefBuilder(); |
36 | |
37 | // Required: specify the type of device this kernel supports. |
38 | // Returns *this. |
39 | KernelDefBuilder& Device(const char* device_type); |
40 | |
41 | // Specify that this kernel supports a limited set of values for a |
42 | // particular type or list(type) attr (a further restriction than |
43 | // what the Op allows). |
44 | // Returns *this. |
45 | template <typename T> |
46 | KernelDefBuilder& AttrConstraint(const char* attr_name, |
47 | gtl::ArraySlice<T> allowed); |
48 | |
49 | // Like AttrConstraint above but supports just a single value. |
50 | template <typename T> |
51 | KernelDefBuilder& AttrConstraint(const char* attr_name, T allowed); |
52 | |
53 | // Specify that this kernel supports a limited set of values for a |
54 | // particular type or list(type) attr (a further restriction than |
55 | // what the Op allows). |
56 | // Returns *this. |
57 | KernelDefBuilder& TypeConstraint(const char* attr_name, |
58 | gtl::ArraySlice<DataType> allowed); |
59 | |
60 | // Like TypeConstraint but supports just a single type. |
61 | KernelDefBuilder& TypeConstraint(const char* attr_name, DataType allowed); |
62 | |
63 | // Like TypeConstraint, but (a) gets the type from a template parameter |
64 | // and (b) only supports a constraint to a single type. |
65 | template <class T> |
66 | KernelDefBuilder& TypeConstraint(const char* attr_name) TF_ATTRIBUTE_NOINLINE; |
67 | // TODO(josh11b): Support other types of attr constraints as needed. |
68 | |
69 | // Specify that this kernel requires/provides an input/output arg |
70 | // in host memory (instead of the default, device memory). |
71 | // Returns *this. |
72 | KernelDefBuilder& HostMemory(const char* arg_name); |
73 | |
74 | // Specify that this kernel requires a particular value for the |
75 | // "_kernel" attr. May only be specified once. Returns *this. |
76 | KernelDefBuilder& Label(const char* label); |
77 | |
78 | // Specify a priority number for this kernel. |
79 | KernelDefBuilder& Priority(int32_t priority); |
80 | |
81 | // Returns a pointer to a KernelDef with fields set based on the |
82 | // above calls to this instance. |
83 | // Caller takes ownership of the result. |
84 | const KernelDef* Build(); |
85 | |
86 | private: |
87 | KernelDef* kernel_def_; |
88 | |
89 | TF_DISALLOW_COPY_AND_ASSIGN(KernelDefBuilder); |
90 | }; |
91 | |
92 | // IMPLEMENTATION |
93 | |
94 | template <class T> |
95 | KernelDefBuilder& KernelDefBuilder::TypeConstraint(const char* attr_name) { |
96 | return this->TypeConstraint(attr_name, DataTypeToEnum<T>::v()); |
97 | } |
98 | |
99 | } // namespace tensorflow |
100 | |
101 | #endif // TENSORFLOW_CORE_FRAMEWORK_KERNEL_DEF_BUILDER_H_ |
102 | |