1 | #include <torch/csrc/distributed/c10d/Utils.hpp> |
---|---|
2 | |
3 | #include <algorithm> |
4 | #include <cstring> |
5 | #include <memory> |
6 | #include <string> |
7 | #include <thread> |
8 | |
9 | namespace c10d { |
10 | |
11 | std::string parse_env(const char* env_var_name) { |
12 | char* stringValue = std::getenv(env_var_name); |
13 | std::string res = "N/A"; |
14 | if (stringValue != nullptr) { |
15 | res = stringValue; |
16 | } |
17 | return res; |
18 | } |
19 | |
20 | std::vector<at::Tensor> getTensorShapes( |
21 | const std::vector<at::Tensor>& tensors) { |
22 | std::vector<at::Tensor> shapeTensors; |
23 | shapeTensors.reserve(tensors.size()); |
24 | for (const auto& tensor : tensors) { |
25 | // Use `at::tensor()` to copy the data underlying `sizes()` since it may be |
26 | // released elsewhere. |
27 | at::Tensor shapesTensor = |
28 | at::tensor(tensor.sizes(), at::TensorOptions().dtype(at::kLong)); |
29 | shapeTensors.emplace_back(std::move(shapesTensor)); |
30 | } |
31 | return shapeTensors; |
32 | } |
33 | |
34 | } // namespace c10d |
35 |