99 lines
4.6 KiB
Python
99 lines
4.6 KiB
Python
from utils import retrive_data, split
|
|
from model import train, gain_accuracy_train
|
|
from sklearn.metrics import confusion_matrix,matthews_corrcoef,accuracy_score
|
|
import xgboost as xgb
|
|
import pandas as pd
|
|
import pickle
|
|
import argparse
|
|
|
|
def main(args):
|
|
|
|
|
|
|
|
labeled,labeled_small,to_remove = retrive_data(reload_data=args.reload_data,threshold_under_represented=0.5,path='/home/agobbi/Projects/PID/datanalytics/PID/src')
|
|
with open('to_remove.pkl','wb') as f:
|
|
pickle.dump(to_remove,f)
|
|
|
|
dataset,dataset_test = split(labeled_small if args.use_small else labeled ,
|
|
SKI_AREA_TEST= 'Klausberg',
|
|
SEASON_TEST_SKIAREA = 'Kronplatz',
|
|
SEASON_TEST_YEAR= 2023,
|
|
use_smote = args.use_smote,
|
|
weight_type = 'sqrt' )
|
|
if args.retrain:
|
|
|
|
|
|
print('OPTUNA hyperparameter tuning, please wait!')
|
|
best_model,params_final = train(dataset,n_trials=args.n_trials,timeout=600,num_boost_round=600)
|
|
feat_imp = pd.Series(best_model.get_fscore()).sort_values(ascending=False)
|
|
|
|
with open('best_params.pkl','wb') as f:
|
|
pickle.dump([params_final,feat_imp,best_model],f)
|
|
|
|
else:
|
|
with open('best_params.pkl','rb') as f:
|
|
params_final,feat_imp,best_model = pickle.load(f)
|
|
|
|
|
|
|
|
|
|
#for retriving prediction must convert to DMatrix type
|
|
tmp_train = xgb.DMatrix(dataset.X_train[best_model.feature_names],dataset.y_train,enable_categorical=True)
|
|
tmp_valid = xgb.DMatrix(dataset.X_valid[best_model.feature_names],dataset.y_valid,enable_categorical=True)
|
|
|
|
|
|
preds_class_valid = best_model.predict(tmp_valid)
|
|
preds_class_train= best_model.predict(tmp_train)
|
|
print('##################RESULT ON THE TRAIN SET#####################')
|
|
print(confusion_matrix(dataset.y_train,preds_class_train.argmax(1)))
|
|
print(f'MCC:{matthews_corrcoef(dataset.y_train,preds_class_train.argmax(1))}')
|
|
print(f'ACC:{accuracy_score(dataset.y_train,preds_class_train.argmax(1))}')
|
|
print('##################RESULT ON THE VALIDATION SET#####################')
|
|
print(confusion_matrix(dataset.y_valid,preds_class_valid.argmax(1)))
|
|
print(f'MCC:{matthews_corrcoef(dataset.y_valid,preds_class_valid.argmax(1))}')
|
|
print(f'ACC:{accuracy_score(dataset.y_valid,preds_class_valid.argmax(1))}')
|
|
|
|
|
|
|
|
|
|
if args.retrain_last_model:
|
|
tot,bst_FS,FS = gain_accuracy_train(dataset,feat_imp,num_boost_round=600,params=params_final)
|
|
with open('best_params_and_final_model.pkl','wb') as f:
|
|
pickle.dump([tot,bst_FS,FS],f)
|
|
else:
|
|
with open('best_params_and_final_model.pkl','rb') as f:
|
|
tot,bst_FS,FS = pickle.load(f)
|
|
|
|
dtest_FS = xgb.DMatrix(dataset_test.X_test_area[bst_FS.feature_names],dataset_test.y_test_area,enable_categorical=True,)
|
|
dtest_season_FS = xgb.DMatrix(dataset_test.X_test_season[bst_FS.feature_names],dataset_test.y_test_season,enable_categorical=True,)
|
|
preds_class_test = bst_FS.predict(dtest_FS)
|
|
preds_class_test_season = bst_FS.predict(dtest_season_FS)
|
|
|
|
mcc = matthews_corrcoef(dataset_test.y_test_area,preds_class_test.argmax(1))
|
|
acc = accuracy_score(dataset_test.y_test_area,preds_class_test.argmax(1))
|
|
cm = confusion_matrix(dataset_test.y_test_area,preds_class_test.argmax(1))
|
|
|
|
print(f'RESULT ON THE TEST SKI AREA {mcc=}, {acc=}, \n{cm=}')
|
|
mcc = matthews_corrcoef(dataset_test.y_test_season,preds_class_test_season.argmax(1))
|
|
acc = accuracy_score(dataset_test.y_test_season,preds_class_test_season.argmax(1))
|
|
cm = confusion_matrix(dataset_test.y_test_season,preds_class_test_season.argmax(1))
|
|
|
|
print(f'RESULT ON THE TEST SKI SEASON {mcc=}, {acc=}, {cm=}')
|
|
|
|
if __name__ == "__main__":
|
|
|
|
|
|
parser = argparse.ArgumentParser(description='Train Optuna XGBOOST model')
|
|
parser.add_argument('--use_small', action='store_true', help="Aggregate under represented input classes (es: rare country)")
|
|
parser.add_argument('--use_smote', action='store_true', help='oversampling underrperesented target labels')
|
|
parser.add_argument('--retrain', action='store_true', help='Retrain the optuna searcher')
|
|
parser.add_argument('--reload_data', action='store_true', help='Dowload data from db')
|
|
parser.add_argument('--retrain_last_model', action='store_true', help='retrain the last model')
|
|
parser.add_argument('--n_trials', type=int,default=1000, help='number of trials per optuna')
|
|
|
|
args = parser.parse_args()
|
|
|
|
|
|
main(args)
|
|
|
|
#python main.py --use_small --retrain --retrain_last_model --n_trials=10 --reload_data |