Experimenting with different techniques for reproducibility in deep learning.
The project deals with the following:
- pix2piXAI: Generating class-specific visualizations from input-specific visualizations like grad-CAM, Saliency Maps and SHAP using pix2pix Generative Adversarial Network
- braXAI: Interpreting braai performing Real-Bogus classification for the Zwicky Transient Facility (ZTF) using Deep learning.
- Classification of Periodic Variables present in Catalina Real-Time Transient Survey(CRTS) using Interpretable Convolutional Neural Networks
- Python 3
- Tensorflow >= 1.0
- Keras > 2.0
- keras-vis
- shap
- tcav
- scikit-image
- AIX360
- pix2pix
- nolearn
- lasagne
- theano
pip install keras-vis
pip install shap
pip install tcav
pip install nolearn
pip install Lasagne==0.1
git clone git://github.com/Theano/Theano.git
git clone https://github.com/IBM/AIX360
git clone https://github.com/affinelayer/pix2pix-tensorflow.git
git clone https://github.com/amiratag/ACE.git
A light curve is a time series dataset of magnitude, the negative logarithm of flux measurement(as smaller magnitude implies brighter objects). The measurements available in these light curve datasets are:
- Right Ascension(RA) and Declination(Dec) which provide the position of the object on the sky
- Time reference(epoch) as Julian Date
- Magnitude
- An error estimate on the magnitude
Most of the data collected from astronomical surveys are sparse, far from continuous, mostly irregular and heteroscedastic.
Astronomical objects exhibit variation in brightness due to some intrinsic physical process like explosion and or merger of matter inside or due to some extrinsic process like eclipse or rotation. These astronomical objects are termed as Variables. Variables with brightness varying by several standard deviations for a very short period of time are called Transients.
A light curve is transformed into a 2D mapping based on changes in magnitude dm and time differences dt, so that they can be used as an input to Convolutional Neural Network. Note that to give each bin in a dmdt an equal footing, the dmdt bins are of same size instead of the bin spacing depending on the actual magnitude of dm and dt. To know more about dmdts, refer Deep-Learnt Classification of Light Curves
Follow DATA.md on how the light curve data as .csv should be placed in Periodic Variable Classification/data folder.
Once the light curve data is placed in the appropriate folder configuration, run transform.py to generate dmdts and appropriate labels in Periodic Variable Classification/data/all folder.
grad-CAM is another way of visualizing attention over input which uses penultimate (pre Dense layer) Conv layer output. The intuition is to use the nearest Conv layer to utilize spatial information that gets completely lost in Dense layers. - Class Activation Maps
Even though Grad-CAM uses class-specific gradient information, in case of dmdts the structure is not apparent hence Grad-CAM visualizations vary depending on the input dmdt.
Saliency Maps are generated by computing the gradient of output category with respect to input image which would tell us how output category value changes with respect to a small change in input image pixels. All the positive values in the gradients tell us that a small change to that pixel will increase the output value. Hence, visualizing these gradients, which are the same shape as the image should provide some intuition of attention. - Saliency Maps
The basic approach of this method is to blank out each pixel of the dmdt image and then predict the changes in prediction probabilities of that dmdt image belonging to the particular class which it belonged to before blanking it out.
Deep SHAP is a high-speed approximation algorithm for SHAP values in deep learning models that builds on a connection with DeepLIFT. - shap
Note that in the above SHAP visualization of 44th test dmdt, hotter pixels increase the model's output while cooler pixels decrease the output.
The visualizations that we had generated up until now i.e. grad-CAM, Saliency Attention Maps, Blanking Exp. and SHAP are all input-specific visualizations; hence for each test data, there would be a corresponding above four visualization plots. However, for interpretability, it would be helpful if these plots are generated class-specific instead of input-specific.
To generate a class-specific visualization from several input-specific visualizations, we have formulated a technique in which we pass the test dmdts along with their corresponding visualization through a pix2pix GAN for training. The motivation behind the same is that after training, the generator of pix2pix GAN will learn the most relevant features from the visualization plots; after which the test data with prediction probability greater than 0.95 is passed as test data to the trained pix2pix GAN which will generate corresponding visualization plots. A similarity metric is then used to find the visualization plot most learnt by the pix2pix GAN and hence we get class-specific visualizations.
To see class-specific interpretations generated by pix2piXAI, click here
The basic approach of this method is to create an instance/image such that just one pixel is lightened (i.e. pixel value = 255). Thereafter this image is passed through already trained CNN which outputs the prediction probabilities. These prediction probabilities map into each of the 2D matrix pertaining to each of the class whose shape is similar to the image. Hence repeating the above process for each pixel of the original image, we get a probability mapping from one lightened pixel into a number of 2D matrices pertaining to each of the classes. This analysis allowed us to visualize which pixel is important for classification when lightened more (i.e. more number of objects lie in that pixel).
To see class-specific interpretations generated by Lighting Experiments, click here
In a CNN, each Conv layer has several learned template matching filters that maximize their output when a similar template pattern is found in the input image. Activation Maximization generates an input image that maximizes the filter output activations. This allows us to understand what sort of input patterns activate a particular filter. - Activation Maximization
To see class-specific interpretations generated by Activation Maximization, click here
braXAI concerns with interpreting braai which performs Real-Bogus classification for the Zwicky Transient Facility (ZTF) using Deep learning. The only difference between our and braai's VGG model is that in our VGG model fc_out (output) layer contains two neurons instead of one in braai's VGG model. These two neurons in the outer layer correspond to the two classes which allows for visualizations pertaining to both the classes instead of visualizations of just one of the class in braai's VGG model case; hence two neurons eases interpretation.
See this Jupyter Notebook or
Class-specific interpretations were generated for braXAI using pix2piXAI in the same manner as earlier(for Periodic Variable classification). However based on the application, real and bogus classes were further divided into subcategories for more refined interpretation
- Real Class:
- real_central: a central roughly round peak
- real_modulo_noise: a mostly blank image other than the centre (modulo noise)
- real_mixture: images with both a central peak and modulo noise
- Bogus Class:
- bogus_blanked: removing artefacts by blanking
- bogus_non_blanked: a bogus image with no blanked portion in the image
Class-specific interpretations for braXAI generated by pix2piXAI corresponds to gradCAM, Saliency Maps, Blanking Exp. and SHAP(of DIFF/SUB) visualizations for each of the above subcategories.
To see all class-specific interpretations generated by pix2piXAI for braXAI, click here
Similar to Periodic Variable classification, Activation Maximization maps are generated for real and bogus classes based on fc_out (output) layer. Below are the corresponding two interpretations.
To get more insight, we have generated visualizations for misclassifications as well. In total there were 45 misclassifications out of 1156 test images. Each misclassification plot contains SCI, REF, DIFF, grad_CAM, saliency, blanking, SHAP of SCI, SHAP of REF and SHAP of DIFF/SUB visualizations. Note that every visualization is with respect to the predicted class(mentioned in the plot's title).
To see all 45 misclassification plots, click here
- Meet Gandhi
- If you encounter any problems/bugs/issues please contact me on Github or by emailing me at gandhi.meet@btech2015.iitgn.ac.in for any bug reports/questions/suggestions. I prefer questions and bug reports on Github as that provides visibility to others who might be encountering same issues or who have the same questions.
- Ashish Mahabal - Website
-
Raghavendra Kotikalapudi and contributors.keras-vis.https://github.com/raghakot/keras-vis, 2017.
-
Phillip Isola, Jun-Yan Zhu, Tinghui Zhou, and Alexei A Efros. Image-to-image translation with conditional adversarial networks.CVPR, 2017.
-
Scott M Lundberg and Su-In Lee. A unified approach to interpreting model predictions. In I. Guyon, U. V. Luxburg, S. Bengio, H. Wallach, R. Fergus, S. Vishwanathan, and R. Garnett, editors, Advances in Neural Information Processing Systems 30, pages 4765–4774. Curran Associates, Inc., 2017.
-
Mahabal, Ashish, et al. "Deep-learnt classification of light curves." 2017 IEEE Symposium Series on Computational Intelligence (SSCI). IEEE, 2017.
-
Dmitry A Duev, Ashish Mahabal, Frank J Masci, Matthew J Graham, Ben Rusholme, Richard Walters, Ishani Karmarkar, Sara Frederick, Mansi M Kasliwal, Umaa Rebbapragada, et al. Real-bogus classification for the zwicky transient facility using deep learning. Monthly Notices of the Royal Astronomical Society, 489(3):3582–3590, 2019.