1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file detect_common_subexpr.cc
22 * \brief Utility to detect common sub expressions.
23 */
24#include <tvm/tir/expr.h>
25
26#include <limits>
27
28#include "../tir/transforms/common_subexpr_elim_tools.h"
29
30namespace tvm {
31namespace arith {
32
33using namespace tir;
34
35Map<PrimExpr, Integer> DetectCommonSubExpr(const PrimExpr& e, int thresh) {
36 // Check the threshold in the range of size_t
37 CHECK_GE(thresh, std::numeric_limits<size_t>::min());
38 CHECK_LE(thresh, std::numeric_limits<size_t>::max());
39 size_t repeat_thr = static_cast<size_t>(thresh);
40 auto IsEligibleComputation = [](const PrimExpr& expr) {
41 return (SideEffect(expr) <= CallEffectKind::kPure && CalculateExprComplexity(expr) > 1 &&
42 (expr.as<RampNode>() == nullptr) && (expr.as<BroadcastNode>() == nullptr));
43 };
44
45 // Analyze the sub expressions
46 ComputationTable table_syntactic_comp_done_by_expr = ComputationsDoneBy::GetComputationsDoneBy(
47 e, IsEligibleComputation, [](const PrimExpr& expr) { return true; });
48
49 std::vector<std::pair<PrimExpr, size_t>> semantic_comp_done_by_expr =
50 SyntacticToSemanticComputations(table_syntactic_comp_done_by_expr, true);
51
52 // Find eligible sub expr if occurrence is under thresh
53 for (size_t i = 0; i < semantic_comp_done_by_expr.size(); i++) {
54 std::pair<PrimExpr, size_t>& computation_and_nb = semantic_comp_done_by_expr[i];
55 if (computation_and_nb.second < repeat_thr) {
56 std::vector<PrimExpr> direct_subexprs =
57 DirectSubexpr::GetDirectSubexpressions(computation_and_nb.first, IsEligibleComputation,
58 [](const PrimExpr& expr) { return true; });
59 InsertVectorToSortedSemanticComputations(&semantic_comp_done_by_expr, direct_subexprs, true,
60 computation_and_nb.second);
61 }
62 }
63
64 // Return the common sub expr that occur more than thresh times
65 Map<PrimExpr, Integer> results;
66 for (auto& it : semantic_comp_done_by_expr) {
67 if (it.second >= repeat_thr) results.Set(it.first, it.second);
68 }
69 return results;
70}
71
72TVM_REGISTER_GLOBAL("arith.DetectCommonSubExpr").set_body_typed(DetectCommonSubExpr);
73} // namespace arith
74} // namespace tvm
75