1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_cuda.h |
22 | * \brief Utility to generate cuda code |
23 | */ |
24 | #ifndef TVM_TARGET_SOURCE_CODEGEN_CUDA_H_ |
25 | #define TVM_TARGET_SOURCE_CODEGEN_CUDA_H_ |
26 | |
27 | #include <tvm/target/codegen.h> |
28 | #include <tvm/tir/expr.h> |
29 | #include <tvm/tir/op.h> |
30 | |
31 | #include <string> |
32 | #include <unordered_map> |
33 | |
34 | #include "codegen_c.h" |
35 | |
36 | namespace tvm { |
37 | namespace codegen { |
38 | |
39 | class CodeGenCUDA final : public CodeGenC { |
40 | public: |
41 | CodeGenCUDA(); |
42 | void Init(bool output_ssa); |
43 | std::string Finish(); |
44 | bool need_include_path() { |
45 | return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_); |
46 | } |
47 | // override behavior |
48 | void PrintFuncPrefix(std::ostream& os) final; |
49 | void (const PrimFunc& f) final; |
50 | void VisitStmt_(const ForNode* op) final; |
51 | void PrintStorageSync(const CallNode* op) final; |
52 | void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) |
53 | void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs, |
54 | std::ostream& os) final; // NOLINT(*) |
55 | void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) |
56 | void PrintVecElemLoad(const std::string& vec, DataType t, int i, |
57 | std::ostream& os) final; // NOLINT(*) |
58 | void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; |
59 | void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) |
60 | void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final; |
61 | std::string CastFromTo(std::string value, DataType from, DataType target) final; |
62 | // overload visitor |
63 | void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*) |
64 | void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*) |
65 | void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*) |
66 | void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) |
67 | void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; |
68 | void VisitExpr_(const CallNode* op, std::ostream& os) final; |
69 | void VisitExpr_(const CastNode* op, std::ostream& os) final; |
70 | void VisitStmt_(const EvaluateNode* op) final; |
71 | void VisitStmt_(const AllocateNode* op) final; |
72 | void VisitStmt_(const AttrStmtNode* op) final; |
73 | |
74 | protected: |
75 | void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args, |
76 | bool skip_first_arg, std::ostream& os) final; // NOLINT(*) |
77 | |
78 | private: |
79 | // Handle volatile loads |
80 | void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op, |
81 | std::ostream& os) final; |
82 | |
83 | // Whether scope such as "__shared__" or "__constant__" is part of type. |
84 | bool IsScopePartOfType() const final { return false; } |
85 | |
86 | // Whether global barrier is needed. |
87 | bool need_global_barrier_{false}; |
88 | // Global barrier state |
89 | std::string vid_global_barrier_state_; |
90 | // Global barrier expected node. |
91 | std::string vid_global_barrier_expect_; |
92 | // whether enable fp16 |
93 | bool enable_fp16_{false}; |
94 | // whether enable bf16 |
95 | bool enable_bf16_{false}; |
96 | // whether enable int8 |
97 | bool enable_int8_{false}; |
98 | // whether enable warp shuffle intrinsics |
99 | bool enable_warp_shuffle_{false}; |
100 | // whether need math_constants.h |
101 | bool need_math_constants_h_{false}; |
102 | // whether need mma.h |
103 | bool need_mma_h_{false}; |
104 | // Op attribute map |
105 | OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle" ); |
106 | |
107 | std::unordered_map<const VarNode*, std::string> fragment_shapes; |
108 | std::unordered_map<const VarNode*, std::string> fragment_layouts; |
109 | friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p); |
110 | void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable, |
111 | std::ostream& os); |
112 | int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size); |
113 | }; |
114 | |
115 | } // namespace codegen |
116 | } // namespace tvm |
117 | |
118 | #endif // TVM_TARGET_SOURCE_CODEGEN_CUDA_H_ |
119 | |