-
Notifications
You must be signed in to change notification settings - Fork 16
/
app.py
100 lines (73 loc) · 5.72 KB
/
app.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import pandas as pd
from flask import Flask, request, render_template, jsonify
from sklearn.preprocessing import LabelEncoder
from download_model import load_model
BUCKET_NAME = 'ml-api-covid-model'
MODEL_FILE_NAME = 'rf_model.joblib'
MODEL_LOCAL_PATH = f'downloaded_{MODEL_FILE_NAME}'
available_countries = [
'Afghanistan', 'Albania', 'Algeria', 'Andorra', 'Angola', 'Anguilla', 'Antigua and Barbuda', 'Argentina', 'Armenia', 'Aruba', 'Australia', 'Austria', 'Azerbaijan', 'Bahamas', 'Bahrain', 'Bangladesh', 'Barbados', 'Belarus', 'Belgium', 'Belize', 'Benin', 'Bermuda', 'Bhutan', 'Bolivia', 'Bonaire Sint Eustatius and Saba', 'Bosnia and Herzegovina', 'Botswana', 'Brazil', 'British Virgin Islands', 'Brunei', 'Bulgaria', 'Burkina Faso', 'Burundi', 'Cambodia', 'Cameroon', 'Canada', 'Cape Verde', 'Cayman Islands', 'Central African Republic', 'Chad', 'Chile', 'China', 'Colombia', 'Comoros', 'Congo', 'Costa Rica', "Cote d'Ivoire", 'Croatia', 'Cuba', 'Curacao', 'Cyprus', 'Czech Republic', 'Democratic Republic of Congo', 'Denmark', 'Djibouti', 'Dominica', 'Dominican Republic', 'Ecuador', 'Egypt', 'El Salvador', 'Equatorial Guinea', 'Eritrea', 'Estonia', 'Ethiopia', 'Faeroe Islands', 'Falkland Islands', 'Fiji', 'Finland', 'France', 'French Polynesia', 'Gabon', 'Gambia', 'Georgia', 'Germany', 'Ghana', 'Gibraltar', 'Greece', 'Greenland', 'Grenada', 'Guam', 'Guatemala', 'Guernsey', 'Guinea', 'Guinea-Bissau', 'Guyana', 'Haiti', 'Honduras', 'Hong Kong', 'Hungary', 'Iceland', 'India', 'Indonesia', 'International', 'Iran', 'Iraq', 'Ireland', 'Isle of Man', 'Israel', 'Italy', 'Jamaica', 'Japan', 'Jersey', 'Jordan', 'Kazakhstan', 'Kenya', 'Kosovo', 'Kuwait', 'Kyrgyzstan', 'Laos', 'Latvia', 'Lebanon', 'Lesotho', 'Liberia', 'Libya', 'Liechtenstein', 'Lithuania', 'Luxembourg', 'Macedonia', 'Madagascar', 'Malawi', 'Malaysia', 'Maldives', 'Mali', 'Malta', 'Mauritania', 'Mauritius', 'Mexico', 'Moldova', 'Monaco', 'Mongolia', 'Montenegro', 'Montserrat', 'Morocco', 'Mozambique', 'Myanmar', 'Namibia', 'Nepal', 'Netherlands', 'New Caledonia', 'New Zealand', 'Nicaragua', 'Niger', 'Nigeria', 'Northern Mariana Islands', 'Norway', 'Oman', 'Pakistan', 'Palestine', 'Panama', 'Papua New Guinea', 'Paraguay', 'Peru', 'Philippines', 'Poland', 'Portugal', 'Puerto Rico', 'Qatar', 'Romania', 'Russia', 'Rwanda', 'Saint Kitts and Nevis', 'Saint Lucia', 'Saint Vincent and the Grenadines', 'San Marino', 'Sao Tome and Principe', 'Saudi Arabia', 'Senegal', 'Serbia', 'Seychelles', 'Sierra Leone', 'Singapore', 'Sint Maarten (Dutch part)', 'Slovakia', 'Slovenia', 'Somalia', 'South Africa', 'South Korea', 'South Sudan', 'Spain', 'Sri Lanka', 'Sudan', 'Suriname', 'Swaziland', 'Sweden', 'Switzerland', 'Syria', 'Taiwan', 'Tajikistan', 'Tanzania', 'Thailand', 'Timor', 'Togo', 'Trinidad and Tobago', 'Tunisia', 'Turkey', 'Turks and Caicos Islands', 'Uganda', 'Ukraine', 'United Arab Emirates', 'United Kingdom', 'United States', 'United States Virgin Islands', 'Uruguay', 'Uzbekistan', 'Vatican', 'Venezuela', 'Vietnam', 'Western Sahara', 'World', 'Yemen', 'Zambia', 'Zimbabwe'
]
url_to_covid = 'https://covid.ourworldindata.org/data/owid-covid-data.csv'
encoder = LabelEncoder()
app = Flask(__name__)
def pre_process(df):
cols_too_many_missing = ['new_tests',
'new_tests_per_thousand',
'total_tests_per_thousand',
'total_tests',
'tests_per_case',
'positive_rate',
'new_tests_smoothed',
'new_tests_smoothed_per_thousand',
'tests_units',
'handwashing_facilities']
df = df.drop(columns=cols_too_many_missing)
nominal = df.select_dtypes(include=['object']).copy()
nominal_cols = nominal.columns.tolist()
for col in nominal_cols:
col
if df[col].isna().sum() > 0:
df[col].fillna('MISSING', inplace=True)
df[col] = encoder.fit_transform(df[col])
numerical = df.select_dtypes(include=['float64']).copy()
for col in numerical:
df[col].fillna((df[col].mean()), inplace=True)
X = df.drop(columns=['new_cases'])
y = df.new_cases
return X, y
def get_prediction_params(input_val, url_to_covid):
df_orig = pd.read_csv(url_to_covid)
_ = encoder.fit_transform(df_orig['location'])
encode_ind = (encoder.classes_).tolist().index(input_val)
df_orig[df_orig.location == input_val]
X, _ = pre_process(df_orig)
to_pred = X[X.location == encode_ind].iloc[-1].values.reshape(1,-1)
return to_pred
@app.route('/')
def home():
return render_template("home.html")
@app.route('/predict',methods=['POST'])
def predict():
input_val = [x for x in request.form.values()][0]
rf = load_model(BUCKET_NAME, MODEL_FILE_NAME, MODEL_LOCAL_PATH)
if input_val not in available_countries:
return f'Country {input_val} is not in available list. Try one from the list! Go back in your browser', 400
to_pred = get_prediction_params(input_val, url_to_covid)
prediction = rf.predict(to_pred)[0]
return render_template('home.html',pred=f'New cases will be {prediction}')
@app.route('/predict_api',methods=['POST'])
def predict_api():
input_val = request.form.get('location')
rf = load_model(BUCKET_NAME, MODEL_FILE_NAME, MODEL_LOCAL_PATH)
if input_val not in available_countries:
return f'Country {input_val} is not in available list. Try one from the list! Go back in your browser', 400
to_pred = get_prediction_params(input_val, url_to_covid)
prediction = rf.predict(to_pred)[0]
result = {
"result": prediction
}
return jsonify(result)
if __name__ == '__main__':
app.run(debug=False)
# app.run(host='0.0.0.0', port=5000)