1/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15// Support for API dispatch at the Python level.
16//
17// The dispatcher is implemented in c++ for efficiency.
18//
19// * PythonAPIDispatcher: Class that handles dispatch for a single Python API.
20// Contains a mapping from PySignatureCheckers to dispatch targets (python
21// functions).
22//
23// * PySignatureChecker: Class to efficiently check whether dispatch should be
24// invoked for a given set of parameters. Contains a collection of
25// PyTypeCheckers.
26//
27// * PyTypeChecker: Class to efficiently check whether a Python value matches
28// a type annotation. Three subclasses (PyInstanceChecker, PyListChecker,
29// and PyUnionChecker) handle the different kinds of type annotation.
30
31#ifndef TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
32#define TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
33
34#include <Python.h>
35
36#include <string>
37#include <vector>
38
39#include "absl/container/flat_hash_map.h"
40#include "tensorflow/python/lib/core/safe_pyobject_ptr.h"
41#include "tensorflow/python/util/function_parameter_canonicalizer.h"
42
43namespace tensorflow {
44
45namespace py_dispatch {
46
47class PyTypeChecker;
48class PySignatureChecker;
49
50// Dispatcher for a single TensorFlow Python API (e.g. `tf.add` or `tf.concat`).
51//
52// A separate `PythonAPIDispatcher` object is created for each API, and handles
53// dispatch for that API. The `Register` method can be used to add new
54// "dispatch targets", which override the default behavior of the API when it
55// is called with parameters matching a given signature. The `Dispatch` method
56// checks if any registered target matches parameters, and if so, then calls
57// that target.
58//
59// This class is *not* thread-safe. It is assumed that the Python Global
60// Interpreter Lock (GIL) will be held when any method is called.
61class PythonAPIDispatcher {
62 // TODO(b/196369143) Add benchmarking for this class.
63 public:
64 // Creates a new PythonAPIDispatcher for the named API.
65 //
66 // Args:
67 // api_name: The name of the API (used for error messages).
68 // arg_names: The argument names (used for parameter canonicalization).
69 // defaults: The argument defaults, as returned by `inspect.getargspec`
70 // (used for parameter canonicalization).
71 PythonAPIDispatcher(const std::string& api_name,
72 absl::Span<const char*> arg_names,
73 absl::Span<PyObject*> defaults);
74
75 // Registers a new dispatch target for this dispatcher. If the API is
76 // called with parameters that match `signature_checker`, then
77 // `dispatch_target` will be called instead of the default API implementation.
78 void Register(PySignatureChecker signature_checker,
79 PyObject* dispatch_target);
80
81 // Performs dispatch with the given set of parameters.
82 //
83 // * If a single target matches the parameters, then that target is called.
84 // * If multiple targets match the parameters, then an exception is raised.
85 // * If no targets match the parameters, then returns `Py_NotImplemented`.
86 //
87 // On error, returns nullptr and sets a Python exception.
88 Safe_PyObjectPtr Dispatch(PyObject* args, PyObject* kwargs);
89
90 // Remove a dispatch target from this dispatcher. If the target was
91 // registered with multiple signatures, then all entries will be removed.
92 // (This method is primarily intended for regression tests.)
93 void Unregister(PyObject* func);
94
95 std::string DebugString() const;
96
97 private:
98 // Name of the API.
99 std::string api_name_;
100
101 // Mapping from signature checkers to dispatch targets.
102 std::vector<std::pair<PySignatureChecker, Safe_PyObjectPtr>> targets_;
103
104 // Parameter canonicalizer.
105 FunctionParameterCanonicalizer canonicalizer_;
106
107 // Target storage for canonicalization. (Note: for efficiency, `Dispatch`
108 // writes to this pre-allocated storage, rather than allocating new storage
109 // each time it is called.)
110 std::vector<PyObject*> canonicalized_args_storage_;
111 absl::Span<PyObject*> canonicalized_args_span_;
112};
113
114// Registers a type for use with dispatch. Dispatch will only occur if at least
115// one parameter value matches an annotation corresponding to a registered
116// dispatchable type.
117//
118// Returns true on success; or sets a Python exception and returns false
119// on error.
120//
121// Must be called before any PyInstanceChecker object is created from py_class.
122//
123// (Note: the CompositeTensor class is automatically registered for dispatch,
124// so you do not need to use this method for any class that is a subclass of
125// CompositeTensor or ExtensionType.)
126bool RegisterDispatchableType(PyObject* py_class);
127
128// Class used by dispatch to check if parameters' values match a signature.
129//
130// Currently only supports checking parameters with kind POSITIONAL_ONLY or
131// POSITIONAL_OR_KEYWORD. (Does not support checking parameters with kind
132// VAR_POSITIONAL, VAR_KEYWORD, or KEYWORD_ONLY.)
133class PySignatureChecker {
134 public:
135 // A parameter index and a TypeChecker for the parameter at that index.
136 using ParamChecker = std::pair<int, std::shared_ptr<PyTypeChecker>>;
137
138 // Constructs a signature checker that will check the specified positional
139 // parameters.
140 explicit PySignatureChecker(std::vector<ParamChecker> parameter_checkers);
141
142 // Returns true if the given canonicalized arguments match this signature
143 // checker.
144 bool CheckCanonicalizedArgs(absl::Span<PyObject*> canon_args) const;
145
146 std::string DebugString() const;
147
148 private:
149 // Type checkers for individual parameters. Only annotated parameters will
150 // be checked. This list is sorted to perform less expensive checks first.
151 // E.g., we check simple values before list values.
152 std::vector<ParamChecker> positional_parameter_checkers_;
153};
154
155// Abstract base class that checks if a Python value matches a type annotation.
156//
157// Subclasses of PyTypeChecker are defined for different annotations (List,
158// Union, etc). Currently, we support the minimum set of type checkers that are
159// required for CompositeTensor dispatch -- namely, `List`, `Union`, and simple
160// types (`IsInstance`). Support for additional annotations may be added in the
161// future.
162class PyTypeChecker {
163 public:
164 using PyTypeChecker_ptr = std::shared_ptr<PyTypeChecker>;
165 PyTypeChecker() = default;
166 PyTypeChecker(const PyTypeChecker&) = delete;
167 PyTypeChecker(PyTypeChecker&&) = delete;
168 virtual ~PyTypeChecker() {}
169
170 // Enumeration used to indicate whether a Python value matches a type
171 // annotation. MATCH and NO_MATCH simply indicate whether a value matches the
172 // annotation.
173 //
174 // MATCH_DISPATCHABLE indicates that a value matches the annotation, and
175 // additionally that the value (or one of its nested values) matched a type
176 // that has been registered for dispatch. This is important information
177 // because we only want to perform dispatch if at least one such value
178 // matches. Otherwise, we would end up using dispatch in undesirable cases.
179 // Examples:
180 //
181 // @tf.dispatch_for(tf.concat)(x=List[MyType])
182 //
183 // We should not dispatch to `my_concat` when the user calls
184 // `tf.concat([])` (even though it's technically true that the empty
185 // list satisfies the type annotation `List[MyType]`).
186 //
187 // @tf.dispatch_for(tf.add)(x=Union[MyType, Tensor], y=Union[MyType, Tensor])
188 //
189 // We should not dispatch to `my_add` when the user calls
190 // `tf.add(tf.constant(1), tf.constant(2))` (even though this technically
191 // matches the annotated types).
192 enum class MatchType { NO_MATCH, MATCH, MATCH_DISPATCHABLE };
193
194 // Returns a value indicating how this type checker matched with the given
195 // value.
196 virtual MatchType Check(PyObject* value) = 0;
197
198 // Approximate cost of calling this type checker, so we can perform less
199 // expensive checks first. (E.g., checking if every element in a list has a
200 // given type is more costly than checking a single value.)
201 virtual int cost() const = 0;
202
203 virtual std::string DebugString() const = 0;
204};
205
206// PyTypeChecker that checks if a value is an instance of a given Python type.
207class PyInstanceChecker : public PyTypeChecker {
208 public:
209 explicit PyInstanceChecker(const std::vector<PyObject*>& py_classes);
210 ~PyInstanceChecker() override;
211 MatchType Check(PyObject* value) override;
212 int cost() const override;
213 std::string DebugString() const override;
214
215 // Size of the cache (for regression testing).
216 size_t cache_size() const { return py_class_cache_.size(); }
217
218 private:
219 // Python class to check values against.
220 std::vector<Safe_PyObjectPtr> py_classes_;
221
222 // Cache to avoid having to call PyObject_IsInstance. Note: we rely on the
223 // Python GIL (global interpreter lock) to avoid concurrent writes to this
224 // cache, since `Check()` is always called from Python (via pybind11).
225 absl::flat_hash_map<PyTypeObject*, MatchType> py_class_cache_;
226
227 // Maximum cache size. In typical user programs, the cache will never become
228 // full, but we use a maximum size in case the user creates types dynamically,
229 // to avoid having an unbounded number of items in the cache.
230 // TODO(b/194903203) Consider switching to an LRU cache.
231 static constexpr int kMaxItemsInCache = 1024;
232};
233
234// PyTypeChecker that checks if a value is a list whose elements all match a
235// nested PyTypeChecker.
236class PyListChecker : public PyTypeChecker {
237 public:
238 explicit PyListChecker(PyTypeChecker_ptr element_type)
239 : element_type_(element_type) {}
240 MatchType Check(PyObject* value) override;
241 int cost() const override;
242 std::string DebugString() const override;
243
244 private:
245 PyTypeChecker_ptr element_type_;
246};
247
248// PyTypeChecker that checks if a value matches at least one nested
249// PyTypeChecker.
250class PyUnionChecker : public PyTypeChecker {
251 public:
252 explicit PyUnionChecker(std::vector<PyTypeChecker_ptr> options)
253 : options_(options) {}
254 MatchType Check(PyObject* value) override;
255 int cost() const override;
256 std::string DebugString() const override;
257
258 private:
259 std::vector<PyTypeChecker_ptr> options_;
260};
261
262} // namespace py_dispatch
263} // namespace tensorflow
264
265#endif // TENSORFLOW_PYTHON_FRAMEWORK_PYTHON_API_DISPATCHER_H_
266