1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19#include <tvm/meta_schedule/schedule/cuda/thread_bind.h>
20
21#include <algorithm>
22#include <limits>
23
24#include "../utils.h"
25
26namespace tvm {
27namespace meta_schedule {
28
29class AutoBindNode : public ScheduleRuleNode {
30 public:
31 // Inherited from ScheduleRuleNode
32 void InitializeWithTuneContext(const TuneContext& context) final {
33 CHECK(context->target.defined()) << "ValueError: target is not defined";
34 Optional<Integer> max_threads_per_block =
35 context->target.value()->GetAttr<Integer>("max_threads_per_block");
36 CHECK(max_threads_per_block.defined())
37 << "ValueError: missing attribute `max_threads_per_block` in the target";
38 this->max_threads_per_block_ = max_threads_per_block.value().IntValue();
39 }
40
41 // Inherited from ScheduleRuleNode
42 Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final;
43
44 // Inherited from ScheduleRuleNode
45 ScheduleRule Clone() const final {
46 ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>(*this);
47 return ScheduleRule(n);
48 }
49
50 public:
51 /*! \brief The max number of threads per block from Target */
52 int64_t max_threads_per_block_ = -1;
53 /*! \brief The max number of threadblocks in the cuda device */
54 int64_t max_threadblocks_ = -1;
55 /*! \brief thread_extents Candidates of thread axis extent. */
56 Array<Integer> thread_extents_;
57
58 void VisitAttrs(tvm::AttrVisitor* v) {
59 // `max_threads_per_block_` is not visited
60 // `max_threadblocks_` is not visited
61 // `thread_extents_` is not visited
62 }
63
64 static constexpr const char* _type_key = "meta_schedule.AutoBind";
65 TVM_DECLARE_FINAL_OBJECT_INFO(AutoBindNode, ScheduleRuleNode);
66};
67
68Array<tir::Schedule> AutoBindNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) {
69 ICHECK_NE(this->max_threads_per_block_, -1);
70 auto get_factor = MakeFactorSampler(sch, this->thread_extents_);
71 BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor);
72 return {sch};
73}
74
75ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<Integer> thread_extents,
76 int max_threads_per_block) {
77 ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>();
78 n->max_threadblocks_ = max_threadblocks;
79 n->max_threads_per_block_ = max_threads_per_block;
80 n->thread_extents_ = std::move(thread_extents);
81 return ScheduleRule(n);
82}
83
84TVM_REGISTER_NODE_TYPE(AutoBindNode);
85TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind").set_body_typed(ScheduleRule::AutoBind);
86
87} // namespace meta_schedule
88} // namespace tvm
89