Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
MeiruiJiang committed Dec 10, 2021
1 parent 0ef5ec3 commit 1caf83c
Show file tree
Hide file tree
Showing 2 changed files with 47 additions and 2 deletions.
7 changes: 6 additions & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,12 @@ conda activate fedbn
cd ./snapshots
unzip digit_model.zip
```
For the original data, please download [here](https://drive.google.com/file/d/1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I/view?usp=sharing), and find the processing steps in `utils/data_preprocess.py`.
For the original data, please download [here](https://drive.google.com/file/d/1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I/view?usp=sharing)
or you can directly run the following to download and process data.
```bash
cd ./utils
python data_preprocess.py
```

**office-caltech10**
- Please download our pre-processed datasets [here](https://drive.google.com/file/d/1gxhV5xRXQgC9AL4XexduH7hdxDng7bJ3/view?usp=sharing), put under `data/` directory and perform following commands:
Expand Down
42 changes: 41 additions & 1 deletion utils/data_preprocess.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,42 @@
import numpy as np
from sklearn.model_selection import StratifiedShuffleSplit, StratifiedKFold
from collections import Counter
from matplotlib import pylab as plt
import requests
from tqdm import tqdm
import zipfile

def download_file_from_google_drive(id, destination):
URL = "https://docs.google.com/uc?export=download"

session = requests.Session()

response = session.get(URL, params = { 'id' : id }, stream = True)
token = get_confirm_token(response)

if token:
params = { 'id' : id, 'confirm' : token }
response = session.get(URL, params = params, stream = True)
total_length = response.headers.get('content-length')
print('Downloading...')
save_response_content(response, destination, total_length)

def get_confirm_token(response):
for key, value in response.cookies.items():
if key.startswith('download_warning'):
return value

return None

def save_response_content(response, destination, total_length):
CHUNK_SIZE = 32768

with open(destination, "wb") as f:
total_length = int(total_length)
for chunk in tqdm(response.iter_content(CHUNK_SIZE),total=int(total_length/CHUNK_SIZE)):
if chunk: # filter out keep-alive new chunks
f.write(chunk)



def stratified_split(X,y):
sss = StratifiedShuffleSplit(n_splits=1, test_size=0.2, random_state=0)
Expand Down Expand Up @@ -249,6 +281,14 @@ def split(data_path, percentage=0.1):


if __name__ == '__main__':
file_id = '1P8g7uHyVxQJPcBKE8TAzfdKbimpRbj0I'
destination = '../data/data.zip'
download_file_from_google_drive(file_id, destination)
print('Extracting...')
with zipfile.ZipFile(destination, 'r') as zip_ref:
for file in tqdm(iterable=zip_ref.namelist(), total=len(zip_ref.namelist())):
zip_ref.extract(member=file, path=os.path.dirname(destination))
print('Processing...')
print('--------MNIST---------')
process_mnist()
print('--------SVHN---------')
Expand Down

0 comments on commit 1caf83c

Please sign in to comment.