Skip to content
This repository has been archived by the owner on Jan 3, 2023. It is now read-only.

Commit

Permalink
Merge branch 'master' into r0.12
Browse files Browse the repository at this point in the history
  • Loading branch information
avijit-nervana committed Apr 10, 2019
2 parents cbbdf34 + a39cd8f commit a8d6db8
Show file tree
Hide file tree
Showing 31 changed files with 4,004 additions and 12 deletions.
10 changes: 10 additions & 0 deletions build_ngtf.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,11 @@ def main():
action="store")

parser.add_argument(
'--enable_variables_and_optimizers',
help="Ops like variable and optimizers are supported by nGraph in this version of the bridge\n",
action="store_true")

parser.add_argument(
'--use_grappler_optimizer',
help="Use Grappler optimizer instead of the optimization passes\n",
action="store_true")
Expand Down Expand Up @@ -264,6 +269,11 @@ def main():
else:
ngraph_tf_cmake_flags.extend(["-DNGRAPH_DISTRIBUTED_ENABLE=FALSE"])

if (arguments.enable_variables_and_optimizers):
ngraph_tf_cmake_flags.extend(["-DNGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS=TRUE"])
else:
ngraph_tf_cmake_flags.extend(["-DNGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS=FALSE"])

if (arguments.use_grappler_optimizer):
ngraph_tf_cmake_flags.extend(
["-DNGRAPH_TF_USE_GRAPPLER_OPTIMIZER=TRUE"])
Expand Down
39 changes: 36 additions & 3 deletions src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,18 +43,51 @@ set(SRC
ngraph_freshness_tracker.cc
ngraph_mark_for_clustering.cc
ngraph_rewrite_for_tracking.cc
ngraph_rewrite_pass.cc
ngraph_tracked_variable.cc
ngraph_utils.cc
tf_graphcycles.cc
tf_deadness_analysis.cc
version.cc
)

