1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_CC_TPU_SYSTEM_INTERFACE_H_ |
17 | #define TENSORFLOW_DTENSOR_CC_TPU_SYSTEM_INTERFACE_H_ |
18 | |
19 | #include <vector> |
20 | |
21 | #include "absl/time/time.h" |
22 | #include "tensorflow/core/framework/op_kernel.h" |
23 | #include "tensorflow/core/framework/resource_mgr.h" |
24 | #include "tensorflow/core/platform/logging.h" |
25 | #include "tensorflow/core/platform/status.h" |
26 | |
27 | // Forward declare TFE_Context to avoid interface depending on c_api. |
28 | typedef struct TFE_Context TFE_Context; |
29 | |
30 | namespace tensorflow { |
31 | namespace dtensor { |
32 | |
33 | // DTensor TPU ops by default use the stream_executor-based TPU runtime. |
34 | // This class defines what an alternative runtime (e.g. TFRT) needs to be |
35 | // capable of to replace the default runtime. |
36 | class TpuSystemInterface { |
37 | public: |
38 | virtual ~TpuSystemInterface() = default; |
39 | |
40 | virtual Status Initialize(OpKernelContext* ctx, ResourceMgr* rmgr, |
41 | absl::Duration retry_timeout, |
42 | std::vector<int32>* core_id_output_vec) = 0; |
43 | |
44 | virtual Status Shutdown() = 0; |
45 | |
46 | virtual std::vector<std::vector<int>> TPUCoreIDsToLocations( |
47 | TFE_Context* context, const std::vector<int>& tpu_core_ids) = 0; |
48 | |
49 | virtual std::vector<int> TPUCoreLocationsToIDs( |
50 | TFE_Context* context, |
51 | const std::vector<std::vector<int>>& tpu_core_locations) = 0; |
52 | }; |
53 | |
54 | // Sets a TPU system for DTensor to initialize and shut down the TPU mesh. |
55 | // This function takes over the ownership of `tpu_system`. |
56 | void SetPreferredTpuSystem(TpuSystemInterface* tpu_system); |
57 | |
58 | // Returns the currently set preferred TPU system, nullptr if none. |
59 | TpuSystemInterface* GetPreferredTpuSystem(); |
60 | |
61 | } // namespace dtensor |
62 | } // namespace tensorflow |
63 | |
64 | #endif // TENSORFLOW_DTENSOR_CC_TPU_SYSTEM_INTERFACE_H_ |
65 | |