1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file src/relay/collage/mock_cost_estimator.cc
22 * \brief A mock CostEstimator to support unit tests.
23 */
24
25#include "./mock_cost_estimator.h"
26
27#include <tvm/relay/expr_functor.h>
28
29namespace tvm {
30namespace relay {
31namespace collage {
32
33TVM_REGISTER_OBJECT_TYPE(MockCostEstimatorNode);
34
35namespace {
36
37/*!
38 * \brief Visitor to accumulate the costs of all calls to operators in an expression.
39 */
40class MockEstimationVisitor : private ExprVisitor {
41 public:
42 MockEstimationVisitor(double op_cost, double fusion_benefit)
43 : op_cost_(op_cost), fusion_benefit_(fusion_benefit) {}
44
45 double EstimateCost(const Expr& body) {
46 VisitExpr(body);
47 return cost_;
48 }
49
50 private:
51 /*! \brief The assumed baseline cost of each operator call. */
52 double op_cost_;
53 /*!
54 * \brief The factor by which each operator call cost is to be changed for every other
55 * operator call in the same group.
56 */
57 double fusion_benefit_;
58 /*! \brief The number of operator calls seen so far. */
59 size_t num_ops_ = 0;
60 /*! \brief Accumulate overall cost. */
61 double cost_ = 0.0;
62
63 void VisitExpr_(const CallNode* call_node) final {
64 if (call_node->op->IsInstance<OpNode>()) {
65 // Account for number of ops seens os far.
66 cost_ += op_cost_ * pow(fusion_benefit_, static_cast<double>(num_ops_));
67 num_ops_++;
68 }
69 ExprVisitor::VisitExpr_(call_node);
70 }
71
72 void VisitExpr_(const FunctionNode* function_node) final {
73 // No "Compiler" functions can be inlined.
74 ICHECK(!function_node->GetAttr<String>(attr::kCompiler).defined())
75 << "All Compiler functions should have been outlined when preparing to estimate costs";
76 ExprVisitor::VisitExpr_(function_node);
77 }
78};
79
80} // namespace
81
82Cost MockCostEstimatorNode::Estimate(const IRModule& mod, const Target& target) const {
83 // Limit the number of estimations.
84 ICHECK(max_estimates_->value == 0 || num_estimates_ < static_cast<size_t>(max_estimates_->value))
85 << "At most " << max_estimates_->value
86 << " non-trivial distinct candidates should have been generated.";
87 ++num_estimates_;
88 double op_cost = static_cast<double>(target_costs_.at(target->kind->name)->value);
89 double cost = 0.0;
90 for (const auto& kv : mod->functions) {
91 if (const auto* function_node = kv.second.as<FunctionNode>()) {
92 auto function = GetRef<Function>(function_node);
93 if (kv.first->name_hint == "main") {
94 // Only tensor args are allowed to main.
95 for (const auto& param : function->params) {
96 ICHECK(param->type_annotation->IsInstance<TensorTypeNode>())
97 << "Any tuple-of-tensor arguments should have been eta-exanded when preparing to "
98 "estimate costs";
99 }
100 }
101 cost += MockEstimationVisitor(op_cost, /*fusion_benefit=*/0.9).EstimateCost(function->body);
102 }
103 }
104 return Cost::Value(cost);
105}
106
107MockCostEstimator::MockCostEstimator(Map<String, Integer> target_costs, Integer max_estimates) {
108 auto node = make_object<MockCostEstimatorNode>();
109 node->target_costs_ = std::move(target_costs);
110 node->max_estimates_ = std::move(max_estimates);
111 data_ = std::move(node);
112}
113
114TVM_REGISTER_GLOBAL("relay.collage.MockCostEstimator")
115 .set_body_typed([](Map<String, Integer> target_costs, Integer max_estimates) {
116 return MockCostEstimator(std::move(target_costs), std::move(max_estimates));
117 });
118
119} // namespace collage
120} // namespace relay
121} // namespace tvm
122