Files
pid/notebooks/old_notebooks/test_binary.ipynb

1191 lines
79 KiB
Plaintext

{
"cells": [
{
"cell_type": "code",
"execution_count": 4,
"id": "7c5d059b-ed8a-4e2e-9420-25890f648895",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_42878/2472232159.py:7: UserWarning: pandas only supports SQLAlchemy connectable (engine/connection) or database string URI or sqlite3 DBAPI2 connection. Other DBAPI2 objects are not tested. Please consider using SQLAlchemy.\n",
" df = pd.read_sql('select * from data_safeidx', con=engine)\n"
]
}
],
"source": [
"import pandas as pd\n",
"import psycopg2 as pg\n",
"import matplotlib.pyplot as plt\n",
"import numpy as np\n",
"from sklearn.model_selection import train_test_split\n",
"engine = pg.connect(\"dbname='safeidx' user='fbk_mpba' host='172.104.247.67' port='5432' password='fbk2024$'\")\n",
"df = pd.read_sql('select * from data_safeidx', con=engine)"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "03aa2a04-93fa-469e-a678-685cacdebd6c",
"metadata": {},
"outputs": [
{
"data": {
"text/html": [
"<div>\n",
"<style scoped>\n",
" .dataframe tbody tr th:only-of-type {\n",
" vertical-align: middle;\n",
" }\n",
"\n",
" .dataframe tbody tr th {\n",
" vertical-align: top;\n",
" }\n",
"\n",
" .dataframe thead th {\n",
" text-align: right;\n",
" }\n",
"</style>\n",
"<table border=\"1\" class=\"dataframe\">\n",
" <thead>\n",
" <tr style=\"text-align: right;\">\n",
" <th></th>\n",
" <th>difficulty</th>\n",
" <th>cause</th>\n",
" <th>town</th>\n",
" <th>province</th>\n",
" <th>gender</th>\n",
" <th>equipment</th>\n",
" <th>helmet</th>\n",
" <th>destination</th>\n",
" <th>diagnosis</th>\n",
" <th>india</th>\n",
" <th>age</th>\n",
" <th>country</th>\n",
" <th>injury_side</th>\n",
" <th>injury_general_location</th>\n",
" <th>evacuation_vehicles</th>\n",
" </tr>\n",
" </thead>\n",
" <tbody>\n",
" <tr>\n",
" <th>0</th>\n",
" <td>novice</td>\n",
" <td>fall_alone</td>\n",
" <td>SIKLOS</td>\n",
" <td></td>\n",
" <td>F</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>hospital_emergency_room</td>\n",
" <td>distortion</td>\n",
" <td>None</td>\n",
" <td>32.0</td>\n",
" <td>Ungheria</td>\n",
" <td>L</td>\n",
" <td>lower_limbs</td>\n",
" <td>[akja]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>1</th>\n",
" <td>advanced</td>\n",
" <td>fall_alone</td>\n",
" <td>MALMO</td>\n",
" <td></td>\n",
" <td>M</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>hospital_emergency_room</td>\n",
" <td>bruise</td>\n",
" <td>None</td>\n",
" <td>32.0</td>\n",
" <td>Svezia</td>\n",
" <td>R</td>\n",
" <td>skull_or_face</td>\n",
" <td>[akja]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>2</th>\n",
" <td>advanced</td>\n",
" <td>fall_alone</td>\n",
" <td>CALDARO</td>\n",
" <td>BZ</td>\n",
" <td>F</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>domicile</td>\n",
" <td>other</td>\n",
" <td>None</td>\n",
" <td>12.0</td>\n",
" <td>Italia</td>\n",
" <td>R</td>\n",
" <td>None</td>\n",
" <td>[snowmobile]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>3</th>\n",
" <td>advanced</td>\n",
" <td>collision_person</td>\n",
" <td>LINZ</td>\n",
" <td></td>\n",
" <td>M</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>hospital_emergency_room</td>\n",
" <td>bruise</td>\n",
" <td>None</td>\n",
" <td>58.0</td>\n",
" <td>Austria</td>\n",
" <td>R</td>\n",
" <td>lower_limbs</td>\n",
" <td>[snowmobile]</td>\n",
" </tr>\n",
" <tr>\n",
" <th>4</th>\n",
" <td>advanced</td>\n",
" <td>collision_person</td>\n",
" <td>RUSAVA</td>\n",
" <td></td>\n",
" <td>M</td>\n",
" <td>ski</td>\n",
" <td>None</td>\n",
" <td>other</td>\n",
" <td>bruise</td>\n",
" <td>None</td>\n",
" <td>25.0</td>\n",
" <td>Repubblica Ceca</td>\n",
" <td>L</td>\n",
" <td>lower_limbs</td>\n",
" <td>[other]</td>\n",
" </tr>\n",
" </tbody>\n",
"</table>\n",
"</div>"
],
"text/plain": [
" difficulty cause town province gender equipment helmet \\\n",
"0 novice fall_alone SIKLOS F ski None \n",
"1 advanced fall_alone MALMO M ski None \n",
"2 advanced fall_alone CALDARO BZ F ski None \n",
"3 advanced collision_person LINZ M ski None \n",
"4 advanced collision_person RUSAVA M ski None \n",
"\n",
" destination diagnosis india age country \\\n",
"0 hospital_emergency_room distortion None 32.0 Ungheria \n",
"1 hospital_emergency_room bruise None 32.0 Svezia \n",
"2 domicile other None 12.0 Italia \n",
"3 hospital_emergency_room bruise None 58.0 Austria \n",
"4 other bruise None 25.0 Repubblica Ceca \n",
"\n",
" injury_side injury_general_location evacuation_vehicles \n",
"0 L lower_limbs [akja] \n",
"1 R skull_or_face [akja] \n",
"2 R None [snowmobile] \n",
"3 R lower_limbs [snowmobile] \n",
"4 L lower_limbs [other] "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"df.head()"
]
},
{
"cell_type": "code",
"execution_count": 99,
"id": "babc2e8b-1030-4e8a-aa41-6d2a788959a5",
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'Series' object has no attribute 'evacuation_vehicles'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m/tmp/ipykernel_42878/3368331793.py\u001b[0m in \u001b[0;36m?\u001b[0;34m()\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[0mev\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m{\u001b[0m\u001b[0;34m}\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mrow\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mdf\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0miterrows\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m----> 3\u001b[0;31m \u001b[0mev\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mev\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0munion\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mset\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mrow\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mevacuation_vehicles\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;32m~/miniconda3/envs/pid/lib/python3.11/site-packages/pandas/core/generic.py\u001b[0m in \u001b[0;36m?\u001b[0;34m(self, name)\u001b[0m\n\u001b[1;32m 6289\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mname\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_accessors\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6290\u001b[0m \u001b[0;32mand\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_info_axis\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_can_hold_identifiers_and_holds_name\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 6291\u001b[0m ):\n\u001b[1;32m 6292\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0mname\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 6293\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mobject\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__getattribute__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mname\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m: 'Series' object has no attribute 'evacuation_vehicles'"
]
}
],
"source": [
"ev = set({})\n",
"for i,row in df.iterrows():\n",
" ev = ev.union(set(row.evacuation_vehicles))"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "c8d6cc1c-f4f5-44ec-8652-b135963452ab",
"metadata": {},
"outputs": [],
"source": [
"for c in ev:\n",
" df[c] = False\n",
"for i,row in df.iterrows():\n",
" for c in row.evacuation_vehicles:\n",
" df.loc[i,c] = True"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "758c3317-1b02-4aed-b94d-7b6998d23797",
"metadata": {},
"outputs": [],
"source": [
"df.drop(columns=['town','province','evacuation_vehicles'],inplace=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "adadc0dc-9d6e-4277-8956-d1d4b2492e7e",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 100,
"id": "33617e77-7c2b-41a3-96c0-8930aa5ac869",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"(array([1.3808e+04, 0.0000e+00, 8.8100e+02, 0.0000e+00, 0.0000e+00,\n",
" 3.3690e+03, 0.0000e+00, 1.5200e+02, 0.0000e+00, 1.1000e+01]),\n",
" array([0. , 0.4, 0.8, 1.2, 1.6, 2. , 2.4, 2.8, 3.2, 3.6, 4. ]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 100,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"labeled = df[~pd.isna(df.india)]\n",
"plt.hist(labeled.india)"
]
},
{
"cell_type": "code",
"execution_count": 101,
"id": "2bda819b-4d4b-4e71-960a-53dd74d80b71",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_42878/2289208715.py:1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled['is_i1']=0\n"
]
},
{
"data": {
"text/plain": [
"(array([ 3532., 0., 0., 0., 0., 0., 0., 0.,\n",
" 0., 14689.]),\n",
" array([0. , 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1. ]),\n",
" <BarContainer object of 10 artists>)"
]
},
"execution_count": 101,
"metadata": {},
"output_type": "execute_result"
},
{
"data": {
"image/png": "",
"text/plain": [
"<Figure size 640x480 with 1 Axes>"
]
},
"metadata": {},
"output_type": "display_data"
}
],
"source": [
"labeled['is_i1']=0\n",
"labeled.loc[labeled.india.isin(['i1','i0']),'is_i1']=1\n",
"plt.hist(labeled.is_i1)"
]
},
{
"cell_type": "code",
"execution_count": 103,
"id": "e8c139d9-bf61-45ec-9da1-7eaf4ff754b4",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/tmp/ipykernel_42878/1116779111.py:1: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].astype(np.float32).fillna(np.nan)\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n",
"/tmp/ipykernel_42878/1116779111.py:4: SettingWithCopyWarning: \n",
"A value is trying to be set on a copy of a slice from a DataFrame.\n",
"Try using .loc[row_indexer,col_indexer] = value instead\n",
"\n",
"See the caveats in the documentation: https://pandas.pydata.org/pandas-docs/stable/user_guide/indexing.html#returning-a-view-versus-a-copy\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n"
]
}
],
"source": [
"labeled[c] = labeled[c].astype(np.float32).fillna(np.nan)\n",
"for c in labeled.columns:\n",
" if c!='age':\n",
" labeled[c] = labeled[c].fillna('None').astype('category')\n"
]
},
{
"cell_type": "code",
"execution_count": 104,
"id": "fabf354e-f39e-4cde-af84-c65a277d309a",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_test, y_train, y_test = train_test_split( labeled.drop(columns=['india','is_i1']),\n",
" labeled.is_i1, test_size=0.33, random_state=0,stratify=labeled.is_i1)"
]
},
{
"cell_type": "code",
"execution_count": 105,
"id": "cf6cf5d8-d43e-499e-98a5-65ecd0b8ccda",
"metadata": {},
"outputs": [],
"source": [
"X_train, X_valid, y_train, y_valid = train_test_split(X_train,y_train, test_size=0.33, random_state=0,stratify=y_train)"
]
},
{
"cell_type": "code",
"execution_count": 147,
"id": "774ad570-a60b-475b-80cf-4f9b9949cc9d",
"metadata": {},
"outputs": [],
"source": [
"weight_train = (1-y_train.values.astype(int))*5+1 ## peso classi unbalanced"
]
},
{
"cell_type": "code",
"execution_count": 149,
"id": "1fcc5234-abad-459a-9420-810833657796",
"metadata": {},
"outputs": [],
"source": [
"from catboost import CatBoostClassifier, Pool\n",
"\n",
"train_data = Pool(data=X_train,\n",
" label=y_train,\n",
" weight=weight_train,cat_features=[c for c in X_train.columns if c!='age'])\n",
"valid_data = Pool(data=X_valid,\n",
" label=y_valid,cat_features=[c for c in X_train.columns if c!='age']\n",
" )\n",
"model = CatBoostClassifier(iterations=1000)\n",
"\n",
"model.fit(train_data,eval_set=valid_data,verbose=False,early_stopping_rounds=100)\n",
"preds_class = model.predict(valid_data,)"
]
},
{
"cell_type": "code",
"execution_count": 150,
"id": "8df84007-110b-4f31-bc7d-53e7d5c9a178",
"metadata": {},
"outputs": [],
"source": [
"preds_class_valid = model.predict(valid_data)\n",
"preds_class_train= model.predict(train_data)"
]
},
{
"cell_type": "code",
"execution_count": 151,
"id": "55627ec4-fd24-4815-98d4-d8462bbfdd9a",
"metadata": {},
"outputs": [],
"source": [
"from sklearn.metrics import confusion_matrix,matthews_corrcoef,accuracy_score"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "877e914d-44d4-4299-8d3c-24c4fc353317",
"metadata": {},
"outputs": [],
"source": []
},
{
"cell_type": "code",
"execution_count": 152,
"id": "776395d9-a8e2-4fda-90e1-11d6dbe80de8",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 622 159]\n",
" [ 778 2470]]\n",
"0.4623150659581546\n",
"0.7674360883593944\n",
"########################################\n",
"[[1436 149]\n",
" [1358 5236]]\n",
"0.5834604503367398\n",
"0.8157476464115417\n"
]
}
],
"source": [
"print(confusion_matrix(y_valid,preds_class_valid))\n",
"print(matthews_corrcoef(y_valid,preds_class_valid))\n",
"print(accuracy_score(y_valid,preds_class_valid))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train,preds_class_train))\n",
"print(matthews_corrcoef(y_train,preds_class_train))\n",
"print(accuracy_score(y_train,preds_class_train))"
]
},
{
"cell_type": "code",
"execution_count": 157,
"id": "34c397bc-529a-4c52-b30e-957b28021200",
"metadata": {},
"outputs": [
{
"ename": "TypeError",
"evalue": "DMatrix.__init__() got an unexpected keyword argument 'scale_pos_weight'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mTypeError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[157], line 5\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mxgboost\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mxgb\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;66;03m# Create regression matrices\u001b[39;00m\n\u001b[0;32m----> 5\u001b[0m dtrain \u001b[38;5;241m=\u001b[39m \u001b[43mxgb\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mDMatrix\u001b[49m\u001b[43m(\u001b[49m\u001b[43mX_train\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43my_train\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mastype\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;28;43mint\u001b[39;49m\u001b[43m)\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43menable_categorical\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;28;43;01mTrue\u001b[39;49;00m\u001b[43m,\u001b[49m\u001b[43mscale_pos_weight\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[38;5;241;43m100\u001b[39;49m\u001b[43m)\u001b[49m\u001b[38;5;66;03m#,weight=weight_train)\u001b[39;00m\n\u001b[1;32m 6\u001b[0m dvalid \u001b[38;5;241m=\u001b[39m xgb\u001b[38;5;241m.\u001b[39mDMatrix(X_valid, y_valid\u001b[38;5;241m.\u001b[39mastype(\u001b[38;5;28mint\u001b[39m), enable_categorical\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
"File \u001b[0;32m~/miniconda3/envs/pid/lib/python3.11/site-packages/xgboost/core.py:730\u001b[0m, in \u001b[0;36mrequire_keyword_args.<locals>.throw_if.<locals>.inner_f\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 728\u001b[0m \u001b[38;5;28;01mfor\u001b[39;00m k, arg \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mzip\u001b[39m(sig\u001b[38;5;241m.\u001b[39mparameters, args):\n\u001b[1;32m 729\u001b[0m kwargs[k] \u001b[38;5;241m=\u001b[39m arg\n\u001b[0;32m--> 730\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mfunc\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n",
"\u001b[0;31mTypeError\u001b[0m: DMatrix.__init__() got an unexpected keyword argument 'scale_pos_weight'"
]
}
],
"source": [
"##try with xgboost\n",
"import xgboost as xgb\n",
"\n",
"# Create regression matrices\n",
"dtrain = xgb.DMatrix(X_train, y_train.astype(int), enable_categorical=True)#,weight=weight_train)\n",
"dvalid = xgb.DMatrix(X_valid, y_valid.astype(int), enable_categorical=True)"
]
},
{
"cell_type": "code",
"execution_count": 161,
"id": "7bb54c1d-ed2c-496a-b149-d567246bfee0",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/agobbi/miniconda3/envs/pid/lib/python3.11/site-packages/xgboost/core.py:160: UserWarning: [11:26:44] WARNING: /workspace/src/learner.cc:742: \n",
"Parameters: { \"scale_pos_weight\" } are not used.\n",
"\n",
" warnings.warn(smsg, UserWarning)\n"
]
}
],
"source": [
"params = {\"objective\": \"multi:softprob\", \"num_class\": 2,'scale_pos_weight':100}\n",
"n = 1000\n",
"\n",
"results = xgb.train(\n",
" params, dtrain,\n",
"\n",
" num_boost_round=n,\n",
" evals = [(dtrain, \"train\"), (dvalid, \"valid\")],\n",
" verbose_eval=False,\n",
" early_stopping_rounds=100\n",
" # metrics=[\"mlogloss\", \"auc\", \"merror\"],\n",
")\n"
]
},
{
"cell_type": "code",
"execution_count": 182,
"id": "353b9430-61ff-4d15-869e-8a5a06a1fb51",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 492 289]\n",
" [ 471 2777]]\n",
"0.4495009568013873\n",
"0.811367585008687\n",
"########################################\n",
"[[1584 1]\n",
" [ 389 6205]]\n",
"0.8688028589259579\n",
"0.9523169091575987\n"
]
}
],
"source": [
"preds_class_valid = results.predict(dvalid)\n",
"preds_class_train= results.predict(dtrain)\n",
"print(confusion_matrix(y_valid,preds_class_valid.argmax(1)))\n",
"print(matthews_corrcoef(y_valid,preds_class_valid.argmax(1)))\n",
"print(accuracy_score(y_valid,preds_class_valid.argmax(1)))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train,preds_class_train.argmax(1)))\n",
"print(matthews_corrcoef(y_train,preds_class_train.argmax(1)))\n",
"print(accuracy_score(y_train,preds_class_train.argmax(1)))"
]
},
{
"cell_type": "code",
"execution_count": 183,
"id": "5120687b-02ed-42ee-80ca-543fe6aa540d",
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"/home/agobbi/miniconda3/envs/pid/lib/python3.11/site-packages/xgboost/sklearn.py:889: UserWarning: `early_stopping_rounds` in `fit` method is deprecated for better compatibility with scikit-learn, use `early_stopping_rounds` in constructor or`set_params` instead.\n",
" warnings.warn(\n",
"/home/agobbi/miniconda3/envs/pid/lib/python3.11/site-packages/xgboost/core.py:160: UserWarning: [11:33:33] WARNING: /workspace/src/learner.cc:742: \n",
"Parameters: { \"scale_pos_weight\" } are not used.\n",
"\n",
" warnings.warn(smsg, UserWarning)\n"
]
},
{
"data": {
"text/html": [
"<style>#sk-container-id-4 {\n",
" /* Definition of color scheme common for light and dark mode */\n",
" --sklearn-color-text: black;\n",
" --sklearn-color-line: gray;\n",
" /* Definition of color scheme for unfitted estimators */\n",
" --sklearn-color-unfitted-level-0: #fff5e6;\n",
" --sklearn-color-unfitted-level-1: #f6e4d2;\n",
" --sklearn-color-unfitted-level-2: #ffe0b3;\n",
" --sklearn-color-unfitted-level-3: chocolate;\n",
" /* Definition of color scheme for fitted estimators */\n",
" --sklearn-color-fitted-level-0: #f0f8ff;\n",
" --sklearn-color-fitted-level-1: #d4ebff;\n",
" --sklearn-color-fitted-level-2: #b3dbfd;\n",
" --sklearn-color-fitted-level-3: cornflowerblue;\n",
"\n",
" /* Specific color for light theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, white)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, black)));\n",
" --sklearn-color-icon: #696969;\n",
"\n",
" @media (prefers-color-scheme: dark) {\n",
" /* Redefinition of color scheme for dark theme */\n",
" --sklearn-color-text-on-default-background: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-background: var(--sg-background-color, var(--theme-background, var(--jp-layout-color0, #111)));\n",
" --sklearn-color-border-box: var(--sg-text-color, var(--theme-code-foreground, var(--jp-content-font-color1, white)));\n",
" --sklearn-color-icon: #878787;\n",
" }\n",
"}\n",
"\n",
"#sk-container-id-4 {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"#sk-container-id-4 pre {\n",
" padding: 0;\n",
"}\n",
"\n",
"#sk-container-id-4 input.sk-hidden--visually {\n",
" border: 0;\n",
" clip: rect(1px 1px 1px 1px);\n",
" clip: rect(1px, 1px, 1px, 1px);\n",
" height: 1px;\n",
" margin: -1px;\n",
" overflow: hidden;\n",
" padding: 0;\n",
" position: absolute;\n",
" width: 1px;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-dashed-wrapped {\n",
" border: 1px dashed var(--sklearn-color-line);\n",
" margin: 0 0.4em 0.5em 0.4em;\n",
" box-sizing: border-box;\n",
" padding-bottom: 0.4em;\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-container {\n",
" /* jupyter's `normalize.less` sets `[hidden] { display: none; }`\n",
" but bootstrap.min.css set `[hidden] { display: none !important; }`\n",
" so we also need the `!important` here to be able to override the\n",
" default hidden behavior on the sphinx rendered scikit-learn.org.\n",
" See: https://github.com/scikit-learn/scikit-learn/issues/21755 */\n",
" display: inline-block !important;\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-text-repr-fallback {\n",
" display: none;\n",
"}\n",
"\n",
"div.sk-parallel-item,\n",
"div.sk-serial,\n",
"div.sk-item {\n",
" /* draw centered vertical line to link estimators */\n",
" background-image: linear-gradient(var(--sklearn-color-text-on-default-background), var(--sklearn-color-text-on-default-background));\n",
" background-size: 2px 100%;\n",
" background-repeat: no-repeat;\n",
" background-position: center center;\n",
"}\n",
"\n",
"/* Parallel-specific style estimator block */\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item::after {\n",
" content: \"\";\n",
" width: 100%;\n",
" border-bottom: 2px solid var(--sklearn-color-text-on-default-background);\n",
" flex-grow: 1;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel {\n",
" display: flex;\n",
" align-items: stretch;\n",
" justify-content: center;\n",
" background-color: var(--sklearn-color-background);\n",
" position: relative;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item {\n",
" display: flex;\n",
" flex-direction: column;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item:first-child::after {\n",
" align-self: flex-end;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item:last-child::after {\n",
" align-self: flex-start;\n",
" width: 50%;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-parallel-item:only-child::after {\n",
" width: 0;\n",
"}\n",
"\n",
"/* Serial-specific style estimator block */\n",
"\n",
"#sk-container-id-4 div.sk-serial {\n",
" display: flex;\n",
" flex-direction: column;\n",
" align-items: center;\n",
" background-color: var(--sklearn-color-background);\n",
" padding-right: 1em;\n",
" padding-left: 1em;\n",
"}\n",
"\n",
"\n",
"/* Toggleable style: style used for estimator/Pipeline/ColumnTransformer box that is\n",
"clickable and can be expanded/collapsed.\n",
"- Pipeline and ColumnTransformer use this feature and define the default style\n",
"- Estimators will overwrite some part of the style using the `sk-estimator` class\n",
"*/\n",
"\n",
"/* Pipeline and ColumnTransformer style (default) */\n",
"\n",
"#sk-container-id-4 div.sk-toggleable {\n",
" /* Default theme specific background. It is overwritten whether we have a\n",
" specific estimator or a Pipeline/ColumnTransformer */\n",
" background-color: var(--sklearn-color-background);\n",
"}\n",
"\n",
"/* Toggleable label */\n",
"#sk-container-id-4 label.sk-toggleable__label {\n",
" cursor: pointer;\n",
" display: block;\n",
" width: 100%;\n",
" margin-bottom: 0;\n",
" padding: 0.5em;\n",
" box-sizing: border-box;\n",
" text-align: center;\n",
"}\n",
"\n",
"#sk-container-id-4 label.sk-toggleable__label-arrow:before {\n",
" /* Arrow on the left of the label */\n",
" content: \"▸\";\n",
" float: left;\n",
" margin-right: 0.25em;\n",
" color: var(--sklearn-color-icon);\n",
"}\n",
"\n",
"#sk-container-id-4 label.sk-toggleable__label-arrow:hover:before {\n",
" color: var(--sklearn-color-text);\n",
"}\n",
"\n",
"/* Toggleable content - dropdown */\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content {\n",
" max-height: 0;\n",
" max-width: 0;\n",
" overflow: hidden;\n",
" text-align: left;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content pre {\n",
" margin: 0.2em;\n",
" border-radius: 0.25em;\n",
" color: var(--sklearn-color-text);\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-toggleable__content.fitted pre {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 input.sk-toggleable__control:checked~div.sk-toggleable__content {\n",
" /* Expand drop-down */\n",
" max-height: 200px;\n",
" max-width: 100%;\n",
" overflow: auto;\n",
"}\n",
"\n",
"#sk-container-id-4 input.sk-toggleable__control:checked~label.sk-toggleable__label-arrow:before {\n",
" content: \"▾\";\n",
"}\n",
"\n",
"/* Pipeline/ColumnTransformer-specific style */\n",
"\n",
"#sk-container-id-4 div.sk-label input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-label.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator-specific style */\n",
"\n",
"/* Colorize estimator box */\n",
"#sk-container-id-4 div.sk-estimator input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-estimator.fitted input.sk-toggleable__control:checked~label.sk-toggleable__label {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-label label.sk-toggleable__label,\n",
"#sk-container-id-4 div.sk-label label {\n",
" /* The background is the default theme color */\n",
" color: var(--sklearn-color-text-on-default-background);\n",
"}\n",
"\n",
"/* On hover, darken the color of the background */\n",
"#sk-container-id-4 div.sk-label:hover label.sk-toggleable__label {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"/* Label box, darken color on hover, fitted */\n",
"#sk-container-id-4 div.sk-label.fitted:hover label.sk-toggleable__label.fitted {\n",
" color: var(--sklearn-color-text);\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Estimator label */\n",
"\n",
"#sk-container-id-4 div.sk-label label {\n",
" font-family: monospace;\n",
" font-weight: bold;\n",
" display: inline-block;\n",
" line-height: 1.2em;\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-label-container {\n",
" text-align: center;\n",
"}\n",
"\n",
"/* Estimator-specific */\n",
"#sk-container-id-4 div.sk-estimator {\n",
" font-family: monospace;\n",
" border: 1px dotted var(--sklearn-color-border-box);\n",
" border-radius: 0.25em;\n",
" box-sizing: border-box;\n",
" margin-bottom: 0.5em;\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-0);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-estimator.fitted {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-0);\n",
"}\n",
"\n",
"/* on hover */\n",
"#sk-container-id-4 div.sk-estimator:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-2);\n",
"}\n",
"\n",
"#sk-container-id-4 div.sk-estimator.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-2);\n",
"}\n",
"\n",
"/* Specification for estimator info (e.g. \"i\" and \"?\") */\n",
"\n",
"/* Common style for \"i\" and \"?\" */\n",
"\n",
".sk-estimator-doc-link,\n",
"a:link.sk-estimator-doc-link,\n",
"a:visited.sk-estimator-doc-link {\n",
" float: right;\n",
" font-size: smaller;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1em;\n",
" height: 1em;\n",
" width: 1em;\n",
" text-decoration: none !important;\n",
" margin-left: 1ex;\n",
" /* unfitted */\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted,\n",
"a:link.sk-estimator-doc-link.fitted,\n",
"a:visited.sk-estimator-doc-link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"div.sk-estimator:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link:hover,\n",
".sk-estimator-doc-link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"div.sk-estimator.fitted:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover,\n",
"div.sk-label-container:hover .sk-estimator-doc-link.fitted:hover,\n",
".sk-estimator-doc-link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"/* Span, style for the box shown on hovering the info icon */\n",
".sk-estimator-doc-link span {\n",
" display: none;\n",
" z-index: 9999;\n",
" position: relative;\n",
" font-weight: normal;\n",
" right: .2ex;\n",
" padding: .5ex;\n",
" margin: .5ex;\n",
" width: min-content;\n",
" min-width: 20ex;\n",
" max-width: 50ex;\n",
" color: var(--sklearn-color-text);\n",
" box-shadow: 2pt 2pt 4pt #999;\n",
" /* unfitted */\n",
" background: var(--sklearn-color-unfitted-level-0);\n",
" border: .5pt solid var(--sklearn-color-unfitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link.fitted span {\n",
" /* fitted */\n",
" background: var(--sklearn-color-fitted-level-0);\n",
" border: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"\n",
".sk-estimator-doc-link:hover span {\n",
" display: block;\n",
"}\n",
"\n",
"/* \"?\"-specific style due to the `<a>` HTML tag */\n",
"\n",
"#sk-container-id-4 a.estimator_doc_link {\n",
" float: right;\n",
" font-size: 1rem;\n",
" line-height: 1em;\n",
" font-family: monospace;\n",
" background-color: var(--sklearn-color-background);\n",
" border-radius: 1rem;\n",
" height: 1rem;\n",
" width: 1rem;\n",
" text-decoration: none;\n",
" /* unfitted */\n",
" color: var(--sklearn-color-unfitted-level-1);\n",
" border: var(--sklearn-color-unfitted-level-1) 1pt solid;\n",
"}\n",
"\n",
"#sk-container-id-4 a.estimator_doc_link.fitted {\n",
" /* fitted */\n",
" border: var(--sklearn-color-fitted-level-1) 1pt solid;\n",
" color: var(--sklearn-color-fitted-level-1);\n",
"}\n",
"\n",
"/* On hover */\n",
"#sk-container-id-4 a.estimator_doc_link:hover {\n",
" /* unfitted */\n",
" background-color: var(--sklearn-color-unfitted-level-3);\n",
" color: var(--sklearn-color-background);\n",
" text-decoration: none;\n",
"}\n",
"\n",
"#sk-container-id-4 a.estimator_doc_link.fitted:hover {\n",
" /* fitted */\n",
" background-color: var(--sklearn-color-fitted-level-3);\n",
"}\n",
"</style><div id=\"sk-container-id-4\" class=\"sk-top-container\"><div class=\"sk-text-repr-fallback\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=True, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=None, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=1000, n_jobs=None, num_class=2,\n",
" num_parallel_tree=None, ...)</pre><b>In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook. <br />On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.</b></div><div class=\"sk-container\" hidden><div class=\"sk-item\"><div class=\"sk-estimator fitted sk-toggleable\"><input class=\"sk-toggleable__control sk-hidden--visually\" id=\"sk-estimator-id-4\" type=\"checkbox\" checked><label for=\"sk-estimator-id-4\" class=\"sk-toggleable__label fitted sk-toggleable__label-arrow fitted\">&nbsp;XGBClassifier<span class=\"sk-estimator-doc-link fitted\">i<span>Fitted</span></span></label><div class=\"sk-toggleable__content fitted\"><pre>XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=True, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=None, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=1000, n_jobs=None, num_class=2,\n",
" num_parallel_tree=None, ...)</pre></div> </div></div></div></div>"
],
"text/plain": [
"XGBClassifier(base_score=None, booster=None, callbacks=None,\n",
" colsample_bylevel=None, colsample_bynode=None,\n",
" colsample_bytree=None, device=None, early_stopping_rounds=None,\n",
" enable_categorical=True, eval_metric=None, feature_types=None,\n",
" gamma=None, grow_policy=None, importance_type=None,\n",
" interaction_constraints=None, learning_rate=None, max_bin=None,\n",
" max_cat_threshold=None, max_cat_to_onehot=None,\n",
" max_delta_step=None, max_depth=None, max_leaves=None,\n",
" min_child_weight=None, missing=nan, monotone_constraints=None,\n",
" multi_strategy=None, n_estimators=1000, n_jobs=None, num_class=2,\n",
" num_parallel_tree=None, ...)"
]
},
"execution_count": 183,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"model = xgb.XGBClassifier(n_estimators=1000,scale_pos_weight=50,objective=\"multi:softprob\",num_class=2,enable_categorical=True)\n",
"model.fit(X_train, y_train.astype(int), early_stopping_rounds=100,eval_set = [(X_valid, y_valid.astype(int))],verbose=False)"
]
},
{
"cell_type": "code",
"execution_count": 184,
"id": "9cdb16a1-0761-4e3d-9214-6fef6ced875a",
"metadata": {},
"outputs": [],
"source": [
"preds_class_valid = model.predict(X_valid)\n",
"preds_class_train= model.predict(X_train)"
]
},
{
"cell_type": "code",
"execution_count": 185,
"id": "65228d59-e4c2-411d-8535-b36c15a55bfd",
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"[[ 308 473]\n",
" [ 151 3097]]\n",
"0.4328302503376973\n",
"0.8451228592702904\n",
"########################################\n",
"[[ 825 760]\n",
" [ 221 6373]]\n",
"0.5763770395263901\n",
"0.8800586868810368\n"
]
}
],
"source": [
"print(confusion_matrix(y_valid,preds_class_valid.argmax(1)))\n",
"print(matthews_corrcoef(y_valid,preds_class_valid.argmax(1)))\n",
"print(accuracy_score(y_valid,preds_class_valid.argmax(1)))\n",
"print('########################################')\n",
"print(confusion_matrix(y_train,preds_class_train.argmax(1)))\n",
"print(matthews_corrcoef(y_train,preds_class_train.argmax(1)))\n",
"print(accuracy_score(y_train,preds_class_train.argmax(1)))"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "95461970-fad9-4c50-84d3-e139bedfec3f",
"metadata": {},
"outputs": [],
"source": [
"0.4495009568013873\n",
"0.811367585008687\n",
"########################################\n",
"[[1584 1]\n",
" [ 389 6205]]\n",
"0.8688028589259579\n",
"0.9523169091575987"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.0"
}
},
"nbformat": 4,
"nbformat_minor": 5
}