if(NGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
message(STATUS "NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS: ${NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS}")

if(NGRAPH_TF_ENABLE_VARIABLES_AND_OPTIMIZERS)
# common files
list(REMOVE_ITEM SRC ngraph_capture_variables.cc)
list(APPEND SRC enable_variable_ops/ngraph_capture_variables.cc)

list(REMOVE_ITEM SRC ngraph_encapsulate_op.cc)
list(APPEND SRC enable_variable_ops/ngraph_encapsulate_op.cc)

list(REMOVE_ITEM SRC ngraph_rewrite_for_tracking.cc)
list(APPEND SRC enable_variable_ops/ngraph_rewrite_for_tracking.cc)

list(REMOVE_ITEM SRC ngraph_rewrite_pass.cc)
list(APPEND SRC enable_variable_ops/ngraph_rewrite_pass.cc)

list(REMOVE_ITEM SRC ngraph_tracked_variable.cc)
list(APPEND SRC enable_variable_ops/ngraph_tracked_variable.cc)

list(REMOVE_ITEM SRC ngraph_utils.cc)
list(APPEND SRC enable_variable_ops/ngraph_utils.cc)

# new files
list(APPEND SRC enable_variable_ops/ngraph_assign_op.cc)
list(APPEND SRC enable_variable_ops/ngraph_catalog.cc)
list(APPEND SRC enable_variable_ops/ngraph_enter_in_catalog.cc)
list(APPEND SRC enable_variable_ops/ngraph_replace_op_utilities.cc)
list(APPEND SRC enable_variable_ops/ngraph_replace_variable_modifiers.cc)
list(APPEND SRC enable_variable_ops/ngraph_variable_modifiers.cc)

endif()


if(NGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
list(REMOVE_ITEM SRC ngraph_rewrite_pass.cc)
list(APPEND SRC grappler/ngraph_optimizer.cc)
add_definitions(-DNGRAPH_TF_USE_GRAPPLER_OPTIMIZER)
else()
list(APPEND SRC ngraph_rewrite_pass.cc)
endif()

message(STATUS "NGRAPH_TF_USE_GRAPPLER_OPTIMIZER: ${NGRAPH_TF_USE_GRAPPLER_OPTIMIZER}")
Expand Down
185 changes: 185 additions & 0 deletions src/enable_variable_ops/ngraph_assign_op.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,185 @@
/*******************************************************************************
* Copyright 2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use thi0s file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "tensorflow/core/common_runtime/dma_helper.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
#include "tensorflow/core/lib/strings/strcat.h"

#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/tensor_types.h"
#include "tensorflow/core/platform/default/logging.h"

#include "ngraph/event_tracing.hpp"
#include "ngraph/runtime/backend.hpp"
#include "ngraph_catalog.h"
#include "ngraph_freshness_tracker.h"
#include "ngraph_timer.h"
#include "ngraph_utils.h"
#include "ngraph_var.h"

using namespace std;
namespace ng = ngraph;

namespace tensorflow {

namespace ngraph_bridge {

/* -------------------------------------------------
//
// NGraphAssignOp
//
---------------------------------------------------*/

// Computes *input[0] = input[1]
class NGraphAssignOp : public OpKernel {
private:
bool just_looking_;
bool copy_to_tf_;
int ng_graph_id_;
static int s_instance_count;
int my_instance_id{0};

// TODO(malikshr): Do we need these attributes, exist in TF Assign ops
// use_exclusive_lock_, validate_shape_, relax_constraints_;

public:
explicit NGraphAssignOp(OpKernelConstruction* context)
: OpKernel(context), just_looking_(false), copy_to_tf_(false) {
OP_REQUIRES_OK(context, context->GetAttr("just_looking", &just_looking_));
OP_REQUIRES_OK(context, context->GetAttr("copy_to_tf", &copy_to_tf_));
OP_REQUIRES_OK(context, context->GetAttr("ngraph_graph_id", &ng_graph_id_));

NGRAPH_VLOG(4) << "NGraphAssign:: Constructor called for: " << def().name()
<< ",just looking " << PrintBool(just_looking_)
<< ",copy-to-tf " << PrintBool(copy_to_tf_) << " ,Graph ID "
<< ng_graph_id_;

OP_REQUIRES(context, IsRefType(context->input_type(0)),
errors::InvalidArgument("lhs input needs to be a ref type"));
my_instance_id = s_instance_count;
s_instance_count++;
}

void Compute(OpKernelContext* context) override {
std::ostringstream oss;
oss << "Execute: Assign_" << my_instance_id << ": " << name();
ngraph::Event event_compute(oss.str(), name(), "");

NGRAPH_VLOG(4) << "NGraphAssign:: Compute called for: " << def().name()
<< " ,just looking " << PrintBool(just_looking_)
<< " ,copy-to-tf " << PrintBool(copy_to_tf_) << " ,Graph ID "
<< ng_graph_id_;

bool log_copies = false;
OP_REQUIRES_OK(context, IsCopyLogEnabled(ng_graph_id_, log_copies));
std::stringstream copy_log_str;
copy_log_str << "KERNEL[" << type_string() << "]: " << name()
<< " ,Copy_TF " << PrintBool(copy_to_tf_) << " ,Just_Looking "
<< PrintBool(just_looking_) << "\n";
int number_of_copies = 0;

bool ref_exists = NGraphCatalog::ExistsInInputVariableSharedNameMap(
ng_graph_id_, def().name(), 0);
if (!ref_exists) {
OP_REQUIRES(context, ref_exists,
errors::Internal(
"Caught exception : RefInput to NGAssign not found \n"));
}
string get_ref_var_name = NGraphCatalog::GetInputVariableSharedName(
ng_graph_id_, def().name(), 0);

NGraphVar* var;
OP_REQUIRES_OK(context,
context->resource_manager()->Lookup<NGraphVar>(
context->resource_manager()->default_container(),
get_ref_var_name, &var));

const Tensor& rhs = context->input(1);

// We always return the input ref.
context->forward_ref_input_to_ref_output(0, 0);

// get the nGraphTensor
shared_ptr<ngraph::runtime::Tensor> ng_tensor_to_assign = var->ng_tensor();

// DO NOT CARE ABOUT SYNCING AS WE ARE ALWAYS SETTING THE NGTENSOR

// Get input[1]
string valkey = to_string(ng_graph_id_) + "_" + def().input(1);
bool valref_exists = NGraphCatalog::ExistsInEncapOutputTensorMap(valkey);
if (valref_exists) {
// Value is from encap
NGRAPH_VLOG(4) << "NGraphAssign::Getting from catalog: " << valkey;
auto ng_val = NGraphCatalog::GetTensorFromEncapOutputTensorMap(valkey);
ng_tensor_to_assign->copy_from(*ng_val);
} else {
number_of_copies++;
copy_log_str << " COPY_INP_VAL[0]";
NGRAPH_VLOG(4) << "NGraphAssign::Getting from TF : " << valkey;
void* tf_src_ptr = (void*)DMAHelper::base(&rhs);
ng_tensor_to_assign->write(
tf_src_ptr, 0, ng_tensor_to_assign->get_element_count() *
ng_tensor_to_assign->get_element_type().size());
}

mutex_lock l(*context->input_ref_mutex(0));
Tensor old_lhs = context->mutable_input(0, /* lock_held */ true);

if (copy_to_tf_) {
number_of_copies++;
copy_log_str << " COPY_TF ";
ReadNGTensor(ng_tensor_to_assign, &old_lhs);

if (!just_looking_) {
// Some tf op might update the ng-tensor value so mark it stale
copy_log_str << " SET_SYNC ";
var->sync_ng_tensor(true);
}
}

copy_log_str << " Number of copies " << number_of_copies << "\n";
if (log_copies) {
cout << copy_log_str.str();
}

// Unref Var
var->Unref();
event_compute.Stop();
ngraph::Event::write_trace(event_compute);
}
};

int NGraphAssignOp::s_instance_count = 0;

REGISTER_OP("NGraphAssign")
.Input("ref: Ref(T)")
.Input("value: T")
.Output("output_ref: Ref(T)")
.Attr("T: type")
.Attr("validate_shape: bool = true")
.Attr("use_locking: bool = true")
.Attr("just_looking: bool = false")
.Attr("copy_to_tf: bool = false")
.Attr("ngraph_graph_id: int");

REGISTER_KERNEL_BUILDER(Name("NGraphAssign").Device(DEVICE_CPU),
NGraphAssignOp);

} // namespace ngraph_bridge

} // namespace tensorflow
97 changes: 97 additions & 0 deletions src/enable_variable_ops/ngraph_capture_variables.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
/*******************************************************************************
* Copyright 2017-2019 Intel Corporation
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*******************************************************************************/

