1 | /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | // Support for API dispatch at the Python level. |
16 | // |
17 | // The dispatcher is implemented in c++ for efficiency. |
18 | // |
19 | // * PythonAPIDispatcher: Class that handles dispatch for a single Python API. |
20 | // Contains a mapping from PySignatureCheckers to dispatch targets (python |
21 | // functions). |
22 | // |
23 | // * PySignatureChecker: Class to efficiently check whether dispatch should be |
24 | // invoked for a given set of parameters. Contains a collection of |
25 | // PyTypeCheckers. |
26 | // |
27 | // * PyTypeChecker: Class to efficiently check whether a Python value matches |
28 | // a type annotation. Three subclasses (PyInstanceChecker, PyListChecker, |
29 | // and PyUnionChecker) handle the different kinds of type annotation. |
30 | |
31 | #ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ |
32 | #define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ |
33 | |
34 | #include <Python.h> |
35 | |
36 | #include <string> |
37 | #include <vector> |
38 | |
39 | #include "absl/container/flat_hash_map.h" |
40 | #include "tensorflow/python/lib/core/safe_pyobject_ptr.h" |
41 | #include "tensorflow/python/util/function_parameter_canonicalizer.h" |
42 | |
43 | namespace tensorflow { |
44 | |
45 | namespace py_dispatch { |
46 | |
47 | class PyTypeChecker; |
48 | class PySignatureChecker; |
49 | |
50 | // Dispatcher for a single TensorFlow Python API (e.g. `tf.add` or `tf.concat`). |
51 | // |
52 | // A separate `PythonAPIDispatcher` object is created for each API, and handles |
53 | // dispatch for that API. The `Register` method can be used to add new |
54 | // "dispatch targets", which override the default behavior of the API when it |
55 | // is called with parameters matching a given signature. The `Dispatch` method |
56 | // checks if any registered target matches parameters, and if so, then calls |
57 | // that target. |
58 | // |
59 | // This class is *not* thread-safe. It is assumed that the Python Global |
60 | // Interpreter Lock (GIL) will be held when any method is called. |
61 | class PythonAPIDispatcher { |
62 | // TODO(b/196369143) Add benchmarking for this class. |
63 | public: |
64 | // Creates a new PythonAPIDispatcher for the named API. |
65 | // |
66 | // Args: |
67 | // api_name: The name of the API (used for error messages). |
68 | // arg_names: The argument names (used for parameter canonicalization). |
69 | // defaults: The argument defaults, as returned by `inspect.getargspec` |
70 | // (used for parameter canonicalization). |
71 | PythonAPIDispatcher(const std::string& api_name, |
72 | absl::Span<const char*> arg_names, |
73 | absl::Span<PyObject*> defaults); |
74 | |
75 | // Registers a new dispatch target for this dispatcher. If the API is |
76 | // called with parameters that match `signature_checker`, then |
77 | // `dispatch_target` will be called instead of the default API implementation. |
78 | void Register(PySignatureChecker signature_checker, |
79 | PyObject* dispatch_target); |
80 | |
81 | // Performs dispatch with the given set of parameters. |
82 | // |
83 | // * If a single target matches the parameters, then that target is called. |
84 | // * If multiple targets match the parameters, then an exception is raised. |
85 | // * If no targets match the parameters, then returns `Py_NotImplemented`. |
86 | // |
87 | // On error, returns nullptr and sets a Python exception. |
88 | Safe_PyObjectPtr Dispatch(PyObject* args, PyObject* kwargs); |
89 | |
90 | // Remove a dispatch target from this dispatcher. If the target was |
91 | // registered with multiple signatures, then all entries will be removed. |
92 | // (This method is primarily intended for regression tests.) |
93 | void Unregister(PyObject* func); |
94 | |
95 | std::string DebugString() const; |
96 | |
97 | private: |
98 | // Name of the API. |
99 | std::string api_name_; |
100 | |
101 | // Mapping from signature checkers to dispatch targets. |
102 | std::vector<std::pair<PySignatureChecker, Safe_PyObjectPtr>> targets_; |
103 | |
104 | // Parameter canonicalizer. |
105 | FunctionParameterCanonicalizer canonicalizer_; |
106 | |
107 | // Target storage for canonicalization. (Note: for efficiency, `Dispatch` |
108 | // writes to this pre-allocated storage, rather than allocating new storage |
109 | // each time it is called.) |
110 | std::vector<PyObject*> canonicalized_args_storage_; |
111 | absl::Span<PyObject*> canonicalized_args_span_; |
112 | }; |
113 | |
114 | // Registers a type for use with dispatch. Dispatch will only occur if at least |
115 | // one parameter value matches an annotation corresponding to a registered |
116 | // dispatchable type. |
117 | // |
118 | // Returns true on success; or sets a Python exception and returns false |
119 | // on error. |
120 | // |
121 | // Must be called before any PyInstanceChecker object is created from py_class. |
122 | // |
123 | // (Note: the CompositeTensor class is automatically registered for dispatch, |
124 | // so you do not need to use this method for any class that is a subclass of |
125 | // CompositeTensor or ExtensionType.) |
126 | bool RegisterDispatchableType(PyObject* py_class); |
127 | |
128 | // Class used by dispatch to check if parameters' values match a signature. |
129 | // |
130 | // Currently only supports checking parameters with kind POSITIONAL_ONLY or |
131 | // POSITIONAL_OR_KEYWORD. (Does not support checking parameters with kind |
132 | // VAR_POSITIONAL, VAR_KEYWORD, or KEYWORD_ONLY.) |
133 | class PySignatureChecker { |
134 | public: |
135 | // A parameter index and a TypeChecker for the parameter at that index. |
136 | using ParamChecker = std::pair<int, std::shared_ptr<PyTypeChecker>>; |
137 | |
138 | // Constructs a signature checker that will check the specified positional |
139 | // parameters. |
140 | explicit PySignatureChecker(std::vector<ParamChecker> parameter_checkers); |
141 | |
142 | // Returns true if the given canonicalized arguments match this signature |
143 | // checker. |
144 | bool CheckCanonicalizedArgs(absl::Span<PyObject*> canon_args) const; |
145 | |
146 | std::string DebugString() const; |
147 | |
148 | private: |
149 | // Type checkers for individual parameters. Only annotated parameters will |
150 | // be checked. This list is sorted to perform less expensive checks first. |
151 | // E.g., we check simple values before list values. |
152 | std::vector<ParamChecker> positional_parameter_checkers_; |
153 | }; |
154 | |
155 | // Abstract base class that checks if a Python value matches a type annotation. |
156 | // |
157 | // Subclasses of PyTypeChecker are defined for different annotations (List, |
158 | // Union, etc). Currently, we support the minimum set of type checkers that are |
159 | // required for CompositeTensor dispatch -- namely, `List`, `Union`, and simple |
160 | // types (`IsInstance`). Support for additional annotations may be added in the |
161 | // future. |
162 | class PyTypeChecker { |
163 | public: |
164 | using PyTypeChecker_ptr = std::shared_ptr<PyTypeChecker>; |
165 | PyTypeChecker() = default; |
166 | PyTypeChecker(const PyTypeChecker&) = delete; |
167 | PyTypeChecker(PyTypeChecker&&) = delete; |
168 | virtual ~PyTypeChecker() {} |
169 | |
170 | // Enumeration used to indicate whether a Python value matches a type |
171 | // annotation. MATCH and NO_MATCH simply indicate whether a value matches the |
172 | // annotation. |
173 | // |
174 | // MATCH_DISPATCHABLE indicates that a value matches the annotation, and |
175 | // additionally that the value (or one of its nested values) matched a type |
176 | // that has been registered for dispatch. This is important information |
177 | // because we only want to perform dispatch if at least one such value |
178 | // matches. Otherwise, we would end up using dispatch in undesirable cases. |
179 | // Examples: |
180 | // |
181 | // @tf.dispatch_for(tf.concat)(x=List[MyType]) |
182 | // |
183 | // We should not dispatch to `my_concat` when the user calls |
184 | // `tf.concat([])` (even though it's technically true that the empty |
185 | // list satisfies the type annotation `List[MyType]`). |
186 | // |
187 | // @tf.dispatch_for(tf.add)(x=Union[MyType, Tensor], y=Union[MyType, Tensor]) |
188 | // |
189 | // We should not dispatch to `my_add` when the user calls |
190 | // `tf.add(tf.constant(1), tf.constant(2))` (even though this technically |
191 | // matches the annotated types). |
192 | enum class MatchType { NO_MATCH, MATCH, MATCH_DISPATCHABLE }; |
193 | |
194 | // Returns a value indicating how this type checker matched with the given |
195 | // value. |
196 | virtual MatchType Check(PyObject* value) = 0; |
197 | |
198 | // Approximate cost of calling this type checker, so we can perform less |
199 | // expensive checks first. (E.g., checking if every element in a list has a |
200 | // given type is more costly than checking a single value.) |
201 | virtual int cost() const = 0; |
202 | |
203 | virtual std::string DebugString() const = 0; |
204 | }; |
205 | |
206 | // PyTypeChecker that checks if a value is an instance of a given Python type. |
207 | class PyInstanceChecker : public PyTypeChecker { |
208 | public: |
209 | explicit PyInstanceChecker(const std::vector<PyObject*>& py_classes); |
210 | ~PyInstanceChecker() override; |
211 | MatchType Check(PyObject* value) override; |
212 | int cost() const override; |
213 | std::string DebugString() const override; |
214 | |
215 | // Size of the cache (for regression testing). |
216 | size_t cache_size() const { return py_class_cache_.size(); } |
217 | |
218 | private: |
219 | // Python class to check values against. |
220 | std::vector<Safe_PyObjectPtr> py_classes_; |
221 | |
222 | // Cache to avoid having to call PyObject_IsInstance. Note: we rely on the |
223 | // Python GIL (global interpreter lock) to avoid concurrent writes to this |
224 | // cache, since `Check()` is always called from Python (via pybind11). |
225 | absl::flat_hash_map<PyTypeObject*, MatchType> py_class_cache_; |
226 | |
227 | // Maximum cache size. In typical user programs, the cache will never become |
228 | // full, but we use a maximum size in case the user creates types dynamically, |
229 | // to avoid having an unbounded number of items in the cache. |
230 | // TODO(b/194903203) Consider switching to an LRU cache. |
231 | static constexpr int kMaxItemsInCache = 1024; |
232 | }; |
233 | |
234 | // PyTypeChecker that checks if a value is a list whose elements all match a |
235 | // nested PyTypeChecker. |
236 | class PyListChecker : public PyTypeChecker { |
237 | public: |
238 | explicit PyListChecker(PyTypeChecker_ptr element_type) |
239 | : element_type_(element_type) {} |
240 | MatchType Check(PyObject* value) override; |
241 | int cost() const override; |
242 | std::string DebugString() const override; |
243 | |
244 | private: |
245 | PyTypeChecker_ptr element_type_; |
246 | }; |
247 | |
248 | // PyTypeChecker that checks if a value matches at least one nested |
249 | // PyTypeChecker. |
250 | class PyUnionChecker : public PyTypeChecker { |
251 | public: |
252 | explicit PyUnionChecker(std::vector<PyTypeChecker_ptr> options) |
253 | : options_(options) {} |
254 | MatchType Check(PyObject* value) override; |
255 | int cost() const override; |
256 | std::string DebugString() const override; |
257 | |
258 | private: |
259 | std::vector<PyTypeChecker_ptr> options_; |
260 | }; |
261 | |
262 | } // namespace py_dispatch |
263 | } // namespace tensorflow |
264 | |
265 | #endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_ |
266 | |