Files
pid/notebooks/old_notebooks/test_multi.ipynb

1555 lines
192 KiB
Plaintext
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"id": "7c5d059b-ed8a-4e2e-9420-25890f648895",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_46791/2472232159.py:1: DeprecationWarning: \n",
"Pyarrow will become a required dependency of pandas in the next major release of pandas (pandas 3.0),\n",
"(to allow more performant data types, such as the Arrow string type, and better interoperability with other libraries)\n",
"but was not found to be installed on your system.\n",
"If this would cause problems for you,\n",
"please provide us feedback at https://github.com/pandas-dev/pandas/issues/54466\n",
" \n",
" import pandas as pd\n",
"/tmp/ipykernel_46791/2472232159.py:7: UserWarning: pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.\n",
" df = pd.read_sql('select * from data_safeidx', con=engine)\n"
]
}
],
"source": [
"import pandas as pd\n",
"import psycopg2 as pg\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"engine = pg.connect(\"dbname='safeidx' user='fbk_mpba' host='172.104.247.67' port='5432' password='fbk2024$'\")\n",
"df = pd.read_sql('select * from data_safeidx', con=engine)"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "03aa2a04-93fa-469e-a678-685cacdebd6c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>difficulty</th>\n",
" <th>cause</th>\n",
" <th>town</th>\n",
" <th>province</th>\n",
" <th>gender</th>\n",
" <th>equipment</th>\n",
" <th>helmet</th>\n",
" <th>destination</th>\n",
" <th>diagnosis</th>\n",
" <th>india</th>\n",
" <th>age</th>\n",
" <th>country</th>\n",
" <th>injury_side</th>\n",
" <th>injury_general_location</th>\n",
" <th>evacuation_vehicles</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>novice</td>\n",
" <td>fall_alone</td>\n",
" <td>SIKLOS</td>\n",
" <td></td>\n",
" <td>F</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>hospital_emergency_room</td>\n",
" <td>distortion</td>\n",
" <td>None</td>\n",
" <td>32.0</td>\n",
" <td>Ungheria</td>\n",
" <td>L</td>\n",
" <td>lower_limbs</td>\n",
" <td>[akja]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>advanced</td>\n",
" <td>fall_alone</td>\n",
" <td>MALMO</td>\n",
" <td></td>\n",
" <td>M</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>hospital_emergency_room</td>\n",
" <td>bruise</td>\n",
" <td>None</td>\n",
" <td>32.0</td>\n",
" <td>Svezia</td>\n",
" <td>R</td>\n",
" <td>skull_or_face</td>\n",
" <td>[akja]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>advanced</td>\n",
" <td>fall_alone</td>\n",
" <td>CALDARO</td>\n",
" <td>BZ</td>\n",
" <td>F</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>domicile</td>\n",
" <td>other</td>\n",
" <td>None</td>\n",
" <td>12.0</td>\n",
" <td>Italia</td>\n",
" <td>R</td>\n",
" <td>None</td>\n",
" <td>[snowmobile]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>advanced</td>\n",
" <td>collision_person</td>\n",
" <td>LINZ</td>\n",
" <td></td>\n",
" <td>M</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>hospital_emergency_room</td>\n",
" <td>bruise</td>\n",
" <td>None</td>\n",
" <td>58.0</td>\n",
" <td>Austria</td>\n",
" <td>R</td>\n",
" <td>lower_limbs</td>\n",
" <td>[snowmobile]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>advanced</td>\n",
" <td>collision_person</td>\n",
" <td>RUSAVA</td>\n",
" <td></td>\n",
" <td>M</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>other</td>\n",
" <td>bruise</td>\n",
" <td>None</td>\n",
" <td>25.0</td>\n",
" <td>Repubblica Ceca</td>\n",
" <td>L</td>\n",
" <td>lower_limbs</td>\n",
" <td>[other]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" difficulty cause town province gender equipment helmet \\\n",
"0 novice fall_alone SIKLOS F ski None \n",
"1 advanced fall_alone MALMO M ski None \n",
"2 advanced fall_alone CALDARO BZ F ski None \n",
"3 advanced collision_person LINZ M ski None \n",
"4 advanced collision_person RUSAVA M ski None \n",
"\n",
" destination diagnosis india age country \\\n",
"0 hospital_emergency_room distortion None 32.0 Ungheria \n",
"1 hospital_emergency_room bruise None 32.0 Svezia \n",
"2 domicile other None 12.0 Italia \n",
"3 hospital_emergency_room bruise None 58.0 Austria \n",
"4 other bruise None 25.0 Repubblica Ceca \n",
"\n",
" injury_side injury_general_location evacuation_vehicles \n",
"0 L lower_limbs [akja] \n",
"1 R skull_or_face [akja] \n",
"2 R None [snowmobile] \n",
"3 R lower_limbs [snowmobile] \n",
"4 L lower_limbs [other] "
]
},
"execution_count": 2,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "babc2e8b-1030-4e8a-aa41-6d2a788959a5",
"metadata": {},
"outputs": [],
"source": [
"ev = set({})\n",
"for i,row in df.iterrows():\n",
" ev = ev.union(set(row.evacuation_vehicles))"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "c8d6cc1c-f4f5-44ec-8652-b135963452ab",
"metadata": {},
"outputs": [],
"source": [
"for c in ev:\n",
" df[c] = False\n",
"for i,row in df.iterrows():\n",
" for c in row.evacuation_vehicles:\n",
" df.loc[i,c] = True"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "758c3317-1b02-4aed-b94d-7b6998d23797",
"metadata": {},
"outputs": [],
"source": [
"df.drop(columns=['town','province','evacuation_vehicles'],inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adadc0dc-9d6e-4277-8956-d1d4b2492e7e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 36,
"id": "33617e77-7c2b-41a3-96c0-8930aa5ac869",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([1.3808e+04, 0.0000e+00, 8.8100e+02, 0.0000e+00, 0.0000e+00,\n",
" 3.3690e+03, 0.0000e+00, 1.5200e+02, 0.0000e+00, 1.1000e+01]),\n",
" array([0. , 0.4, 0.8, 1.2, 1.6, 2. , 2.4, 2.8, 3.2, 3.6, 4. ]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 36,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"labeled = df[~pd.isna(df.india)]\n",
"plt.hist(labeled.india)"
]
},
{
"cell_type": "code",
"execution_count": 37,
"id": "e8c139d9-bf61-45ec-9da1-7eaf4ff754b4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_46791/382759161.py:1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled['age'] = labeled['age'].astype(np.float32).fillna(np.nan)\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_46791/382759161.py:5: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled.dropna(inplace=True)\n"
]
}
],
"source": [
"labeled['age'] = labeled['age'].astype(np.float32).fillna(np.nan)\n",
"for c in labeled.columns:\n",
" if c!='age':\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"labeled.dropna(inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": 38,
"id": "fabf354e-f39e-4cde-af84-c65a277d309a",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split( labeled.drop(columns=['india']),\n",
" labeled.india, test_size=0.33, random_state=0,stratify=labeled.india)"
]
},
{
"cell_type": "code",
"execution_count": 39,
"id": "cf6cf5d8-d43e-499e-98a5-65ecd0b8ccda",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_valid, y_train, y_valid = train_test_split(X_train,y_train, test_size=0.33, random_state=0,stratify=y_train)"
]
},
{
"cell_type": "code",
"execution_count": 177,
"id": "71d8d93f-cb4f-402f-b4ce-b67e7b964c1b",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" class p\n",
"0 i0 0.238038\n",
"1 i1 0.014605\n",
"2 i2 0.059746\n",
"3 i3 1.326712\n",
"4 i4 18.043281\n"
]
}
],
"source": [
"w = pd.DataFrame(np.unique(y_train,return_counts=True)).T\n",
"w.columns = ['class','p']\n",
"w.p = np.sqrt(w.p.sum())/w.p\n",
"print(w)\n",
"weight_train = pd.merge(pd.DataFrame({'class':y_train}),w).p.values"
]
},
{
"cell_type": "code",
"execution_count": 184,
"id": "d1453965-f927-4edc-ad3b-421997d62268",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" class p\n",
"0 i0 4.634106\n",
"1 i1 1.147881\n",
"2 i2 2.321652\n",
"3 i3 10.940346\n",
"4 i4 40.346004\n"
]
}
],
"source": [
"w = pd.DataFrame(np.unique(y_train,return_counts=True)).T\n",
"w.columns = ['class','p']\n",
"w.p = np.sqrt(np.array((w.p.sum()/w.p).values).astype(float))\n",
"print(w)\n",
"weight_train = pd.merge(pd.DataFrame({'class':y_train}),w).p.values"
]
},
{
"cell_type": "code",
"execution_count": 172,
"id": "774ad570-a60b-475b-80cf-4f9b9949cc9d",
"metadata": {},
"outputs": [],
"source": [
"#weight_train = (1-y_train.values.astype(int))*5+1 ## peso classi unbalanced"
]
},
{
"cell_type": "code",
"execution_count": 178,
"id": "1fcc5234-abad-459a-9420-810833657796",
"metadata": {},
"outputs": [],
"source": [
"from catboost import CatBoostClassifier, Pool\n",
"\n",
"train_data = Pool(data=X_train,\n",
" label=y_train,\n",
" weight=weight_train,cat_features=[c for c in X_train.columns if c!='age'])\n",
"valid_data = Pool(data=X_valid,\n",
" label=y_valid,cat_features=[c for c in X_train.columns if c!='age']\n",
" )\n",
"model = CatBoostClassifier(iterations=1000)\n",
"\n",
"model.fit(train_data,eval_set=valid_data,verbose=False,early_stopping_rounds=100)\n",
"preds_class = model.predict(valid_data,)"
]
},
{
"cell_type": "code",
"execution_count": 179,
"id": "8df84007-110b-4f31-bc7d-53e7d5c9a178",
"metadata": {},
"outputs": [],
"source": [
"preds_class_valid = model.predict(valid_data)\n",
"preds_class_train= model.predict(train_data)"
]
},
{
"cell_type": "code",
"execution_count": 180,
"id": "55627ec4-fd24-4815-98d4-d8462bbfdd9a",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix,matthews_corrcoef,accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": 181,
"id": "7451f713-fc81-4688-8440-996ffb280572",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 145 40 2 0 0]\n",
" [ 98 2388 555 2 0]\n",
" [ 1 211 512 20 0]\n",
" [ 0 3 21 10 0]\n",
" [ 0 1 0 0 1]]\n",
"0.4797388666004143\n",
"0.7620947630922693\n",
"########################################\n",
"[[ 349 27 3 0 0]\n",
" [ 194 5033 946 4 0]\n",
" [ 6 274 1213 17 0]\n",
" [ 0 3 12 53 0]\n",
" [ 1 0 0 0 4]]\n",
"0.6100405041372103\n",
"0.8172994225334808\n"
]
}
],
"source": [
"print(confusion_matrix(y_valid,preds_class_valid))\n",
"print(matthews_corrcoef(y_valid,preds_class_valid))\n",
"print(accuracy_score(y_valid,preds_class_valid))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train,preds_class_train))\n",
"print(matthews_corrcoef(y_train,preds_class_train))\n",
"print(accuracy_score(y_train,preds_class_train))"
]
},
{
"cell_type": "code",
"execution_count": 144,
"id": "877e914d-44d4-4299-8d3c-24c4fc353317",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 145 40 2 0 0]\n",
" [ 98 2388 555 2 0]\n",
" [ 1 211 512 20 0]\n",
" [ 0 3 21 10 0]\n",
" [ 0 1 0 0 1]]\n",
"0.4797388666004143\n",
"0.7620947630922693\n",
"########################################\n",
"[[ 349 27 3 0 0]\n",
" [ 194 5033 946 4 0]\n",
" [ 6 274 1213 17 0]\n",
" [ 0 3 12 53 0]\n",
" [ 1 0 0 0 4]]\n",
"0.6100405041372103\n",
"0.8172994225334808\n"
]
}
],
"source": [
"print(confusion_matrix(y_valid,preds_class_valid))\n",
"print(matthews_corrcoef(y_valid,preds_class_valid))\n",
"print(accuracy_score(y_valid,preds_class_valid))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train,preds_class_train))\n",
"print(matthews_corrcoef(y_train,preds_class_train))\n",
"print(accuracy_score(y_train,preds_class_train))"
]
},
{
"cell_type": "code",
"execution_count": 139,
"id": "776395d9-a8e2-4fda-90e1-11d6dbe80de8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 129 57 1 0 0]\n",
" [ 13 2912 118 0 0]\n",
" [ 0 456 286 2 0]\n",
" [ 0 4 30 0 0]\n",
" [ 0 2 0 0 0]]\n",
"0.5046851486832885\n",
"0.8296758104738154\n",
"########################################\n",
"[[ 249 128 2 0 0]\n",
" [ 39 5946 192 0 0]\n",
" [ 1 908 601 0 0]\n",
" [ 1 10 53 4 0]\n",
" [ 1 0 1 0 3]]\n",
"0.5234784446629579\n",
"0.8358520702789041\n"
]
}
],
"source": [
"print(confusion_matrix(y_valid,preds_class_valid))\n",
"print(matthews_corrcoef(y_valid,preds_class_valid))\n",
"print(accuracy_score(y_valid,preds_class_valid))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train,preds_class_train))\n",
"print(matthews_corrcoef(y_train,preds_class_train))\n",
"print(accuracy_score(y_train,preds_class_train))"
]
},
{
"cell_type": "code",
"execution_count": 201,
"id": "34c397bc-529a-4c52-b30e-957b28021200",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 136 49 2 0 0]\n",
" [ 40 2665 337 1 0]\n",
" [ 1 300 419 24 0]\n",
" [ 0 3 21 10 0]\n",
" [ 0 1 0 0 1]]\n",
"0.5028705314408995\n",
"0.8057356608478803\n",
"########################################\n",
"[[ 367 12 0 0 0]\n",
" [ 54 5682 439 2 0]\n",
" [ 1 344 1164 1 0]\n",
" [ 0 0 0 68 0]\n",
" [ 0 0 0 0 5]]\n",
"0.7373051729464436\n",
"0.8951959700208871\n"
]
}
],
"source": [
"##try with xgboost\n",
"import xgboost as xgb\n",
"\n",
"# Create regression matrices\n",
"dtrain = xgb.DMatrix(X_train,y_train.apply(lambda x:x.replace('i','')).astype(int), enable_categorical=True,feature_names=list(X_train.columns.values),weight=weight_train)\n",
"dvalid = xgb.DMatrix(X_valid,y_valid.apply(lambda x:x.replace('i','')).astype(int), enable_categorical=True,feature_names=list(X_train.columns.values))\n",
"\n",
"params = {\"objective\": \"multi:softprob\", \"num_class\": 5,\"eta\":0.05,'min_child_weight':10,'lambda':1.5,}\n",
"n = 1000\n",
"\n",
"results = xgb.train(\n",
" params, dtrain,\n",
"\n",
" num_boost_round=n,\n",
" evals = [(dtrain, \"train\"), (dvalid, \"valid\")],\n",
" verbose_eval=False,\n",
" early_stopping_rounds=100\n",
" # metrics=[\"mlogloss\", \"auc\", \"merror\"],\n",
")\n",
"preds_class_valid = results.predict(dvalid)\n",
"preds_class_train= results.predict(dtrain)\n",
"print(confusion_matrix(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid.argmax(1)))\n",
"print(matthews_corrcoef(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid.argmax(1)))\n",
"print(accuracy_score(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid.argmax(1)))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train.argmax(1)))\n",
"print(matthews_corrcoef(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train.argmax(1)))\n",
"print(accuracy_score(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train.argmax(1)))"
]
},
{
"cell_type": "code",
"execution_count": 207,
"id": "04747338-22aa-498e-aea0-449b196d8f9f",
"metadata": {},
"outputs": [],
"source": [
"cvresult = xgb.cv(params, dtrain, 1000, nfold=5,\n",
" metrics='mlogloss', early_stopping_rounds=100)"
]
},
{
"cell_type": "code",
"execution_count": 208,
"id": "0040e5e5-9e45-4c03-b3e4-bd88f10f6838",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>train-mlogloss-mean</th>\n",
" <th>train-mlogloss-std</th>\n",
" <th>test-mlogloss-mean</th>\n",
" <th>test-mlogloss-std</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>1.538693</td>\n",
" <td>0.000438</td>\n",
" <td>1.544775</td>\n",
" <td>0.002529</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>1.475059</td>\n",
" <td>0.000783</td>\n",
" <td>1.486535</td>\n",
" <td>0.004961</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>1.417342</td>\n",
" <td>0.001149</td>\n",
" <td>1.433997</td>\n",
" <td>0.007307</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>1.364478</td>\n",
" <td>0.001440</td>\n",
" <td>1.386108</td>\n",
" <td>0.009380</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>1.315964</td>\n",
" <td>0.001718</td>\n",
" <td>1.342237</td>\n",
" <td>0.011211</td>\n",
" </tr>\n",
" <tr>\n",
" <th>...</th>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" <td>...</td>\n",
" </tr>\n",
" <tr>\n",
" <th>95</th>\n",
" <td>0.438063</td>\n",
" <td>0.004375</td>\n",
" <td>0.685584</td>\n",
" <td>0.082405</td>\n",
" </tr>\n",
" <tr>\n",
" <th>96</th>\n",
" <td>0.436326</td>\n",
" <td>0.004270</td>\n",
" <td>0.685423</td>\n",
" <td>0.082891</td>\n",
" </tr>\n",
" <tr>\n",
" <th>97</th>\n",
" <td>0.434673</td>\n",
" <td>0.004317</td>\n",
" <td>0.685321</td>\n",
" <td>0.083182</td>\n",
" </tr>\n",
" <tr>\n",
" <th>98</th>\n",
" <td>0.433180</td>\n",
" <td>0.004361</td>\n",
" <td>0.685375</td>\n",
" <td>0.083748</td>\n",
" </tr>\n",
" <tr>\n",
" <th>99</th>\n",
" <td>0.431569</td>\n",
" <td>0.004265</td>\n",
" <td>0.685245</td>\n",
" <td>0.083837</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"<p>100 rows × 4 columns</p>\n",
"</div>"
],
"text/plain": [
" train-mlogloss-mean train-mlogloss-std test-mlogloss-mean \\\n",
"0 1.538693 0.000438 1.544775 \n",
"1 1.475059 0.000783 1.486535 \n",
"2 1.417342 0.001149 1.433997 \n",
"3 1.364478 0.001440 1.386108 \n",
"4 1.315964 0.001718 1.342237 \n",
".. ... ... ... \n",
"95 0.438063 0.004375 0.685584 \n",
"96 0.436326 0.004270 0.685423 \n",
"97 0.434673 0.004317 0.685321 \n",
"98 0.433180 0.004361 0.685375 \n",
"99 0.431569 0.004265 0.685245 \n",
"\n",
" test-mlogloss-std \n",
"0 0.002529 \n",
"1 0.004961 \n",
"2 0.007307 \n",
"3 0.009380 \n",
"4 0.011211 \n",
".. ... \n",
"95 0.082405 \n",
"96 0.082891 \n",
"97 0.083182 \n",
"98 0.083748 \n",
"99 0.083837 \n",
"\n",
"[100 rows x 4 columns]"
]
},
"execution_count": 208,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"cvresult"
]
},
{
"cell_type": "code",
"execution_count": 115,
"id": "1ec2f8d4-e56a-4978-b3e1-696ab3bec1df",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(4010, 24, 5)"
]
},
"execution_count": 115,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"import shap\n",
"\n",
"pred = results.predict(dvalid, output_margin=True,)\n",
"\n",
"explainer = shap.TreeExplainer(results,feature_names=list(X_train.columns.values))\n",
"explanation = explainer(dvalid)\n",
"\n",
"shap_values = explanation.values\n",
"shap_values.shape"
]
},
{
"cell_type": "code",
"execution_count": 119,
"id": "e31317a7-3257-405f-9762-bf8d77699176",
"metadata": {},
"outputs": [
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 800x550 with 2 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"shap.plots.beeswarm(explanation[:,:,2])\n"
]
},
{
"cell_type": "code",
"execution_count": 132,
"id": "8ccca6dc-5bca-45c4-8ea1-b21a7ae8b433",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"array(['diagnosis', 'destination', 'age', 'helicopter',\n",
" 'injury_general_location', 'difficulty', 'cause', 'country',\n",
" 'injury_side', 'ski_lift'], dtype=object)"
]
},
"execution_count": 132,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"vals= np.abs(shap_values).mean(0).mean(1)\n",
"feature_importance = pd.DataFrame(list(zip(X_train.columns,vals)),columns=['col_name','feature_importance_vals'])\n",
"feature_importance.sort_values(by=['feature_importance_vals'],ascending=False,inplace=True)\n",
"feature_importance.col_name[0:10].values"
]
},
{
"cell_type": "code",
"execution_count": 135,
"id": "d3fa6df7-1056-401f-bea8-aa724ff46613",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 118 67 2 0 0]\n",
" [ 30 2794 217 2 0]\n",
" [ 0 435 300 9 0]\n",
" [ 0 5 24 5 0]\n",
" [ 0 2 0 0 0]]\n",
"0.43952307639594157\n",
"0.8022443890274314\n",
"########################################\n",
"[[ 371 8 0 0 0]\n",
" [ 1 6091 85 0 0]\n",
" [ 0 353 1156 1 0]\n",
" [ 0 0 0 68 0]\n",
" [ 0 0 0 0 5]]\n",
"0.8543147994467417\n",
"0.9449563828480158\n"
]
}
],
"source": [
"##try with xgboost\n",
"import xgboost as xgb\n",
"X_train_small = X_train[feature_importance.col_name[0:10].values]\n",
"X_valid_small = X_valid[feature_importance.col_name[0:10].values]\n",
"\n",
"# Create regression matrices\n",
"dtrain = xgb.DMatrix(X_train_small,y_train.apply(lambda x:x.replace('i','')).astype(int), enable_categorical=True,feature_names=list(X_train_small.columns.values))#,weight=weight_train)\n",
"dvalid = xgb.DMatrix(X_valid_small,y_valid.apply(lambda x:x.replace('i','')).astype(int), enable_categorical=True,feature_names=list(X_train_small.columns.values))\n",
"\n",
"params = {\"objective\": \"multi:softprob\", \"num_class\": 5}\n",
"n = 1000\n",
"\n",
"results = xgb.train(\n",
" params, dtrain,\n",
"\n",
" num_boost_round=n,\n",
" evals = [(dtrain, \"train\"), (dvalid, \"valid\")],\n",
" verbose_eval=False,\n",
" early_stopping_rounds=100\n",
" # metrics=[\"mlogloss\", \"auc\", \"merror\"],\n",
")\n",
"preds_class_valid = results.predict(dvalid)\n",
"preds_class_train= results.predict(dtrain)\n",
"print(confusion_matrix(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid.argmax(1)))\n",
"print(matthews_corrcoef(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid.argmax(1)))\n",
"print(accuracy_score(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid.argmax(1)))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train.argmax(1)))\n",
"print(matthews_corrcoef(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train.argmax(1)))\n",
"print(accuracy_score(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train.argmax(1)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "5c79de7d-8336-4d07-9571-79be5f92c381",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 54,
"id": "5120687b-02ed-42ee-80ca-543fe6aa540d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/agobbi/miniconda3/envs/pid/lib/python3.11/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n",
" warnings.warn(\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-1 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-1 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-1 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-1 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-1 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-1 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-1 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-1 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-1 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-1 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-1 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-1 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-1 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-1 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-1 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-1 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-1 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-1 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-1\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=True, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=None, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=1000, n_jobs=None, num_class=5,\n",
" num_parallel_tree=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-1\" type=\"checkbox\" checked><label for=\"sk-estimator-id-1\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;XGBClassifier<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=True, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=None, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=1000, n_jobs=None, num_class=5,\n",
" num_parallel_tree=None, ...)</pre></div> </div></div></div></div>"
],
"text/plain": [
"XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=True, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=None, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=1000, n_jobs=None, num_class=5,\n",
" num_parallel_tree=None, ...)"
]
},
"execution_count": 54,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = xgb.XGBClassifier(n_estimators=1000,objective=\"multi:softprob\",num_class=5,enable_categorical=True)\n",
"model.fit(X_train, y_train.apply(lambda x:x.replace('i','')).astype(int), early_stopping_rounds=100,eval_set = [(X_valid, y_valid.apply(lambda x:x.replace('i','')).astype(int))],verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 55,
"id": "9cdb16a1-0761-4e3d-9214-6fef6ced875a",
"metadata": {},
"outputs": [],
"source": [
"preds_class_valid = model.predict(X_valid)\n",
"preds_class_train= model.predict(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 60,
"id": "7c980af9-65ce-4472-806b-2e45295eb86c",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"4"
]
},
"execution_count": 60,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"preds_class_valid.max()"
]
},
{
"cell_type": "code",
"execution_count": 61,
"id": "65228d59-e4c2-411d-8535-b36c15a55bfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 125 61 1 0 0]\n",
" [ 15 2885 143 0 0]\n",
" [ 1 440 296 7 0]\n",
" [ 1 5 24 4 0]\n",
" [ 0 1 0 0 1]]\n",
"0.49681030860342146\n",
"0.8256857855361596\n",
"########################################\n",
"[[ 294 84 1 0 0]\n",
" [ 18 5995 164 0 0]\n",
" [ 1 739 770 0 0]\n",
" [ 0 9 18 41 0]\n",
" [ 1 0 1 0 3]]\n",
"0.6440004170345552\n",
"0.8727116353360364\n"
]
}
],
"source": [
"print(confusion_matrix(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid))\n",
"print(matthews_corrcoef(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid))\n",
"print(accuracy_score(y_valid.apply(lambda x:x.replace('i','')).astype(int),preds_class_valid))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train))\n",
"print(matthews_corrcoef(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train))\n",
"print(accuracy_score(y_train.apply(lambda x:x.replace('i','')).astype(int),preds_class_train))"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}