1/*
2 * Licensed to the Apache Software Foundation (ASF) under one
3 * or more contributor license agreements. See the NOTICE file
4 * distributed with this work for additional information
5 * regarding copyright ownership. The ASF licenses this file
6 * to you under the Apache License, Version 2.0 (the
7 * "License"); you may not use this file except in compliance
8 * with the License. You may obtain a copy of the License at
9 *
10 * http://www.apache.org/licenses/LICENSE-2.0
11 *
12 * Unless required by applicable law or agreed to in writing,
13 * software distributed under the License is distributed on an
14 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
15 * KIND, either express or implied. See the License for the
16 * specific language governing permissions and limitations
17 * under the License.
18 */
19
20/*!
21 * \file codegen_cuda.h
22 * \brief Utility to generate cuda code
23 */
24#ifndef TVM_TARGET_SOURCE_CODEGEN_CUDA_H_
25#define TVM_TARGET_SOURCE_CODEGEN_CUDA_H_
26
27#include <tvm/target/codegen.h>
28#include <tvm/tir/expr.h>
29#include <tvm/tir/op.h>
30
31#include <string>
32#include <unordered_map>
33
34#include "codegen_c.h"
35
36namespace tvm {
37namespace codegen {
38
39class CodeGenCUDA final : public CodeGenC {
40 public:
41 CodeGenCUDA();
42 void Init(bool output_ssa);
43 std::string Finish();
44 bool need_include_path() {
45 return (enable_fp16_ || enable_bf16_ || enable_int8_ || need_math_constants_h_ || need_mma_h_);
46 }
47 // override behavior
48 void PrintFuncPrefix(std::ostream& os) final;
49 void PrintExtraAttrs(const PrimFunc& f) final;
50 void VisitStmt_(const ForNode* op) final;
51 void PrintStorageSync(const CallNode* op) final;
52 void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*)
53 void PrintVecBinaryOp(const std::string& op, DataType t, PrimExpr lhs, PrimExpr rhs,
54 std::ostream& os) final; // NOLINT(*)
55 void PrintType(DataType t, std::ostream& os) final; // NOLINT(*)
56 void PrintVecElemLoad(const std::string& vec, DataType t, int i,
57 std::ostream& os) final; // NOLINT(*)
58 void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final;
59 void BindThreadIndex(const IterVar& iv) final; // NOLINT(*)
60 void PrintVecElemLoadExpr(DataType t, int i, const std::string& value, std::ostream& os) final;
61 std::string CastFromTo(std::string value, DataType from, DataType target) final;
62 // overload visitor
63 void VisitExpr_(const RampNode* op, std::ostream& os) final; // NOLINT(*)
64 void VisitExpr_(const ShuffleNode* op, std::ostream& os) final; // NOLINT(*)
65 void VisitExpr_(const SelectNode* op, std::ostream& os) final; // NOLINT(*)
66 void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*)
67 void VisitExpr_(const FloatImmNode* op, std::ostream& os) final;
68 void VisitExpr_(const CallNode* op, std::ostream& os) final;
69 void VisitExpr_(const CastNode* op, std::ostream& os) final;
70 void VisitStmt_(const EvaluateNode* op) final;
71 void VisitStmt_(const AllocateNode* op) final;
72 void VisitStmt_(const AttrStmtNode* op) final;
73
74 protected:
75 void PrintCallExtern(Type ret_type, String global_symbol, const Array<PrimExpr>& args,
76 bool skip_first_arg, std::ostream& os) final; // NOLINT(*)
77
78 private:
79 // Handle volatile loads
80 void HandleVolatileLoads(const std::string& value, const BufferLoadNode* op,
81 std::ostream& os) final;
82
83 // Whether scope such as "__shared__" or "__constant__" is part of type.
84 bool IsScopePartOfType() const final { return false; }
85
86 // Whether global barrier is needed.
87 bool need_global_barrier_{false};
88 // Global barrier state
89 std::string vid_global_barrier_state_;
90 // Global barrier expected node.
91 std::string vid_global_barrier_expect_;
92 // whether enable fp16
93 bool enable_fp16_{false};
94 // whether enable bf16
95 bool enable_bf16_{false};
96 // whether enable int8
97 bool enable_int8_{false};
98 // whether enable warp shuffle intrinsics
99 bool enable_warp_shuffle_{false};
100 // whether need math_constants.h
101 bool need_math_constants_h_{false};
102 // whether need mma.h
103 bool need_mma_h_{false};
104 // Op attribute map
105 OpAttrMap<bool> op_need_warp_shuffle_ = Op::GetAttrMap<bool>("cuda.need_warp_shuffle");
106
107 std::unordered_map<const VarNode*, std::string> fragment_shapes;
108 std::unordered_map<const VarNode*, std::string> fragment_layouts;
109 friend void PrintConst(const FloatImmNode* op, std::ostream& os, CodeGenCUDA* p);
110 void PrintWmmaScope(const std::string& scope, DataType t, const VarNode* variable,
111 std::ostream& os);
112 int32_t GetWmmaFragmentSize(const std::string& scope, const VarNode* variable, int32_t size);
113};
114
115} // namespace codegen
116} // namespace tvm
117
118#endif // TVM_TARGET_SOURCE_CODEGEN_CUDA_H_
119