Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

added horizontal and vertical flip augmentations #38

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
159 changes: 157 additions & 2 deletions augmentation/augmentation.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@

#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp>
#include <mlpack/core/util/to_lower.hpp>
#include <mlpack/core/data/split_data.hpp>
#include <boost/regex.hpp>

#ifndef MODELS_AUGMENTATION_HPP
Expand Down Expand Up @@ -105,7 +106,49 @@ class Augmentation
const size_t datapointDepth,
const std::string& augmentation);

private:
/**
* Applies horizontal flip transform to the splited dataset.
*
* @tparam DatasetType Datatype on which augmentation will be done.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For one 2-dimensional
* data point, set it to 1. Defaults to 1.
* @param augmentation String containing the transform.
*/
template<typename DatasetType>
void HorizontalFlipTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);

/**
* Applies verticle flip transform to the splited dataset.
*
* @tparam DatasetType Datatype on which augmentation will be done.
*
* @param dataset Dataset on which augmentation will be applied.
* @param datapointWidth Width of a single data point i.e.
* Since each column represents a seperate data
* point.
* @param datapointHeight Height of a single data point.
* @param datapointDepth Depth of a single data point. For one 2-dimensional
* data point, set it to 1. Defaults to 1.
* @param augmentation String containing the transform.
*/
template<typename DatasetType>
void VerticalFlipTransform(DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation);

