1 | #pragma once |
2 | |
3 | // TODO: unify to C10_MOBILE. In theory this header could be used in OSS. |
4 | #ifdef TEMPLATE_SELECTIVE_BUILD |
5 | #include <ATen/selected_mobile_ops.h> |
6 | #endif |
7 | |
8 | /** |
9 | * This header implements functionality to build PyTorch with only a certain |
10 | * set of operators (+ dependencies) included. |
11 | * |
12 | * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these |
13 | * two ops will be included in your build. The allowlist records operators |
14 | * only, no overloads; if you include aten::add, all overloads of aten::add |
15 | * will be included. |
16 | * |
17 | * Internally, this is done by removing the operator registration calls |
18 | * using compile time programming, and the linker will then prune all |
19 | * operator functions that weren't registered. |
20 | * See Note [Selective build] for more details |
21 | * |
22 | * WARNING: The allowlist mechanism doesn't work for all ways you could go about |
23 | * registering an operator. If the dispatch key / operator name is not |
24 | * sufficiently obvious at compile time, then the allowlisting mechanism |
25 | * will fail (and the operator will be included in the binary anyway). |
26 | */ |
27 | |
28 | #include <c10/util/string_view.h> |
29 | #include <c10/core/DispatchKey.h> |
30 | #include <c10/macros/Macros.h> |
31 | |
32 | |
33 | #if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) |
34 | #include <ATen/record_function.h> |
35 | #endif |
36 | |
37 | namespace c10 { |
38 | |
39 | namespace impl { |
40 | |
41 | constexpr bool allowlist_contains(string_view allowlist, string_view item); // Forward Declare |
42 | |
43 | /** |
44 | * In selective build mode returns true/false depending on whether a build |
45 | * feature is available or not. |
46 | * |
47 | * In instrumenting mode (tracing mode), always returns true, and doesn't |
48 | * trigger any side effects. |
49 | */ |
50 | constexpr bool is_build_feature_available(const char* name) { |
51 | #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) |
52 | // Selective Build mode. |
53 | #if !defined(TORCH_BUILD_FEATURE_ALLOWLIST) |
54 | (void)name; |
55 | return true; |
56 | #else |
57 | return allowlist_contains( |
58 | C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST), |
59 | name); |
60 | #endif |
61 | |
62 | #else |
63 | // Instrumenting mode. |
64 | (void)name; |
65 | return true; |
66 | #endif |
67 | } |
68 | |
69 | [[noreturn]] void build_feature_required_feature_not_available(const char* feature); |
70 | |
71 | /** |
72 | * Use BUILD_FEATURE_REQUIRED macro in user-code. |
73 | * |
74 | * In selective build mode becomes a no-op if the build feature passed |
75 | * in is available. If not available, throws an exception (c10::Error). |
76 | * The compiler is able to perform dead code elimination for code |
77 | * following this method if the build feature is not available. |
78 | * |
79 | * In instrumenting mode (tracing mode), registers (as a side effect) |
80 | * the presence of this specific build feature being triggered. |
81 | */ |
82 | #if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode |
83 | |
84 | #if defined(TORCH_BUILD_FEATURE_ALLOWLIST) |
85 | #define BUILD_FEATURE_REQUIRED(NAME) \ |
86 | if (!c10::impl::is_build_feature_available(NAME)) { \ |
87 | ::c10::impl::build_feature_required_feature_not_available(NAME); \ |
88 | } |
89 | #else // Everything trivially selected |
90 | #define BUILD_FEATURE_REQUIRED(NAME) |
91 | |
92 | #endif |
93 | |
94 | #else // trace mode |
95 | #define BUILD_FEATURE_REQUIRED(NAME) \ |
96 | RECORD_FUNCTION_WITH_SCOPE( \ |
97 | at::RecordScope::BUILD_FEATURE, \ |
98 | std::string(NAME), \ |
99 | {}); |
100 | #endif |
101 | |
102 | // Use this macro, and not is_build_feature_available |
103 | #define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME) |
104 | |
105 | // returns true iff allowlist contains item |
106 | // allowlist_contains("a;bc;d", "bc") == true |
107 | constexpr bool allowlist_contains(string_view allowlist, string_view item) { |
108 | //Choose a really big value for next so that if something goes wrong |
109 | //this code will blow up in a hopefully detectable way. |
110 | size_t next = std::numeric_limits<size_t>::max(); |
111 | for (size_t cur = 0; cur <= allowlist.size(); cur = next) { |
112 | next = allowlist.find(';', cur); |
113 | if (next != string_view::npos) { |
114 | if (allowlist.substr(cur, next - cur).compare(item) == 0) { |
115 | return true; |
116 | } |
117 | next++; |
118 | } else { |
119 | if (allowlist.substr(cur).compare(item) == 0) { |
120 | return true; |
121 | } |
122 | break; |
123 | } |
124 | } |
125 | return false; |
126 | } |
127 | |
128 | // Returns true iff the given op name is on the allowlist |
129 | // and should be registered |
130 | constexpr bool op_allowlist_check(string_view op_name) { |
131 | assert(op_name.find("::" ) != string_view::npos); |
132 | // Use assert() instead of throw() due to a gcc bug. See: |
133 | // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function |
134 | // https://github.com/fmtlib/fmt/issues/682 |
135 | assert(op_name.find("(" ) == string_view::npos); |
136 | #if !defined(TORCH_OPERATOR_WHITELIST) |
137 | // If the TORCH_OPERATOR_WHITELIST parameter is not defined, |
138 | // all ops are to be registered |
139 | return true; |
140 | #else |
141 | return allowlist_contains( |
142 | C10_STRINGIZE(TORCH_OPERATOR_WHITELIST), |
143 | // This function is majorly used for mobile selective build with |
144 | // root operators, where the overload is included in the allowlist. |
145 | op_name); |
146 | // // Strip overload name (as allowlist doesn't contain overloads) |
147 | // // Another function based on this may be added when there's usage |
148 | // // on op names without overload. |
149 | // OperatorNameView::parse(op_name).name); |
150 | #endif |
151 | } |
152 | |
153 | // Returns true iff the given schema string is on the allowlist |
154 | // and should be registered |
155 | constexpr bool schema_allowlist_check(string_view schema) { |
156 | #if defined(TORCH_FORCE_SCHEMA_REGISTRATION) |
157 | return true; |
158 | #else |
159 | return op_allowlist_check(schema.substr(0, schema.find("(" ))); |
160 | #endif |
161 | } |
162 | |
163 | // Returns true iff the given custom class name is on the allowlist |
164 | // and should be registered |
165 | constexpr bool custom_class_allowlist_check(string_view custom_class_name) { |
166 | #if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST) |
167 | // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined, |
168 | // all custom classes are to be registered |
169 | (void)custom_class_name; |
170 | return true; |
171 | #else |
172 | return allowlist_contains( |
173 | C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST), |
174 | custom_class_name); |
175 | #endif |
176 | } |
177 | |
178 | // schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST. |
179 | // Add this API to pass arbitrary allowlist. |
180 | constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) { |
181 | return allowlist_contains(allowlist, schema.substr(0, schema.find("(" ))); |
182 | } |
183 | |
184 | // Returns true iff the given dispatch key is on the allowlist |
185 | // and should be registered. When we turn this on, the list of valid |
186 | // mobile dispatch keys is hard coded (but you need to make sure |
187 | // that you have the correct set of dispatch keys for this). |
188 | constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) { |
189 | #ifdef C10_MOBILE |
190 | return true; |
191 | // Disabled for now: to be enabled later! |
192 | // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll; |
193 | #else |
194 | return true; |
195 | #endif |
196 | } |
197 | |
198 | } // namespace impl |
199 | } // namespace c10 |
200 | |