1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#ifndef TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_
20#define TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_
21
22#include <tvm/tir/schedule/instruction.h>
23#include <tvm/tir/schedule/schedule.h>
24
25#include <algorithm>
26#include <sstream>
27#include <string>
28#include <utility>
29#include <vector>
30
31namespace tvm {
32namespace tir {
33
34/*!
35 * \brief Register an InstructionKind using a trait class
36 * \param InstructionKindTraits A traits class of an InstructionKind
37 *
38 * Example:
39 *
40 * \code
41 *
42 * struct SomeInstructionKindTraits {
43 * static constexpr const char* kName = "name-of-the-instruction";
44 * static constexpr bool kIsPure = false;
45 *
46 * // Convertible to `InstructionKindNode::FInstructionApply`
47 * static Array<ObjectRef> ApplyToSchedule(
48 * const tir::Schedule& sch,
49 * const Array<ObjectRef>& inputs,
50 * const Array<ObjectRef>& attrs,
51 * const Optional<ObjectRef>& decision);
52 *
53 * // Convertible to `InstructionKindNode::FInstructionAsPython`
54 * static String AsPython(
55 * const Array<String>& inputs,
56 * const Array<ObjectRef>& attrs,
57 * const Optional<ObjectRef>& decision,
58 * const Array<String>& outputs);
59 *
60 * // Convertible to `InstructionKindNode::FInstructionAttrsAsJSON`
61 * static ObjectRef AttrsAsJSON(
62 * const Array<ObjectRef>& attrs);
63 *
64 * // Convertible to `InstructionKindNode::FInstructionAttrsFromJSON`
65 * static Array<ObjectRef> AttrsFromJSON(
66 * const ObjectRef& attrs_record);
67 * };
68 *
69 * TVM_REGISTER_INST_KIND_TRAITS(SomeInstructionKindTraits);
70 *
71 * \endcode
72 */
73#define TVM_REGISTER_INST_KIND_TRAITS(InstructionKindTraits) \
74 TVM_REGISTER_INST_KIND(InstructionKindTraits::kName) \
75 .set_is_pure(InstructionKindTraits::kIsPure) \
76 .set_apply_to_schedule(InstructionKindTraits::ApplyToSchedule) \
77 .set_attrs_as_json(InstructionKindTraits::AttrsAsJSON) \
78 .set_attrs_from_json(InstructionKindTraits::AttrsFromJSON) \
79 .set_as_python(InstructionKindTraits::AsPython)
80
81/*!
82 * \brief A helper to conveniently register an InstructionKind. When inherited in curiously
83 * recursive template pattern, the derived class `TTraits` only needs to define two functions on the
84 * unpacked inputs, and the helper handles unpacking and downcasting. See the example for more
85 * details.
86 *
87 * \tparam TTraits The derived class
88 *
89 * Example:
90 *
91 * \code
92 *
93 * struct SamplePerfectTileTraits : public UnpackedInstTraits<SamplePerfectTileTraits> {
94 * // The name of this kind of instruction
95 * static constexpr const char* kName = "SamplePerfectTile";
96 * // A boolean indicating if the instruction is pure, i.e. change nothing in the schedule state
97 * static constexpr bool kIsPure = true;
98 * // The number of inputs in this kind of instruction
99 * static constexpr size_t kNumInputs = 1;
100 * // The number of attributes in this kind of instruction
101 * static constexpr size_t kNumAttrs = 2;
102 * // The number of decisions in this kind of instruction (only 0 or 1 is allowed)
103 * static constexpr size_t kNumDecisions = 1;
104 *
105 * // Calling convention:
106 * // - All the arguments must be ObjectRef
107 * // - The 1st argument is Schedule
108 * // - The next `kNumInputs` arguments are input random variables
109 * // - The next `kNumAttrs` arguments are attributes
110 * // - The next argument is decision, if `kNumDecisions == 1`
111 * static Array<Var> UnpackedApplyToSchedule(
112 * Schedule sch,
113 * LoopRV loop_rv,
114 * Integer n,
115 * Integer max_innermost_factor,
116 * Optional<Array<Integer>> decision) {
117 * return sch->SamplePerfectTile(loop_rv, n->value, max_innermost_factor->value, decision);
118 * }
119 *
120 * // Calling convention:
121 * // - All the arguments must be ObjectRef
122 * // - The 1st argument is an array containing names of output random variables
123 * // - The next `kNumInputs` arguments are names of input random variables
124 * // - The next `kNumAttrs` arguments are attributes
125 * // - The next argument is decision, if `kNumDecisions == 1`
126 * static String UnpackedAsPython(
127 * Array<String> outputs,
128 * String loop_rv,
129 * Integer n,
130 * Integer max_innermost_factor,
131 * Optional<Array<Integer>> decision) {
132 * PythonAPICall py("sample_perfect_tile");
133 * py.Input("loop", loop_rv);
134 * py.Input("n", n->value);
135 * py.Input("max_innermost_factor", max_innermost_factor->value);
136 * py.Decision(decision);
137 * py.OutputList(outputs);
138 * return py.Str();
139 * }
140 *
141 * template <typename>
142 * friend struct UnpackedInstTraits;
143 * };
144 *
145 * TVM_REGISTER_INST_KIND(SamplePerfectTileTraits);
146 * \endcode
147 */
148template <class TTraits>
149struct UnpackedInstTraits {
150 /*!
151 * \brief Unpack the arguments in the calling convention, and feed them into
152 * `TTraits::UnpackedApplyToSchedule`
153 * \sa InstructionKindNode::f_apply_to_schedule
154 */
155 static Array<ObjectRef> ApplyToSchedule(const Schedule& sch, const Array<ObjectRef>& inputs,
156 const Array<ObjectRef>& attrs,
157 const Optional<ObjectRef>& decision);
158
159 /*!
160 * \brief Unpack the arguments in the calling convention, and feed them into
161 * `TTraits::UnpackedAsPython`
162 * \sa InstructionKindNode::f_as_python
163 */
164 static String AsPython(const Array<ObjectRef>& inputs, const Array<ObjectRef>& attrs,
165 const Optional<ObjectRef>& decision, const Array<String>& outputs);
166
167 /*! \brief No customized serializer by default */
168 static constexpr std::nullptr_t AttrsAsJSON = nullptr;
169
170 /*! \brief No customized deserializer by default */
171 static constexpr std::nullptr_t AttrsFromJSON = nullptr;
172
173 protected:
174 template <size_t index_offset>
175 static TVM_ALWAYS_INLINE void _SetInputs(const runtime::TVMArgsSetter& setter,
176 const Array<ObjectRef>& inputs);
177 template <size_t index_offset>
178 static TVM_ALWAYS_INLINE void _SetAttrs(const runtime::TVMArgsSetter& setter,
179 const Array<ObjectRef>& attrs);
180 template <size_t index_offset>
181 static TVM_ALWAYS_INLINE void _SetDecision(const runtime::TVMArgsSetter& setter,
182 const Optional<ObjectRef>& decision);
183 static TVM_ALWAYS_INLINE Array<ObjectRef> _ConvertOutputs(const TVMRetValue& rv);
184};
185
186/*!
187 * \brief A helper class that constructs schedule API call in python syntax,
188 * which helps convert an Inst to a python statement.
189 * \sa InstructionKindNode::f_as_python
190 */
191class PythonAPICall {
192 public:
193 /*!
194 * \brief Constructor
195 * \param method_name The name of the schedule API to be called
196 */
197 explicit PythonAPICall(String method_name) : method_name_(method_name), output_(NullOpt) {}
198 /*! \brief Add an integer input */
199 inline void Input(String arg_name, int arg);
200 /*! \brief Add an integer input */
201 inline void Input(String arg_name, int64_t arg);
202 /*! \brief Add a bool input */
203 inline void Input(String arg_name, bool arg);
204 /*! \brief Add a double input */
205 inline void Input(String arg_name, double arg);
206 /*! \brief Add an input random variable */
207 inline void Input(String arg_name, String arg);
208 /*! \brief Add an input, dispatched to different implementations according to the object's type */
209 inline void Input(String arg_name, ObjectRef arg);
210 /*! \brief Add the decision */
211 inline void Decision(ObjectRef decision);
212 /*!
213 * \brief Add a single output random variable
214 * \param unit_array An array containing only one element
215 */
216 inline void SingleOutput(Array<String> unit_array);
217 /*! \brief Add a list of output random variables */
218 inline void OutputList(Array<String> outputs);
219 /*! \returns The schedule API call in python syntax */
220 inline String Str() const;
221
222 private:
223 /*! \brief Converts a TVM object to python string and print to the output stream */
224 inline void AsPythonString(const ObjectRef& obj, std::ostream& os);
225
226 private:
227 /*! \brief The name of the API to call */
228 String method_name_;
229 /*! \brief The output of the instruction */
230 Optional<String> output_;
231 /*! \brief The names of input arguments */
232 std::vector<String> arg_names_;
233 /*! \brief The values of input arguments */
234 std::vector<String> args_;
235};
236
237/********** implementation details **********/
238
239// forward declaration
240namespace details {
241
242template <typename... Args>
243struct _ArgsPacker;
244
245template <>
246struct _ArgsPacker<> {
247 static constexpr bool checked = true;
248};
249
250template <typename TObjectRef, typename... Args>
251struct _ArgsPacker<TObjectRef, Args...> {
252 static constexpr bool checked =
253 std::is_base_of<ObjectRef, TObjectRef>::value && _ArgsPacker<Args...>::checked;
254};
255
256template <typename T>
257struct _MethodType {};
258
259template <typename TReturn, typename... Args>
260struct _MethodType<TReturn(Args...)> {
261 using return_type = TReturn;
262 using argument_type = _ArgsPacker<Args...>;
263};
264
265template <typename T>
266struct _NumArgs {};
267
268template <typename TReturn, typename... Args>
269struct _NumArgs<TReturn(Args...)> {
270 static constexpr size_t value = sizeof...(Args);
271};
272
273template <typename>
274struct _IsTVMArray : std::false_type {};
275
276template <typename T>
277struct _IsTVMArray<runtime::Array<T>> : std::true_type {};
278
279template <typename T>
280struct _IsSingleObject
281 : std::integral_constant<bool, std::is_base_of<ObjectRef, T>::value && !_IsTVMArray<T>::value> {
282};
283
284template <class T>
285using ReturnType = typename _MethodType<std::remove_cv_t<T>>::return_type;
286
287template <class T>
288static constexpr bool ArgumentAreAllObjects =
289 _MethodType<std::remove_cv_t<T>>::argument_type::checked;
290
291template <class T>
292static constexpr size_t NumArgs = _NumArgs<std::remove_cv_t<T>>::value;
293
294template <class T>
295static constexpr int IsTVMArray = _IsTVMArray<std::remove_cv_t<T>>::value;
296
297template <class T>
298static constexpr int IsSingleObject = _IsSingleObject<std::remove_cv_t<T>>::value;
299
300}; // namespace details
301
302template <class TTraits>
303Array<ObjectRef> UnpackedInstTraits<TTraits>::ApplyToSchedule(const Schedule& sch,
304 const Array<ObjectRef>& inputs,
305 const Array<ObjectRef>& attrs,
306 const Optional<ObjectRef>& decision) {
307 using method_type = decltype(TTraits::UnpackedApplyToSchedule);
308 using return_type = details::ReturnType<method_type>;
309 static_assert(details::ArgumentAreAllObjects<method_type>,
310 "All arguments to `UnpackedApplyToSchedule` must be subclasses of ObjectRef");
311 constexpr size_t kNumArgs = details::NumArgs<method_type>;
312 constexpr size_t kNumInputs = TTraits::kNumInputs;
313 constexpr size_t kNumAttrs = TTraits::kNumAttrs;
314 constexpr size_t kNumDecisions = TTraits::kNumDecisions;
315 static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions,
316 "length of argument list mismatch");
317 TVMValue tvm_values[kNumArgs];
318 int tvm_type_codes[kNumArgs];
319 runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes);
320 setter(0, sch);
321 TTraits::template _SetInputs<1>(setter, inputs);
322 TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs);
323 TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision);
324 PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void {
325 using runtime::detail::unpack_call;
326 constexpr size_t kNumArgs = details::NumArgs<method_type>;
327 ICHECK_EQ(args.size(), kNumArgs);
328 unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedApplyToSchedule, args, rv);
329 });
330 TVMRetValue rv;
331 pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv);
332 return TTraits::_ConvertOutputs(rv);
333}
334
335template <class TTraits>
336String UnpackedInstTraits<TTraits>::AsPython(const Array<ObjectRef>& inputs,
337 const Array<ObjectRef>& attrs,
338 const Optional<ObjectRef>& decision,
339 const Array<String>& outputs) {
340 using method_type = decltype(TTraits::UnpackedAsPython);
341 using return_type = details::ReturnType<method_type>;
342 static_assert(details::ArgumentAreAllObjects<method_type>,
343 "All arguments to `UnpackedAsPython` must be subclasses of ObjectRef");
344 constexpr size_t kNumArgs = details::NumArgs<method_type>;
345 constexpr size_t kNumInputs = TTraits::kNumInputs;
346 constexpr size_t kNumAttrs = TTraits::kNumAttrs;
347 constexpr size_t kNumDecisions = TTraits::kNumDecisions;
348 static_assert(kNumArgs == 1 + kNumInputs + kNumAttrs + kNumDecisions,
349 "length of argument list mismatch");
350 TVMValue tvm_values[kNumArgs];
351 int tvm_type_codes[kNumArgs];
352 runtime::TVMArgsSetter setter(tvm_values, tvm_type_codes);
353 setter(0, outputs);
354 TTraits::template _SetInputs<1>(setter, inputs);
355 TTraits::template _SetAttrs<1 + kNumInputs>(setter, attrs);
356 TTraits::template _SetDecision<1 + kNumInputs + kNumAttrs>(setter, decision);
357 PackedFunc pf([](const TVMArgs& args, TVMRetValue* rv) -> void {
358 using runtime::detail::unpack_call;
359 constexpr size_t kNumArgs = details::NumArgs<method_type>;
360 ICHECK_EQ(args.size(), kNumArgs);
361 unpack_call<return_type, kNumArgs>(nullptr, TTraits::UnpackedAsPython, args, rv);
362 });
363 TVMRetValue rv;
364 pf.CallPacked(TVMArgs(tvm_values, tvm_type_codes, kNumArgs), &rv);
365 String result = rv;
366 return result;
367}
368
369template <class TTraits>
370template <size_t index_offset>
371TVM_ALWAYS_INLINE void UnpackedInstTraits<TTraits>::_SetInputs(const runtime::TVMArgsSetter& setter,
372 const Array<ObjectRef>& inputs) {
373 constexpr size_t kNumInputs = TTraits::kNumInputs;
374 ICHECK_EQ(kNumInputs, inputs.size())
375 << "ValueError: Incorrect kNumInputs for instruction: " << TTraits::kName;
376 const ObjectRef* ptr = inputs.template as<ArrayNode>()->begin();
377 for (size_t i = 0; i < kNumInputs; ++i) {
378 setter(i + index_offset, *(ptr + i));
379 }
380}
381
382template <class TTraits>
383template <size_t index_offset>
384TVM_ALWAYS_INLINE void UnpackedInstTraits<TTraits>::_SetAttrs(const runtime::TVMArgsSetter& setter,
385 const Array<ObjectRef>& attrs) {
386 constexpr size_t kNumAttrs = TTraits::kNumAttrs;
387 ICHECK_EQ(kNumAttrs, attrs.size())
388 << "ValueError: Incorrect kNumAttrs for instruction: " << TTraits::kName;
389 const ObjectRef* ptr = attrs.as<ArrayNode>()->begin();
390 for (size_t i = 0; i < kNumAttrs; ++i) {
391 setter(i + index_offset, *(ptr + i));
392 }
393}
394
395template <class TTraits>
396template <size_t index_offset>
397TVM_ALWAYS_INLINE void UnpackedInstTraits<TTraits>::_SetDecision(
398 const runtime::TVMArgsSetter& setter, const Optional<ObjectRef>& decision) {
399 constexpr size_t kNumDecisions = TTraits::kNumDecisions;
400 static_assert(kNumDecisions <= 1, "an instruction is supposed to have at most 1 decision");
401 if (kNumDecisions == 1) {
402 setter(index_offset, decision);
403 } else {
404 ICHECK(!decision.defined());
405 }
406}
407
408template <class TTraits>
409TVM_ALWAYS_INLINE Array<ObjectRef> UnpackedInstTraits<TTraits>::_ConvertOutputs(
410 const TVMRetValue& rv) {
411 using method_type = decltype(TTraits::UnpackedApplyToSchedule);
412 using return_type = details::ReturnType<method_type>;
413 constexpr int is_array = details::IsTVMArray<return_type>;
414 constexpr int is_single_obj = details::IsSingleObject<return_type>;
415 constexpr int is_void = std::is_void<return_type>::value;
416 static_assert(is_array || is_single_obj || is_void, "return type not supported");
417 static_assert(is_array + is_single_obj + is_void == 1, "internal template error");
418 if (is_void) {
419 return {};
420 } else if (is_single_obj) {
421 ObjectRef obj = rv;
422 return {obj};
423 } else if (is_array) {
424 ObjectRef obj = rv;
425 const ArrayNode* array = obj.as<ArrayNode>();
426 return GetRef<Array<ObjectRef>>(array);
427 }
428}
429
430/********** PythonAPICall **********/
431
432inline void PythonAPICall::AsPythonString(const ObjectRef& obj, std::ostream& os) {
433 if (!obj.defined()) {
434 os << "None";
435 } else if (const auto* str = obj.as<runtime::StringObj>()) {
436 os << str->data;
437 } else if (const auto* int_imm = obj.as<IntImmNode>()) {
438 os << int_imm->value;
439 } else if (const auto* float_imm = obj.as<FloatImmNode>()) {
440 os.precision(17);
441 os << float_imm->value;
442 } else if (const auto* array = obj.as<ArrayNode>()) {
443 os << '[';
444 bool is_first = true;
445 for (const ObjectRef& e : *array) {
446 if (is_first) {
447 is_first = false;
448 } else {
449 os << ", ";
450 }
451 AsPythonString(e, os);
452 }
453 os << ']';
454 } else if (const auto* dict = obj.as<MapNode>()) {
455 os << '{';
456 bool is_first = true;
457 std::vector<std::pair<std::string, std::string>> dict_items;
458 for (auto it = dict->begin(); it != dict->end(); ++it) {
459 std::ostringstream ks;
460 AsPythonString(it->first, ks);
461 std::ostringstream vs;
462 AsPythonString(it->second, vs);
463 dict_items.emplace_back(ks.str(), vs.str());
464 }
465 std::sort(dict_items.begin(), dict_items.end(),
466 [](const auto& p1, const auto& p2) { return p1.first < p2.first; });
467 for (const auto& kv : dict_items) {
468 if (is_first) {
469 is_first = false;
470 } else {
471 os << ", ";
472 }
473 os << '\"' << kv.first << "\": " << kv.second;
474 }
475 os << '}';
476 } else {
477 LOG(FATAL) << "ValueError: Cannot translate type '" << obj->GetTypeKey()
478 << "' to python. Its value is: " << obj;
479 throw;
480 }
481}
482
483void PythonAPICall::Input(String arg_name, int arg) {
484 arg_names_.emplace_back(std::move(arg_name));
485 args_.push_back(std::to_string(arg));
486}
487
488void PythonAPICall::Input(String arg_name, int64_t arg) {
489 arg_names_.emplace_back(std::move(arg_name));
490 args_.push_back(std::to_string(arg));
491}
492
493void PythonAPICall::Input(String arg_name, bool arg) {
494 static const char* true_str = "True";
495 static const char* false_str = "False";
496 arg_names_.emplace_back(std::move(arg_name));
497 if (arg) {
498 args_.push_back(true_str);
499 } else {
500 args_.push_back(false_str);
501 }
502}
503
504void PythonAPICall::Input(String arg_name, double arg) {
505 arg_names_.emplace_back(std::move(arg_name));
506 std::ostringstream os;
507 os.precision(17);
508 os << arg;
509 args_.push_back(os.str());
510}
511
512void PythonAPICall::Input(String arg_name, String arg) {
513 arg_names_.emplace_back(std::move(arg_name));
514 args_.emplace_back(std::move(arg));
515}
516
517void PythonAPICall::Input(String arg_name, ObjectRef arg) {
518 arg_names_.emplace_back(std::move(arg_name));
519 std::ostringstream os;
520 AsPythonString(arg, os);
521 args_.push_back(os.str());
522}
523
524void PythonAPICall::Decision(ObjectRef decision) {
525 if (decision.defined()) {
526 this->Input("decision", decision);
527 }
528}
529
530void PythonAPICall::SingleOutput(Array<String> unit_array) {
531 ICHECK_EQ(unit_array.size(), 1);
532 this->output_ = unit_array[0];
533}
534
535void PythonAPICall::OutputList(Array<String> outputs) {
536 if (outputs.empty()) {
537 return;
538 }
539 if (outputs.size() == 1) {
540 this->output_ = outputs[0] + ",";
541 return;
542 }
543 std::ostringstream os;
544 os << outputs[0];
545 for (int i = 1, n = outputs.size(); i < n; ++i) {
546 os << ", " << outputs[i];
547 }
548 this->output_ = os.str();
549}
550
551String PythonAPICall::Str() const {
552 std::ostringstream os;
553 if (output_.defined()) {
554 os << output_.value() << " = ";
555 }
556 os << "sch." << method_name_ << '(';
557 int n = args_.size();
558 for (int i = 0; i < n; ++i) {
559 if (i > 0) {
560 os << ", ";
561 }
562 if (arg_names_[i].empty()) {
563 os << args_[i];
564 } else {
565 os << arg_names_[i] << '=' << args_[i];
566 }
567 }
568 os << ')';
569 return os.str();
570}
571
572} // namespace tir
573} // namespace tvm
574
575#endif // TVM_TIR_SCHEDULE_INSTRUCTION_TRAITS_H_
576