1#pragma once
2
3// TODO: unify to C10_MOBILE. In theory this header could be used in OSS.
4#ifdef TEMPLATE_SELECTIVE_BUILD
5#include <ATen/selected_mobile_ops.h>
6#endif
7
8/**
9 * This header implements functionality to build PyTorch with only a certain
10 * set of operators (+ dependencies) included.
11 *
12 * - Build with -DTORCH_OPERATOR_WHITELIST="aten::add;aten::sub" and only these
13 * two ops will be included in your build. The allowlist records operators
14 * only, no overloads; if you include aten::add, all overloads of aten::add
15 * will be included.
16 *
17 * Internally, this is done by removing the operator registration calls
18 * using compile time programming, and the linker will then prune all
19 * operator functions that weren't registered.
20 * See Note [Selective build] for more details
21 *
22 * WARNING: The allowlist mechanism doesn't work for all ways you could go about
23 * registering an operator. If the dispatch key / operator name is not
24 * sufficiently obvious at compile time, then the allowlisting mechanism
25 * will fail (and the operator will be included in the binary anyway).
26 */
27
28#include <c10/util/string_view.h>
29#include <c10/core/DispatchKey.h>
30#include <c10/macros/Macros.h>
31
32
33#if defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
34#include <ATen/record_function.h>
35#endif
36
37namespace c10 {
38
39namespace impl {
40
41constexpr bool allowlist_contains(string_view allowlist, string_view item); // Forward Declare
42
43/**
44 * In selective build mode returns true/false depending on whether a build
45 * feature is available or not.
46 *
47 * In instrumenting mode (tracing mode), always returns true, and doesn't
48 * trigger any side effects.
49 */
50constexpr bool is_build_feature_available(const char* name) {
51#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE)
52 // Selective Build mode.
53#if !defined(TORCH_BUILD_FEATURE_ALLOWLIST)
54 (void)name;
55 return true;
56#else
57 return allowlist_contains(
58 C10_STRINGIZE(TORCH_BUILD_FEATURE_ALLOWLIST),
59 name);
60#endif
61
62#else
63 // Instrumenting mode.
64 (void)name;
65 return true;
66#endif
67}
68
69[[noreturn]] void build_feature_required_feature_not_available(const char* feature);
70
71/**
72 * Use BUILD_FEATURE_REQUIRED macro in user-code.
73 *
74 * In selective build mode becomes a no-op if the build feature passed
75 * in is available. If not available, throws an exception (c10::Error).
76 * The compiler is able to perform dead code elimination for code
77 * following this method if the build feature is not available.
78 *
79 * In instrumenting mode (tracing mode), registers (as a side effect)
80 * the presence of this specific build feature being triggered.
81 */
82#if !defined(ENABLE_RECORD_KERNEL_FUNCTION_DTYPE) // selective build mode
83
84#if defined(TORCH_BUILD_FEATURE_ALLOWLIST)
85#define BUILD_FEATURE_REQUIRED(NAME) \
86 if (!c10::impl::is_build_feature_available(NAME)) { \
87 ::c10::impl::build_feature_required_feature_not_available(NAME); \
88 }
89#else // Everything trivially selected
90#define BUILD_FEATURE_REQUIRED(NAME)
91
92#endif
93
94#else // trace mode
95#define BUILD_FEATURE_REQUIRED(NAME) \
96 RECORD_FUNCTION_WITH_SCOPE( \
97 at::RecordScope::BUILD_FEATURE, \
98 std::string(NAME), \
99 {});
100#endif
101
102// Use this macro, and not is_build_feature_available
103#define BUILD_FEATURE_AVAILABLE(NAME) ::c10::impl::is_build_feature_available(NAME)
104
105// returns true iff allowlist contains item
106// allowlist_contains("a;bc;d", "bc") == true
107constexpr bool allowlist_contains(string_view allowlist, string_view item) {
108 //Choose a really big value for next so that if something goes wrong
109 //this code will blow up in a hopefully detectable way.
110 size_t next = std::numeric_limits<size_t>::max();
111 for (size_t cur = 0; cur <= allowlist.size(); cur = next) {
112 next = allowlist.find(';', cur);
113 if (next != string_view::npos) {
114 if (allowlist.substr(cur, next - cur).compare(item) == 0) {
115 return true;
116 }
117 next++;
118 } else {
119 if (allowlist.substr(cur).compare(item) == 0) {
120 return true;
121 }
122 break;
123 }
124 }
125 return false;
126}
127
128// Returns true iff the given op name is on the allowlist
129// and should be registered
130constexpr bool op_allowlist_check(string_view op_name) {
131 assert(op_name.find("::") != string_view::npos);
132 // Use assert() instead of throw() due to a gcc bug. See:
133 // https://stackoverflow.com/questions/34280729/throw-in-constexpr-function
134 // https://github.com/fmtlib/fmt/issues/682
135 assert(op_name.find("(") == string_view::npos);
136#if !defined(TORCH_OPERATOR_WHITELIST)
137 // If the TORCH_OPERATOR_WHITELIST parameter is not defined,
138 // all ops are to be registered
139 return true;
140#else
141 return allowlist_contains(
142 C10_STRINGIZE(TORCH_OPERATOR_WHITELIST),
143 // This function is majorly used for mobile selective build with
144 // root operators, where the overload is included in the allowlist.
145 op_name);
146 // // Strip overload name (as allowlist doesn't contain overloads)
147 // // Another function based on this may be added when there's usage
148 // // on op names without overload.
149 // OperatorNameView::parse(op_name).name);
150#endif
151}
152
153// Returns true iff the given schema string is on the allowlist
154// and should be registered
155constexpr bool schema_allowlist_check(string_view schema) {
156#if defined(TORCH_FORCE_SCHEMA_REGISTRATION)
157 return true;
158#else
159 return op_allowlist_check(schema.substr(0, schema.find("(")));
160#endif
161}
162
163// Returns true iff the given custom class name is on the allowlist
164// and should be registered
165constexpr bool custom_class_allowlist_check(string_view custom_class_name) {
166#if !defined(TORCH_CUSTOM_CLASS_ALLOWLIST)
167 // If the TORCH_CUSTOM_CLASS_ALLOWLIST parameter is not defined,
168 // all custom classes are to be registered
169 (void)custom_class_name;
170 return true;
171#else
172 return allowlist_contains(
173 C10_STRINGIZE(TORCH_CUSTOM_CLASS_ALLOWLIST),
174 custom_class_name);
175#endif
176}
177
178// schema_allowlist_check() implicitly depends on a macro, TORCH_OPERATOR_WHITELIST.
179// Add this API to pass arbitrary allowlist.
180constexpr bool op_allowlist_contains_name_in_schema(string_view allowlist, string_view schema) {
181 return allowlist_contains(allowlist, schema.substr(0, schema.find("(")));
182}
183
184// Returns true iff the given dispatch key is on the allowlist
185// and should be registered. When we turn this on, the list of valid
186// mobile dispatch keys is hard coded (but you need to make sure
187// that you have the correct set of dispatch keys for this).
188constexpr bool dispatch_key_allowlist_check(DispatchKey /*k*/) {
189#ifdef C10_MOBILE
190 return true;
191 // Disabled for now: to be enabled later!
192 // return k == DispatchKey::CPU || k == DispatchKey::Vulkan || k == DispatchKey::QuantizedCPU || k == DispatchKey::BackendSelect || k == DispatchKey::CatchAll;
193#else
194 return true;
195#endif
196}
197
198} // namespace impl
199} // namespace c10
200