1#pragma once
2#include <iostream>
3
4// note: windows build doesn't find symbols in operator files unless
5// this is a header file
6
7namespace c10 {
8
9inline std::ostream& operator<<(std::ostream& out, const FunctionSchema& schema) {
10 // eventually this should look almost identical to python arg parser, but
11 // it is simpler for now to work directly on this schema
12
13 out << schema.name();
14 if (!schema.overload_name().empty()) {
15 out << "." << schema.overload_name();
16 }
17 out << "(";
18
19 bool seen_kwarg_only = false;
20 for (const auto i : c10::irange(schema.arguments().size())) {
21 if (i > 0) out << ", ";
22 if (schema.arguments()[i].kwarg_only() && !seen_kwarg_only) {
23 out << "*, ";
24 seen_kwarg_only = true;
25 }
26 out << schema.arguments()[i];
27 }
28
29 if(schema.is_vararg()) {
30 if(!schema.arguments().empty())
31 out << ", ";
32 out << "...";
33 }
34
35 out << ") -> ";
36
37 const auto& returns = schema.returns();
38
39 /*
40 * We should skip parenthesis if we return a single item and it's not varret,
41 * or we return nothing but varret.
42 *
43 * Need special handling for schema
44 * aten::items.str(Dict(str, t) self) -> (str,t)[]
45 * Even though this schema returns a single item, we need add parenthesis.
46 * The is necessary so the printed schema can be parsed by the C++ SchemaParser
47 * Without the extra parenthesis, the parser sees the first parenthesis in '(str,t)' and mistakenly
48 * treat the return type as a tuple. An alternative is to enhance the Lexer
49 * to lookahead multiple tokens to accurately decide if the return type is
50 * a tuple.
51 */
52 bool need_paren = !(
53 (returns.size() == 1 && !schema.is_varret()) ||
54 (returns.empty() && schema.is_varret()));
55
56 if (returns.size() == 1 && !schema.is_varret()) {
57 std::stringstream return_ss;
58 return_ss << returns.at(0);
59 auto return_str = return_ss.str();
60
61 // enclosing the single return item with parenthesis if the return type
62 // starts with a left parenthesis.
63 //
64 // There are 2 cases
65 // 1. something like 'aten::items.str(Dict(str, t) self) -> ((str, t)[])'.
66 // without the extra parenthesis, the c++ schem parser can not parse it.
67 // 2. something like '-> ((str, str))'. Need extra parenthesis so the return
68 // type is a single tuple rather than two strings.
69 // PR (https://github.com/pytorch/pytorch/pull/23204) has more context about
70 // this. test_serialize_and_deserialize (https://github.com/pytorch/pytorch/blob/master/test/test_function_schema.py#L15)
71 // also covers this case.
72 if (!return_str.empty() && return_str.front() == '(') {
73 need_paren = true;
74 }
75 }
76
77 if (need_paren) {
78 out << "(";
79 }
80 for (const auto i : c10::irange(returns.size())) {
81 if (i > 0) {
82 out << ", ";
83 }
84 out << returns.at(i);
85 }
86 if (schema.is_varret()) {
87 if (!returns.empty()) {
88 out << ", ";
89 }
90 out << "...";
91 }
92 if (need_paren) {
93 out << ")";
94 }
95 return out;
96}
97
98inline size_t findFirstOutArg(const std::vector<Argument>& args) {
99 // find the start of out args in the schema
100 for (const auto out_start_idx : c10::irange(args.size())) {
101 if (args.at(out_start_idx).is_out()) {
102 return out_start_idx;
103 }
104 }
105 return args.size();
106}
107
108inline bool Argument::isBackwardCompatibleWith(
109 const Argument& old,
110 std::ostream* why_not) const {
111 const Argument* lhs = this;
112 const Argument* rhs = &old;
113 if (!(lhs->name() == rhs->name()
114 && lhs->N() == rhs->N()
115 && (lhs->alias_info() == rhs->alias_info()
116 || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
117 && *lhs->alias_info() == *rhs->alias_info())))) {
118 return false;
119 }
120 if (lhs->kwarg_only() && !rhs->kwarg_only()) {
121 return false;
122 }
123 if (!rhs->type()->isSubtypeOfExt(*lhs->type(), why_not)) {
124 return false;
125 }
126 if (rhs->default_value().has_value() &&
127 lhs->default_value() != rhs->default_value()) {
128 return false;
129 }
130 return true;
131}
132
133inline bool Argument::isForwardCompatibleWith(
134 const Argument& old,
135 std::ostream* why_not) const {
136 const Argument* lhs = this;
137 const Argument* rhs = &old;
138 if (!(lhs->name() == rhs->name()
139 && lhs->N() == rhs->N()
140 && (lhs->alias_info() == rhs->alias_info()
141 || (lhs->alias_info() != nullptr && rhs->alias_info() != nullptr
142 && *lhs->alias_info() == *rhs->alias_info())))) {
143 return false;
144 }
145 if (lhs->kwarg_only() && !rhs->kwarg_only()) {
146 return false;
147 }
148 if (!lhs->type()->isSubtypeOfExt(rhs->type(), why_not)) {
149 return false;
150 }
151 if (rhs->default_value().has_value() &&
152 lhs->default_value() != rhs->default_value()) {
153 return false;
154 }
155 if (lhs->default_value().has_value() && !rhs->default_value().has_value()) {
156 return false;
157 }
158 return true;
159}
160
161inline std::string FunctionSchema::formatTypeMismatchMsg(
162 const Argument& expected,
163 const std::string& actual_type,
164 c10::optional<size_t> position,
165 c10::optional<std::string> value) const {
166 std::string position_str;
167 if (position) {
168 position_str = c10::str("Position: ", *position, "\n");
169 }
170 std::string value_str;
171 if (value) {
172 value_str = c10::str("Value: ", *value, "\n");
173 }
174 return c10::str(
175 name(),
176 "() ",
177 expected.formatTypeMismatchMsg(actual_type),
178 position_str,
179 value_str,
180 "Declaration: ",
181 *this);
182}
183
184inline bool FunctionSchema::isBackwardCompatibleWith(
185 const FunctionSchema& old,
186 std::ostream* why_not) const {
187 if (!(name() == old.name()
188 && overload_name() == old.overload_name()
189 // we are conservative on is_vararg and is_varret,
190 // since they are only used by internal operators
191 && is_vararg() == old.is_vararg()
192 && is_varret() == old.is_varret()
193 && returns().size() == old.returns().size()
194 && arguments().size() >= old.arguments().size())) {
195 return false;
196 }
197 for (const auto i : c10::irange(returns().size())) {
198 // Backwards compatibility requires covariance on argument types
199 // (i.e. more generic), and contravariance on return types (i.e.
200 // more specific).
201 if (!old.returns().at(i).isBackwardCompatibleWith(
202 returns().at(i),
203 why_not)) {
204 return false;
205 }
206 }
207
208 // we want to test both out and default args separately
209 size_t old_out_start_idx = findFirstOutArg(old.arguments());
210 size_t new_out_start_idx = findFirstOutArg(arguments());
211
212 // make sure among the default args, they are backward compatible
213 for (const auto i : c10::irange(old_out_start_idx)) {
214 if (!arguments().at(i).isBackwardCompatibleWith(
215 old.arguments().at(i), why_not)) {
216 return false;
217 }
218 }
219
220 // Validate that all new arguments provided has a default value
221 for (const auto i : c10::irange(old_out_start_idx, new_out_start_idx)) {
222 if (!arguments().at(i).default_value()) {
223 if (why_not) {
224 *why_not
225 << "Function schema not backward compatible since the new argument '"
226 << arguments().at(i).name() << "' of type "
227 << arguments().at(i).type()->str()
228 << " did not provide a default value.";
229 }
230 return false;
231 }
232 }
233
234 // now compare the out args
235 for (const auto i : c10::irange(old_out_start_idx, old.arguments().size())) {
236 if (!arguments()
237 .at(i - old_out_start_idx + new_out_start_idx)
238 .isBackwardCompatibleWith(old.arguments().at(i), why_not)) {
239 return false;
240 }
241 }
242
243 return true;
244}
245
246inline bool FunctionSchema::isForwardCompatibleWith(
247 const FunctionSchema& old,
248 std::ostringstream& why_not) const {
249 if (!(name() == old.name() &&
250 overload_name() == old.overload_name()
251 // we are conservative on is_vararg and is_varret,
252 // since they are only used by internal operators
253 && is_vararg() == old.is_vararg() && is_varret() == old.is_varret() &&
254 returns().size() == old.returns().size())) {
255 return false;
256 }
257
258 // we want to test both out and default args separately
259 size_t old_out_start_idx = findFirstOutArg(old.arguments());
260 size_t new_out_start_idx = findFirstOutArg(arguments());
261
262 if (old.arguments().size() - old_out_start_idx !=
263 arguments().size() - new_out_start_idx) {
264 if (why_not) {
265 why_not << "Function schema should have the "
266 << "same number of out arguments";
267 }
268 return false;
269 }
270
271 // make sure among the default args, they are forward compatible
272 for (size_t i = 0; i < std::min(old_out_start_idx, new_out_start_idx); i++) {
273 if (!arguments().at(i).isForwardCompatibleWith(old.arguments().at(i))) {
274 if (why_not) {
275 why_not
276 << "'" << arguments().at(i).name() << "'"
277 << " is not forward compatible with the older version of the schema";
278 }
279 return false;
280 }
281 }
282
283 // Validate that all new arguments provided has a default value
284 for (size_t i = old_out_start_idx; i < new_out_start_idx; ++i) {
285 if (!arguments().at(i).default_value()) {
286 if (why_not) {
287 why_not
288 << "Function schema is not forward compatible since the new argument '"
289 << arguments().at(i).name() << "' of type "
290 << arguments().at(i).type()->str()
291 << " did not provide a default value.";
292 }
293 return false;
294 }
295
296 auto default_val = arguments().at(i).default_value().value();
297 if (default_val.isList() || default_val.isGenericDict()) {
298 if (why_not) {
299 why_not
300 << "Function schema is not forward compatible since the new argument '"
301 << arguments().at(i).name() << "' of type "
302 << arguments().at(i).type()->str() << " has a container type "
303 << "as its default value.";
304 }
305 return false;
306 }
307 }
308
309 // now compare the out args
310 for (size_t i = old_out_start_idx; i < old.arguments().size(); i++) {
311 if (!arguments()
312 .at(i - old_out_start_idx + new_out_start_idx)
313 .isForwardCompatibleWith(old.arguments().at(i))) {
314 if (why_not) {
315 why_not << "Out argument '"
316 << "'" << arguments().at(i).name()
317 << " is not FC with the older version of the schema";
318 }
319 return false;
320 }
321 }
322
323 return true;
324}
325
326template<typename T>
327inline void FunctionSchema::checkArg(
328 const IValue& value,
329 const Argument& argument,
330 optional<size_t> pos) const {
331 if (value.isTensor() && argument.type() == TensorType::get()) {
332 // Fast-path for the common case
333 return;
334 }
335 if (!value.type<T>()->isSubtypeOf(*argument.type())) {
336 TORCH_CHECK(
337 false,
338 formatTypeMismatchMsg(
339 argument, value.type<T>()->repr_str(), pos));
340 }
341}
342
343inline std::string FunctionSchema::findErrorInKwargs(const std::vector<std::string>& kwargs) const {
344 // First check if any of the kwargs are unknown, i.e. don't match the name of
345 // any argument in the schema.
346 for (const auto& kwarg : kwargs) {
347 if (!std::count_if(
348 arguments().begin(),
349 arguments().end(),
350 [&kwarg](const Argument& argument) {
351 return argument.name() == kwarg;
352 })) {
353 return c10::str(
354 "Unknown keyword argument '",
355 kwarg,
356 "' for operator '",
357 name(),
358 "'. Schema: ",
359 *this);
360 }
361 }
362 // If there are unconsumed kwargs but none of them were unknown, the first
363 // positional argument present in the kwargs is duplicated.
364 for (const auto& argument : arguments()) {
365 if (std::find(kwargs.begin(), kwargs.end(), argument.name()) != kwargs.end()) {
366 AT_ASSERT(!argument.default_value());
367 return c10::str(
368 "Argument '",
369 argument.name(),
370 "' specified both as positional and ",
371 "keyword argument. Schema: ",
372 *this);
373 }
374 }
375 return "";
376}
377
378template <typename T>
379inline void FunctionSchema::checkAndNormalizeInputs(
380 std::vector<IValue>& inputs,
381 const std::unordered_map<std::string, IValue>& kwargs) const {
382 // Do we have more inputs than the schema accepts?
383 TORCH_CHECK(
384 inputs.size() <= arguments().size(),
385 "Expected at most ",
386 arguments().size(),
387 " argument(s) for operator '",
388 name(),
389 "', but received ",
390 inputs.size(),
391 " argument(s). Declaration: ",
392 *this);
393
394 size_t consumed_kwargs = 0;
395 for (const auto pos : c10::irange(arguments().size())) {
396 const auto& argument = arguments()[pos];
397 if (pos < inputs.size()) {
398 checkArg<T>(inputs[pos], argument, pos);
399 continue;
400 }
401 auto it = kwargs.find(argument.name());
402 if (it != kwargs.end()) {
403 checkArg<T>(it->second, argument, nullopt);
404 inputs.push_back(it->second);
405 consumed_kwargs++;
406 continue;
407 }
408 if (argument.default_value()) {
409 inputs.push_back(*argument.default_value());
410 continue;
411 }
412 AT_ERROR(
413 name(),
414 "() is missing value for argument '",
415 argument.name(),
416 "'. Declaration: ",
417 *this);
418 }
419 if (consumed_kwargs != kwargs.size()) {
420 std::vector<std::string> names;
421 names.reserve(kwargs.size());
422 for(const auto& k : kwargs) {
423 names.emplace_back(k.first);
424 }
425 throw std::runtime_error(findErrorInKwargs(names));
426 }
427}
428
429inline FunctionSchema FunctionSchema::cloneWithRemappedTypes(
430 const std::function<TypePtr(TypePtr)> type_map) const {
431 auto update_args = [&](const std::vector<Argument>& args) {
432 std::vector<Argument> new_args;
433 new_args.reserve(args.size());
434 for(const Argument& arg : args) {
435 new_args.emplace_back(arg.cloneWithType(type_map(arg.type())));
436 }
437 return new_args;
438 };
439 return FunctionSchema(
440 name(),
441 overload_name(),
442 update_args(arguments()),
443 update_args(returns()),
444 is_vararg(),
445 is_varret());
446}
447
448// covariant subtyping of list of Arguments
449inline bool isSubtypeOfList(
450 ArrayRef<Argument> child,
451 ArrayRef<Argument> parent,
452 std::ostream* why_not) {
453 if (child.size() != parent.size()) {
454 return false;
455 }
456 for (const auto i : c10::irange(child.size())) {
457 const Argument& c = child[i];
458 const Argument& p = parent[i];
459 if (c.name() != p.name()) {
460 return false;
461 }
462 if (!c.type()->isSubtypeOfExt(*p.type(), why_not)) {
463 return false;
464 }
465 }
466 return true;
467}
468
469inline bool FunctionSchema::isSubtypeOf(
470 const FunctionSchema& rhs,
471 bool as_method,
472 std::ostream* why_not) const {
473 size_t start = as_method ? 1 : 0;
474 // functions are contravariant in arguments but covariant in returns
475 return isSubtypeOfList(
476 ArrayRef<Argument>(rhs.arguments()).slice(start),
477 ArrayRef<Argument>(arguments()).slice(start),
478 why_not) &&
479 isSubtypeOfList(returns(), rhs.returns(), why_not);
480}
481
482} // namespace c10
483