60 lines
2.0 KiB
Python
60 lines
2.0 KiB
Python
|
|
|
|
import json
|
|
import xgboost as xgb
|
|
import pickle
|
|
import pandas as pd
|
|
import numpy as np
|
|
def predict(json_input:str):
|
|
|
|
with open('app/metadata.pkl','rb') as f:
|
|
to_remove,use_small,evacuations,encoders = pickle.load(f)
|
|
|
|
|
|
data = pd.DataFrame(json.loads(json_input[1:-1]),index=[0])
|
|
|
|
data.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True, errors='ignore')
|
|
|
|
##evacuation_vehicles must be explicitated and workaround for other!
|
|
evacuation = data.evacuation_vehicles.values[0]
|
|
if isinstance(evacuation,str):
|
|
evacuation = [evacuation]
|
|
for c in evacuations:
|
|
data[c] = False
|
|
for c in evacuation:
|
|
data[c] = True
|
|
|
|
for c in evacuation:
|
|
if c not in evacuations:
|
|
data['other'] = True
|
|
break
|
|
|
|
data.drop(columns=['town','province','evacuation_vehicles'],inplace=True, errors='ignore')
|
|
|
|
|
|
data['age'] = data['age'].astype(np.float32)
|
|
|
|
|
|
|
|
for c in data.columns:
|
|
if c not in ['india','age','season','skiarea_name','destination']:
|
|
data[c] = data[c].astype('str')
|
|
if use_small:
|
|
for c in to_remove.keys():
|
|
for k in to_remove[c]:
|
|
data.loc[data[c]==k,c] = 'other'
|
|
for c in data.columns:
|
|
if c not in ['age','season','skiarea_name','india']:
|
|
data[c] = data[c].fillna('None')
|
|
if use_small:
|
|
data[c] = pd.Categorical( encoders['small'][c].transform(data[c]), categories=list(range(len(encoders['small'][c].classes_))), ordered=False)
|
|
else:
|
|
data[c] = pd.Categorical( encoders['normal'][c].transform(data[c]), categories=list(range(len(encoders['normal'][c].classes_))), ordered=False)
|
|
|
|
bst_FS = xgb.Booster()
|
|
bst_FS.load_model("app/model.json")
|
|
|
|
dtest_FS = xgb.DMatrix(data[bst_FS.feature_names],enable_categorical=True)
|
|
preds = bst_FS.predict(dtest_FS)
|
|
|
|
return preds[0]*100 |