1 | /* Copyright 2015 The TensorFlow Authors. All Rights Reserved. |
---|---|
2 | |
3 | Licensed under the Apache License, Version 2.0 (the "License"); |
4 | you may not use this file except in compliance with the License. |
5 | You may obtain a copy of the License at |
6 | |
7 | http://www.apache.org/licenses/LICENSE-2.0 |
8 | |
9 | Unless required by applicable law or agreed to in writing, software |
10 | distributed under the License is distributed on an "AS IS" BASIS, |
11 | WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | See the License for the specific language governing permissions and |
13 | limitations under the License. |
14 | ==============================================================================*/ |
15 | |
16 | #include "tensorflow/core/common_runtime/costmodel_manager.h" |
17 | #include "tensorflow/core/lib/gtl/map_util.h" |
18 | |
19 | namespace tensorflow { |
20 | |
21 | namespace { |
22 | |
23 | static const string kCostModelLogTag = "COST_MODEL"; |
24 | |
25 | } // namespace |
26 | |
27 | CostModelManager::~CostModelManager() { |
28 | for (auto it : cost_models_) { |
29 | delete it.second; |
30 | } |
31 | } |
32 | |
33 | CostModel* CostModelManager::FindOrCreateCostModel(const Graph* graph) { |
34 | mutex_lock l(mu_); |
35 | auto it = cost_models_.find(graph); |
36 | if (it != cost_models_.end()) { |
37 | return it->second; |
38 | } |
39 | CostModel* cost_model = new CostModel(false); |
40 | cost_model->InitFromGraph(*graph); |
41 | cost_models_.emplace(graph, cost_model); |
42 | return cost_model; |
43 | } |
44 | |
45 | bool CostModelManager::RemoveCostModelForGraph(const Graph* graph) { |
46 | mutex_lock l(mu_); |
47 | auto itr = cost_models_.find(graph); |
48 | if (itr == cost_models_.end()) { |
49 | return false; |
50 | } |
51 | delete itr->second; |
52 | cost_models_.erase(graph); |
53 | return true; |
54 | } |
55 | |
56 | Status CostModelManager::AddToCostGraphDef(const Graph* graph, |
57 | CostGraphDef* cost_graph) { |
58 | mutex_lock l(mu_); |
59 | // Get the cost model for the graph. |
60 | auto it = cost_models_.find(graph); |
61 | if (it == cost_models_.end()) { |
62 | return errors::InvalidArgument("The cost model graph doesn't exist."); |
63 | } |
64 | CostModel* cost_model = it->second; |
65 | cost_model->AddToCostGraphDef(graph, cost_graph); |
66 | return OkStatus(); |
67 | } |
68 | |
69 | } // namespace tensorflow |
70 |