1 | /* |
2 | * Licensed to the Apache Software Foundation (ASF) under one |
3 | * or more contributor license agreements. See the NOTICE file |
4 | * distributed with this work for additional information |
5 | * regarding copyright ownership. The ASF licenses this file |
6 | * to you under the Apache License, Version 2.0 (the |
7 | * "License"); you may not use this file except in compliance |
8 | * with the License. You may obtain a copy of the License at |
9 | * |
10 | * http://www.apache.org/licenses/LICENSE-2.0 |
11 | * |
12 | * Unless required by applicable law or agreed to in writing, |
13 | * software distributed under the License is distributed on an |
14 | * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
15 | * KIND, either express or implied. See the License for the |
16 | * specific language governing permissions and limitations |
17 | * under the License. |
18 | */ |
19 | |
20 | /*! |
21 | * \file codegen_metal.h |
22 | * \brief Generate Metal device code. |
23 | */ |
24 | #ifndef TVM_TARGET_SOURCE_CODEGEN_METAL_H_ |
25 | #define TVM_TARGET_SOURCE_CODEGEN_METAL_H_ |
26 | |
27 | #include <tvm/target/codegen.h> |
28 | |
29 | #include <string> |
30 | |
31 | #include "codegen_c.h" |
32 | |
33 | namespace tvm { |
34 | namespace codegen { |
35 | |
36 | class CodeGenMetal final : public CodeGenC { |
37 | public: |
38 | explicit CodeGenMetal(Target target); |
39 | // override print thread tag. |
40 | void PrintArgUnionDecl(); |
41 | void AddFunction(const PrimFunc& f); // NOLINT(*) |
42 | void InitFuncState(const PrimFunc& f) final; |
43 | void PrintStorageScope(const std::string& scope, std::ostream& os) final; // NOLINT(*) |
44 | void PrintStorageSync(const CallNode* op) final; // NOLINT(*) |
45 | void PrintType(DataType t, std::ostream& os) final; // NOLINT(*) |
46 | void BindThreadIndex(const IterVar& iv) final; // NOLINT(*) |
47 | // print load of single element |
48 | void PrintVecElemLoad(const std::string& vec, DataType t, int i, |
49 | std::ostream& os) final; // NOLINT(*) |
50 | // print store of single element. |
51 | void PrintVecElemStore(const std::string& vec, DataType t, int i, const std::string& value) final; |
52 | // overload visitor |
53 | void VisitExpr_(const BroadcastNode* op, std::ostream& os) final; // NOLINT(*) |
54 | void VisitExpr_(const CallNode* op, std::ostream& os) final; // NOLINT(*) |
55 | void VisitExpr_(const FloatImmNode* op, std::ostream& os) final; |
56 | // reuse parent's function. |
57 | using CodeGenC::PrintType; |
58 | |
59 | private: |
60 | int thread_index_bits_{32}; |
61 | Target target_; |
62 | }; |
63 | } // namespace codegen |
64 | } // namespace tvm |
65 | |
66 | #endif // TVM_TARGET_SOURCE_CODEGEN_METAL_H_ |
67 | |