ag added service API using docker and fastAPI

This commit is contained in:
2024-03-20 16:30:30 +01:00
parent 759f4c50d0
commit 08c5f672be
16 changed files with 3450 additions and 264 deletions

View File

@@ -8,7 +8,7 @@ Nel file `main.py` si trova il workflow completo nei paragrafi qui sotto vengono
Permette di scarcare i dati dal db, sistema i campi nested (tipologia di veicolo utilizzato) e, per ogni colonna, calcola la percentuale di ogni variabile categorica. Visto che alcuni samples sono abbastanza unici, per evitare overfitting e' possibile aggregare questi valori. Di default tutte le classi con meno del 0.5% di rappresentazna vengnono convogliate nella classe other. Restituisce due dataset: uno con le classi accorpate (chiamato small) uno no (per essere sicuri che rappresentino le stesse informazioni) Permette di scarcare i dati dal db, sistema i campi nested (tipologia di veicolo utilizzato) e, per ogni colonna, calcola la percentuale di ogni variabile categorica. Visto che alcuni samples sono abbastanza unici, per evitare overfitting e' possibile aggregare questi valori. Di default tutte le classi con meno del 0.5% di rappresentazna vengnono convogliate nella classe other. Restituisce due dataset: uno con le classi accorpate (chiamato small) uno no (per essere sicuri che rappresentino le stesse informazioni)
## split ## split
Questo serve per splittare il dataset (small o normale) in train, validation e test. Specificando `SKI_AREA_TEST` e' possibilile rimuovere completamente una skiarea dal dataset per simulare safe index su una zona nuova. La coppia `SEASON_TEST_SKIAREA`-`SEASON_TEST_YEAR` invece serve per rimuovere dal dataset alcune stagioni da una precisa skiarea: questo serve per simulare come si comporta su nuove stagioni di cui ha gia' visto dati passati. Una volta rimossi i dati relativi al test set, il dataset rimamenente viene separato in train e validation (66%-33%) stratificando su india (in modo da avere piu' o meno il rapporto tra le classi costante). Ci sono due modi per aiutare il modello con il dataset sbilanciato (pochissimi india 3 e 4): il primo e fare oversampling delle classi piccole (a me non piace), alternativamente si pesa in maniera diversa l'errore fatto sulle classi piccole. Ne ho messi due uno utilizza la radice del totale dei casi e divide per gli elementi della classe: un po' meno fiscale del dividere la somma per il numero di elementi. Ritorna due Dataset (uno normale e uno di test), sono delle classi di supporto per andare meglio nella fase di train Questo serve per splittare il dataset (small o normale) in train, validation e test. Specificando `skiarea_test` e' possibilile rimuovere completamente una skiarea dal dataset per simulare safe index su una zona nuova. La coppia `season_test_skiarea`-`season_test_year` invece serve per rimuovere dal dataset alcune stagioni da una precisa skiarea: questo serve per simulare come si comporta su nuove stagioni di cui ha gia' visto dati passati. Una volta rimossi i dati relativi al test set, il dataset rimamenente viene separato in train e validation (66%-33%) stratificando su india (in modo da avere piu' o meno il rapporto tra le classi costante). Ci sono due modi per aiutare il modello con il dataset sbilanciato (pochissimi india 3 e 4): il primo e fare oversampling delle classi piccole (a me non piace), alternativamente si pesa in maniera diversa l'errore fatto sulle classi piccole. Ne ho messi due uno utilizza la radice del totale dei casi e divide per gli elementi della classe: un po' meno fiscale del dividere la somma per il numero di elementi. Ritorna due Dataset (uno normale e uno di test), sono delle classi di supporto per andare meglio nella fase di train
## train ## train
Questo e' il core del programma: ho messo una griglia di iperparametri con dei range di solito utilizzati. Si allena un xgboost a massimizzare MCC (non accuracy che non e' indicato in caso di classi sbilanciate). Si imposta il numero di trial (suggerisco almeno 1000) e un timeout (in caso di risorse limitate). `num_boost` e' il numero massimo di step, c'e' un sistema di overfitting detection per fermarlo prima. Questo e' il core del programma: ho messo una griglia di iperparametri con dei range di solito utilizzati. Si allena un xgboost a massimizzare MCC (non accuracy che non e' indicato in caso di classi sbilanciate). Si imposta il numero di trial (suggerisco almeno 1000) e un timeout (in caso di risorse limitate). `num_boost` e' il numero massimo di step, c'e' un sistema di overfitting detection per fermarlo prima.
@@ -26,9 +26,9 @@ Le features vengono ordinate per importanza utilizzando lo score (`best_model.ge
## Modello finale ## Modello finale
Una volta fatti tutti i test possiamo anche reintrgrare il test set nel train: Una volta fatti tutti i test possiamo anche reintrgrare il test set nel train:
``` ```
SKI_AREA_TEST= None skiarea_test: None
SEASON_TEST_SKIAREA = None season_test_skiarea : None
SEASON_TEST_YEAR= None season_test_year: None
``` ```
e anche aumentare la quantita' di punti nel trainin set: e anche aumentare la quantita' di punti nel trainin set:
``` ```
@@ -40,3 +40,34 @@ Ci sono alcuni notebook, TRAIN contiene piu' o meno quello che fa `main.py`, o m
![Sample](img/sample.png) ![Sample](img/sample.png)
![Interpretability](img/Interpretability.png) ![Interpretability](img/Interpretability.png)
## Train del modello
```
cd src
python main.py
python main.py model.n_trials=200 ## se volgio cambiare dei parametri
```
Questo automaticamente legge i dati dal file `conf.yaml`. Lo script crea una cartella `model.name` e ci mette tutto quello che il main genera (il conf usato, tutti i log, i modelli salvati e alcune metafeatures da usare poi nel service).
## Service del modello
Nella cartella `service` e' stata implementato un endpoint per interrogare il modello, basato su docker e fastapi.
Una volta scelto il modello in base alle performances si devono spostare i files `metadata.pkl` e `model.json` dalla cartella `src/<model.name>` alla cartella `service/app`.
Accedere alla cartella service ed eseguire i seguenti comandi:
```
docker build -t myimage .
docker run -d --name mycontainer -p 80:80 myimage
```
Questo copia tutto il contenuto di app e builda il container (non ci sono volumi montati).
Per testare:
```
0.0.0.0/predict/'{"dateandtime":1231754520000,"skiarea_id":null,"skiarea_name":"Pampeago","day_of_year":12,"minute_of_day":602,"year":2009,"season":2009,"difficulty":"novice","cause":"fall_alone","town":"SIKLOS","province":"","gender":"F","equipment":"ski","helmet":null,"destination":"hospital_emergency_room","diagnosis":"distortion","india":null,"age":32.0,"country":"Ungheria","injury_side":"L","injury_general_location":"lower_limbs","evacuation_vehicles":["akja"]}'
0.0.0.0/predict/'{"dateandtime":1512294600000,"skiarea_id":13.0,"skiarea_name":"Kronplatz","day_of_year":337,"minute_of_day":590,"year":2017,"season":2018,"difficulty":"intermediate","cause":"fall_alone","town":"Pieve di Soligo","province":"Treviso","gender":"M","equipment":"ski","helmet":true,"destination":"hospital_emergency_room","diagnosis":"other","india":"i1","age":43.0,"country":"Italia","injury_side":"L","injury_general_location":"lower_limbs","evacuation_vehicles":["akja"]}'
0.0.0.0/predict/'{"dateandtime":1512726900000,"skiarea_id":13.0,"skiarea_name":"Kronplatz","day_of_year":342,"minute_of_day":595,"year":2017,"season":2018,"difficulty":"easy","cause":"fall_alone","town":"Vigo Novo","province":"Venezia","gender":"M","equipment":"ski","helmet":true,"destination":"hospital_emergency_room","diagnosis":"distortion","india":"i2","age":23.0,"country":"Italia","injury_side":"L","injury_general_location":"lower_limbs","evacuation_vehicles":["snowmobile"]}'
```
Non so bene come verranno passati i dati all'end point in produzione, in caso c'e' da modificare la parte del serving che fa il parsing dei dati passati da url. Se ci sono stringhe strane si arrabbia (succede con bolzano che ci sono dei backslash).

View File

@@ -1,154 +0,0 @@
name: pid
channels:
- defaults
dependencies:
- _libgcc_mutex=0.1=main
- _openmp_mutex=5.1=1_gnu
- bzip2=1.0.8=h7b6447c_0
- ca-certificates=2023.12.12=h06a4308_0
- expat=2.5.0=h6a678d5_0
- ld_impl_linux-64=2.38=h1181459_1
- libffi=3.4.4=h6a678d5_0
- libgcc-ng=11.2.0=h1234567_1
- libgomp=11.2.0=h1234567_1
- libstdcxx-ng=11.2.0=h1234567_1
- libuuid=1.41.5=h5eee18b_0
- ncurses=6.4=h6a678d5_0
- openssl=1.1.1w=h7f8727e_0
- pip=23.3.1=py311h06a4308_0
- python=3.11.0=h7a1cb2a_3
- readline=8.2=h5eee18b_0
- setuptools=68.2.2=py311h06a4308_0
- sqlite=3.41.2=h5eee18b_0
- tk=8.6.12=h1ccaba5_0
- wheel=0.41.2=py311h06a4308_0
- xz=5.4.5=h5eee18b_0
- zlib=1.2.13=h5eee18b_0
- pip:
- alembic==1.13.1
- anyio==4.2.0
- argon2-cffi==23.1.0
- argon2-cffi-bindings==21.2.0
- arrow==1.3.0
- asttokens==2.4.1
- async-lru==2.0.4
- attrs==23.2.0
- babel==2.14.0
- beautifulsoup4==4.12.3
- bleach==6.1.0
- catboost==1.2.2
- certifi==2024.2.2
- cffi==1.16.0
- charset-normalizer==3.3.2
- cloudpickle==3.0.0
- colorlog==6.8.2
- comm==0.2.1
- contourpy==1.2.0
- cycler==0.12.1
- debugpy==1.8.0
- decorator==5.1.1
- defusedxml==0.7.1
- executing==2.0.1
- fastjsonschema==2.19.1
- fonttools==4.47.2
- fqdn==1.5.1
- greenlet==3.0.3
- idna==3.6
- imageio==2.33.1
- imbalanced-learn==0.12.0
- imblearn==0.0
- ipykernel==6.29.0
- ipython==8.21.0
- isoduration==20.11.0
- jedi==0.19.1
- jinja2==3.1.3
- joblib==1.3.2
- json5==0.9.14
- jsonpointer==2.4
- jsonschema==4.21.1
- jsonschema-specifications==2023.12.1
- jupyter-client==8.6.0
- jupyter-core==5.7.1
- jupyter-events==0.9.0
- jupyter-lsp==2.2.2
- jupyter-server==2.12.5
- jupyter-server-terminals==0.5.2
- jupyterlab==4.0.12
- jupyterlab-pygments==0.3.0
- jupyterlab-server==2.25.2
- kiwisolver==1.4.5
- lazy-loader==0.3
- lime==0.2.0.1
- llvmlite==0.42.0
- mako==1.3.2
- markupsafe==2.1.4
- matplotlib==3.8.2
- matplotlib-inline==0.1.6
- mistune==3.0.2
- nbclient==0.9.0
- nbconvert==7.14.2
- nbformat==5.9.2
- nest-asyncio==1.6.0
- networkx==3.2.1
- notebook-shim==0.2.3
- numba==0.59.0
- numpy==1.26.3
- optuna==3.5.0
- overrides==7.7.0
- packaging==23.2
- pandas==2.2.0
- pandocfilters==1.5.1
- parso==0.8.3
- pexpect==4.9.0
- pillow==10.2.0
- platformdirs==4.2.0
- plotly==5.18.0
- prometheus-client==0.19.0
- prompt-toolkit==3.0.43
- psutil==5.9.8
- psycopg2-binary==2.9.9
- ptyprocess==0.7.0
- pure-eval==0.2.2
- pycparser==2.21
- pygments==2.17.2
- pyparsing==3.1.1
- python-dateutil==2.8.2
- python-graphviz==0.20.1
- python-json-logger==2.0.7
- pytz==2024.1
- pyyaml==6.0.1
- pyzmq==25.1.2
- referencing==0.33.0
- requests==2.31.0
- rfc3339-validator==0.1.4
- rfc3986-validator==0.1.1
- rpds-py==0.17.1
- scikit-image==0.22.0
- scikit-learn==1.4.0
- scipy==1.12.0
- send2trash==1.8.2
- shap==0.44.1
- six==1.16.0
- slicer==0.0.7
- sniffio==1.3.0
- soupsieve==2.5
- sqlalchemy==2.0.25
- stack-data==0.6.3
- tenacity==8.2.3
- terminado==0.18.0
- threadpoolctl==3.2.0
- tifffile==2024.1.30
- tinycss2==1.2.1
- tornado==6.4
- tqdm==4.66.1
- traitlets==5.14.1
- types-python-dateutil==2.8.19.20240106
- typing-extensions==4.9.0
- tzdata==2023.4
- uri-template==1.3.0
- urllib3==2.2.0
- wcwidth==0.2.13
- webcolors==1.13
- webencodings==0.5.1
- websocket-client==1.7.0
- xgboost==2.0.3

View File

@@ -2,26 +2,21 @@
"cells": [ "cells": [
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": 1,
"id": "7c5d059b-ed8a-4e2e-9420-25890f648895", "id": "7c5d059b-ed8a-4e2e-9420-25890f648895",
"metadata": { "metadata": {
"scrolled": true "scrolled": true
}, },
"outputs": [ "outputs": [
{ {
"name": "stderr", "ename": "ModuleNotFoundError",
"output_type": "stream", "evalue": "No module named 'psycopg2'",
"text": [ "output_type": "error",
"/tmp/ipykernel_20879/1378035245.py:1: DeprecationWarning: \n", "traceback": [
"Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),\n", "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)\n", "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
"but was not found to be installed on your system.\n", "Cell \u001b[0;32mIn[1], line 2\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[0;32m----> 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpsycopg2\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpg\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mmatplotlib\u001b[39;00m\u001b[38;5;21;01m.\u001b[39;00m\u001b[38;5;21;01mpyplot\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mplt\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n",
"If this would cause problems for you,\n", "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'psycopg2'"
"please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466\n",
" \n",
" import pandas as pd\n",
"/home/agobbi/miniconda3/envs/pid/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
" from .autonotebook import tqdm as notebook_tqdm\n"
] ]
} }
], ],
@@ -9894,7 +9889,7 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.11.0" "version": "3.12.2"
} }
}, },
"nbformat": 4, "nbformat": 4,

File diff suppressed because one or more lines are too long

134
requirements.txt Normal file
View File

@@ -0,0 +1,134 @@
alembic==1.13.1
annotated-types==0.6.0
antlr4-python3-runtime==4.9.3
anyio==4.3.0
argon2-cffi==23.1.0
argon2-cffi-bindings==21.2.0
arrow==1.3.0
asttokens==2.4.1
async-lru==2.0.4
attrs==23.2.0
Babel==2.14.0
beautifulsoup4==4.12.3
bleach==6.1.0
certifi==2024.2.2
cffi==1.16.0
charset-normalizer==3.3.2
click==8.1.7
colorlog==6.8.2
comm==0.2.2
contourpy==1.2.0
cycler==0.12.1
debugpy==1.8.1
decorator==5.1.1
defusedxml==0.7.1
executing==2.0.1
fastapi==0.110.0
fastjsonschema==2.19.1
fonttools==4.50.0
fqdn==1.5.1
greenlet==3.0.3
h11==0.14.0
httpcore==1.0.4
httptools==0.6.1
httpx==0.27.0
hydra-core==1.3.2
idna==3.6
imbalanced-learn==0.12.0
imblearn==0.0
ipykernel==6.29.3
ipython==8.22.2
isoduration==20.11.0
jedi==0.19.1
Jinja2==3.1.3
joblib==1.3.2
json5==0.9.22
jsonpointer==2.4
jsonschema==4.21.1
jsonschema-specifications==2023.12.1
jupyter-events==0.9.1
jupyter-lsp==2.2.4
jupyter_client==8.6.1
jupyter_core==5.7.2
jupyter_server==2.13.0
jupyter_server_terminals==0.5.3
jupyterlab==4.1.4
jupyterlab_pygments==0.3.0
jupyterlab_server==2.25.4
kiwisolver==1.4.5
Mako==1.3.2
MarkupSafe==2.1.5
matplotlib==3.8.3
matplotlib-inline==0.1.6
mistune==3.0.2
nbclient==0.10.0
nbconvert==7.16.2
nbformat==5.10.2
nest-asyncio==1.6.0
notebook_shim==0.2.4
numpy==1.26.4
omegaconf==2.3.0
optuna==3.5.0
overrides==7.7.0
packaging==24.0
pandas==2.2.1
pandocfilters==1.5.1
parso==0.8.3
pexpect==4.9.0
pillow==10.2.0
platformdirs==4.2.0
prometheus_client==0.20.0
prompt-toolkit==3.0.43
psutil==5.9.8
psycopg2==2.9.9
psycopg2-binary==2.9.9
ptyprocess==0.7.0
pure-eval==0.2.2
pycparser==2.21
pydantic==2.6.4
pydantic_core==2.16.3
Pygments==2.17.2
pyparsing==3.1.2
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
python-json-logger==2.0.7
pytz==2024.1
PyYAML==6.0.1
pyzmq==25.1.2
referencing==0.33.0
requests==2.31.0
rfc3339-validator==0.1.4
rfc3986-validator==0.1.1
rpds-py==0.18.0
scikit-learn==1.4.1.post1
scipy==1.12.0
seaborn==0.13.2
Send2Trash==1.8.2
setuptools==68.2.2
six==1.16.0
sniffio==1.3.1
soupsieve==2.5
SQLAlchemy==2.0.28
stack-data==0.6.3
starlette==0.36.3
terminado==0.18.1
threadpoolctl==3.3.0
tinycss2==1.2.1
tornado==6.4
tqdm==4.66.2
traitlets==5.14.2
types-python-dateutil==2.8.19.20240311
typing_extensions==4.10.0
tzdata==2024.1
uri-template==1.3.0
urllib3==2.2.1
uvicorn==0.28.0
uvloop==0.19.0
watchfiles==0.21.0
wcwidth==0.2.13
webcolors==1.13
webencodings==0.5.1
websocket-client==1.7.0
websockets==12.0
wheel==0.41.2
xgboost==2.0.3

18
service/Dockerfile Normal file
View File

@@ -0,0 +1,18 @@
#
FROM python:3.9
#
WORKDIR /code
#
COPY ./requirements.txt /code/requirements.txt
#
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
#
COPY ./app /code/app
#
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]

0
service/app/__init__.py Normal file
View File

27
service/app/main.py Normal file
View File

@@ -0,0 +1,27 @@
from typing import Union
from .predict import predict
from fastapi import FastAPI
import numpy as np
app = FastAPI()
@app.get("/")
def read_root():
return {"Hello": "World"}
@app.get("/predict/{data}")
def read_item(data: str):
final_json = {'probabilities':{}}
try:
res = predict(data)
for i in range(5):
final_json['probabilities'][f"india_{i}"] = round(float(res[i]),2)
final_json['success'] = True
except Exception as e:
final_json['success'] = False
final_json['error'] = repr(e)
return final_json

BIN
service/app/metadata.pkl Normal file

Binary file not shown.

1
service/app/model.json Normal file

File diff suppressed because one or more lines are too long

60
service/app/predict.py Normal file
View File

@@ -0,0 +1,60 @@
import json
import xgboost as xgb
import pickle
import pandas as pd
import numpy as np
def predict(json_input:str):
with open('app/metadata.pkl','rb') as f:
to_remove,use_small,evacuations,encoders = pickle.load(f)
data = pd.DataFrame(json.loads(json_input[1:-1]),index=[0])
data.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True, errors='ignore')
##evacuation_vehicles must be explicitated and workaround for other!
evacuation = data.evacuation_vehicles.values[0]
if isinstance(evacuation,str):
evacuation = [evacuation]
for c in evacuations:
data[c] = False
for c in evacuation:
data[c] = True
for c in evacuation:
if c not in evacuations:
data['other'] = True
break
data.drop(columns=['town','province','evacuation_vehicles'],inplace=True, errors='ignore')
data['age'] = data['age'].astype(np.float32)
for c in data.columns:
if c not in ['india','age','season','skiarea_name','destination']:
data[c] = data[c].astype('str')
if use_small:
for c in to_remove.keys():
for k in to_remove[c]:
data.loc[data[c]==k,c] = 'other'
for c in data.columns:
if c not in ['age','season','skiarea_name','india']:
data[c] = data[c].fillna('None')
if use_small:
data[c] = pd.Categorical( encoders['small'][c].transform(data[c]), categories=list(range(len(encoders['small'][c].classes_))), ordered=False)
else:
data[c] = pd.Categorical( encoders['normal'][c].transform(data[c]), categories=list(range(len(encoders['normal'][c].classes_))), ordered=False)
bst_FS = xgb.Booster()
bst_FS.load_model("app/model.json")
dtest_FS = xgb.DMatrix(data[bst_FS.feature_names],enable_categorical=True)
preds = bst_FS.predict(dtest_FS)
return preds[0]*100

31
service/requirements.txt Normal file
View File

@@ -0,0 +1,31 @@
annotated-types==0.6.0
anyio==4.3.0
click==8.1.7
fastapi==0.110.0
h11==0.14.0
httptools==0.6.1
idna==3.6
joblib==1.3.2
numpy==1.26.4
pandas==2.2.1
pydantic==2.6.4
pydantic_core==2.16.3
python-dateutil==2.9.0.post0
python-dotenv==1.0.1
pytz==2024.1
PyYAML==6.0.1
scikit-learn==1.4.1.post1
scipy==1.12.0
setuptools==68.2.2
six==1.16.0
sniffio==1.3.1
starlette==0.36.3
threadpoolctl==3.3.0
typing_extensions==4.10.0
tzdata==2024.1
uvicorn==0.28.0
uvloop==0.19.0
watchfiles==0.21.0
websockets==12.0
wheel==0.41.2
xgboost==2.0.3

23
src/config.yaml Normal file
View File

@@ -0,0 +1,23 @@
processing:
skiarea_test: 'Klausberg' ##you can put it to None
season_test_skiarea : 'Kronplatz' ##you can put it to None
season_test_year: 2023 ##you can put it to None
weight_type: 'sqrt'
use_small: True ## condensate underrepresented classes (no destination!)
reload_data: False
use_smote: False ##I don't like to use it, leave to False
undersampling: False ##I don't like to use it, leave to False
test_size: 0.33
model:
name: test
num_boost_round : 2500
retrain: True
retrain_last_model: True
n_trials: 2000
hydra:
output_subdir: null
run:
dir: .

View File

@@ -4,49 +4,66 @@ from sklearn.metrics import confusion_matrix,matthews_corrcoef,accuracy_score
import xgboost as xgb import xgboost as xgb
import pandas as pd import pandas as pd
import pickle import pickle
import argparse from omegaconf import DictConfig,OmegaConf
import hydra
import logging
def main(args): import os
#you can put these parameters in the args but here I keep it simpler
num_boost_round = 600
SKI_AREA_TEST= 'Klausberg' ##you can put it to None
SEASON_TEST_SKIAREA = 'Kronplatz'##you can put it to None
SEASON_TEST_YEAR= 2023 ##you can put it to None @hydra.main(config_name='config.yaml')
weight_type = 'sqrt' def main(conf: DictConfig) -> None:
skiarea_test= conf.processing.skiarea_test
season_test_skiarea = conf.processing.season_test_skiarea
season_test_year= conf.processing.season_test_year
weight_type = conf.processing.weight_type
reload_data = conf.processing.reload_data
use_smote = conf.processing.use_smote
undersampling = conf.processing.undersampling
test_size = conf.processing.test_size
use_small= conf.processing.use_small
num_boost_round = conf.model.num_boost_round
retrain_last_model = conf.model.retrain_last_model
retrain = conf.model.retrain
n_trials = conf.model.n_trials
name = conf.model.name
os.makedirs(name,exist_ok=True)
with open(os.path.join(name,"conf.yaml"),'w') as f:
OmegaConf.save(conf, f)
logging.basicConfig(level=logging.INFO, handlers=[logging.FileHandler(os.path.join(name,"debug.log")),logging.StreamHandler() ])
##these are passed
reload_data = args.reload_data
use_smote = args.use_smote ##I don't like to use it, leave to False
undersampling = args.undersampling ##I don't like to use it, leave to False
retrain = args.retrain
retrain_last_model = args.retrain_last_model
test_size = args.test_size
## get the data ## get the data
labeled,labeled_small,to_remove = retrive_data(reload_data=reload_data,threshold_under_represented=0.5,path='/home/agobbi/Projects/PID/datanalytics/PID/src') labeled,labeled_small,to_remove,evacuations,encoders = retrive_data(reload_data=reload_data,threshold_under_represented=0.5,path='/home/agobbi/Projects/PID/datanalytics/PID/src')
with open('to_remove.pkl','wb') as f:
pickle.dump(to_remove,f)
with open(os.path.join(name,'metadata.pkl'),'wb') as f:
pickle.dump([to_remove,use_small,evacuations,encoders],f)
#split the data #split the data
dataset,dataset_test = split(labeled_small if args.use_small else labeled ,
SKI_AREA_TEST= SKI_AREA_TEST, dataset,dataset_test = split(labeled_small if use_small else labeled ,
SEASON_TEST_SKIAREA = SEASON_TEST_SKIAREA, skiarea_test= skiarea_test,
SEASON_TEST_YEAR= SEASON_TEST_YEAR, season_test_skiarea = season_test_skiarea,
season_test_year= season_test_year,
use_smote = use_smote, use_smote = use_smote,
undersampling = undersampling, undersampling = undersampling,
test_size = test_size, test_size = test_size,
weight_type = weight_type ) weight_type = weight_type )
#if you changed something you may want to retrain the model and save the best model #if you changed something you may want to retrain the model and save the best model
if retrain: if retrain:
print('OPTUNA hyperparameter tuning, please wait!') logging.info('OPTUNA hyperparameter tuning, please wait!')
best_model,params_final,study = train(dataset,n_trials=args.n_trials,timeout=600,num_boost_round=num_boost_round) best_model,params_final,study = train(dataset,n_trials=n_trials,timeout=6000,num_boost_round=num_boost_round)
feat_imp = pd.Series(best_model.get_fscore()).sort_values(ascending=False) feat_imp = pd.Series(best_model.get_fscore()).sort_values(ascending=False)
with open('best_params.pkl','wb') as f: with open(os.path.join(name,'best_params.pkl'),'wb') as f:
pickle.dump([params_final,feat_imp,best_model,study],f) pickle.dump([params_final,feat_imp,best_model,study],f)
else: else:
with open('best_params.pkl','rb') as f: with open(os.path.join(name,'best_params.pkl'),'rb') as f:
params_final,feat_imp,best_model,study = pickle.load(f) params_final,feat_imp,best_model,study = pickle.load(f)
@@ -59,25 +76,33 @@ def main(args):
##get the scores ##get the scores
preds_class_valid = best_model.predict(tmp_valid) preds_class_valid = best_model.predict(tmp_valid)
preds_class_train= best_model.predict(tmp_train) preds_class_train= best_model.predict(tmp_train)
print('##################RESULT ON THE TRAIN SET#####################') logging.info('##################RESULT ON THE TRAIN SET#####################')
print(confusion_matrix(dataset.y_train,preds_class_train.argmax(1))) logging.info(confusion_matrix(dataset.y_train,preds_class_train.argmax(1)))
print(f'MCC:{matthews_corrcoef(dataset.y_train,preds_class_train.argmax(1))}') logging.info(f'MCC:{matthews_corrcoef(dataset.y_train,preds_class_train.argmax(1))}')
print(f'ACC:{accuracy_score(dataset.y_train,preds_class_train.argmax(1))}') logging.info(f'ACC:{accuracy_score(dataset.y_train,preds_class_train.argmax(1))}')
print('##################RESULT ON THE VALIDATION SET#####################') logging.info('##################RESULT ON THE VALIDATION SET#####################')
print(confusion_matrix(dataset.y_valid,preds_class_valid.argmax(1))) logging.info(confusion_matrix(dataset.y_valid,preds_class_valid.argmax(1)))
print(f'MCC:{matthews_corrcoef(dataset.y_valid,preds_class_valid.argmax(1))}') logging.info(f'MCC:{matthews_corrcoef(dataset.y_valid,preds_class_valid.argmax(1))}')
print(f'ACC:{accuracy_score(dataset.y_valid,preds_class_valid.argmax(1))}') logging.info(f'ACC:{accuracy_score(dataset.y_valid,preds_class_valid.argmax(1))}')
#now you can train the final model, for example using gain_accuracy_train for reducing the number of features used #now you can train the final model, for example using gain_accuracy_train for reducing the number of features used
if retrain_last_model: if retrain_last_model:
tot,bst_FS,FS = gain_accuracy_train(dataset,feat_imp,num_boost_round=num_boost_round,params=params_final) tot,bst_FS,FS = gain_accuracy_train(dataset,feat_imp,num_boost_round=num_boost_round,params=params_final)
with open('best_params_and_final_model.pkl','wb') as f: with open(os.path.join(name,'best_params_and_final_model.pkl'),'wb') as f:
pickle.dump([tot,bst_FS,FS],f) pickle.dump([tot,bst_FS,FS],f)
bst_FS.save_model(os.path.join(name,"model.json"))
else: else:
with open('best_params_and_final_model.pkl','rb') as f: with open(os.path.join(name,'best_params_and_final_model.pkl'),'rb') as f:
tot,bst_FS,FS = pickle.load(f) tot,bst_FS,FS = pickle.load(f)
bst_FS = xgb.Booster()
bst_FS.load_model(os.path.join(name,"model.json"))
## save the model in json format, maybe it is better
if dataset_test.X_test_area is not None: if dataset_test.X_test_area is not None:
dtest_FS = xgb.DMatrix(dataset_test.X_test_area[bst_FS.feature_names],dataset_test.y_test_area,enable_categorical=True,) dtest_FS = xgb.DMatrix(dataset_test.X_test_area[bst_FS.feature_names],dataset_test.y_test_area,enable_categorical=True,)
@@ -85,7 +110,7 @@ def main(args):
mcc = matthews_corrcoef(dataset_test.y_test_area,preds_class_test.argmax(1)) mcc = matthews_corrcoef(dataset_test.y_test_area,preds_class_test.argmax(1))
acc = accuracy_score(dataset_test.y_test_area,preds_class_test.argmax(1)) acc = accuracy_score(dataset_test.y_test_area,preds_class_test.argmax(1))
cm = confusion_matrix(dataset_test.y_test_area,preds_class_test.argmax(1)) cm = confusion_matrix(dataset_test.y_test_area,preds_class_test.argmax(1))
print(f'RESULT ON THE TEST SKI AREA {mcc=}, {acc=}, \n{cm=}') logging.info(f'RESULT ON THE TEST SKI AREA {mcc=}, {acc=}, \n{cm=}')
if dataset_test.X_test_season is not None: if dataset_test.X_test_season is not None:
dtest_season_FS = xgb.DMatrix(dataset_test.X_test_season[bst_FS.feature_names],dataset_test.y_test_season,enable_categorical=True,) dtest_season_FS = xgb.DMatrix(dataset_test.X_test_season[bst_FS.feature_names],dataset_test.y_test_season,enable_categorical=True,)
@@ -94,26 +119,10 @@ def main(args):
acc = accuracy_score(dataset_test.y_test_season,preds_class_test_season.argmax(1)) acc = accuracy_score(dataset_test.y_test_season,preds_class_test_season.argmax(1))
cm = confusion_matrix(dataset_test.y_test_season,preds_class_test_season.argmax(1)) cm = confusion_matrix(dataset_test.y_test_season,preds_class_test_season.argmax(1))
print(f'RESULT ON THE TEST SKI SEASON {mcc=}, {acc=}, {cm=}') logging.info(f'RESULT ON THE TEST SKI SEASON {mcc=}, {acc=}, {cm=}')
if __name__ == "__main__": if __name__ == "__main__":
main()
parser = argparse.ArgumentParser(description='Train Optuna XGBOOST model')
parser.add_argument('--use_small', action='store_true', help="Aggregate under represented input classes (es: rare country)")
parser.add_argument('--use_smote', action='store_true', help='oversampling underrperesented target labels')
parser.add_argument('--retrain', action='store_true', help='Retrain the optuna searcher')
parser.add_argument('--reload_data', action='store_true', help='Dowload data from db')
parser.add_argument('--retrain_last_model', action='store_true', help='retrain the last model')
parser.add_argument('--n_trials', type=int,default=1000, help='number of trials per optuna')
parser.add_argument('--undersampling', action='store_true', help='Undersample the training dataset')
parser.add_argument('--test_size', type=float,default=0.33, help='Percentage of dataset to use as validation')
args = parser.parse_args()
main(args)
#python main.py --use_small --retrain --retrain_last_model --n_trials=10 --reload_data

View File

@@ -1,11 +1,11 @@
import xgboost as xgb import xgboost as xgb
import optuna
from sklearn.metrics import matthews_corrcoef, accuracy_score from sklearn.metrics import matthews_corrcoef, accuracy_score
import optuna import optuna
from utils import Dataset from utils import Dataset
import pandas as pd import pandas as pd
import logging
def objective(trial,dataset:Dataset,num_boost_round:int)->float: def objective(trial,dataset:Dataset,num_boost_round:int)->float:
"""function to maximize during the tuning phase """function to maximize during the tuning phase
@@ -22,21 +22,21 @@ def objective(trial,dataset:Dataset,num_boost_round:int)->float:
params = dict( params = dict(
learning_rate = trial.suggest_float("learning_rate", 0.01, 0.2), learning_rate = trial.suggest_float("learning_rate", 0.01, 0.2),
max_depth= trial.suggest_int("max_depth",5, 15), max_depth= trial.suggest_int("max_depth",5, 15),
min_child_weight = trial.suggest_int("min_child_weight", 1, 8), min_child_weight = trial.suggest_int("min_child_weight", 2, 8),
gamma = trial.suggest_float("gamma", 0, 10), gamma = trial.suggest_float("gamma",0, 10),
subsample = trial.suggest_float("subsample", 0.01,1), subsample = trial.suggest_float("subsample", 0.01,1),
colsample_bytree = trial.suggest_float("colsample_bytree", 0.01,1), colsample_bytree = trial.suggest_float("colsample_bytree", 0.01,1),
alpha = trial.suggest_float("alpha", 0, 10), alpha = trial.suggest_float("alpha", 1, 10),
objective= 'multi:softprob', objective= 'multi:softprob',
nthread=4, nthread=4,
num_class= 5, num_class= 5,
seed=27) seed=27)
params['lambda'] = trial.suggest_float("lambda", 0, 10) params['lambda'] = trial.suggest_float("lambda", 1, 10)
dtrain = xgb.DMatrix(dataset.X_train,dataset.y_train, dtrain = xgb.DMatrix(dataset.X_train,dataset.y_train,
enable_categorical=True, enable_categorical=True,
weight=dataset.weight_train) weight=dataset.weight_train)#np.power(dataset.weight_train,trial.suggest_float("power", 0.1, 2)))
dvalid = xgb.DMatrix(dataset.X_valid,dataset.y_valid, dvalid = xgb.DMatrix(dataset.X_valid,dataset.y_valid,
enable_categorical=True, enable_categorical=True,
) )
@@ -45,7 +45,7 @@ def objective(trial,dataset:Dataset,num_boost_round:int)->float:
bst = xgb.train(params, dtrain,verbose_eval=False, num_boost_round=num_boost_round, bst = xgb.train(params, dtrain,verbose_eval=False, num_boost_round=num_boost_round,
evals = [(dtrain, "train"), (dvalid, "valid")], evals = [(dtrain, "train"), (dvalid, "valid")],
early_stopping_rounds=100) early_stopping_rounds=100)
logging.info(bst.best_iteration)
preds = bst.predict(dvalid) preds = bst.predict(dvalid)
##MCC is more solid ##MCC is more solid
mcc = matthews_corrcoef(dataset.y_valid,preds.argmax(1)) mcc = matthews_corrcoef(dataset.y_valid,preds.argmax(1))
@@ -119,7 +119,7 @@ def gain_accuracy_train(dataset:Dataset,feat_imp:pd.DataFrame,num_boost_round:in
tot = pd.DataFrame(tot) tot = pd.DataFrame(tot)
FS = int(tot.loc[tot.acc.argmax()].FS) ## get best FS = int(tot.loc[tot.acc.argmax()].FS) ## get best
print(f'Best model with {FS} features, retraining....') logging.info(f'Best model with {FS} features, retraining....')
dtrain_FS = xgb.DMatrix(dataset.X_train[list(feat_imp.head(FS).index)],dataset.y_train, enable_categorical=True, weight=dataset.weight_train) dtrain_FS = xgb.DMatrix(dataset.X_train[list(feat_imp.head(FS).index)],dataset.y_train, enable_categorical=True, weight=dataset.weight_train)
dvalid_FS = xgb.DMatrix(dataset.X_valid[list(feat_imp.head(FS).index)],dataset.y_valid,enable_categorical=True, ) dvalid_FS = xgb.DMatrix(dataset.X_valid[list(feat_imp.head(FS).index)],dataset.y_valid,enable_categorical=True, )

View File

@@ -7,7 +7,10 @@ import pickle
from dataclasses import dataclass from dataclasses import dataclass
from typing import Union from typing import Union
import os import os
from imblearn.under_sampling import RandomUnderSampler,RandomOverSampler from imblearn.under_sampling import RandomUnderSampler
from imblearn.over_sampling import RandomOverSampler
from sklearn.preprocessing import LabelEncoder
import logging
##AUXILIARY CLASSES ##AUXILIARY CLASSES
@dataclass @dataclass
@@ -69,7 +72,8 @@ def prepare_new_data(dataset:pd.DataFrame,to_remove:dict)->(pd.DataFrame,pd.Data
def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(pd.DataFrame,pd.DataFrame):
def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(pd.DataFrame,pd.DataFrame,list):
"""Get data """Get data
Args: Args:
@@ -79,6 +83,7 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
Returns: Returns:
two pandas dataframe, one the original the second with condensed classes and a dictionarly of condesed classes two pandas dataframe, one the original the second with condensed classes and a dictionarly of condesed classes
and a list of all evacuations and the encoders for the categorical features
""" """
if reload_data: if reload_data:
engine = pg.connect("dbname='safeidx' user='fbk_mpba' host='172.104.247.67' port='5432' password='fbk2024$'") engine = pg.connect("dbname='safeidx' user='fbk_mpba' host='172.104.247.67' port='5432' password='fbk2024$'")
@@ -88,8 +93,11 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
else: else:
with open(os.path.join(path,'data.pkl'),'rb') as f: with open(os.path.join(path,'data.pkl'),'rb') as f:
df = pickle.load(f) df = pickle.load(f)
#import pdb
#pdb.set_trace()
df = df[df.year>2011]
## these columns can lead to overfit! ## these columns can lead to overfit!
df.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True) df.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True)
##evacuation_vehicles must be explicitated ##evacuation_vehicles must be explicitated
@@ -112,7 +120,7 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
## maybe it is possible to obtain a more stable model removing such classes ## maybe it is possible to obtain a more stable model removing such classes
to_remove = {} to_remove = {}
for c in labeled.columns: for c in labeled.columns:
if c not in ['india','age','season','skiarea_name']: if c not in ['india','age','season','skiarea_name','destination']:
labeled[c] = labeled[c].astype('str') labeled[c] = labeled[c].astype('str')
tmp = labeled.groupby(c)[c].count() tmp = labeled.groupby(c)[c].count()
tmp = 100*tmp/tmp.max() tmp = 100*tmp/tmp.max()
@@ -125,13 +133,20 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
## keep the datasets ## keep the datasets
labeled_small = labeled.copy() labeled_small = labeled.copy()
encoders = {'small':{},'normal':{}}
for c in to_remove.keys(): for c in to_remove.keys():
for k in to_remove[c]: for k in to_remove[c]:
labeled_small.loc[labeled_small[c]==k,c] = 'other' labeled_small.loc[labeled_small[c]==k,c] = 'other'
for c in labeled_small.columns: for c in labeled_small.columns:
if c not in ['age','season','skiarea_name']: if c not in ['age','season','skiarea_name','india']:
labeled_small[c] = labeled_small[c].fillna('None').astype('category') le = LabelEncoder()
labeled[c] = labeled[c].fillna('None').astype('category') labeled_small[c] = le.fit_transform(labeled_small[c].fillna('None'))
labeled_small[c] = labeled_small[c].astype('category')
encoders['small'][c] = le
le = LabelEncoder()
labeled[c] = le.fit_transform(labeled[c].fillna('None'))
labeled[c] = labeled[c].astype('category')
encoders['normal'][c] = le
labeled.dropna(inplace=True) labeled.dropna(inplace=True)
labeled_small.dropna(inplace=True) labeled_small.dropna(inplace=True)
@@ -139,29 +154,29 @@ def retrive_data(reload_data:bool,threshold_under_represented:float,path:str)->(
labeled.india = labeled.india.apply(lambda x: x.replace('i','')).astype(int) labeled.india = labeled.india.apply(lambda x: x.replace('i','')).astype(int)
labeled_small.india = labeled_small.india.apply(lambda x: x.replace('i','')).astype(int) labeled_small.india = labeled_small.india.apply(lambda x: x.replace('i','')).astype(int)
return labeled,labeled_small,to_remove return labeled,labeled_small,to_remove,list(ev),encoders
def split(labeled:pd.DataFrame, def split(labeled:pd.DataFrame,
SKI_AREA_TEST: str = 'Klausberg', skiarea_test: str = 'Klausberg',
SEASON_TEST_SKIAREA:str = 'Kronplatz', season_test_skiarea:str = 'Kronplatz',
SEASON_TEST_YEAR:int = 2023, season_test_year:int = 2023,
use_smote:bool = False, use_smote:bool = False,
undersampling:bool=False, undersampling:bool=False,
test_size:float=0.33, test_size:float=0.33,
weight_type:str = 'sqrt' )->(Dataset, Dataset_test): weight_type:str = 'sqrt' )->(Dataset, Dataset_test):
"""Split the dataset into train,validation test. From the initial dataset we remove a single skiarea (SKI_AREA_TEST) """Split the dataset into train,validation test. From the initial dataset we remove a single skiarea (skiarea_test)
generating the first test set. Then we select a skieare and a starting season (SEASON_TEST_SKIAREA,SEASON_TEST_YEAR) generating the first test set. Then we select a skieare and a starting season (season_test_skiarea,season_test_year)
and generate the seconda test set. The rest of the data are splitted 66-33 stratified on the target column (india). and generate the seconda test set. The rest of the data are splitted 66-33 stratified on the target column (india).
It is possible to specify the weight of eact sample. There are two strategies implemented: using the sum or the square root It is possible to specify the weight of eact sample. There are two strategies implemented: using the sum or the square root
of the sum. This is used for mitigating the class umbalance. Another alternative is to use an oversampling procedure (use_smote) of the sum. This is used for mitigating the class umbalance. Another alternative is to use an oversampling procedure (use_smote)
Args: Args:
labeled (pd.DataFrame): dataset labeled (pd.DataFrame): dataset
SKI_AREA_TEST (str, optional): skiarea to remove from the train and use in test. Defaults to 'Klausberg'. skiarea_test (str, optional): skiarea to remove from the train and use in test. Defaults to 'Klausberg'.
SEASON_TEST_SKIAREA (str, optional): skiarea to remove from the dataset if the season is greater than SEASON_TEST_YEAR. Defaults to 'Kronplatz'. season_test_skiarea (str, optional): skiarea to remove from the dataset if the season is greater than season_test_year. Defaults to 'Kronplatz'.
SEASON_TEST_YEAR (int, optional): see SEASON_TEST_SKIAREA . Defaults to 2023. season_test_year (int, optional): see season_test_skiarea . Defaults to 2023.
use_smote (bool, optional): use oversampling for class umbalance. Defaults to False. use_smote (bool, optional): use oversampling for class umbalance. Defaults to False.
undersampling (bool, optional): use undersampling for class umbalance. Defaults to False. undersampling (bool, optional): use undersampling for class umbalance. Defaults to False.
test_size (float, optional): percentage of dataset to use as validation. Defaults to 0.33. test_size (float, optional): percentage of dataset to use as validation. Defaults to 0.33.
@@ -174,15 +189,15 @@ def split(labeled:pd.DataFrame,
labeled_tmp = labeled.copy() labeled_tmp = labeled.copy()
##remove from dataset the corresponding test rows ##remove from dataset the corresponding test rows
if SKI_AREA_TEST is not None: if skiarea_test is not None:
test_area = labeled[labeled.skiarea_name==SKI_AREA_TEST] test_area = labeled[labeled.skiarea_name==skiarea_test]
labeled_tmp = labeled_tmp[labeled_tmp.skiarea_name!=SKI_AREA_TEST] labeled_tmp = labeled_tmp[labeled_tmp.skiarea_name!=skiarea_test]
else: else:
test_area = None test_area = None
if SEASON_TEST_SKIAREA is not None and SEASON_TEST_YEAR is not None: if season_test_skiarea is not None and season_test_year is not None:
test_area_season = labeled[(labeled.skiarea_name==SEASON_TEST_SKIAREA)&(labeled.season>=SEASON_TEST_YEAR)] test_area_season = labeled[(labeled.skiarea_name==season_test_skiarea)&(labeled.season>=season_test_year)]
labeled_tmp = labeled_tmp[(labeled_tmp.skiarea_name!=SEASON_TEST_SKIAREA)|(labeled_tmp.season<SEASON_TEST_YEAR) ] labeled_tmp = labeled_tmp[(labeled_tmp.skiarea_name!=season_test_skiarea)|(labeled_tmp.season<season_test_year) ]
else: else:
test_area_season = None test_area_season = None
@@ -210,12 +225,12 @@ def split(labeled:pd.DataFrame,
##when computing the error, these are the weights used for each class: you can punish more errpr on most severe clases ##when computing the error, these are the weights used for each class: you can punish more errpr on most severe clases
if weight_type == 'sqrt': if weight_type == 'sqrt':
w.p = np.sqrt(w.p.sum())/w.p w.p = np.sqrt(w.p.sum())/w.p
print(w) logging.info(w)
elif weight_type == 'sum': elif weight_type == 'sum':
w.p = w.p.sum()/w.p/w.shape[0] w.p = w.p.sum()/w.p/w.shape[0]
print(w) logging.info(w)
else: else:
print(f'{weight_type=} not implemented please use a valid one: sqrt or sum, I will set all the weights to 0') logging.info(f'{weight_type=} not implemented please use a valid one: sqrt or sum, I will set all the weights to 0')
w.p = 1 w.p = 1
if use_smote is False and undersampling is False: if use_smote is False and undersampling is False: