1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file tvm/arithmetic/pattern_match.h
22 *
23 * \brief Internal tool for expression-template based pattern matching.
24 *
25 * It helps to simplify pattern matching and rewrites.
26 * All the patterns are generated via expression template during compile time,
27 * so the result code should be as efficient as manually written pattern match code.
28 *
29 * The code below shows how to use the pattern matcher.
30 *
31 * \code
32 *
33 * // max(x + z, y + z) => max(x, y) + z
34 * arith::PVar<Expr> x, y, z;
35 *
36 * // The following code tries to match the declared pattern.
37 * // Match will fill the result of match into PVar if successful.
38 * // Note that z occurs twice in the pattern,
39 * // an equality check is performed to ensure each occurance of z
40 * // is equivalent to each other.
41 * if (max(x + z, y + z).Match(expr)) {
42 * // Eval evaluates a pattern with the current matched value.
43 * // The filled value is valid until the next call to Match.
44 * return (max(x, y) + z).Eval();
45 * }
46 *
47 * tvm::tir::Var tx, ty;
48 * arith::PVar<IntImm> c;
49 * arith::PVar<Var> v;
50 * // We can match integer and Var, both of which are
51 * // special case container of Expr
52 * ICHECK((v * c).Match(tx * 3));
53 * ICHECK_EQ(c.Eval()->value, 3);
54 * // cannot match c to ty
55 * ICHECK(!(v * c).Match(tx * ty));
56 *
57 * \endcode
58 *
59 * \note The pattern matcher is not threadsafe,
60 * do not use the same PVar in multiple threads.
61 *
62 * Please be aware that the filled value in a PVar
63 * can be overriden in the next call to Match.
64 */
65#ifndef TVM_ARITH_PATTERN_MATCH_H_
66#define TVM_ARITH_PATTERN_MATCH_H_
67
68#include <tvm/tir/analysis.h>
69#include <tvm/tir/builtin.h>
70#include <tvm/tir/expr.h>
71
72#include <cmath>
73#include <tuple>
74
75#include "const_fold.h"
76
77namespace tvm {
78namespace arith {
79/*!
80 * \brief Base class of all the patterns.
81 *
82 * There are two major member functions supported by each pattern.
83 * - Match: checks if value matches the pattern.
84 * - Eval: construct a new value based on matched values in PVar.
85 *
86 * We use curiously recurring template pattern to construct
87 * expression templates.
88 *
89 * \tparam Derived The type of the derived class.
90 */
91template <typename Derived>
92class Pattern {
93 public:
94 /*!
95 * \brief Nested storage type in the expression.
96 *
97 * Depending on the Derived class,
98 * Nested can be Derived (nest by value) or
99 * const Derived& (nest by reference).
100 *
101 * The trick of Nested typedef originates from Eigen.
102 *
103 * \note We use nest by value for intermediate expressions,
104 * and nest by reference for PVars.
105 */
106 using Nested = Derived;
107 /*!
108 * \brief Check if value matches the current pattern.
109 *
110 * This call also populates the PVars with matched value.
111 * The values in PVars are valid until the next call to Match.
112 *
113 * \return whether value matches the pattern.
114 */
115 template <typename NodeType>
116 bool Match(const NodeType& value) const {
117 derived().InitMatch_();
118 return derived().Match_(value);
119 }
120 /*! \return Derived instance of current class. */
121 const Derived& derived() const { return *static_cast<const Derived*>(this); }
122};
123
124/*!
125 * \brief Default deep equality checker
126 * \tparam T the comparison point.
127 */
128template <typename T>
129class PEqualChecker {
130 public:
131 bool operator()(const T& lhs, const T& rhs) const { return lhs == rhs; }
132};
133
134template <>
135class PEqualChecker<PrimExpr> {
136 public:
137 bool operator()(const PrimExpr& lhs, const PrimExpr& rhs) const {
138 if (lhs.same_as(rhs)) return true;
139 return tir::ExprDeepEqual()(lhs, rhs);
140 }
141};
142
143template <>
144class PEqualChecker<IntImm> {
145 public:
146 bool operator()(const IntImm& lhs, const IntImm& rhs) const { return lhs->value == rhs->value; }
147};
148
149template <>
150class PEqualChecker<FloatImm> {
151 public:
152 bool operator()(const FloatImm& lhs, const FloatImm& rhs) const {
153 return std::fabs(lhs->value - rhs->value) < 1e-20;
154 }
155};
156
157template <>
158class PEqualChecker<tir::Var> {
159 public:
160 bool operator()(const tir::Var& lhs, const tir::Var& rhs) const { return lhs.same_as(rhs); }
161};
162
163/*!
164 * \brief Pattern variable container.
165 *
166 * PVar is used as a "hole" in the pattern that can be matched.
167 *
168 * \tparam T the type of the hole.
169 *
170 * \note PVar is not thread safe.
171 * Do not use the same PVar in multiple threads.
172 */
173template <typename T>
174class PVar : public Pattern<PVar<T>> {
175 public:
176 // Store PVars by reference in the expression.
177 using Nested = const PVar<T>&;
178
179 void InitMatch_() const { filled_ = false; }
180
181 bool Match_(const T& value) const {
182 if (!filled_) {
183 value_ = value;
184 filled_ = true;
185 return true;
186 } else {
187 return PEqualChecker<T>()(value_, value);
188 }
189 }
190
191 template <typename NodeRefType,
192 typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
193 bool Match_(const NodeRefType& value) const {
194 if (const auto* ptr = value.template as<typename T::ContainerType>()) {
195 return Match_(GetRef<T>(ptr));
196 } else {
197 return false;
198 }
199 }
200
201 T Eval() const {
202 ICHECK(filled_);
203 return value_;
204 }
205
206 T EvalOr(const T& default_value) const { return filled_ ? value_ : default_value; }
207
208 protected:
209 /*! \brief The matched value */
210 mutable T value_;
211 /*! \brief whether the variable has been filled */
212 mutable bool filled_{false};
213};
214
215/*!
216 * \brief Wrapper for pattern variable container with extra match logic.
217 *
218 * \tparam Derived the type of derived class.
219 * \tparam T the type of the hole.
220 */
221template <typename Derived, typename T>
222class PVarWithCheck : public arith::Pattern<PVarWithCheck<Derived, T>> {
223 public:
224 // Store by reference in the expression.
225 using Nested = const PVarWithCheck<Derived, T>&;
226
227 void InitMatch_() const { pvar_.InitMatch_(); }
228
229 bool Match_(const T& value) const {
230 if (!static_cast<const Derived*>(this)->Match_(value)) return false;
231 return pvar_.Match_(value);
232 }
233
234 template <typename NodeRefType,
235 typename = typename std::enable_if<std::is_base_of<NodeRefType, T>::value>::type>
236 bool Match_(const NodeRefType& value) const {
237 if (const auto* ptr = value.template as<typename T::ContainerType>()) {
238 return Match_(GetRef<T>(ptr));
239 } else {
240 return false;
241 }
242 }
243
244 T Eval() const { return pvar_.Eval(); }
245
246 protected:
247 arith::PVar<T> pvar_;
248};
249
250/*!
251 * \brief Pattern variable container with expr type check.
252 *
253 * \tparam T the type of the hole.
254 * \tparam DType the Pattern type of dtype.
255 */
256template <typename T, typename DType,
257 typename = std::enable_if<std::is_base_of<T, PrimExpr>::value>>
258class PVarWithDataType : public PVarWithCheck<PVarWithDataType<T, DType>, T> {
259 public:
260 explicit PVarWithDataType(const DType& dtype) : dtype_(dtype) {}
261
262 bool Match_(const T& value) const { return dtype_.Match_(value->dtype); }
263
264 protected:
265 typename DType::Nested dtype_;
266};
267
268/*!
269 * \brief Pattern variable container for data type with lanes.
270 */
271class PVecDataType : public PVarWithCheck<PVecDataType, DataType> {
272 public:
273 /*! \brief construct vector dtype placeholder with element type check */
274 explicit PVecDataType(const DataType& elem_dtype) : elem_dtype_(elem_dtype) {}
275
276 bool Match_(const DataType& dtype) const { return dtype.code() == elem_dtype_.code(); }
277
278 protected:
279 DataType elem_dtype_;
280};
281
282/*!
283 * \brief Constant Pattern variable container.
284 *
285 * \tparam T the type of the hole.
286 */
287template <typename T>
288class PConst : public Pattern<PConst<T>> {
289 public:
290 PConst(T value) // NOLINT(*)
291 : value_(value) {}
292
293 void InitMatch_() const {}
294
295 bool Match_(const T& value) const { return PEqualChecker<T>()(value_, value); }
296
297 T Eval() const { return value_; }
298
299 private:
300 const T value_;
301};
302
303/*!
304 * \brief Pattern binary expression.
305 * \tparam OpType The AST noderef type.
306 * \tparam TA The pattern type of the first operand.
307 * \tparam TB The pattern type of the second operand.
308 */
309template <typename OpType, typename TA, typename TB>
310class PBinaryExpr : public Pattern<PBinaryExpr<OpType, TA, TB>> {
311 public:
312 PBinaryExpr(const TA& a, const TB& b) : a_(a), b_(b) {}
313
314 void InitMatch_() const {
315 a_.InitMatch_();
316 b_.InitMatch_();
317 }
318
319 bool Match_(const ObjectRef& node) const {
320 using NodeType = typename OpType::ContainerType;
321 if (const NodeType* ptr = node.as<NodeType>()) {
322 if (!a_.Match_(ptr->a)) return false;
323 if (!b_.Match_(ptr->b)) return false;
324 return true;
325 } else {
326 return false;
327 }
328 }
329
330 PrimExpr Eval() const {
331 PrimExpr lhs = a_.Eval();
332 PrimExpr rhs = b_.Eval();
333 if (auto ret = TryConstFold<OpType>(lhs, rhs)) return ret.value();
334 return OpType(lhs, rhs);
335 }
336
337 private:
338 typename TA::Nested a_;
339 typename TB::Nested b_;
340};
341
342template <typename TA>
343class PConstWithTypeLike : public Pattern<PConstWithTypeLike<TA>> {
344 public:
345 PConstWithTypeLike(const TA& ref, int64_t value) : ref_(ref), value_(value) {}
346
347 void InitMatch_() const {}
348
349 bool Match_(const ObjectRef& node) const {
350 if (const tir::IntImmNode* ptr = node.as<tir::IntImmNode>()) {
351 return ptr->value == value_;
352 } else {
353 return false;
354 }
355 }
356
357 PrimExpr Eval() const { return tir::make_const(ref_.Eval().dtype(), value_); }
358
359 private:
360 typename TA::Nested ref_;
361 int64_t value_;
362};
363
364#define TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, CheckStep) \
365 template <typename TA, typename TB> \
366 inline PBinaryExpr<NodeName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
367 CheckStep; \
368 return PBinaryExpr<NodeName, TA, TB>(a.derived(), b.derived()); \
369 } \
370 template <typename TA> \
371 inline PBinaryExpr<NodeName, TA, PConstWithTypeLike<TA>> FuncName(const Pattern<TA>& a, \
372 int64_t b) { \
373 CheckStep; \
374 return FuncName(a, PConstWithTypeLike<TA>(a.derived(), b)); \
375 } \
376 template <typename TA> \
377 inline PBinaryExpr<NodeName, PConstWithTypeLike<TA>, TA> FuncName(int64_t b, \
378 const Pattern<TA>& a) { \
379 CheckStep; \
380 return FuncName(PConstWithTypeLike<TA>(a.derived(), b), a); \
381 }
382
383#define TVM_PATTERN_BINARY_OP(FuncName, NodeName) TVM_PATTERN_BINARY_OP_EX(FuncName, NodeName, )
384
385// raise ambiguity error for operator overload of / and %
386TVM_PATTERN_BINARY_OP_EX(operator/, tir::Div, DivAmbiguityError(a));
387TVM_PATTERN_BINARY_OP_EX(operator%, tir::Mod, DivAmbiguityError(a));
388
389// arithmetic expressions
390TVM_PATTERN_BINARY_OP(operator+, tir::Add);
391TVM_PATTERN_BINARY_OP(operator-, tir::Sub);
392TVM_PATTERN_BINARY_OP(operator*, tir::Mul);
393TVM_PATTERN_BINARY_OP(min, tir::Min);
394TVM_PATTERN_BINARY_OP(max, tir::Max);
395TVM_PATTERN_BINARY_OP(div, tir::Div);
396TVM_PATTERN_BINARY_OP(truncdiv, tir::Div);
397TVM_PATTERN_BINARY_OP(truncmod, tir::Mod);
398TVM_PATTERN_BINARY_OP(floordiv, tir::FloorDiv);
399TVM_PATTERN_BINARY_OP(floormod, tir::FloorMod);
400
401// logical expressions
402TVM_PATTERN_BINARY_OP(operator>, tir::GT);
403TVM_PATTERN_BINARY_OP(operator>=, tir::GE);
404TVM_PATTERN_BINARY_OP(operator<, tir::LT);
405TVM_PATTERN_BINARY_OP(operator<=, tir::LE);
406TVM_PATTERN_BINARY_OP(operator==, tir::EQ);
407TVM_PATTERN_BINARY_OP(operator!=, tir::NE);
408TVM_PATTERN_BINARY_OP(operator&&, tir::And);
409TVM_PATTERN_BINARY_OP(operator||, tir::Or);
410
411/*!
412 * \brief Pattern not expression.
413 * \tparam TA The pattern type of the true operand.
414 */
415template <typename TA>
416class PNotExpr : public Pattern<PNotExpr<TA>> {
417 public:
418 explicit PNotExpr(const TA& value) : value_(value) {}
419
420 void InitMatch_() const { value_.InitMatch_(); }
421
422 bool Match_(const ObjectRef& node) const {
423 if (const tir::NotNode* ptr = node.as<tir::NotNode>()) {
424 if (!value_.Match_(ptr->a)) return false;
425 return true;
426 } else {
427 return false;
428 }
429 }
430
431 PrimExpr Eval() const { return tir::Not(value_.Eval()); }
432
433 private:
434 typename TA::Nested value_;
435};
436
437template <typename TA>
438inline PNotExpr<TA> operator!(const Pattern<TA>& value) {
439 return PNotExpr<TA>(value.derived());
440}
441
442// select
443/*!
444 * \brief Pattern select expression.
445 * \tparam TCond The pattern type of the condition.
446 * \tparam TA The pattern type of the true operand.
447 * \tparam TB The pattern type of the false operand.
448 */
449template <typename TCond, typename TA, typename TB>
450class PSelectExpr : public Pattern<PSelectExpr<TCond, TA, TB>> {
451 public:
452 PSelectExpr(const TCond& condition, const TA& true_value, const TB& false_value)
453 : condition_(condition), true_value_(true_value), false_value_(false_value) {}
454
455 void InitMatch_() const {
456 condition_.InitMatch_();
457 true_value_.InitMatch_();
458 false_value_.InitMatch_();
459 }
460
461 bool Match_(const ObjectRef& node) const {
462 if (const tir::SelectNode* ptr = node.as<tir::SelectNode>()) {
463 if (!condition_.Match_(ptr->condition)) return false;
464 if (!true_value_.Match_(ptr->true_value)) return false;
465 if (!false_value_.Match_(ptr->false_value)) return false;
466 return true;
467 } else {
468 return false;
469 }
470 }
471
472 PrimExpr Eval() const {
473 return tir::Select(condition_.Eval(), true_value_.Eval(), false_value_.Eval());
474 }
475
476 private:
477 typename TCond::Nested condition_;
478 typename TA::Nested true_value_;
479 typename TB::Nested false_value_;
480};
481
482/*!
483 * \brief Construct a select pattern.
484 *
485 * \param condition The condition expression.
486 * \param true_value The value when condition is true.
487 * \param true_value The value when condition is false.
488 *
489 * \return The result pattern.
490 *
491 * \tparam TCond The pattern type of the condition.
492 * \tparam TA The pattern type of the true operand.
493 * \tparam TB The pattern type of the false operand.
494 */
495template <typename TCond, typename TA, typename TB>
496inline PSelectExpr<TCond, TA, TB> select(const Pattern<TCond>& condition,
497 const Pattern<TA>& true_value,
498 const Pattern<TB>& false_value) {
499 return PSelectExpr<TCond, TA, TB>(condition.derived(), true_value.derived(),
500 false_value.derived());
501}
502
503/*!
504 * \brief Pattern cast expression.
505 * \tparam DType The Pattern type of dtype.
506 * \tparam TA The pattern type of the first operand.
507 */
508template <typename DType, typename TA>
509class PCastExpr : public Pattern<PCastExpr<DType, TA>> {
510 public:
511 PCastExpr(const DType& dtype, const TA& value) : dtype_(dtype), value_(value) {}
512
513 void InitMatch_() const {
514 dtype_.InitMatch_();
515 value_.InitMatch_();
516 }
517
518 bool Match_(const ObjectRef& node) const {
519 if (const tir::CastNode* ptr = node.as<tir::CastNode>()) {
520 if (!dtype_.Match_(ptr->dtype)) return false;
521 if (!value_.Match_(ptr->value)) return false;
522 return true;
523 } else {
524 return false;
525 }
526 }
527
528 PrimExpr Eval() const { return tir::Cast(dtype_.Eval(), value_.Eval()); }
529
530 private:
531 typename DType::Nested dtype_;
532 typename TA::Nested value_;
533};
534
535/*!
536 * \brief Construct a cast pattern.
537 *
538 * \param dtype The target data type, can be PVar<DataType> or PConst<DataType>.
539 * \param value The input type.
540 *
541 * \return The result pattern.
542 *
543 * \tparam DType The pattern type of type.
544 * \tparam TA The pattern type of value.
545 */
546template <typename DType, typename TA>
547inline PCastExpr<DType, TA> cast(const Pattern<DType>& dtype, const Pattern<TA>& value) {
548 return PCastExpr<DType, TA>(dtype.derived(), value.derived());
549}
550
551/*!
552 * \brief Pattern ramp expression.
553 * \tparam TBase The pattern type of the base.
554 * \tparam TStride The pattern type of the stride.
555 * \tparam TLanes The pattern type of the lanes.
556 */
557template <typename TBase, typename TStride, typename TLanes>
558class PRampExpr : public Pattern<PRampExpr<TBase, TStride, TLanes>> {
559 public:
560 PRampExpr(const TBase& base, const TStride& stride, const TLanes& lanes)
561 : base_(base), stride_(stride), lanes_(lanes) {}
562
563 void InitMatch_() const {
564 base_.InitMatch_();
565 stride_.InitMatch_();
566 lanes_.InitMatch_();
567 }
568
569 bool Match_(const ObjectRef& node) const {
570 if (const tir::RampNode* ptr = node.as<tir::RampNode>()) {
571 if (!base_.Match_(ptr->base)) return false;
572 if (!stride_.Match_(ptr->stride)) return false;
573 if (!lanes_.Match_(ptr->lanes)) return false;
574 return true;
575 } else {
576 return false;
577 }
578 }
579
580 PrimExpr Eval() const { return tir::Ramp(base_.Eval(), stride_.Eval(), lanes_.Eval()); }
581
582 private:
583 typename TBase::Nested base_;
584 typename TStride::Nested stride_;
585 typename TLanes::Nested lanes_;
586};
587
588/*!
589 * \brief Construct a ramp pattern.
590 *
591 * \param base The base pattern.
592 * \param stride The stride pattern.
593 * \param lanes The lanes pattern.
594 *
595 * \return The result pattern.
596 *
597 * \tparam TBase The pattern type of the base.
598 * \tparam TStride The pattern type of the stride.
599 * \tparam TLanes The pattern type of the lanes.
600 */
601template <typename TBase, typename TStride, typename TLanes>
602inline PRampExpr<TBase, TStride, TLanes> ramp(const Pattern<TBase>& base,
603 const Pattern<TStride>& stride,
604 const Pattern<TLanes>& lanes) {
605 return PRampExpr<TBase, TStride, TLanes>(base.derived(), stride.derived(), lanes.derived());
606}
607
608template <typename TBase>
609inline PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>> ramp(const Pattern<TBase>& base,
610 int stride, int lanes) {
611 return PRampExpr<TBase, PConstWithTypeLike<TBase>, PConst<int>>(
612 base.derived(), PConstWithTypeLike<TBase>(base.derived(), stride), PConst<int>(lanes));
613}
614
615/*!
616 * \brief Pattern broadcast expression.
617 * \tparam TA The pattern type of the value.
618 * \tparam TLanes The pattern type of the lanes.
619 */
620template <typename TA, typename TLanes>
621class PBroadcastExpr : public Pattern<PBroadcastExpr<TA, TLanes>> {
622 public:
623 PBroadcastExpr(const TA& value, const TLanes& lanes) : value_(value), lanes_(lanes) {}
624
625 void InitMatch_() const {
626 value_.InitMatch_();
627 lanes_.InitMatch_();
628 }
629
630 bool Match_(const ObjectRef& node) const {
631 if (const tir::BroadcastNode* ptr = node.as<tir::BroadcastNode>()) {
632 if (!value_.Match_(ptr->value)) return false;
633 if (!lanes_.Match_(ptr->lanes)) return false;
634 return true;
635 } else {
636 return false;
637 }
638 }
639
640 PrimExpr Eval() const { return tir::Broadcast(value_.Eval(), lanes_.Eval()); }
641
642 private:
643 typename TA::Nested value_;
644 typename TLanes::Nested lanes_;
645};
646
647/*!
648 * \brief Construct a broadcast pattern.
649 *
650 * \param value The value pattern.
651 * \param lanes The lanes pattern.
652 *
653 * \return The result pattern.
654 *
655 * \tparam TA The pattern type of the value.
656 * \tparam TLanes The pattern type of the lanes.
657 */
658template <typename TA, typename TLanes>
659inline PBroadcastExpr<TA, TLanes> broadcast(const Pattern<TA>& value,
660 const Pattern<TLanes>& lanes) {
661 return PBroadcastExpr<TA, TLanes>(value.derived(), lanes.derived());
662}
663
664// internal namespace
665namespace detail {
666// implementation details for CallExpr
667template <bool stop, std::size_t I, typename F>
668struct tuple_for_each_dispatcher {
669 template <typename TTuple>
670 static void run(F& f, const TTuple& tuple) { // NOLINT(*)
671 f(I, std::get<I>(tuple));
672 tuple_for_each_dispatcher<(I + 1) == std::tuple_size<TTuple>::value, (I + 1), F>::run(f, tuple);
673 }
674};
675
676template <std::size_t I, typename F>
677struct tuple_for_each_dispatcher<true, I, F> {
678 template <typename TTuple>
679 static void run(F& f, const TTuple& tuple) {} // NOLINT(*)
680};
681
682template <typename F, typename TTuple>
683inline void tuple_for_each(F& f, const TTuple& tuple) { // NOLINT(*)
684 tuple_for_each_dispatcher<std::tuple_size<TTuple>::value == 0, 0, F>::run(f, tuple);
685}
686
687struct PCallExprInitMatchFunctor {
688 template <typename T>
689 void operator()(size_t i, const T& pattern) const {
690 pattern.InitMatch_();
691 }
692};
693
694struct PCallExprMatchFunctor {
695 const tir::CallNode* call_;
696 bool matched_{true};
697
698 explicit PCallExprMatchFunctor(const tir::CallNode* call) : call_(call) {}
699
700 template <typename T>
701 void operator()(size_t i, const T& pattern) {
702 matched_ = matched_ && pattern.Match_(call_->args[i]);
703 }
704};
705
706struct PCallExprEvalArgsFunctor {
707 Array<PrimExpr> args_;
708
709 template <typename T>
710 void operator()(size_t i, const T& pattern) {
711 args_.push_back(pattern.Eval());
712 }
713};
714} // namespace detail
715
716/*!
717 * \brief Pattern CallExpr expression.
718 * \tparam Op The operator functor class.
719 * \tparam TArgs The arguments.
720 * \note Op functor contains the name of the function and
721 * the implementation of Eval.
722 */
723template <typename Op, typename... TArgs>
724class PCallExpr : public Pattern<PCallExpr<Op, TArgs...>> {
725 public:
726 explicit PCallExpr(const TArgs&... args) : args_(args...) {}
727
728 void InitMatch_() const {
729 detail::PCallExprInitMatchFunctor finit;
730 detail::tuple_for_each(finit, args_);
731 }
732
733 bool Match_(const ObjectRef& node) const {
734 if (const tir::CallNode* ptr = node.as<tir::CallNode>()) {
735 if (ptr->args.size() != sizeof...(TArgs)) return false;
736 if (!ptr->op.same_as(Op::GetOp())) return false;
737 detail::PCallExprMatchFunctor fmatch(ptr);
738 detail::tuple_for_each(fmatch, args_);
739 return fmatch.matched_;
740 } else {
741 return false;
742 }
743 }
744
745 PrimExpr Eval() const {
746 detail::PCallExprEvalArgsFunctor feval_args;
747 detail::tuple_for_each(feval_args, args_);
748 return Op::Eval(feval_args.args_);
749 }
750
751 private:
752 std::tuple<typename TArgs::Nested...> args_;
753};
754
755// arithemetic intrinsics
756#define TVM_PATTERN_BINARY_INTRIN(FuncName, OpName, IntrinOpName) \
757 struct OpName { \
758 static PrimExpr Eval(Array<PrimExpr> args) { \
759 return tir::Call(args[0].dtype(), GetOp(), args); \
760 } \
761 static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
762 }; \
763 template <typename TA, typename TB> \
764 inline PCallExpr<OpName, TA, TB> FuncName(const Pattern<TA>& a, const Pattern<TB>& b) { \
765 return PCallExpr<OpName, TA, TB>(a.derived(), b.derived()); \
766 }
767
768TVM_PATTERN_BINARY_INTRIN(operator<<, PLeftShiftOp, shift_left);
769TVM_PATTERN_BINARY_INTRIN(operator>>, PRightShiftOp, shift_right);
770TVM_PATTERN_BINARY_INTRIN(operator&, PBitwiseAndOp, bitwise_and);
771TVM_PATTERN_BINARY_INTRIN(operator|, PBitwiseOrOp, bitwise_or);
772TVM_PATTERN_BINARY_INTRIN(operator^, PBitwiseXorOp, bitwise_xor);
773
774// unary intrinsics
775#define TVM_PATTERN_UNARY_INTRIN(FuncName, OpName, IntrinOpName) \
776 struct OpName { \
777 static PrimExpr Eval(Array<PrimExpr> args) { \
778 return tir::Call(args[0].dtype(), GetOp(), args); \
779 } \
780 static const Op& GetOp() { return tir::builtin::IntrinOpName(); } \
781 }; \
782 template <typename TA> \
783 inline PCallExpr<OpName, TA> FuncName(const Pattern<TA>& a) { \
784 return PCallExpr<OpName, TA>(a.derived()); \
785 }
786
787TVM_PATTERN_UNARY_INTRIN(operator~, PBitwiseNotOp, bitwise_not);
788
789// if_then_else
790struct PIfThenElseOp {
791 static PrimExpr Eval(Array<PrimExpr> args) { return tir::Call(args[1].dtype(), GetOp(), args); }
792 static const Op& GetOp() { return tir::builtin::if_then_else(); }
793};
794
795/*!
796 * \brief Construct a if_then_else pattern.
797 *
798 * \param cond The condition expression.
799 * \param true_value The value when condition is true.
800 * \param true_value The value when condition is false.
801 *
802 * \return The result pattern.
803 *
804 * \tparam TCond The pattern type of the condition.
805 * \tparam TA The pattern type of the true operand.
806 * \tparam TB The pattern type of the false operand.
807 */
808template <typename TCond, typename TA, typename TB>
809inline PCallExpr<PIfThenElseOp, TCond, TA, TB> if_then_else(const Pattern<TCond>& cond,
810 const Pattern<TA>& true_value,
811 const Pattern<TB>& false_value) {
812 return PCallExpr<PIfThenElseOp, TCond, TA, TB>(cond.derived(), true_value.derived(),
813 false_value.derived());
814}
815
816} // namespace arith
817} // namespace tvm
818#endif // TVM_ARITH_PATTERN_MATCH_H_
819