1 | /* Copyright 2022 The TensorFlow Authors. All Rights Reserved. |
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #ifndef TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_ |
17 | #define TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_ |
18 | |
19 | #include <string> |
20 | #include <unordered_map> |
21 | #include <vector> |
22 | |
23 | #include "absl/strings/string_view.h" |
24 | #include "tensorflow/c/eager/c_api_experimental.h" |
25 | |
26 | namespace tensorflow { |
27 | namespace dtensor { |
28 | |
29 | // Configure a custom device which runs dtensor while executing |
30 | // operations on `underlying_devices`. Allocates `device_info` and fills |
31 | // `device`, which should then be passed to |
32 | // TFE_RegisterCustomDevice. This only affects eager execution. |
33 | // |
34 | // `device_name` arg should match the `device_name` argument to |
35 | // TFE_RegisterCustomDevice, and is the name of the custom device itself |
36 | // (e.g. pass it to `tf.device` to place operations on it from Python). |
37 | void AllocateDTensorDevice(absl::string_view device_name, |
38 | TFE_CustomDevice* device, void** device_info); |
39 | |
40 | // Add a mesh to the layout propagator indicated by `device_info`. |
41 | // |
42 | // `serialized_mesh` is a serialized Mesh proto. |
43 | // |
44 | // is_async indicates whether DTensor operations on this mesh will return |
45 | // immediately (with "non-ready" handles) or block until executed. This is |
46 | // exposed as an option for ease of debugging, and will typically be on. |
47 | // |
48 | // `is_host_mesh` indicates this is a CPU mesh used only for sea-of-donuts-style |
49 | // host collectives. |
50 | void AddMesh(const std::string& serialized_mesh, void* device_info, |
51 | bool is_async, bool is_host_mesh, TF_Status* status); |
52 | |
53 | // Sets a requested layout for outputs of all operations. |
54 | void ExperimentalSetDefaultLayout(const std::string& serialized_layout, |
55 | void* device_info, TF_Status* status); |
56 | void ExperimentalClearDefaultLayout(void* device_info, TF_Status* status); |
57 | |
58 | // TODO(b/175928457): remove once the bug is fixed. |
59 | // Sets a requested default mesh. |
60 | void ExperimentalSetDefaultMesh(const std::string& serialized_mesh, |
61 | void* device_info, TF_Status* status); |
62 | void ExperimentalClearDefaultMesh(void* device_info, TF_Status* status); |
63 | |
64 | // Determines whether tensors with a shape previously associated with only one |
65 | // layout use that layout if nothing else can be inferred. |
66 | void SetSameShapePolicy(void* device_info, bool enabled); |
67 | |
68 | // Sets the global device ID-to-core ID mapping for a mesh. Global device IDs |
69 | // are equal to XLA replica IDs for the single XLA computation used by DTensor. |
70 | // |
71 | // See the comment above Mesh::tpu_core_ids() for some nuances. |
72 | void SetTPUCoreIDs(const std::string& mesh_name, |
73 | const std::vector<int>& tpu_core_ids, void* device_info, |
74 | TF_Status* status); |
75 | |
76 | // TODO(b/187112276): Delete once we have the TPUCoreIDs live with Device. |
77 | void ClearTPUCoreIDs(void* device_info); |
78 | |
79 | // Returns TPU core locations when given a list of TPU core IDs. |
80 | std::vector<std::vector<int>> TPUCoreIDsToLocations( |
81 | TFE_Context* context, const std::vector<int>& tpu_core_ids, |
82 | void* device_info); |
83 | |
84 | // Returns TPU core IDs when given a list of TPU core locations. |
85 | std::vector<int> TPUCoreLocationsToIDs( |
86 | TFE_Context* context, |
87 | const std::vector<std::vector<int>>& tpu_core_locations, void* device_info); |
88 | |
89 | // Pack `inputs` tensors into a single parallel tensor handle. |
90 | TFE_TensorHandle* Pack(TFE_Context* context, int num_inputs, |
91 | TFE_TensorHandle** inputs, |
92 | const std::string& string_layout, void* device_info, |
93 | TF_Status* status); |
94 | |
95 | // Returns the raw components placed on each device of `inputs`'s mesh. |
96 | std::vector<TFE_TensorHandle*> Unpack(TFE_Context* context, |
97 | TFE_TensorHandle* input, |
98 | void* device_info, TF_Status* status); |
99 | |
100 | // Returns the layout of the dtensor 'input'. |
101 | std::string FetchLayout(TFE_Context* context, TFE_TensorHandle* input, |
102 | void* device_info, TF_Status* status); |
103 | |
104 | // Pack `indices`, `values`, `shapes` tensors into a SparseTensorWithLayout. |
105 | TFE_TensorHandle* SparsePack(TFE_Context* context, int num_inputs, |
106 | TFE_TensorHandle** indices, |
107 | TFE_TensorHandle** values, |
108 | TFE_TensorHandle** shapes, |
109 | const std::string& string_layout, |
110 | void* device_info, TF_Status* status); |
111 | |
112 | // Returns whether `input` is a sparse dtensor. Used in `Unpack` at the python |
113 | // level to determine whether we should wrap component tensors back into a |
114 | // SparseTensor. |
115 | bool IsSparseDTensor(TFE_Context* context, TFE_TensorHandle* input, |
116 | void* device_info, TF_Status* status); |
117 | |
118 | // Returns a dictionary with cache hits and cache miss information. |
119 | // Cache hit count is mapped under 'hit', and cache miss count is mapped under |
120 | // 'miss'. |
121 | std::unordered_map<std::string, int> GetFunctionCacheHitAndMissCount( |
122 | TFE_Context* context, void* device_info, TF_Status* status); |
123 | } // namespace dtensor |
124 | } // namespace tensorflow |
125 | |
126 | #endif // TENSORFLOW_DTENSOR_CC_DTENSOR_DEVICE_H_ |
127 | |