Skip to content

Commit

Permalink
conflict (#29498)
Browse files Browse the repository at this point in the history
  • Loading branch information
cryoco authored Dec 8, 2020
1 parent 6bfc572 commit d5ff367
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 6 deletions.
20 changes: 14 additions & 6 deletions paddle/fluid/inference/tensorrt/convert/hard_swish_op.cc
Original file line number Diff line number Diff line change
Expand Up @@ -65,13 +65,21 @@ class HardSwishOpConverter : public OpConverter {
const float offset = op_desc.HasAttr("offset")
? BOOST_GET_CONST(float, op_desc.GetAttr("offset"))
: 3.0f;

nvinfer1::ILayer* layer = nullptr;

plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);

if (threshold == scale) {
auto* hsig_layer = TRT_ENGINE_ADD_LAYER(
engine_, Activation, *input, nvinfer1::ActivationType::kHARD_SIGMOID);
hsig_layer->setAlpha(1.0 / scale);
hsig_layer->setBeta(offset / scale);
nvinfer1::IElementWiseLayer* eltwise_layer = TRT_ENGINE_ADD_LAYER(
engine_, ElementWise, *input, *(hsig_layer->getOutput(0)),
nvinfer1::ElementWiseOperation::kPROD);
layer = eltwise_layer;
} else {
plugin::HardSwishPlugin* plugin =
new plugin::HardSwishPlugin(threshold, scale, offset);
layer = engine_->AddPlugin(&input, input_num, plugin);
}
auto output_name = op_desc.Output("Out")[0];
RreplenishLayerAndOutput(layer, "hard_swish", {output_name}, test_mode);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -341,6 +341,12 @@ def append_act(self, x):
return fluid.layers.hard_swish(x)


class TensorRTSubgraphPassHardSwishPluginTest(
TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.hard_swish(x, threshold=4.0, scale=8.0)


class TensorRTSubgraphPassHardSigmoidTest(TensorRTSubgraphPassActivationTest):
def append_act(self, x):
return fluid.layers.hard_sigmoid(x)
Expand Down

0 comments on commit d5ff367

Please sign in to comment.