diff --git a/README.md b/README.md index 46d210b..417cc19 100644 --- a/README.md +++ b/README.md @@ -32,7 +32,12 @@ conda activate fedbn cd ./snapshots unzip digit_model.zip ``` -For the original data, please download [here](https://drive.google.com/file/d/1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I/view?usp=sharing), and find the processing steps in `utils/data_preprocess.py`. +For the original data, please download [here](https://drive.google.com/file/d/1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I/view?usp=sharing) +or you can directly run the following to download and process data. +```bash +cd ./utils +python data_preprocess.py +``` **office-caltech10** - Please download our pre-processed datasets [here](https://drive.google.com/file/d/1gxhV5xRXQgC9AL4XexduH7hdxDng7bJ3/view?usp=sharing), put under `data/` directory and perform following commands: diff --git a/utils/data_preprocess.py b/utils/data_preprocess.py index 83d9555..61e89a5 100755 --- a/utils/data_preprocess.py +++ b/utils/data_preprocess.py @@ -12,10 +12,42 @@ import numpy as np from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold from collections import Counter -from matplotlib import pylab as plt +import requests +from tqdm import tqdm +import zipfile +def download_file_from_google_drive(id, destination): + URL = "https://docs.google.com/uc?export=download" + session = requests.Session() + response = session.get(URL, params = { 'id' : id }, stream = True) + token = get_confirm_token(response) + + if token: + params = { 'id' : id, 'confirm' : token } + response = session.get(URL, params = params, stream = True) + total_length = response.headers.get('content-length') + print('Downloading...') + save_response_content(response, destination, total_length) + +def get_confirm_token(response): + for key, value in response.cookies.items(): + if key.startswith('download_warning'): + return value + + return None + +def save_response_content(response, destination, total_length): + CHUNK_SIZE = 32768 + + with open(destination, "wb") as f: + total_length = int(total_length) + for chunk in tqdm(response.iter_content(CHUNK_SIZE),total=int(total_length/CHUNK_SIZE)): + if chunk: # filter out keep-alive new chunks + f.write(chunk) + + def stratified_split(X,y): sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0) @@ -249,6 +281,14 @@ def split(data_path, percentage=0.1): if __name__ == '__main__': + file_id = '1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I' + destination = '../data/data.zip' + download_file_from_google_drive(file_id, destination) + print('Extracting...') + with zipfile.ZipFile(destination, 'r') as zip_ref: + for file in tqdm(iterable=zip_ref.namelist(), total=len(zip_ref.namelist())): + zip_ref.extract(member=file, path=os.path.dirname(destination)) + print('Processing...') print('--------MNIST---------') process_mnist() print('--------SVHN---------')