Skip to content

Commit

Permalink
Merge pull request #228 from deephealthproject/develop
Browse files Browse the repository at this point in the history
Fix numerical stability issues
  • Loading branch information
salvacarrion authored Dec 9, 2020
2 parents b858d9a + 306193f commit cba5d66
Show file tree
Hide file tree
Showing 22 changed files with 497 additions and 221 deletions.
2 changes: 1 addition & 1 deletion examples/nn/1_mnist/7_mnist_conv.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ int main(int argc, char **argv) {
download_mnist();

// Settings
int epochs = 5;
int epochs = 10;
int batch_size = 100;
int num_classes = 10;

Expand Down
19 changes: 7 additions & 12 deletions examples/tensor/eddl_tests_dev.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,15 @@ using namespace eddl;
int main(int argc, char **argv) {
cout << "Tests for development. Ignore." << endl;

Tensor* t1 = new Tensor({12, INFINITY, NAN, -INFINITY, 0.0f, +INFINITY}, {2,3});
// [
// [12.00 inf nan]
// [-inf 0.00 inf]
// ]
// Tensor* t1 = Tensor::load("/home/salvacarrion/Documents/Programming/C++/eddl/nan_tensor_lout_input.bin");
Tensor *t1 = new Tensor({-114.67 ,-153.77 ,-122.57 ,-113.86 ,-141.96 ,-119.93 ,-116.40 ,-135.25 ,-105.31 ,-117.21}, {1, 10}, DEV_CPU);
t1->print(2);

Tensor* r1 = t1->isfinite(); // returns new tensor

r1->print(2); // Temp.
// [
// [1.00 0.00 0.00]
// [0.00 1.00 0.00]
// ]
Tensor* t2 = Tensor::zeros_like(t1);
// t2 = t1->exp();
tensorNN::FullSoftmax(t1, t2, 1);

t2->print(2);

cout << "Done!" << endl;

Expand Down
6 changes: 6 additions & 0 deletions include/eddl/hardware/gpu/gpu_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@
#include <curand.h>
#include <curand_kernel.h>

// Same as in tensor.h
#define GPU_MIN_FLOAT 1.17549e-38f; // Minimum finite value
#define GPU_MAX_FLOAT 3.40282e+38f; // Maximum finite value
#define GPU_EPS_FLOAT 1.19209e-07f; // Machine epsilon (the difference between 1 and the least value greater than 1 that is representable).
#define GPU_LOWEST_FLOAT -3.40282e+38f; // For floating-point types: implementation-dependent; generally, the negative of max()


// GPU: Core (static)
//void gpu_transpose(Tensor *A, Tensor *B);
Expand Down
3 changes: 0 additions & 3 deletions include/eddl/hardware/gpu/nn/gpu_tensor_nn_kernels.h
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,6 @@
#include <cuda.h>
#include <cstdio>

// todo
#define GPU_MAX_FLOAT 1000000.0f
#define GPU_MIN_FLOAT -10000000.0f

// GPU: Activations
__global__ void relu(float *a,float *b,long int size);
Expand Down
14 changes: 13 additions & 1 deletion include/eddl/tensor/tensor.h
Original file line number Diff line number Diff line change
Expand Up @@ -56,6 +56,17 @@
#define MAX_GPUS 8
#define MAX_FPGAS 8

#define CPU_MIN_FLOAT 1.17549e-38f; // Minimum finite value
#define CPU_MAX_FLOAT 3.40282e+38f; // Maximum finite value
#define CPU_EPS_FLOAT 1.19209e-07f; // Machine epsilon (the difference between 1 and the least value greater than 1 that is representable).
#define CPU_LOWEST_FLOAT -3.40282e+38f; // For floating-point types: implementation-dependent; generally, the negative of max()

//const float CPU_MIN_FLOAT = std::numeric_limits<float>::min(); // Minimum finite value
//const float CPU_MAX_FLOAT = std::numeric_limits<float>::max(); // Maximum finite value
//const float CPU_EPS_FLOAT = std::numeric_limits<float>::epsilon(); // Machine epsilon (the difference between 1 and the least value greater than 1 that is representable).
//const float CPU_LOWEST_FLOAT = -CPU_MAX_FLOAT; // For floating-point types: implementation-dependent; generally, the negative of max()


using namespace std;

// TODO: Remove this. Don't like here
Expand Down Expand Up @@ -2422,7 +2433,8 @@ class Tensor {
*/
Tensor* isnan();
static void isnan(Tensor *A, Tensor* B);


bool anynan();

/**
* @brief Test element-wise for negative infinity.
Expand Down
2 changes: 1 addition & 1 deletion src/hardware/cpu/cpu_comparison.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@

#include "eddl/hardware/cpu/cpu_tensor.h"
#include "eddl/system_info.h"
#include <limits>


// CPU: Logic functions: Truth value testing
bool cpu_all(Tensor *A){
Expand Down
3 changes: 2 additions & 1 deletion src/hardware/cpu/cpu_indexing.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@


#include "eddl/hardware/cpu/cpu_tensor.h"
#include <limits>



std::pair<unsigned int*, int> cpu_nonzero(Tensor *A){
// This can be improved:
Expand Down
12 changes: 5 additions & 7 deletions src/hardware/cpu/nn/cpu_activations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -274,15 +274,15 @@ void cpu_full_softmax_batched_2d(Tensor *A, Tensor *B, bool stable){

// Numerical stability (opt.)
// stable => first value, no stable => 0.0f
float max_value = 0.0f;
float max_value = CPU_LOWEST_FLOAT;
if(stable){
for(int j=start; j<end; j++){
if (A->ptr[j] > max_value) { max_value = A->ptr[j]; }
}
}

// Numerator
float denominator = 0.0f;
float denominator = CPU_EPS_FLOAT;
for(int j=start; j<end; j++){
float value = ::expf(A->ptr[j] - max_value);
B->ptr[j] = value;
Expand Down Expand Up @@ -318,19 +318,17 @@ void cpu_full_softmax_nd(Tensor *A, Tensor *B, int axis, bool stable){

// Numerical stability (opt.)
// stable => first value, no stable => 0.0f
float max_value = 0.0f;
float max_value = CPU_LOWEST_FLOAT;
if (stable) {
for (int i = start_b; i <= end_b; i += inner_stride) {
if (A->ptr[i] > max_value) { max_value = A->ptr[i]; }
}
}

// Numerator
float denominator = 0.0f;
float denominator = CPU_EPS_FLOAT;
for (int i = start_b; i <= end_b; i += inner_stride) {
// cout << i << endl;
// cout << A->ptr[i] << " exp(x-max)=" << value << endl;
float value = ::expf(A->ptr[i] - max_value);
float value = ::expf(A->ptr[i] - max_value); // Highest number should be zero
B->ptr[i] = value;
denominator += value;
}
Expand Down
2 changes: 1 addition & 1 deletion src/hardware/cpu/nn/cpu_pool.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ void cpu_mpool2D(PoolDescriptor *D){
for(int j=-D->padcl; j<=D->ic+D->padcr-D->kc; j+=D->sc, p++) { // cols: left-right

// Get max value in window
float max = std::numeric_limits<float>::min();
float max = CPU_LOWEST_FLOAT;
for(int ki=0; ki<D->kr; ki++){ // rows (kernel): top-bottom
for(int kj=0; kj<D->kc; kj++) { // cols (kernel): left-right

Expand Down
1 change: 0 additions & 1 deletion src/hardware/gpu/nn/gpu_activations.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,6 @@
#include "eddl/tensor/tensor.h"
#include "eddl/descriptors/descriptors.h"

#define PRECISION_FLOAT -std::numeric_limits<float>::max()


void gpu_relu(Tensor *A,Tensor *B){
Expand Down
12 changes: 6 additions & 6 deletions src/hardware/gpu/nn/gpu_activations_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -274,17 +274,17 @@ __global__ void full_softmax_batched(float *A, float *B, bool stable, unsigned i

// Numerical stability (opt.)
// stable => first value, no stable => 0.0f
float max_value = 0.0f;
float max_value = GPU_LOWEST_FLOAT;
if(stable){
for(unsigned int j=start; j<end; j++){
if (A[j] > max_value) { max_value = A[j]; }
}
}

// Numerator
float denominator = 0.0f;
float denominator = GPU_EPS_FLOAT;
for(unsigned int j=start; j<end; j++){
float value = expf(A[j] - max_value);
float value = expf(A[j] - max_value); // Highest number should be zero
B[j] = value;
denominator += value;
}
Expand Down Expand Up @@ -334,17 +334,17 @@ __global__ void full_softmax_nd(float *A, float *B, bool stable, int n_samples,

// Numerical stability (opt.)
// stable => first value, no stable => 0.0f
float max_value = 0.0f;
float max_value = GPU_LOWEST_FLOAT;
if(stable){
for (int i = start_b; i <= end_b; i += inner_stride) {
if (A[i] > max_value) { max_value = A[i]; }
}
}

// Numerator
float denominator = 0.0f;
float denominator = GPU_EPS_FLOAT;
for (int i = start_b; i <= end_b; i += inner_stride) {
float value = expf(A[i] - max_value);
float value = expf(A[i] - max_value); // Highest number should be zero
B[i] = value;
denominator += value;
}
Expand Down
2 changes: 1 addition & 1 deletion src/hardware/gpu/nn/gpu_pool_kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ __global__ void maxpool2d(float* I, int batch,int irows,int icols, int idepth, i
// Check bounds
if (i <= max_i && j <= max_j){

float max = GPU_MIN_FLOAT;
float max = GPU_LOWEST_FLOAT;
//float max = I[i,j];
for (int ki = 0; ki < kr; ki++){ // rows (kernel): top-bottom
for (int kj = 0; kj < kc; kj++) { // cols (kernel): left-right
Expand Down
4 changes: 2 additions & 2 deletions src/initializers/initializer_glorot_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ void IGlorotUniform::apply(Tensor* params) {

params->fill_rand_signed_uniform_(1.0);

float limits=sqrtf(6.0 / (float)(fin+fout));
float limits=sqrtf(6.0f / (float)(fin+fout));

params->mult_(limits);
}
Expand All @@ -55,7 +55,7 @@ void IGlorotUniform::apply(Tensor* params) {

params->fill_rand_signed_uniform_(1.0);

float limits=sqrtf(6.0 / (float)(fin+fout));
float limits=sqrtf(6.0f / (float)(fin+fout));

params->mult_(limits);

Expand Down
2 changes: 1 addition & 1 deletion src/initializers/initializer_he_uniform.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ void IHeUniform::apply(Tensor* params) {

params->fill_rand_signed_uniform_(1.0);

float limits=sqrtf(6.0 / (float)(fin));
float limits=sqrtf(6.0f / (float)(fin));

params->mult_(limits);

Expand Down
2 changes: 1 addition & 1 deletion src/losses/loss_soft_cross_entropy.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
using namespace std;


LSoftCrossEntropy::LSoftCrossEntropy() : Loss("soft_cross_entropy"){
LSoftCrossEntropy::LSoftCrossEntropy() : Loss("softmax_cross_entropy"){
}


Expand Down
16 changes: 8 additions & 8 deletions src/net/net_api.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -586,10 +586,10 @@ void Net::print_loss(int b)

fprintf(stdout, "%s ( ", name.c_str());
if (losses.size()>=(k+1)) {
fprintf(stdout, "loss[%s]=%1.3f ", losses[k]->name.c_str(), total_loss[k] / (length*inferenced_samples));
fprintf(stdout, "loss[%s]=%1.4f ", losses[k]->name.c_str(), total_loss[k] / (length*inferenced_samples));
}
if (metrics.size()>=(k+1)) {
fprintf(stdout, "metric[%s]=%1.3f ", metrics[k]->name.c_str(), total_metric[k] / (length*inferenced_samples));
fprintf(stdout, "metric[%s]=%1.4f ", metrics[k]->name.c_str(), total_metric[k] / (length*inferenced_samples));
}

fprintf(stdout, ") -- ");
Expand All @@ -598,11 +598,11 @@ void Net::print_loss(int b)
if ((flog_tr!=nullptr)&&(trmode)) {
fprintf(flog_tr, "%s ", name.c_str());
if (losses.size()>=(k+1)) {
fprintf(flog_tr, "loss[%s]=%1.3f ", losses[k]->name.c_str(), total_loss[k] / inferenced_samples);
fprintf(flog_tr, "loss[%s]=%1.4f ", losses[k]->name.c_str(), total_loss[k] / inferenced_samples);
}
if (metrics.size()>=(k+1)) {
if (metrics[k]->name!="none")
fprintf(flog_tr, "metric[%s]=%1.3f ", metrics[k]->name.c_str(), total_metric[k] / inferenced_samples);
fprintf(flog_tr, "metric[%s]=%1.4f ", metrics[k]->name.c_str(), total_metric[k] / inferenced_samples);
}

fprintf(flog_tr, " -- ");
Expand All @@ -612,11 +612,11 @@ void Net::print_loss(int b)
if ((flog_ts!=nullptr)&&(!trmode)) {
fprintf(flog_ts, "%s ", name.c_str());
if (losses.size()>=(k+1)) {
fprintf(flog_ts, "loss[%s]=%1.3f ", losses[k]->name.c_str(), total_loss[k] / inferenced_samples);
fprintf(flog_ts, "loss[%s]=%1.4f ", losses[k]->name.c_str(), total_loss[k] / inferenced_samples);
}
if (metrics.size()>=(k+1)) {
if (metrics[k]->name!="none")
fprintf(flog_ts, "metric[%s]=%1.3f ", metrics[k]->name.c_str(), total_metric[k] / inferenced_samples);
fprintf(flog_ts, "metric[%s]=%1.4f ", metrics[k]->name.c_str(), total_metric[k] / inferenced_samples);
}

fprintf(flog_ts, " -- ");
Expand Down Expand Up @@ -839,14 +839,14 @@ void Net::fit(vtensor tin, vtensor tout, int batch, int epochs) {

high_resolution_clock::time_point e2 = high_resolution_clock::now();
duration<double> epoch_time_span = e2 - e1;
fprintf(stdout, "%1.3f secs/batch\r", epoch_time_span.count()/(j+1));
fprintf(stdout, "%1.4f secs/batch\r", epoch_time_span.count()/(j+1));
fflush(stdout);


}
high_resolution_clock::time_point e2 = high_resolution_clock::now();
duration<double> epoch_time_span = e2 - e1;
fprintf(stdout, "\n%1.3f secs/epoch\n", epoch_time_span.count());
fprintf(stdout, "\n%1.4f secs/epoch\n", epoch_time_span.count());
}
fflush(stdout);
}
Expand Down
2 changes: 1 addition & 1 deletion src/net/net_build.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ void Net::build(Optimizer *opt, vloss lo, vmetrics me, bool initialize) {
else losses = vloss(lo);

for (int i = 0; i < losses.size(); i++) {
if (losses[i]->name == "soft_cross_entropy") lout[i]->delta_bp = 1;
if (losses[i]->name == "softmax_cross_entropy") lout[i]->delta_bp = 1;
lout[i]->target = new Tensor(lout[i]->output->getShape(), dev);
}
// set metrics
Expand Down
11 changes: 7 additions & 4 deletions src/net/net_func.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -127,11 +127,14 @@ void Net::do_compute_loss() {
int p = 0;
for (int i = 0; i < lout.size(); i++, p += 2) {
// loss value
if (losses.size()>=(i+1))
fiterr[p] = losses[i]->value(lout[i]->target, lout[i]->output);
if (losses.size()>=(i+1)){
fiterr[p] = losses[i]->value(lout[i]->target, lout[i]->output);
}

// metric value
if (metrics.size()>=(i+1))
fiterr[p + 1] = metrics[i]->value(lout[i]->target, lout[i]->output);
if (metrics.size()>=(i+1)){
fiterr[p + 1] = metrics[i]->value(lout[i]->target, lout[i]->output);
}
}

if (VERBOSE) {
Expand Down
Loading

0 comments on commit cba5d66

Please sign in to comment.