private:
/**
* Function to determine if augmentation has Resize function.
*
Expand Down Expand Up @@ -170,7 +213,119 @@ class Augmentation
}
}

//! Locally held augmentations and transforms that need to be applied.
/**
* Function to determine if augmentation has horizontal-flip function.
*
* @param augmentation Optional argument to check if a string has
* horizontal-flip substring.
*/
bool HasHorizontalFlipParam(const std::string& augmentation = "")
{
if (augmentation.length())
return augmentation.find("horizontal-flip") != std::string::npos;


// Search in augmentation vector.
for(size_t i=0; i<augmentations.size(); i++)
{
if(augmentations[i].find("horizontal-flip") != std::string::npos)
return true;
}
return false;

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need this if we use a map instead.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


/**
* Function to determine if augmentation has vertical-flip function.
*
* @param augmentation Optional argument to check if a string has
* vertical-flip substring.
*/
bool HasVerticalFlipParam(const std::string& augmentation = "")
{
if (augmentation.length())
return augmentation.find("vertical-flip") != std::string::npos;


// Search in augmentation vector.
for(size_t i=0; i<augmentations.size(); i++)
{
if(augmentations[i].find("vertical-flip") != std::string::npos)
return true;
}
return false;

}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.


/**
* find if new data should horizontal flipped.
*
* @param ishortiflip Output is horizontal flipped or not.
* @param augmentation String from boolean value is extracted.
*/
void GetHorizontalFlipParam(bool& ishortiflip,
const std::string& augmentation)
{
if (!HasHorizontalFlipParam())
return;

ishortiflip = false;

// Use regex to find true or false.
boost::regex regex{"(?:true|false)"};

// Create an iterator to find matches.
boost::sregex_token_iterator matches(augmentation.begin(),
augmentation.end(), regex, 0), end;

size_t matchesCount = std::distance(matches, end);

if (matchesCount == 1)
{
ishortiflip = (*matches) == "true" ? true:false;
}
else
{
mlpack::Log::Fatal << "Invalid boolean value in " <<
augmentation << std::endl;
}
}
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If we remove the need for true and false, assume we need to flip if there is horizontal flip then we can remove this as well.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yes


/**
* find if new data should vertical flipped.
*
* @param isvertiflip Output is verticalr flipped or not.
* @param augmentation String from boolean value is extracted.
*/
void GetVerticalFlipParam(bool& isvertiflip,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same as above.

const std::string& augmentation)
{
if (!HasVerticalFlipParam())
return;

isvertiflip = false;

// Use regex to find true or false.
boost::regex regex{"^(?i)(true|false)$"};

// Create an iterator to find matches.
boost::sregex_token_iterator matches(augmentation.begin(),
augmentation.end(), regex, 0), end;

size_t matchesCount = std::distance(matches, end);

if (matchesCount == 1)
{
isvertiflip = (*matches) == "true" ? true:false;
}
else
{
mlpack::Log::Fatal << "Invalid boolean value in " <<
augmentation << std::endl;
}
}

//! Locally held augmentations and transforms that need to be applied.
std::vector<std::string> augmentations;

//! Locally held value of augmentation probability.
Expand Down
54 changes: 54 additions & 0 deletions augmentation/augmentation_impl.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,16 @@ void Augmentation::Transform(DatasetType& dataset,
this->ResizeTransform(dataset, datapointWidth, datapointHeight,
datapointDepth, augmentations[i]);
}
else if(this->HasHorizontalFlipParam(augmentations[i]))
{
this->HorizontalFlipTransform(dataset, datapointWidth, datapointHeight,
datapointDepth, augmentations[i]);
}
else if(this->HasVerticalFlipParam(augmentations[i]))
{
this->VerticalFlipTransform(dataset, datapointWidth, datapointHeight,
datapointDepth, augmentations[i]);
}
else
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here I am adding directly in the for loop because I am not able to use augmentationMap. I tried to use like -

https://stackoverflow.com/questions/14419202/c-map-of-string-and-member-function-pointer but I am getting error. So if can suggest How should I use it? How should I store the function in the map?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, I will give detailed review in a day or two. There are some more changes that we would need to make.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

okay

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you also share the exact error. I will take a look. Thanks.

{
mlpack::Log::Warn << "Unknown augmentation : \'" <<
Expand Down Expand Up @@ -70,4 +80,48 @@ void Augmentation::ResizeTransform(
dataset = std::move(output);
}

template<typename DatasetType>
void Augmentation::HorizontalFlipTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
bool ishortiflip = false;
// Get ishortiflip.
GetHorizontalFlipParam(ishortiflip, augmentation);
// if(!ishortiflip) return ;
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I will uncomment it. I have added the comment by mistake.


// We will use mlpack's split to split the dataset.
auto splitResult = mlpack::data::Split(dataset, augmentationProbability);
// We will use arma's fliplr to flip the columns.
std::get<1>(splitResult) = (arma::fliplr(std::get<1>(splitResult)));
dataset = arma::join_rows( std::get<0>(splitResult), std::get<1>(splitResult) );
dataset = std::move(dataset);

}

template<typename DatasetType>
void Augmentation::VerticalFlipTransform(
DatasetType& dataset,
const size_t datapointWidth,
const size_t datapointHeight,
const size_t datapointDepth,
const std::string& augmentation)
{
bool isvertiflip = false;
// Get isvertiflip.
GetVerticalFlipParam(isvertiflip, augmentation);
if(!isvertiflip) return ;

// We will use mlpack's split to split the dataset.
auto splitResult = mlpack::data::Split(dataset, augmentationProbability);
// We will use arma's flipud to flip the rows.
std::get<1>(splitResult) = (arma::flipud(std::get<1>(splitResult)));
dataset = arma::join_rows( std::get<0>(splitResult), std::get<1>(splitResult) );
dataset = std::move(dataset);

}

#endif
4 changes: 2 additions & 2 deletions tests/augmentation_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -42,8 +42,8 @@ BOOST_AUTO_TEST_CASE(ResizeAugmentationTest)
input.zeros(inputWidth * inputHeight * depth, 2);

// Rectangular input to sqaure output.
std::vector<std::string> augmentationVector = {"horizontal-flip",
"resize : 8"};
std::vector<std::string> augmentationVector = {"horizontal-flip : true",
"resize : 8", "resize : 8"};
Augmentation augmentation2(augmentationVector, 0.2);

// Resize function called.
Expand Down