-
Notifications
You must be signed in to change notification settings - Fork 40
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
added horizontal and vertical flip augmentations #38
base: master
Are you sure you want to change the base?
Changes from 1 commit
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -12,6 +12,7 @@ | |
|
||
#include <mlpack/methods/ann/layer/bilinear_interpolation.hpp> | ||
#include <mlpack/core/util/to_lower.hpp> | ||
#include <mlpack/core/data/split_data.hpp> | ||
#include <boost/regex.hpp> | ||
|
||
#ifndef MODELS_AUGMENTATION_HPP | ||
|
@@ -105,7 +106,49 @@ class Augmentation | |
const size_t datapointDepth, | ||
const std::string& augmentation); | ||
|
||
private: | ||
/** | ||
* Applies horizontal flip transform to the splited dataset. | ||
* | ||
* @tparam DatasetType Datatype on which augmentation will be done. | ||
* | ||
* @param dataset Dataset on which augmentation will be applied. | ||
* @param datapointWidth Width of a single data point i.e. | ||
* Since each column represents a seperate data | ||
* point. | ||
* @param datapointHeight Height of a single data point. | ||
* @param datapointDepth Depth of a single data point. For one 2-dimensional | ||
* data point, set it to 1. Defaults to 1. | ||
* @param augmentation String containing the transform. | ||
*/ | ||
template<typename DatasetType> | ||
void HorizontalFlipTransform(DatasetType& dataset, | ||
const size_t datapointWidth, | ||
const size_t datapointHeight, | ||
const size_t datapointDepth, | ||
const std::string& augmentation); | ||
|
||
/** | ||
* Applies verticle flip transform to the splited dataset. | ||
* | ||
* @tparam DatasetType Datatype on which augmentation will be done. | ||
* | ||
* @param dataset Dataset on which augmentation will be applied. | ||
* @param datapointWidth Width of a single data point i.e. | ||
* Since each column represents a seperate data | ||
* point. | ||
* @param datapointHeight Height of a single data point. | ||
* @param datapointDepth Depth of a single data point. For one 2-dimensional | ||
* data point, set it to 1. Defaults to 1. | ||
* @param augmentation String containing the transform. | ||
*/ | ||
template<typename DatasetType> | ||
void VerticalFlipTransform(DatasetType& dataset, | ||
const size_t datapointWidth, | ||
const size_t datapointHeight, | ||
const size_t datapointDepth, | ||
const std::string& augmentation); | ||
|
||
private: | ||
/** | ||
* Function to determine if augmentation has Resize function. | ||
* | ||
|
@@ -170,7 +213,119 @@ class Augmentation | |
} | ||
} | ||
|
||
//! Locally held augmentations and transforms that need to be applied. | ||
/** | ||
* Function to determine if augmentation has horizontal-flip function. | ||
* | ||
* @param augmentation Optional argument to check if a string has | ||
* horizontal-flip substring. | ||
*/ | ||
bool HasHorizontalFlipParam(const std::string& augmentation = "") | ||
{ | ||
if (augmentation.length()) | ||
return augmentation.find("horizontal-flip") != std::string::npos; | ||
|
||
|
||
// Search in augmentation vector. | ||
for(size_t i=0; i<augmentations.size(); i++) | ||
{ | ||
if(augmentations[i].find("horizontal-flip") != std::string::npos) | ||
return true; | ||
} | ||
return false; | ||
|
||
} | ||
|
||
/** | ||
* Function to determine if augmentation has vertical-flip function. | ||
* | ||
* @param augmentation Optional argument to check if a string has | ||
* vertical-flip substring. | ||
*/ | ||
bool HasVerticalFlipParam(const std::string& augmentation = "") | ||
{ | ||
if (augmentation.length()) | ||
return augmentation.find("vertical-flip") != std::string::npos; | ||
|
||
|
||
// Search in augmentation vector. | ||
for(size_t i=0; i<augmentations.size(); i++) | ||
{ | ||
if(augmentations[i].find("vertical-flip") != std::string::npos) | ||
return true; | ||
} | ||
return false; | ||
|
||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
|
||
/** | ||
* find if new data should horizontal flipped. | ||
* | ||
* @param ishortiflip Output is horizontal flipped or not. | ||
* @param augmentation String from boolean value is extracted. | ||
*/ | ||
void GetHorizontalFlipParam(bool& ishortiflip, | ||
const std::string& augmentation) | ||
{ | ||
if (!HasHorizontalFlipParam()) | ||
return; | ||
|
||
ishortiflip = false; | ||
|
||
// Use regex to find true or false. | ||
boost::regex regex{"(?:true|false)"}; | ||
|
||
// Create an iterator to find matches. | ||
boost::sregex_token_iterator matches(augmentation.begin(), | ||
augmentation.end(), regex, 0), end; | ||
|
||
size_t matchesCount = std::distance(matches, end); | ||
|
||
if (matchesCount == 1) | ||
{ | ||
ishortiflip = (*matches) == "true" ? true:false; | ||
} | ||
else | ||
{ | ||
mlpack::Log::Fatal << "Invalid boolean value in " << | ||
augmentation << std::endl; | ||
} | ||
} | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If we remove the need for true and false, assume we need to flip if there is horizontal flip then we can remove this as well. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. yes |
||
|
||
/** | ||
* find if new data should vertical flipped. | ||
* | ||
* @param isvertiflip Output is verticalr flipped or not. | ||
* @param augmentation String from boolean value is extracted. | ||
*/ | ||
void GetVerticalFlipParam(bool& isvertiflip, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same as above. |
||
const std::string& augmentation) | ||
{ | ||
if (!HasVerticalFlipParam()) | ||
return; | ||
|
||
isvertiflip = false; | ||
|
||
// Use regex to find true or false. | ||
boost::regex regex{"^(?i)(true|false)$"}; | ||
|
||
// Create an iterator to find matches. | ||
boost::sregex_token_iterator matches(augmentation.begin(), | ||
augmentation.end(), regex, 0), end; | ||
|
||
size_t matchesCount = std::distance(matches, end); | ||
|
||
if (matchesCount == 1) | ||
{ | ||
isvertiflip = (*matches) == "true" ? true:false; | ||
} | ||
else | ||
{ | ||
mlpack::Log::Fatal << "Invalid boolean value in " << | ||
augmentation << std::endl; | ||
} | ||
} | ||
|
||
//! Locally held augmentations and transforms that need to be applied. | ||
std::vector<std::string> augmentations; | ||
|
||
//! Locally held value of augmentation probability. | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -38,6 +38,16 @@ void Augmentation::Transform(DatasetType& dataset, | |
this->ResizeTransform(dataset, datapointWidth, datapointHeight, | ||
datapointDepth, augmentations[i]); | ||
} | ||
else if(this->HasHorizontalFlipParam(augmentations[i])) | ||
{ | ||
this->HorizontalFlipTransform(dataset, datapointWidth, datapointHeight, | ||
datapointDepth, augmentations[i]); | ||
} | ||
else if(this->HasVerticalFlipParam(augmentations[i])) | ||
{ | ||
this->VerticalFlipTransform(dataset, datapointWidth, datapointHeight, | ||
datapointDepth, augmentations[i]); | ||
} | ||
else | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Here I am adding directly in the https://stackoverflow.com/questions/14419202/c-map-of-string-and-member-function-pointer but I am getting error. So if can suggest How should I use it? How should I store the function in the map? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Hmm, I will give detailed review in a day or two. There are some more changes that we would need to make. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. okay There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you also share the exact error. I will take a look. Thanks. |
||
{ | ||
mlpack::Log::Warn << "Unknown augmentation : \'" << | ||
|
@@ -70,4 +80,48 @@ void Augmentation::ResizeTransform( | |
dataset = std::move(output); | ||
} | ||
|
||
template<typename DatasetType> | ||
void Augmentation::HorizontalFlipTransform( | ||
DatasetType& dataset, | ||
const size_t datapointWidth, | ||
const size_t datapointHeight, | ||
const size_t datapointDepth, | ||
const std::string& augmentation) | ||
{ | ||
bool ishortiflip = false; | ||
// Get ishortiflip. | ||
GetHorizontalFlipParam(ishortiflip, augmentation); | ||
// if(!ishortiflip) return ; | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I will uncomment it. I have added the comment by mistake. |
||
|
||
// We will use mlpack's split to split the dataset. | ||
auto splitResult = mlpack::data::Split(dataset, augmentationProbability); | ||
// We will use arma's fliplr to flip the columns. | ||
std::get<1>(splitResult) = (arma::fliplr(std::get<1>(splitResult))); | ||
dataset = arma::join_rows( std::get<0>(splitResult), std::get<1>(splitResult) ); | ||
dataset = std::move(dataset); | ||
|
||
} | ||
|
||
template<typename DatasetType> | ||
void Augmentation::VerticalFlipTransform( | ||
DatasetType& dataset, | ||
const size_t datapointWidth, | ||
const size_t datapointHeight, | ||
const size_t datapointDepth, | ||
const std::string& augmentation) | ||
{ | ||
bool isvertiflip = false; | ||
// Get isvertiflip. | ||
GetVerticalFlipParam(isvertiflip, augmentation); | ||
if(!isvertiflip) return ; | ||
|
||
// We will use mlpack's split to split the dataset. | ||
auto splitResult = mlpack::data::Split(dataset, augmentationProbability); | ||
// We will use arma's flipud to flip the rows. | ||
std::get<1>(splitResult) = (arma::flipud(std::get<1>(splitResult))); | ||
dataset = arma::join_rows( std::get<0>(splitResult), std::get<1>(splitResult) ); | ||
dataset = std::move(dataset); | ||
|
||
} | ||
|
||
#endif |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think we need this if we use a map instead.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
yes