1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | #include <tvm/meta_schedule/schedule/cuda/thread_bind.h> |
20 | |
21 | #include <algorithm> |
22 | #include <limits> |
23 | |
24 | #include "../utils.h" |
25 | |
26 | namespace tvm { |
27 | namespace meta_schedule { |
28 | |
29 | class AutoBindNode : public ScheduleRuleNode { |
30 | public: |
31 | // Inherited from ScheduleRuleNode |
32 | void InitializeWithTuneContext(const TuneContext& context) final { |
33 | CHECK(context->target.defined()) << "ValueError: target is not defined" ; |
34 | Optional<Integer> max_threads_per_block = |
35 | context->target.value()->GetAttr<Integer>("max_threads_per_block" ); |
36 | CHECK(max_threads_per_block.defined()) |
37 | << "ValueError: missing attribute `max_threads_per_block` in the target" ; |
38 | this->max_threads_per_block_ = max_threads_per_block.value().IntValue(); |
39 | } |
40 | |
41 | // Inherited from ScheduleRuleNode |
42 | Array<tir::Schedule> Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) final; |
43 | |
44 | // Inherited from ScheduleRuleNode |
45 | ScheduleRule Clone() const final { |
46 | ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>(*this); |
47 | return ScheduleRule(n); |
48 | } |
49 | |
50 | public: |
51 | /*! \brief The max number of threads per block from Target */ |
52 | int64_t max_threads_per_block_ = -1; |
53 | /*! \brief The max number of threadblocks in the cuda device */ |
54 | int64_t max_threadblocks_ = -1; |
55 | /*! \brief thread_extents Candidates of thread axis extent. */ |
56 | Array<Integer> thread_extents_; |
57 | |
58 | void VisitAttrs(tvm::AttrVisitor* v) { |
59 | // `max_threads_per_block_` is not visited |
60 | // `max_threadblocks_` is not visited |
61 | // `thread_extents_` is not visited |
62 | } |
63 | |
64 | static constexpr const char* _type_key = "meta_schedule.AutoBind" ; |
65 | TVM_DECLARE_FINAL_OBJECT_INFO(AutoBindNode, ScheduleRuleNode); |
66 | }; |
67 | |
68 | Array<tir::Schedule> AutoBindNode::Apply(const tir::Schedule& sch, const tir::BlockRV& block_rv) { |
69 | ICHECK_NE(this->max_threads_per_block_, -1); |
70 | auto get_factor = MakeFactorSampler(sch, this->thread_extents_); |
71 | BindBlockThreadIdx(sch, block_rv, max_threadblocks_, max_threads_per_block_, get_factor); |
72 | return {sch}; |
73 | } |
74 | |
75 | ScheduleRule ScheduleRule::AutoBind(int max_threadblocks, Array<Integer> thread_extents, |
76 | int max_threads_per_block) { |
77 | ObjectPtr<AutoBindNode> n = make_object<AutoBindNode>(); |
78 | n->max_threadblocks_ = max_threadblocks; |
79 | n->max_threads_per_block_ = max_threads_per_block; |
80 | n->thread_extents_ = std::move(thread_extents); |
81 | return ScheduleRule(n); |
82 | } |
83 | |
84 | TVM_REGISTER_NODE_TYPE(AutoBindNode); |
85 | TVM_REGISTER_GLOBAL("meta_schedule.ScheduleRuleAutoBind" ).set_body_typed(ScheduleRule::AutoBind); |
86 | |
87 | } // namespace meta_schedule |
88 | } // namespace tvm |
89 | |