ag added service API using docker and fastAPI
This commit is contained in:
39
README.md
39
README.md
@@ -8,7 +8,7 @@ Nel file `main.py` si trova il workflow completo nei paragrafi qui sotto vengono
|
||||
Permette di scarcare i dati dal db, sistema i campi nested (tipologia di veicolo utilizzato) e, per ogni colonna, calcola la percentuale di ogni variabile categorica. Visto che alcuni samples sono abbastanza unici, per evitare overfitting e' possibile aggregare questi valori. Di default tutte le classi con meno del 0.5% di rappresentazna vengnono convogliate nella classe other. Restituisce due dataset: uno con le classi accorpate (chiamato small) uno no (per essere sicuri che rappresentino le stesse informazioni)
|
||||
|
||||
## split
|
||||
Questo serve per splittare il dataset (small o normale) in train, validation e test. Specificando `SKI_AREA_TEST` e' possibilile rimuovere completamente una skiarea dal dataset per simulare safe index su una zona nuova. La coppia `SEASON_TEST_SKIAREA`-`SEASON_TEST_YEAR` invece serve per rimuovere dal dataset alcune stagioni da una precisa skiarea: questo serve per simulare come si comporta su nuove stagioni di cui ha gia' visto dati passati. Una volta rimossi i dati relativi al test set, il dataset rimamenente viene separato in train e validation (66%-33%) stratificando su india (in modo da avere piu' o meno il rapporto tra le classi costante). Ci sono due modi per aiutare il modello con il dataset sbilanciato (pochissimi india 3 e 4): il primo e fare oversampling delle classi piccole (a me non piace), alternativamente si pesa in maniera diversa l'errore fatto sulle classi piccole. Ne ho messi due uno utilizza la radice del totale dei casi e divide per gli elementi della classe: un po' meno fiscale del dividere la somma per il numero di elementi. Ritorna due Dataset (uno normale e uno di test), sono delle classi di supporto per andare meglio nella fase di train
|
||||
Questo serve per splittare il dataset (small o normale) in train, validation e test. Specificando `skiarea_test` e' possibilile rimuovere completamente una skiarea dal dataset per simulare safe index su una zona nuova. La coppia `season_test_skiarea`-`season_test_year` invece serve per rimuovere dal dataset alcune stagioni da una precisa skiarea: questo serve per simulare come si comporta su nuove stagioni di cui ha gia' visto dati passati. Una volta rimossi i dati relativi al test set, il dataset rimamenente viene separato in train e validation (66%-33%) stratificando su india (in modo da avere piu' o meno il rapporto tra le classi costante). Ci sono due modi per aiutare il modello con il dataset sbilanciato (pochissimi india 3 e 4): il primo e fare oversampling delle classi piccole (a me non piace), alternativamente si pesa in maniera diversa l'errore fatto sulle classi piccole. Ne ho messi due uno utilizza la radice del totale dei casi e divide per gli elementi della classe: un po' meno fiscale del dividere la somma per il numero di elementi. Ritorna due Dataset (uno normale e uno di test), sono delle classi di supporto per andare meglio nella fase di train
|
||||
## train
|
||||
Questo e' il core del programma: ho messo una griglia di iperparametri con dei range di solito utilizzati. Si allena un xgboost a massimizzare MCC (non accuracy che non e' indicato in caso di classi sbilanciate). Si imposta il numero di trial (suggerisco almeno 1000) e un timeout (in caso di risorse limitate). `num_boost` e' il numero massimo di step, c'e' un sistema di overfitting detection per fermarlo prima.
|
||||
|
||||
@@ -26,9 +26,9 @@ Le features vengono ordinate per importanza utilizzando lo score (`best_model.ge
|
||||
## Modello finale
|
||||
Una volta fatti tutti i test possiamo anche reintrgrare il test set nel train:
|
||||
```
|
||||
SKI_AREA_TEST= None
|
||||
SEASON_TEST_SKIAREA = None
|
||||
SEASON_TEST_YEAR= None
|
||||
skiarea_test: None
|
||||
season_test_skiarea : None
|
||||
season_test_year: None
|
||||
```
|
||||
e anche aumentare la quantita' di punti nel trainin set:
|
||||
```
|
||||
@@ -40,3 +40,34 @@ Ci sono alcuni notebook, TRAIN contiene piu' o meno quello che fa `main.py`, o m
|
||||

|
||||
|
||||

|
||||
## Train del modello
|
||||
|
||||
```
|
||||
cd src
|
||||
python main.py
|
||||
python main.py model.n_trials=200 ## se volgio cambiare dei parametri
|
||||
```
|
||||
Questo automaticamente legge i dati dal file `conf.yaml`. Lo script crea una cartella `model.name` e ci mette tutto quello che il main genera (il conf usato, tutti i log, i modelli salvati e alcune metafeatures da usare poi nel service).
|
||||
|
||||
## Service del modello
|
||||
Nella cartella `service` e' stata implementato un endpoint per interrogare il modello, basato su docker e fastapi.
|
||||
Una volta scelto il modello in base alle performances si devono spostare i files `metadata.pkl` e `model.json` dalla cartella `src/<model.name>` alla cartella `service/app`.
|
||||
|
||||
|
||||
Accedere alla cartella service ed eseguire i seguenti comandi:
|
||||
```
|
||||
docker build -t myimage .
|
||||
docker run -d --name mycontainer -p 80:80 myimage
|
||||
```
|
||||
Questo copia tutto il contenuto di app e builda il container (non ci sono volumi montati).
|
||||
Per testare:
|
||||
```
|
||||
0.0.0.0/predict/'{"dateandtime":1231754520000,"skiarea_id":null,"skiarea_name":"Pampeago","day_of_year":12,"minute_of_day":602,"year":2009,"season":2009,"difficulty":"novice","cause":"fall_alone","town":"SIKLOS","province":"","gender":"F","equipment":"ski","helmet":null,"destination":"hospital_emergency_room","diagnosis":"distortion","india":null,"age":32.0,"country":"Ungheria","injury_side":"L","injury_general_location":"lower_limbs","evacuation_vehicles":["akja"]}'
|
||||
|
||||
0.0.0.0/predict/'{"dateandtime":1512294600000,"skiarea_id":13.0,"skiarea_name":"Kronplatz","day_of_year":337,"minute_of_day":590,"year":2017,"season":2018,"difficulty":"intermediate","cause":"fall_alone","town":"Pieve di Soligo","province":"Treviso","gender":"M","equipment":"ski","helmet":true,"destination":"hospital_emergency_room","diagnosis":"other","india":"i1","age":43.0,"country":"Italia","injury_side":"L","injury_general_location":"lower_limbs","evacuation_vehicles":["akja"]}'
|
||||
|
||||
|
||||
0.0.0.0/predict/'{"dateandtime":1512726900000,"skiarea_id":13.0,"skiarea_name":"Kronplatz","day_of_year":342,"minute_of_day":595,"year":2017,"season":2018,"difficulty":"easy","cause":"fall_alone","town":"Vigo Novo","province":"Venezia","gender":"M","equipment":"ski","helmet":true,"destination":"hospital_emergency_room","diagnosis":"distortion","india":"i2","age":23.0,"country":"Italia","injury_side":"L","injury_general_location":"lower_limbs","evacuation_vehicles":["snowmobile"]}'
|
||||
```
|
||||
|
||||
Non so bene come verranno passati i dati all'end point in produzione, in caso c'e' da modificare la parte del serving che fa il parsing dei dati passati da url. Se ci sono stringhe strane si arrabbia (succede con bolzano che ci sono dei backslash).
|
||||
154
environment.yml
154
environment.yml
@@ -1,154 +0,0 @@
|
||||
name: pid
|
||||
channels:
|
||||
- defaults
|
||||
dependencies:
|
||||
- _libgcc_mutex=0.1=main
|
||||
- _openmp_mutex=5.1=1_gnu
|
||||
- bzip2=1.0.8=h7b6447c_0
|
||||
- ca-certificates=2023.12.12=h06a4308_0
|
||||
- expat=2.5.0=h6a678d5_0
|
||||
- ld_impl_linux-64=2.38=h1181459_1
|
||||
- libffi=3.4.4=h6a678d5_0
|
||||
- libgcc-ng=11.2.0=h1234567_1
|
||||
- libgomp=11.2.0=h1234567_1
|
||||
- libstdcxx-ng=11.2.0=h1234567_1
|
||||
- libuuid=1.41.5=h5eee18b_0
|
||||
- ncurses=6.4=h6a678d5_0
|
||||
- openssl=1.1.1w=h7f8727e_0
|
||||
- pip=23.3.1=py311h06a4308_0
|
||||
- python=3.11.0=h7a1cb2a_3
|
||||
- readline=8.2=h5eee18b_0
|
||||
- setuptools=68.2.2=py311h06a4308_0
|
||||
- sqlite=3.41.2=h5eee18b_0
|
||||
- tk=8.6.12=h1ccaba5_0
|
||||
- wheel=0.41.2=py311h06a4308_0
|
||||
- xz=5.4.5=h5eee18b_0
|
||||
- zlib=1.2.13=h5eee18b_0
|
||||
- pip:
|
||||
- alembic==1.13.1
|
||||
- anyio==4.2.0
|
||||
- argon2-cffi==23.1.0
|
||||
- argon2-cffi-bindings==21.2.0
|
||||
- arrow==1.3.0
|
||||
- asttokens==2.4.1
|
||||
- async-lru==2.0.4
|
||||
- attrs==23.2.0
|
||||
- babel==2.14.0
|
||||
- beautifulsoup4==4.12.3
|
||||
- bleach==6.1.0
|
||||
- catboost==1.2.2
|
||||
- certifi==2024.2.2
|
||||
- cffi==1.16.0
|
||||
- charset-normalizer==3.3.2
|
||||
- cloudpickle==3.0.0
|
||||
- colorlog==6.8.2
|
||||
- comm==0.2.1
|
||||
- contourpy==1.2.0
|
||||
- cycler==0.12.1
|
||||
- debugpy==1.8.0
|
||||
- decorator==5.1.1
|
||||
- defusedxml==0.7.1
|
||||
- executing==2.0.1
|
||||
- fastjsonschema==2.19.1
|
||||
- fonttools==4.47.2
|
||||
- fqdn==1.5.1
|
||||
- greenlet==3.0.3
|
||||
- idna==3.6
|
||||
- imageio==2.33.1
|
||||
- imbalanced-learn==0.12.0
|
||||
- imblearn==0.0
|
||||
- ipykernel==6.29.0
|
||||
- ipython==8.21.0
|
||||
- isoduration==20.11.0
|
||||
- jedi==0.19.1
|
||||
- jinja2==3.1.3
|
||||
- joblib==1.3.2
|
||||
- json5==0.9.14
|
||||
- jsonpointer==2.4
|
||||
- jsonschema==4.21.1
|
||||
- jsonschema-specifications==2023.12.1
|
||||
- jupyter-client==8.6.0
|
||||
- jupyter-core==5.7.1
|
||||
- jupyter-events==0.9.0
|
||||
- jupyter-lsp==2.2.2
|
||||
- jupyter-server==2.12.5
|
||||
- jupyter-server-terminals==0.5.2
|
||||
- jupyterlab==4.0.12
|
||||
- jupyterlab-pygments==0.3.0
|
||||
- jupyterlab-server==2.25.2
|
||||
- kiwisolver==1.4.5
|
||||
- lazy-loader==0.3
|
||||
- lime==0.2.0.1
|
||||
- llvmlite==0.42.0
|
||||
- mako==1.3.2
|
||||
- markupsafe==2.1.4
|
||||
- matplotlib==3.8.2
|
||||
- matplotlib-inline==0.1.6
|
||||
- mistune==3.0.2
|
||||
- nbclient==0.9.0
|
||||
- nbconvert==7.14.2
|
||||
- nbformat==5.9.2
|
||||
- nest-asyncio==1.6.0
|
||||
- networkx==3.2.1
|
||||
- notebook-shim==0.2.3
|
||||
- numba==0.59.0
|
||||
- numpy==1.26.3
|
||||
- optuna==3.5.0
|
||||
- overrides==7.7.0
|
||||
- packaging==23.2
|
||||
- pandas==2.2.0
|
||||
- pandocfilters==1.5.1
|
||||
- parso==0.8.3
|
||||
- pexpect==4.9.0
|
||||
- pillow==10.2.0
|
||||
- platformdirs==4.2.0
|
||||
- plotly==5.18.0
|
||||
- prometheus-client==0.19.0
|
||||
- prompt-toolkit==3.0.43
|
||||
- psutil==5.9.8
|
||||
- psycopg2-binary==2.9.9
|
||||
- ptyprocess==0.7.0
|
||||
- pure-eval==0.2.2
|
||||
- pycparser==2.21
|
||||
- pygments==2.17.2
|
||||
- pyparsing==3.1.1
|
||||
- python-dateutil==2.8.2
|
||||
- python-graphviz==0.20.1
|
||||
- python-json-logger==2.0.7
|
||||
- pytz==2024.1
|
||||
- pyyaml==6.0.1
|
||||
- pyzmq==25.1.2
|
||||
- referencing==0.33.0
|
||||
- requests==2.31.0
|
||||
- rfc3339-validator==0.1.4
|
||||
- rfc3986-validator==0.1.1
|
||||
- rpds-py==0.17.1
|
||||
- scikit-image==0.22.0
|
||||
- scikit-learn==1.4.0
|
||||
- scipy==1.12.0
|
||||
- send2trash==1.8.2
|
||||
- shap==0.44.1
|
||||
- six==1.16.0
|
||||
- slicer==0.0.7
|
||||
- sniffio==1.3.0
|
||||
- soupsieve==2.5
|
||||
- sqlalchemy==2.0.25
|
||||
- stack-data==0.6.3
|
||||
- tenacity==8.2.3
|
||||
- terminado==0.18.0
|
||||
- threadpoolctl==3.2.0
|
||||
- tifffile==2024.1.30
|
||||
- tinycss2==1.2.1
|
||||
- tornado==6.4
|
||||
- tqdm==4.66.1
|
||||
- traitlets==5.14.1
|
||||
- types-python-dateutil==2.8.19.20240106
|
||||
- typing-extensions==4.9.0
|
||||
- tzdata==2023.4
|
||||
- uri-template==1.3.0
|
||||
- urllib3==2.2.0
|
||||
- wcwidth==0.2.13
|
||||
- webcolors==1.13
|
||||
- webencodings==0.5.1
|
||||
- websocket-client==1.7.0
|
||||
- xgboost==2.0.3
|
||||
@@ -2,26 +2,21 @@
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"execution_count": 1,
|
||||
"id": "7c5d059b-ed8a-4e2e-9420-25890f648895",
|
||||
"metadata": {
|
||||
"scrolled": true
|
||||
},
|
||||
"outputs": [
|
||||
{
|
||||
"name": "stderr",
|
||||
"output_type": "stream",
|
||||
"text": [
|
||||
"/tmp/ipykernel_20879/1378035245.py:1: DeprecationWarning: \n",
|
||||
"Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),\n",
|
||||
"(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)\n",
|
||||
"but was not found to be installed on your system.\n",
|
||||
"If this would cause problems for you,\n",
|
||||
"please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466\n",
|
||||
" \n",
|
||||
" import pandas as pd\n",
|
||||
"/home/agobbi/miniconda3/envs/pid/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
||||
" from .autonotebook import tqdm as notebook_tqdm\n"
|
||||
"ename": "ModuleNotFoundError",
|
||||
"evalue": "No module named 'psycopg2'",
|
||||
"output_type": "error",
|
||||
"traceback": [
|
||||
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
||||
"Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpsycopg2\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpg\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n",
|
||||
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psycopg2'"
|
||||
]
|
||||
}
|
||||
],
|
||||
@@ -9894,7 +9889,7 @@
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.11.0"
|
||||
"version": "3.12.2"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
|
||||
2996
notebooks/explore_results.ipynb
Normal file
2996
notebooks/explore_results.ipynb
Normal file
File diff suppressed because one or more lines are too long
134
requirements.txt
Normal file
134
requirements.txt
Normal file
@@ -0,0 +1,134 @@
|
||||
alembic==1.13.1
|
||||
annotated-types==0.6.0
|
||||
antlr4-python3-runtime==4.9.3
|
||||
anyio==4.3.0
|
||||
argon2-cffi==23.1.0
|
||||
argon2-cffi-bindings==21.2.0
|
||||
arrow==1.3.0
|
||||
asttokens==2.4.1
|
||||
async-lru==2.0.4
|
||||
attrs==23.2.0
|
||||
Babel==2.14.0
|
||||
beautifulsoup4==4.12.3
|
||||
bleach==6.1.0
|
||||
certifi==2024.2.2
|
||||
cffi==1.16.0
|
||||
charset-normalizer==3.3.2
|
||||
click==8.1.7
|
||||
colorlog==6.8.2
|
||||
comm==0.2.2
|
||||
contourpy==1.2.0
|
||||
cycler==0.12.1
|
||||
debugpy==1.8.1
|
||||
decorator==5.1.1
|
||||
defusedxml==0.7.1
|
||||
executing==2.0.1
|
||||
fastapi==0.110.0
|
||||
fastjsonschema==2.19.1
|
||||
fonttools==4.50.0
|
||||
fqdn==1.5.1
|
||||
greenlet==3.0.3
|
||||
h11==0.14.0
|
||||
httpcore==1.0.4
|
||||
httptools==0.6.1
|
||||
httpx==0.27.0
|
||||
hydra-core==1.3.2
|
||||
idna==3.6
|
||||
imbalanced-learn==0.12.0
|
||||
imblearn==0.0
|
||||
ipykernel==6.29.3
|
||||
ipython==8.22.2
|
||||
isoduration==20.11.0
|
||||
jedi==0.19.1
|
||||
Jinja2==3.1.3
|
||||
joblib==1.3.2
|
||||
json5==0.9.22
|
||||
jsonpointer==2.4
|
||||
jsonschema==4.21.1
|
||||
jsonschema-specifications==2023.12.1
|
||||
jupyter-events==0.9.1
|
||||
jupyter-lsp==2.2.4
|
||||
jupyter_client==8.6.1
|
||||
jupyter_core==5.7.2
|
||||
jupyter_server==2.13.0
|
||||
jupyter_server_terminals==0.5.3
|
||||
jupyterlab==4.1.4
|
||||
jupyterlab_pygments==0.3.0
|
||||
jupyterlab_server==2.25.4
|
||||
kiwisolver==1.4.5
|
||||
Mako==1.3.2
|
||||
MarkupSafe==2.1.5
|
||||
matplotlib==3.8.3
|
||||
matplotlib-inline==0.1.6
|
||||
mistune==3.0.2
|
||||
nbclient==0.10.0
|
||||
nbconvert==7.16.2
|
||||
nbformat==5.10.2
|
||||
nest-asyncio==1.6.0
|
||||
notebook_shim==0.2.4
|
||||
numpy==1.26.4
|
||||
omegaconf==2.3.0
|
||||
optuna==3.5.0
|
||||
overrides==7.7.0
|
||||
packaging==24.0
|
||||
pandas==2.2.1
|
||||
pandocfilters==1.5.1
|
||||
parso==0.8.3
|
||||
pexpect==4.9.0
|
||||
pillow==10.2.0
|
||||
platformdirs==4.2.0
|
||||
prometheus_client==0.20.0
|
||||
prompt-toolkit==3.0.43
|
||||
psutil==5.9.8
|
||||
psycopg2==2.9.9
|
||||
psycopg2-binary==2.9.9
|
||||
ptyprocess==0.7.0
|
||||
pure-eval==0.2.2
|
||||
pycparser==2.21
|
||||
pydantic==2.6.4
|
||||
pydantic_core==2.16.3
|
||||
Pygments==2.17.2
|
||||
pyparsing==3.1.2
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
python-json-logger==2.0.7
|
||||
pytz==2024.1
|
||||
PyYAML==6.0.1
|
||||
pyzmq==25.1.2
|
||||
referencing==0.33.0
|
||||
requests==2.31.0
|
||||
rfc3339-validator==0.1.4
|
||||
rfc3986-validator==0.1.1
|
||||
rpds-py==0.18.0
|
||||
scikit-learn==1.4.1.post1
|
||||
scipy==1.12.0
|
||||
seaborn==0.13.2
|
||||
Send2Trash==1.8.2
|
||||
setuptools==68.2.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.1
|
||||
soupsieve==2.5
|
||||
SQLAlchemy==2.0.28
|
||||
stack-data==0.6.3
|
||||
starlette==0.36.3
|
||||
terminado==0.18.1
|
||||
threadpoolctl==3.3.0
|
||||
tinycss2==1.2.1
|
||||
tornado==6.4
|
||||
tqdm==4.66.2
|
||||
traitlets==5.14.2
|
||||
types-python-dateutil==2.8.19.20240311
|
||||
typing_extensions==4.10.0
|
||||
tzdata==2024.1
|
||||
uri-template==1.3.0
|
||||
urllib3==2.2.1
|
||||
uvicorn==0.28.0
|
||||
uvloop==0.19.0
|
||||
watchfiles==0.21.0
|
||||
wcwidth==0.2.13
|
||||
webcolors==1.13
|
||||
webencodings==0.5.1
|
||||
websocket-client==1.7.0
|
||||
websockets==12.0
|
||||
wheel==0.41.2
|
||||
xgboost==2.0.3
|
||||
18
service/Dockerfile
Normal file
18
service/Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
#
|
||||
FROM python:3.9
|
||||
|
||||
#
|
||||
WORKDIR /code
|
||||
|
||||
#
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
|
||||
#
|
||||
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
||||
|
||||
#
|
||||
COPY ./app /code/app
|
||||
|
||||
#
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
0
service/app/__init__.py
Normal file
0
service/app/__init__.py
Normal file
27
service/app/main.py
Normal file
27
service/app/main.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Union
|
||||
from .predict import predict
|
||||
from fastapi import FastAPI
|
||||
import numpy as np
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"Hello": "World"}
|
||||
|
||||
|
||||
@app.get("/predict/{data}")
|
||||
def read_item(data: str):
|
||||
final_json = {'probabilities':{}}
|
||||
|
||||
try:
|
||||
res = predict(data)
|
||||
for i in range(5):
|
||||
final_json['probabilities'][f"india_{i}"] = round(float(res[i]),2)
|
||||
final_json['success'] = True
|
||||
except Exception as e:
|
||||
final_json['success'] = False
|
||||
final_json['error'] = repr(e)
|
||||
|
||||
|
||||
return final_json
|
||||
BIN
service/app/metadata.pkl
Normal file
BIN
service/app/metadata.pkl
Normal file
Binary file not shown.
1
service/app/model.json
Normal file
1
service/app/model.json
Normal file
File diff suppressed because one or more lines are too long
60
service/app/predict.py
Normal file
60
service/app/predict.py
Normal file
@@ -0,0 +1,60 @@
|
||||
|
||||
|
||||
import json
|
||||
import xgboost as xgb
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
def predict(json_input:str):
|
||||
|
||||
with open('app/metadata.pkl','rb') as f:
|
||||
to_remove,use_small,evacuations,encoders = pickle.load(f)
|
||||
|
||||
|
||||
data = pd.DataFrame(json.loads(json_input[1:-1]),index=[0])
|
||||
|
||||
data.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True, errors='ignore')
|
||||
|
||||
##evacuation_vehicles must be explicitated and workaround for other!
|
||||
evacuation = data.evacuation_vehicles.values[0]
|
||||
if isinstance(evacuation,str):
|
||||
evacuation = [evacuation]
|
||||
for c in evacuations:
|
||||
data[c] = False
|
||||
for c in evacuation:
|
||||
data[c] = True
|
||||
|
||||
for c in evacuation:
|
||||
if c not in evacuations:
|
||||
data['other'] = True
|
||||
break
|
||||
|
||||
data.drop(columns=['town','province','evacuation_vehicles'],inplace=True, errors='ignore')
|
||||
|
||||
|
||||
data['age'] = data['age'].astype(np.float32)
|
||||
|
||||
|
||||
|
||||
for c in data.columns:
|
||||
if c not in ['india','age','season','skiarea_name','destination']:
|
||||
data[c] = data[c].astype('str')
|
||||
if use_small:
|
||||
for c in to_remove.keys():
|
||||
for k in to_remove[c]:
|
||||
data.loc[data[c]==k,c] = 'other'
|
||||
for c in data.columns:
|
||||
if c not in ['age','season','skiarea_name','india']:
|
||||
data[c] = data[c].fillna('None')
|
||||
if use_small:
|
||||
data[c] = pd.Categorical( encoders['small'][c].transform(data[c]), categories=list(range(len(encoders['small'][c].classes_))), ordered=False)
|
||||
else:
|
||||
data[c] = pd.Categorical( encoders['normal'][c].transform(data[c]), categories=list(range(len(encoders['normal'][c].classes_))), ordered=False)
|
||||
|
||||
bst_FS = xgb.Booster()
|
||||
bst_FS.load_model("app/model.json")
|
||||
|
||||
dtest_FS = xgb.DMatrix(data[bst_FS.feature_names],enable_categorical=True)
|
||||
preds = bst_FS.predict(dtest_FS)
|
||||
|
||||
return preds[0]*100
|
||||
31
service/requirements.txt
Normal file
31
service/requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
annotated-types==0.6.0
|
||||
anyio==4.3.0
|
||||
click==8.1.7
|
||||
fastapi==0.110.0
|
||||
h11==0.14.0
|
||||
httptools==0.6.1
|
||||
idna==3.6
|
||||
joblib==1.3.2
|
||||
numpy==1.26.4
|
||||
pandas==2.2.1
|
||||
pydantic==2.6.4
|
||||
pydantic_core==2.16.3
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
pytz==2024.1
|
||||
PyYAML==6.0.1
|
||||
scikit-learn==1.4.1.post1
|
||||
scipy==1.12.0
|
||||
setuptools==68.2.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.1
|
||||
starlette==0.36.3
|
||||
threadpoolctl==3.3.0
|
||||
typing_extensions==4.10.0
|
||||
tzdata==2024.1
|
||||
uvicorn==0.28.0
|
||||
uvloop==0.19.0
|
||||
watchfiles==0.21.0
|
||||
websockets==12.0
|
||||
wheel==0.41.2
|
||||
xgboost==2.0.3
|
||||
23
src/config.yaml
Normal file
23
src/config.yaml
Normal file
@@ -0,0 +1,23 @@
|
||||
processing:
|
||||
skiarea_test: 'Klausberg' ##you can put it to None
|
||||
season_test_skiarea : 'Kronplatz' ##you can put it to None
|
||||
season_test_year: 2023 ##you can put it to None
|
||||
weight_type: 'sqrt'
|
||||
use_small: True ## condensate underrepresented classes (no destination!)
|
||||
reload_data: False
|
||||
use_smote: False ##I don't like to use it, leave to False
|
||||
undersampling: False ##I don't like to use it, leave to False
|
||||
test_size: 0.33
|
||||
|
||||
model:
|
||||
name: test
|
||||
num_boost_round : 2500
|
||||
retrain: True
|
||||
retrain_last_model: True
|
||||
n_trials: 2000
|
||||
|
||||
|
||||
hydra:
|
||||
output_subdir: null
|
||||
run:
|
||||
dir: .
|
||||
119
src/main.py
119
src/main.py
@@ -4,49 +4,66 @@ from sklearn.metrics import confusion_matrix,matthews_corrcoef,accuracy_score
|
||||
import xgboost as xgb
|
||||
import pandas as pd
|
||||
import pickle
|
||||
import argparse
|
||||
from omegaconf import DictConfig,OmegaConf
|
||||
import hydra
|
||||
import logging
|
||||
|
||||
def main(args):
|
||||
#you can put these parameters in the args but here I keep it simpler
|
||||
num_boost_round = 600
|
||||
SKI_AREA_TEST= 'Klausberg' ##you can put it to None
|
||||
SEASON_TEST_SKIAREA = 'Kronplatz'##you can put it to None
|
||||
SEASON_TEST_YEAR= 2023 ##you can put it to None
|
||||
weight_type = 'sqrt'
|
||||
import os
|
||||
|
||||
|
||||
|
||||
|
||||
@hydra.main(config_name='config.yaml')
|
||||
def main(conf: DictConfig) -> None:
|
||||
|
||||
skiarea_test= conf.processing.skiarea_test
|
||||
season_test_skiarea = conf.processing.season_test_skiarea
|
||||
season_test_year= conf.processing.season_test_year
|
||||
weight_type = conf.processing.weight_type
|
||||
reload_data = conf.processing.reload_data
|
||||
use_smote = conf.processing.use_smote
|
||||
undersampling = conf.processing.undersampling
|
||||
test_size = conf.processing.test_size
|
||||
use_small= conf.processing.use_small
|
||||
num_boost_round = conf.model.num_boost_round
|
||||
retrain_last_model = conf.model.retrain_last_model
|
||||
retrain = conf.model.retrain
|
||||
n_trials = conf.model.n_trials
|
||||
name = conf.model.name
|
||||
os.makedirs(name,exist_ok=True)
|
||||
with open(os.path.join(name,"conf.yaml"),'w') as f:
|
||||
OmegaConf.save(conf, f)
|
||||
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(name,"debug.log")),logging.StreamHandler() ])
|
||||
|
||||
##these are passed
|
||||
reload_data = args.reload_data
|
||||
use_smote = args.use_smote ##I don't like to use it, leave to False
|
||||
undersampling = args.undersampling ##I don't like to use it, leave to False
|
||||
retrain = args.retrain
|
||||
retrain_last_model = args.retrain_last_model
|
||||
test_size = args.test_size
|
||||
|
||||
## get the data
|
||||
labeled,labeled_small,to_remove = retrive_data(reload_data=reload_data,threshold_under_represented=0.5,path='/home/agobbi/Projects/PID/datanalytics/PID/src')
|
||||
with open('to_remove.pkl','wb') as f:
|
||||
pickle.dump(to_remove,f)
|
||||
labeled,labeled_small,to_remove,evacuations,encoders = retrive_data(reload_data=reload_data,threshold_under_represented=0.5,path='/home/agobbi/Projects/PID/datanalytics/PID/src')
|
||||
|
||||
|
||||
with open(os.path.join(name,'metadata.pkl'),'wb') as f:
|
||||
pickle.dump([to_remove,use_small,evacuations,encoders],f)
|
||||
|
||||
#split the data
|
||||
dataset,dataset_test = split(labeled_small if args.use_small else labeled ,
|
||||
SKI_AREA_TEST= SKI_AREA_TEST,
|
||||
SEASON_TEST_SKIAREA = SEASON_TEST_SKIAREA,
|
||||
SEASON_TEST_YEAR= SEASON_TEST_YEAR,
|
||||
|
||||
dataset,dataset_test = split(labeled_small if use_small else labeled ,
|
||||
skiarea_test= skiarea_test,
|
||||
season_test_skiarea = season_test_skiarea,
|
||||
season_test_year= season_test_year,
|
||||
use_smote = use_smote,
|
||||
undersampling = undersampling,
|
||||
test_size = test_size,
|
||||
weight_type = weight_type )
|
||||
#if you changed something you may want to retrain the model and save the best model
|
||||
if retrain:
|
||||
print('OPTUNA hyperparameter tuning, please wait!')
|
||||
best_model,params_final,study = train(dataset,n_trials=args.n_trials,timeout=600,num_boost_round=num_boost_round)
|
||||
logging.info('OPTUNA hyperparameter tuning, please wait!')
|
||||
best_model,params_final,study = train(dataset,n_trials=n_trials,timeout=6000,num_boost_round=num_boost_round)
|
||||
feat_imp = pd.Series(best_model.get_fscore()).sort_values(ascending=False)
|
||||
|
||||
with open('best_params.pkl','wb') as f:
|
||||
with open(os.path.join(name,'best_params.pkl'),'wb') as f:
|
||||
pickle.dump([params_final,feat_imp,best_model,study],f)
|
||||
|
||||
else:
|
||||
with open('best_params.pkl','rb') as f:
|
||||
with open(os.path.join(name,'best_params.pkl'),'rb') as f:
|
||||
params_final,feat_imp,best_model,study = pickle.load(f)
|
||||
|
||||
|
||||
@@ -59,25 +76,33 @@ def main(args):
|
||||
##get the scores
|
||||
preds_class_valid = best_model.predict(tmp_valid)
|
||||
preds_class_train= best_model.predict(tmp_train)
|
||||
print('##################RESULT ON THE TRAIN SET#####################')
|
||||
print(confusion_matrix(dataset.y_train,preds_class_train.argmax(1)))
|
||||
print(f'MCC:{matthews_corrcoef(dataset.y_train,preds_class_train.argmax(1))}')
|
||||
print(f'ACC:{accuracy_score(dataset.y_train,preds_class_train.argmax(1))}')
|
||||
print('##################RESULT ON THE VALIDATION SET#####################')
|
||||
print(confusion_matrix(dataset.y_valid,preds_class_valid.argmax(1)))
|
||||
print(f'MCC:{matthews_corrcoef(dataset.y_valid,preds_class_valid.argmax(1))}')
|
||||
print(f'ACC:{accuracy_score(dataset.y_valid,preds_class_valid.argmax(1))}')
|
||||
logging.info('##################RESULT ON THE TRAIN SET#####################')
|
||||
logging.info(confusion_matrix(dataset.y_train,preds_class_train.argmax(1)))
|
||||
logging.info(f'MCC:{matthews_corrcoef(dataset.y_train,preds_class_train.argmax(1))}')
|
||||
logging.info(f'ACC:{accuracy_score(dataset.y_train,preds_class_train.argmax(1))}')
|
||||
logging.info('##################RESULT ON THE VALIDATION SET#####################')
|
||||
logging.info(confusion_matrix(dataset.y_valid,preds_class_valid.argmax(1)))
|
||||
logging.info(f'MCC:{matthews_corrcoef(dataset.y_valid,preds_class_valid.argmax(1))}')
|
||||
logging.info(f'ACC:{accuracy_score(dataset.y_valid,preds_class_valid.argmax(1))}')
|
||||
|
||||
|
||||
|
||||
#now you can train the final model, for example using gain_accuracy_train for reducing the number of features used
|
||||
if retrain_last_model:
|
||||
tot,bst_FS,FS = gain_accuracy_train(dataset,feat_imp,num_boost_round=num_boost_round,params=params_final)
|
||||
with open('best_params_and_final_model.pkl','wb') as f:
|
||||
with open(os.path.join(name,'best_params_and_final_model.pkl'),'wb') as f:
|
||||
pickle.dump([tot,bst_FS,FS],f)
|
||||
bst_FS.save_model(os.path.join(name,"model.json"))
|
||||
else:
|
||||
with open('best_params_and_final_model.pkl','rb') as f:
|
||||
with open(os.path.join(name,'best_params_and_final_model.pkl'),'rb') as f:
|
||||
tot,bst_FS,FS = pickle.load(f)
|
||||
bst_FS = xgb.Booster()
|
||||
bst_FS.load_model(os.path.join(name,"model.json"))
|
||||
|
||||
|
||||
## save the model in json format, maybe it is better
|
||||
|
||||
|
||||
|
||||
if dataset_test.X_test_area is not None:
|
||||
dtest_FS = xgb.DMatrix(dataset_test.X_test_area[bst_FS.feature_names],dataset_test.y_test_area,enable_categorical=True,)
|
||||
@@ -85,7 +110,7 @@ def main(args):
|
||||
mcc = matthews_corrcoef(dataset_test.y_test_area,preds_class_test.argmax(1))
|
||||
acc = accuracy_score(dataset_test.y_test_area,preds_class_test.argmax(1))
|
||||
cm = confusion_matrix(dataset_test.y_test_area,preds_class_test.argmax(1))
|
||||
print(f'RESULT ON THE TEST SKI AREA {mcc=}, {acc=}, \n{cm=}')
|
||||
logging.info(f'RESULT ON THE TEST SKI AREA {mcc=}, {acc=}, \n{cm=}')
|
||||
|
||||
if dataset_test.X_test_season is not None:
|
||||
dtest_season_FS = xgb.DMatrix(dataset_test.X_test_season[bst_FS.feature_names],dataset_test.y_test_season,enable_categorical=True,)
|
||||
@@ -94,26 +119,10 @@ def main(args):
|
||||
acc = accuracy_score(dataset_test.y_test_season,preds_class_test_season.argmax(1))
|
||||
cm = confusion_matrix(dataset_test.y_test_season,preds_class_test_season.argmax(1))
|
||||
|
||||
print(f'RESULT ON THE TEST SKI SEASON {mcc=}, {acc=}, {cm=}')
|
||||
logging.info(f'RESULT ON THE TEST SKI SEASON {mcc=}, {acc=}, {cm=}')
|
||||
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
|
||||
|
||||
parser = argparse.ArgumentParser(description='Train Optuna XGBOOST model')
|
||||
parser.add_argument('--use_small', action='store_true', help="Aggregate under represented input classes (es: rare country)")
|
||||
parser.add_argument('--use_smote', action='store_true', help='oversampling underrperesented target labels')
|
||||
parser.add_argument('--retrain', action='store_true', help='Retrain the optuna searcher')
|
||||
parser.add_argument('--reload_data', action='store_true', help='Dowload data from db')
|
||||
parser.add_argument('--retrain_last_model', action='store_true', help='retrain the last model')
|
||||
parser.add_argument('--n_trials', type=int,default=1000, help='number of trials per optuna')
|
||||
parser.add_argument('--undersampling', action='store_true', help='Undersample the training dataset')
|
||||
parser.add_argument('--test_size', type=float,default=0.33, help='Percentage of dataset to use as validation')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
|
||||
main(args)
|
||||
|
||||
#python main.py --use_small --retrain --retrain_last_model --n_trials=10 --reload_data
|
||||
16
src/model.py
16
src/model.py
@@ -1,11 +1,11 @@
|
||||
|
||||
import xgboost as xgb
|
||||
import optuna
|
||||
from sklearn.metrics import matthews_corrcoef, accuracy_score
|
||||
import optuna
|
||||
from utils import Dataset
|
||||
import pandas as pd
|
||||
|
||||
|
||||
import logging
|
||||
|
||||
def objective(trial,dataset:Dataset,num_boost_round:int)->float:
|
||||
"""function to maximize during the tuning phase
|
||||
@@ -22,21 +22,21 @@ def objective(trial,dataset:Dataset,num_boost_round:int)->float:
|
||||
params = dict(
|
||||
learning_rate = trial.suggest_float("learning_rate", 0.01, 0.2),
|
||||
max_depth= trial.suggest_int("max_depth",5, 15),
|
||||
min_child_weight = trial.suggest_int("min_child_weight", 1, 8),
|
||||
min_child_weight = trial.suggest_int("min_child_weight", 2, 8),
|
||||
gamma = trial.suggest_float("gamma",0, 10),
|
||||
subsample = trial.suggest_float("subsample", 0.01,1),
|
||||
colsample_bytree = trial.suggest_float("colsample_bytree", 0.01,1),
|
||||
alpha = trial.suggest_float("alpha", 0, 10),
|
||||
alpha = trial.suggest_float("alpha", 1, 10),
|
||||
objective= 'multi:softprob',
|
||||
nthread=4,
|
||||
num_class= 5,
|
||||
seed=27)
|
||||
params['lambda'] = trial.suggest_float("lambda", 0, 10)
|
||||
params['lambda'] = trial.suggest_float("lambda", 1, 10)
|
||||
|
||||
|
||||
dtrain = xgb.DMatrix(dataset.X_train,dataset.y_train,
|
||||
enable_categorical=True,
|
||||
weight=dataset.weight_train)
|
||||
weight=dataset.weight_train)#np.power(dataset.weight_train,trial.suggest_float("power", 0.1, 2)))
|
||||
dvalid = xgb.DMatrix(dataset.X_valid,dataset.y_valid,
|
||||
enable_categorical=True,
|
||||
)
|
||||
@@ -45,7 +45,7 @@ def objective(trial,dataset:Dataset,num_boost_round:int)->float:
|
||||
bst = xgb.train(params, dtrain,verbose_eval=False, num_boost_round=num_boost_round,
|
||||
evals = [(dtrain, "train"), (dvalid, "valid")],
|
||||
early_stopping_rounds=100)
|
||||
|
||||
logging.info(bst.best_iteration)
|
||||
preds = bst.predict(dvalid)
|
||||
##MCC is more solid
|
||||
mcc = matthews_corrcoef(dataset.y_valid,preds.argmax(1))
|
||||
@@ -119,7 +119,7 @@ def gain_accuracy_train(dataset:Dataset,feat_imp:pd.DataFrame,num_boost_round:in
|
||||
|
||||
tot = pd.DataFrame(tot)
|
||||
FS = int(tot.loc[tot.acc.argmax()].FS) ## get best
|
||||
print(f'Best model with {FS} features, retraining....')
|
||||
logging.info(f'Best model with {FS} features, retraining....')
|
||||
|
||||
dtrain_FS = xgb.DMatrix(dataset.X_train[list(feat_imp.head(FS).index)],dataset.y_train, enable_categorical=True, weight=dataset.weight_train)
|
||||
dvalid_FS = xgb.DMatrix(dataset.X_valid[list(feat_imp.head(FS).index)],dataset.y_valid,enable_categorical=True, )
|
||||
|
||||
65
src/utils.py
65
src/utils.py
@@ -7,7 +7,10 @@ import pickle
|
||||
from dataclasses import dataclass
|
||||
from typing import Union
|
||||
import os
|
||||
from imblearn.under_sampling import RandomUnderSampler,RandomOverSampler
|
||||
from imblearn.under_sampling import RandomUnderSampler
|
||||
from imblearn.over_sampling import RandomOverSampler
|
||||
from sklearn.preprocessing import LabelEncoder
|
||||
import logging
|
||||
|
||||
##AUXILIARY CLASSES
|
||||
@dataclass
|
||||
@@ -69,7 +72,8 @@ def prepare_new_data(dataset:pd.DataFrame,to_remove:dict)->(pd.DataFrame,pd.Data
|
||||
|
||||
|
||||
|
||||
def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(pd.DataFrame,pd.DataFrame):
|
||||
|
||||
def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(pd.DataFrame,pd.DataFrame,list):
|
||||
"""Get data
|
||||
|
||||
Args:
|
||||
@@ -79,6 +83,7 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
|
||||
|
||||
Returns:
|
||||
two pandas dataframe, one the original the second with condensed classes and a dictionarly of condesed classes
|
||||
and a list of all evacuations and the encoders for the categorical features
|
||||
"""
|
||||
if reload_data:
|
||||
engine = pg.connect("dbname='safeidx' user='fbk_mpba' host='172.104.247.67' port='5432' password='fbk2024$'")
|
||||
@@ -88,8 +93,11 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
|
||||
else:
|
||||
with open(os.path.join(path,'data.pkl'),'rb') as f:
|
||||
df = pickle.load(f)
|
||||
|
||||
#import pdb
|
||||
#pdb.set_trace()
|
||||
df = df[df.year>2011]
|
||||
## these columns can lead to overfit!
|
||||
|
||||
df.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True)
|
||||
|
||||
##evacuation_vehicles must be explicitated
|
||||
@@ -112,7 +120,7 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
|
||||
## maybe it is possible to obtain a more stable model removing such classes
|
||||
to_remove = {}
|
||||
for c in labeled.columns:
|
||||
if c not in ['india','age','season','skiarea_name']:
|
||||
if c not in ['india','age','season','skiarea_name','destination']:
|
||||
labeled[c] = labeled[c].astype('str')
|
||||
tmp = labeled.groupby(c)[c].count()
|
||||
tmp = 100*tmp/tmp.max()
|
||||
@@ -125,13 +133,20 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
|
||||
|
||||
## keep the datasets
|
||||
labeled_small = labeled.copy()
|
||||
encoders = {'small':{},'normal':{}}
|
||||
for c in to_remove.keys():
|
||||
for k in to_remove[c]:
|
||||
labeled_small.loc[labeled_small[c]==k,c] = 'other'
|
||||
for c in labeled_small.columns:
|
||||
if c not in ['age','season','skiarea_name']:
|
||||
labeled_small[c] = labeled_small[c].fillna('None').astype('category')
|
||||
labeled[c] = labeled[c].fillna('None').astype('category')
|
||||
if c not in ['age','season','skiarea_name','india']:
|
||||
le = LabelEncoder()
|
||||
labeled_small[c] = le.fit_transform(labeled_small[c].fillna('None'))
|
||||
labeled_small[c] = labeled_small[c].astype('category')
|
||||
encoders['small'][c] = le
|
||||
le = LabelEncoder()
|
||||
labeled[c] = le.fit_transform(labeled[c].fillna('None'))
|
||||
labeled[c] = labeled[c].astype('category')
|
||||
encoders['normal'][c] = le
|
||||
labeled.dropna(inplace=True)
|
||||
labeled_small.dropna(inplace=True)
|
||||
|
||||
@@ -139,29 +154,29 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
|
||||
labeled.india = labeled.india.apply(lambda x: x.replace('i','')).astype(int)
|
||||
labeled_small.india = labeled_small.india.apply(lambda x: x.replace('i','')).astype(int)
|
||||
|
||||
return labeled,labeled_small,to_remove
|
||||
return labeled,labeled_small,to_remove,list(ev),encoders
|
||||
|
||||
|
||||
|
||||
def split(labeled:pd.DataFrame,
|
||||
SKI_AREA_TEST: str = 'Klausberg',
|
||||
SEASON_TEST_SKIAREA:str = 'Kronplatz',
|
||||
SEASON_TEST_YEAR:int = 2023,
|
||||
skiarea_test: str = 'Klausberg',
|
||||
season_test_skiarea:str = 'Kronplatz',
|
||||
season_test_year:int = 2023,
|
||||
use_smote:bool = False,
|
||||
undersampling:bool=False,
|
||||
test_size:float=0.33,
|
||||
weight_type:str = 'sqrt' )->(Dataset, Dataset_test):
|
||||
"""Split the dataset into train,validation test. From the initial dataset we remove a single skiarea (SKI_AREA_TEST)
|
||||
generating the first test set. Then we select a skieare and a starting season (SEASON_TEST_SKIAREA,SEASON_TEST_YEAR)
|
||||
"""Split the dataset into train,validation test. From the initial dataset we remove a single skiarea (skiarea_test)
|
||||
generating the first test set. Then we select a skieare and a starting season (season_test_skiarea,season_test_year)
|
||||
and generate the seconda test set. The rest of the data are splitted 66-33 stratified on the target column (india).
|
||||
It is possible to specify the weight of eact sample. There are two strategies implemented: using the sum or the square root
|
||||
of the sum. This is used for mitigating the class umbalance. Another alternative is to use an oversampling procedure (use_smote)
|
||||
|
||||
Args:
|
||||
labeled (pd.DataFrame): dataset
|
||||
SKI_AREA_TEST (str, optional): skiarea to remove from the train and use in test. Defaults to 'Klausberg'.
|
||||
SEASON_TEST_SKIAREA (str, optional): skiarea to remove from the dataset if the season is greater than SEASON_TEST_YEAR. Defaults to 'Kronplatz'.
|
||||
SEASON_TEST_YEAR (int, optional): see SEASON_TEST_SKIAREA . Defaults to 2023.
|
||||
skiarea_test (str, optional): skiarea to remove from the train and use in test. Defaults to 'Klausberg'.
|
||||
season_test_skiarea (str, optional): skiarea to remove from the dataset if the season is greater than season_test_year. Defaults to 'Kronplatz'.
|
||||
season_test_year (int, optional): see season_test_skiarea . Defaults to 2023.
|
||||
use_smote (bool, optional): use oversampling for class umbalance. Defaults to False.
|
||||
undersampling (bool, optional): use undersampling for class umbalance. Defaults to False.
|
||||
test_size (float, optional): percentage of dataset to use as validation. Defaults to 0.33.
|
||||
@@ -174,15 +189,15 @@ def split(labeled:pd.DataFrame,
|
||||
labeled_tmp = labeled.copy()
|
||||
##remove from dataset the corresponding test rows
|
||||
|
||||
if SKI_AREA_TEST is not None:
|
||||
test_area = labeled[labeled.skiarea_name==SKI_AREA_TEST]
|
||||
labeled_tmp = labeled_tmp[labeled_tmp.skiarea_name!=SKI_AREA_TEST]
|
||||
if skiarea_test is not None:
|
||||
test_area = labeled[labeled.skiarea_name==skiarea_test]
|
||||
labeled_tmp = labeled_tmp[labeled_tmp.skiarea_name!=skiarea_test]
|
||||
else:
|
||||
test_area = None
|
||||
|
||||
if SEASON_TEST_SKIAREA is not None and SEASON_TEST_YEAR is not None:
|
||||
test_area_season = labeled[(labeled.skiarea_name==SEASON_TEST_SKIAREA)&(labeled.season>=SEASON_TEST_YEAR)]
|
||||
labeled_tmp = labeled_tmp[(labeled_tmp.skiarea_name!=SEASON_TEST_SKIAREA)|(labeled_tmp.season<SEASON_TEST_YEAR) ]
|
||||
if season_test_skiarea is not None and season_test_year is not None:
|
||||
test_area_season = labeled[(labeled.skiarea_name==season_test_skiarea)&(labeled.season>=season_test_year)]
|
||||
labeled_tmp = labeled_tmp[(labeled_tmp.skiarea_name!=season_test_skiarea)|(labeled_tmp.season<season_test_year) ]
|
||||
else:
|
||||
test_area_season = None
|
||||
|
||||
@@ -210,12 +225,12 @@ def split(labeled:pd.DataFrame,
|
||||
##when computing the error, these are the weights used for each class: you can punish more errpr on most severe clases
|
||||
if weight_type == 'sqrt':
|
||||
w.p = np.sqrt(w.p.sum())/w.p
|
||||
print(w)
|
||||
logging.info(w)
|
||||
elif weight_type == 'sum':
|
||||
w.p = w.p.sum()/w.p/w.shape[0]
|
||||
print(w)
|
||||
logging.info(w)
|
||||
else:
|
||||
print(f'{weight_type=} not implemented please use a valid one: sqrt or sum, I will set all the weights to 0')
|
||||
logging.info(f'{weight_type=} not implemented please use a valid one: sqrt or sum, I will set all the weights to 0')
|
||||
w.p = 1
|
||||
|
||||
if use_smote is False and undersampling is False:
|
||||
|
||||
Reference in New Issue
Block a user