#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/graph/node_builder.h"

#include "ngraph_api.h"
#include "ngraph_capture_variables.h"
#include "ngraph_replace_op_utilities.h"
#include "ngraph_utils.h"

using namespace std;

namespace tensorflow {

namespace ngraph_bridge {

//
// Utility function to check if placement on the NGRAPH device has been
// requested.
//
// FIXME(amprocte): stubbed out for now because NGRAPH device is gone.
//
static bool NGraphPlacementRequested(const Node* node) { return true; }

//
// Main entry point for the variable-capture.
//
Status CaptureVariables(Graph* graph, std::vector<string> skip_these_nodes) {
const static std::map<
const string,
const pair<string,
function<Status(
Graph * graph, Node * node, Node * *replacement,
const string replacement_node_name,
const string replacement_op_type, const bool just_looking,
const bool outputs_ng_supported, const int graph_id,
const bool is_backend_set)>>>
CAPTURE_REPLACE_OP_MAP{
{"ApplyGradientDescent", std::make_pair("NGraphApplyGradientDescent",
ReplaceApplyGradientDescent)},
{"Assign", std::make_pair("NGraphAssign", ReplaceAssign)},
{"AssignAdd", std::make_pair("NGraphAssignAdd", ReplaceAssign)},
{"AssignSub", std::make_pair("NGraphAssignSub", ReplaceAssign)},
{"VariableV2", std::make_pair("NGraphVariable", ReplaceVariable)}};

std::vector<Node*> replaced_nodes;
for (auto node : graph->op_nodes()) {
if (NGraphPlacementRequested(node)) {
auto itr = CAPTURE_REPLACE_OP_MAP.find(node->type_string());
if (itr != CAPTURE_REPLACE_OP_MAP.end()) {
NGRAPH_VLOG(1) << "Capturing: " << node->name();
Node* replacement;

// Create the replacement node
TF_RETURN_IF_ERROR((itr->second.second)(graph, node, &replacement,
node->name(), itr->second.first,
false, false, 0, false));

std::vector<const Edge*> edges;

NGRAPH_VLOG(4) << "Replacing Node " << node->DebugString() << " with "
<< replacement->DebugString();

TF_RETURN_IF_ERROR(ReplaceInputControlEdges(graph, node, replacement));
TF_RETURN_IF_ERROR(ReplaceOutputEdges(graph, node, replacement));

replaced_nodes.push_back(node);
}

} // end of checking NGraphPlacementRequested
} // end of looping through nodes in the graph

for (auto node : replaced_nodes) {
NGRAPH_VLOG(4) << "Removing: " << node->name();
graph->RemoveNode(node);
}

return Status::OK();
}

} // namespace ngraph_bridge

} // namespace tensorflow
Loading

0 comments on commit a8d6db8

Please sign in to comment.