{ "cells": [ { "cell_type": "markdown", "id": "iraqi-wound", "metadata": {}, "source": [ "## Dependencies" ] }, { "cell_type": "code", "execution_count": 1, "id": "legal-router", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: imbalanced-learn in /opt/conda/lib/python3.8/site-packages (0.8.0)\n", "Requirement already satisfied: scipy>=0.19.1 in /opt/conda/lib/python3.8/site-packages (from imbalanced-learn) (1.6.0)\n", "Requirement already satisfied: scikit-learn>=0.24 in /opt/conda/lib/python3.8/site-packages (from imbalanced-learn) (0.24.1)\n", "Requirement already satisfied: numpy>=1.13.3 in /opt/conda/lib/python3.8/site-packages (from imbalanced-learn) (1.20.0)\n", "Requirement already satisfied: joblib>=0.11 in /opt/conda/lib/python3.8/site-packages (from imbalanced-learn) (1.0.0)\n", "Requirement already satisfied: threadpoolctl>=2.0.0 in /opt/conda/lib/python3.8/site-packages (from scikit-learn>=0.24->imbalanced-learn) (2.1.0)\n", "Requirement already satisfied: python-slugify in /opt/conda/lib/python3.8/site-packages (4.0.1)\n", "Requirement already satisfied: text-unidecode>=1.3 in /opt/conda/lib/python3.8/site-packages (from python-slugify) (1.3)\n" ] } ], "source": [ "import sys\n", "!{sys.executable} -m pip install -U imbalanced-learn\n", "!{sys.executable} -m pip install -U python-slugify" ] }, { "cell_type": "markdown", "id": "superior-clinton", "metadata": {}, "source": [ "## Dataset overview" ] }, { "cell_type": "code", "execution_count": 2, "id": "assumed-progress", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Bankrupt?ROA(C) before interest and depreciation before interestROA(A) before interest and % after taxROA(B) before interest and depreciation after taxOperating Gross MarginRealized Sales Gross MarginOperating Profit RatePre-tax net Interest RateAfter-tax net Interest RateNon-industry income and expenditure/revenue...Net Income to Total AssetsTotal assets to GNP priceNo-credit IntervalGross Profit to SalesNet Income to Stockholder's EquityLiability to EquityDegree of Financial Leverage (DFL)Interest Coverage Ratio (Interest expense to EBIT)Net Income FlagEquity to Liability
010.3705940.4243890.4057500.6014570.6014570.9989690.7968870.8088090.302646...0.7168450.0092190.6228790.6014530.8278900.2902020.0266010.56405010.016469
110.4642910.5382140.5167300.6102350.6102350.9989460.7973800.8093010.303556...0.7952970.0083230.6236520.6102370.8399690.2838460.2645770.57017510.020794
210.4260710.4990190.4722950.6014500.6013640.9988570.7964030.8083880.302035...0.7746700.0400030.6238410.6014490.8367740.2901890.0265550.56370610.016474
310.3998440.4512650.4577330.5835410.5835410.9987000.7969670.8089660.303350...0.7395550.0032520.6229290.5835380.8346970.2817210.0266970.56466310.023982
410.4650220.5384320.5222980.5987830.5987830.9989730.7973660.8093040.303475...0.7950160.0038780.6235210.5987820.8399730.2785140.0247520.57561710.035490
\n", "

5 rows × 96 columns

\n", "
" ], "text/plain": [ " Bankrupt? ROA(C) before interest and depreciation before interest \\\n", "0 1 0.370594 \n", "1 1 0.464291 \n", "2 1 0.426071 \n", "3 1 0.399844 \n", "4 1 0.465022 \n", "\n", " ROA(A) before interest and % after tax \\\n", "0 0.424389 \n", "1 0.538214 \n", "2 0.499019 \n", "3 0.451265 \n", "4 0.538432 \n", "\n", " ROA(B) before interest and depreciation after tax \\\n", "0 0.405750 \n", "1 0.516730 \n", "2 0.472295 \n", "3 0.457733 \n", "4 0.522298 \n", "\n", " Operating Gross Margin Realized Sales Gross Margin \\\n", "0 0.601457 0.601457 \n", "1 0.610235 0.610235 \n", "2 0.601450 0.601364 \n", "3 0.583541 0.583541 \n", "4 0.598783 0.598783 \n", "\n", " Operating Profit Rate Pre-tax net Interest Rate \\\n", "0 0.998969 0.796887 \n", "1 0.998946 0.797380 \n", "2 0.998857 0.796403 \n", "3 0.998700 0.796967 \n", "4 0.998973 0.797366 \n", "\n", " After-tax net Interest Rate Non-industry income and expenditure/revenue \\\n", "0 0.808809 0.302646 \n", "1 0.809301 0.303556 \n", "2 0.808388 0.302035 \n", "3 0.808966 0.303350 \n", "4 0.809304 0.303475 \n", "\n", " ... Net Income to Total Assets Total assets to GNP price \\\n", "0 ... 0.716845 0.009219 \n", "1 ... 0.795297 0.008323 \n", "2 ... 0.774670 0.040003 \n", "3 ... 0.739555 0.003252 \n", "4 ... 0.795016 0.003878 \n", "\n", " No-credit Interval Gross Profit to Sales \\\n", "0 0.622879 0.601453 \n", "1 0.623652 0.610237 \n", "2 0.623841 0.601449 \n", "3 0.622929 0.583538 \n", "4 0.623521 0.598782 \n", "\n", " Net Income to Stockholder's Equity Liability to Equity \\\n", "0 0.827890 0.290202 \n", "1 0.839969 0.283846 \n", "2 0.836774 0.290189 \n", "3 0.834697 0.281721 \n", "4 0.839973 0.278514 \n", "\n", " Degree of Financial Leverage (DFL) \\\n", "0 0.026601 \n", "1 0.264577 \n", "2 0.026555 \n", "3 0.026697 \n", "4 0.024752 \n", "\n", " Interest Coverage Ratio (Interest expense to EBIT) Net Income Flag \\\n", "0 0.564050 1 \n", "1 0.570175 1 \n", "2 0.563706 1 \n", "3 0.564663 1 \n", "4 0.575617 1 \n", "\n", " Equity to Liability \n", "0 0.016469 \n", "1 0.020794 \n", "2 0.016474 \n", "3 0.023982 \n", "4 0.035490 \n", "\n", "[5 rows x 96 columns]" ] }, "execution_count": 2, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import pandas as pd\n", "\n", "df = pd.read_csv('company_bankrupcy.csv').dropna()\n", "df.head()" ] }, { "cell_type": "markdown", "id": "premier-lincoln", "metadata": {}, "source": [ "## Imbalance Analysis\n", "\n", "\n", "Is always helpful to know how imbalance the dataset is. Are there more positive samples than negatives ? How much ?" ] }, { "cell_type": "code", "execution_count": 3, "id": "cooked-april", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
countpct
OK65990.967737
KO2200.032263
\n", "
" ], "text/plain": [ " count pct\n", "OK 6599 0.967737\n", "KO 220 0.032263" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "count = len(df)\n", "stats = df[['Bankrupt?', ' After-tax net Interest Rate']]\\\n", " .groupby(['Bankrupt?'])\\\n", " .agg(['count'])\\\n", " .reset_index(drop=True)\\\n", " .rename(index={0: \"OK\", 1: \"KO\"})\\\n", " .T.reset_index(drop=True, level=0).T\n", "\n", "stats['pct'] = stats['count'] / count\n", "stats" ] }, { "cell_type": "markdown", "id": "based-virtue", "metadata": {}, "source": [ "It seems this dataset is highly imbalanced. Most of the samples are negative target samples, whereas only 3% of the samples are of type positive." ] }, { "cell_type": "markdown", "id": "imperial-clinton", "metadata": {}, "source": [ "### Imbalanced: Resampling & Feature selection\n", "\n", "We need to narrow down the number of features that are really relevant for classifying the samples between bankrupcy and bussiness as usual. The steps I'm going to do are:\n", "\n", "- Correlation Matrix\n", "- Scatter Matrix\n", "- Lasso feature selection\n", "- SBKBest feature selection" ] }, { "cell_type": "markdown", "id": "sitting-posting", "metadata": {}, "source": [ "First I'm normalizing the **dataframe's columns names**" ] }, { "cell_type": "code", "execution_count": 4, "id": "delayed-velvet", "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
bankruptroa-c-before-interest-and-depreciation-before-interestroa-a-before-interest-and-after-taxroa-b-before-interest-and-depreciation-after-taxoperating-gross-marginrealized-sales-gross-marginoperating-profit-ratepre-tax-net-interest-rateafter-tax-net-interest-ratenon-industry-income-and-expenditure-revenue...net-income-to-total-assetstotal-assets-to-gnp-priceno-credit-intervalgross-profit-to-salesnet-income-to-stockholder-s-equityliability-to-equitydegree-of-financial-leverage-dflinterest-coverage-ratio-interest-expense-to-ebitnet-income-flagequity-to-liability
010.3705940.4243890.4057500.6014570.6014570.9989690.7968870.8088090.302646...0.7168450.0092190.6228790.6014530.8278900.2902020.0266010.56405010.016469
110.4642910.5382140.5167300.6102350.6102350.9989460.7973800.8093010.303556...0.7952970.0083230.6236520.6102370.8399690.2838460.2645770.57017510.020794
210.4260710.4990190.4722950.6014500.6013640.9988570.7964030.8083880.302035...0.7746700.0400030.6238410.6014490.8367740.2901890.0265550.56370610.016474
310.3998440.4512650.4577330.5835410.5835410.9987000.7969670.8089660.303350...0.7395550.0032520.6229290.5835380.8346970.2817210.0266970.56466310.023982
410.4650220.5384320.5222980.5987830.5987830.9989730.7973660.8093040.303475...0.7950160.0038780.6235210.5987820.8399730.2785140.0247520.57561710.035490
\n", "

5 rows × 96 columns

\n", "
" ], "text/plain": [ " bankrupt roa-c-before-interest-and-depreciation-before-interest \\\n", "0 1 0.370594 \n", "1 1 0.464291 \n", "2 1 0.426071 \n", "3 1 0.399844 \n", "4 1 0.465022 \n", "\n", " roa-a-before-interest-and-after-tax \\\n", "0 0.424389 \n", "1 0.538214 \n", "2 0.499019 \n", "3 0.451265 \n", "4 0.538432 \n", "\n", " roa-b-before-interest-and-depreciation-after-tax operating-gross-margin \\\n", "0 0.405750 0.601457 \n", "1 0.516730 0.610235 \n", "2 0.472295 0.601450 \n", "3 0.457733 0.583541 \n", "4 0.522298 0.598783 \n", "\n", " realized-sales-gross-margin operating-profit-rate \\\n", "0 0.601457 0.998969 \n", "1 0.610235 0.998946 \n", "2 0.601364 0.998857 \n", "3 0.583541 0.998700 \n", "4 0.598783 0.998973 \n", "\n", " pre-tax-net-interest-rate after-tax-net-interest-rate \\\n", "0 0.796887 0.808809 \n", "1 0.797380 0.809301 \n", "2 0.796403 0.808388 \n", "3 0.796967 0.808966 \n", "4 0.797366 0.809304 \n", "\n", " non-industry-income-and-expenditure-revenue ... \\\n", "0 0.302646 ... \n", "1 0.303556 ... \n", "2 0.302035 ... \n", "3 0.303350 ... \n", "4 0.303475 ... \n", "\n", " net-income-to-total-assets total-assets-to-gnp-price no-credit-interval \\\n", "0 0.716845 0.009219 0.622879 \n", "1 0.795297 0.008323 0.623652 \n", "2 0.774670 0.040003 0.623841 \n", "3 0.739555 0.003252 0.622929 \n", "4 0.795016 0.003878 0.623521 \n", "\n", " gross-profit-to-sales net-income-to-stockholder-s-equity \\\n", "0 0.601453 0.827890 \n", "1 0.610237 0.839969 \n", "2 0.601449 0.836774 \n", "3 0.583538 0.834697 \n", "4 0.598782 0.839973 \n", "\n", " liability-to-equity degree-of-financial-leverage-dfl \\\n", "0 0.290202 0.026601 \n", "1 0.283846 0.264577 \n", "2 0.290189 0.026555 \n", "3 0.281721 0.026697 \n", "4 0.278514 0.024752 \n", "\n", " interest-coverage-ratio-interest-expense-to-ebit net-income-flag \\\n", "0 0.564050 1 \n", "1 0.570175 1 \n", "2 0.563706 1 \n", "3 0.564663 1 \n", "4 0.575617 1 \n", "\n", " equity-to-liability \n", "0 0.016469 \n", "1 0.020794 \n", "2 0.016474 \n", "3 0.023982 \n", "4 0.035490 \n", "\n", "[5 rows x 96 columns]" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from slugify import slugify\n", "\n", "columns = [slugify(name) for name in df.columns.values]\n", "df.columns = columns\n", "df.head()" ] }, { "cell_type": "markdown", "id": "interim-danger", "metadata": {}, "source": [ "Now I'm getting **target column** and the **rest of the columns** as features" ] }, { "cell_type": "code", "execution_count": 5, "id": "structural-probability", "metadata": {}, "outputs": [], "source": [ "target = 'bankrupt'\n", "features = [name for name in df.columns.values if name != 'bankrupt']" ] }, { "cell_type": "markdown", "id": "informational-popularity", "metadata": {}, "source": [ "## RESAMPLING" ] }, { "cell_type": "code", "execution_count": 6, "id": "joint-evidence", "metadata": {}, "outputs": [], "source": [ "from imblearn.over_sampling import SMOTE\n", "\n", "X = df[features]\n", "y = df[target]\n", "\n", "# RESAMPLING\n", "rnd = SMOTE()\n", "X, y = rnd.fit_resample(X, y)" ] }, { "cell_type": "markdown", "id": "raising-administrator", "metadata": {}, "source": [ "## SCALE" ] }, { "cell_type": "code", "execution_count": 7, "id": "sized-translator", "metadata": {}, "outputs": [], "source": [ "from sklearn.preprocessing import StandardScaler\n", "\n", "X_scaled = StandardScaler().fit_transform(X)" ] }, { "cell_type": "markdown", "id": "elegant-halloween", "metadata": {}, "source": [ "## FEATURE SELECTION" ] }, { "cell_type": "markdown", "id": "becoming-omaha", "metadata": {}, "source": [ "Creating initial target/features selection. I tried ALL but even a trivial matrix plot was taking forever. That's why I decided to do a top 10 feature selection directly." ] }, { "cell_type": "code", "execution_count": 8, "id": "adult-excellence", "metadata": { "scrolled": true }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/opt/conda/lib/python3.8/site-packages/sklearn/feature_selection/_univariate_selection.py:114: UserWarning: Features [93] are constant.\n", " warnings.warn(\"Features %s are constant.\" % constant_features_idx,\n", "/opt/conda/lib/python3.8/site-packages/sklearn/feature_selection/_univariate_selection.py:116: RuntimeWarning: invalid value encountered in true_divide\n", " f = msb / msw\n" ] }, { "data": { "text/plain": [ "['roa-c-before-interest-and-depreciation-before-interest',\n", " 'roa-a-before-interest-and-after-tax',\n", " 'tax-rate-a',\n", " 'net-value-per-share-c',\n", " 'operating-profit-per-share-yuan-y',\n", " 'total-debt-total-net-worth',\n", " 'debt-ratio',\n", " 'operating-profit-paid-in-capital',\n", " 'liability-assets-flag']" ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from sklearn.feature_selection import SelectKBest\n", "from sklearn.feature_selection import f_classif\n", "\n", "fs = SelectKBest(f_classif, k=10)\n", "X_scaled = fs.fit_transform(X_scaled, y)\n", "cols = fs.get_support(indices=True)\n", "cols = [name for name in df.columns[cols] if name != 'bankrupt']\n", "cols" ] }, { "cell_type": "markdown", "id": "broad-rings", "metadata": {}, "source": [ "Now creating the scatter matrix to see feature distributions" ] }, { "cell_type": "markdown", "id": "electoral-transparency", "metadata": {}, "source": [ "## CROSS VALIDATION" ] }, { "cell_type": "code", "execution_count": 9, "id": "consistent-chinese", "metadata": {}, "outputs": [], "source": [ "from sklearn.model_selection import cross_val_score, train_test_split\n", "\n", "X = X_scaled\n", "\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, random_state=0)" ] }, { "cell_type": "code", "execution_count": 10, "id": "emerging-mailman", "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "LogisticRegression\n", "==================\n", "{'C': 1} 0.8926910223286709\n", "---\n", "KNeighborsClassifier\n", "====================\n", "{'n_neighbors': 5} 0.9531649031498819\n", "---\n", "DecisionTreeClassifier\n", "======================\n", "{'max_depth': 6} 0.916512484124205\n", "---\n" ] } ], "source": [ "from sklearn.linear_model import LogisticRegression\n", "from sklearn.svm import LinearSVC\n", "from sklearn.neural_network import MLPClassifier\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "from sklearn.ensemble import RandomForestClassifier\n", "from sklearn.model_selection import GridSearchCV\n", "\n", "def print_search_result(name, search_result):\n", " print(name)\n", " print(\"=\" * len(name))\n", " print(search_result.best_params_, search_result.best_score_)\n", " print('---')\n", " \n", "classifiers = [\n", " {\n", " 'classifier': LogisticRegression(),\n", " 'params': { 'C': [1, 5, 10, 20, 40] }\n", " },\n", " {\n", " 'classifier': KNeighborsClassifier(),\n", " 'params': { 'n_neighbors': [5, 10, 15] }\n", " },\n", " {\n", " 'classifier': DecisionTreeClassifier(),\n", " 'params': { 'max_depth': [3, 4, 5, 6] }\n", " }\n", "]\n", "\n", "for entry in classifiers:\n", " classifier = entry['classifier']\n", " classifier_name = type(classifier).__name__\n", " \n", " search_params = entry['params']\n", " search_result = GridSearchCV(classifier, param_grid=search_params, scoring='recall')\\\n", " .fit(X_train, y_train)\n", " \n", " print_search_result(classifier_name, search_result)\n", " " ] }, { "cell_type": "code", "execution_count": 11, "id": "boxed-worry", "metadata": {}, "outputs": [], "source": [ "from sklearn.linear_model import LogisticRegression\n", "from sklearn.tree import DecisionTreeClassifier\n", "from sklearn.neighbors import KNeighborsClassifier\n", "\n", "log = LogisticRegression(C=1).fit(X_train, y_train)\n", "knn = KNeighborsClassifier(n_neighbors=5).fit(X_train, y_train)\n", "dtc = DecisionTreeClassifier(max_depth=6).fit(X_train, y_train)\n", "\n", "lst = [log, knn, dtc]" ] }, { "cell_type": "markdown", "id": "champion-control", "metadata": {}, "source": [ "## Decision functions" ] }, { "cell_type": "code", "execution_count": 12, "id": "moral-papua", "metadata": {}, "outputs": [], "source": [ "def get_y_predict(clsf, samples):\n", " dec_fun = getattr(clsf, 'decision_function', None)\n", " if dec_fun:\n", " return clsf.decision_function(samples)\n", " else:\n", " # only interested in positive score\n", " return clsf.predict_proba(samples)[0:, 1:]" ] }, { "cell_type": "markdown", "id": "oriental-sharing", "metadata": {}, "source": [ "## Precision/Recall curves" ] }, { "cell_type": "code", "execution_count": 13, "id": "native-restoration", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import numpy as np\n", "from sklearn.metrics import precision_recall_curve\n", "import matplotlib.pyplot as plt\n", "\n", "# threshold I'm interested in\n", "THRESHOLD = 0.75\n", "\n", "plt.figure()\n", "_, ax = plt.subplots(1, 3, figsize=(15, 4))\n", "cols = 0\n", "\n", "# plotting precision-recall charts\n", "for classifier in lst:\n", " y_predict = get_y_predict(classifier, X_test)\n", " precision, recall, thresholds = precision_recall_curve(y_test, y_predict)\n", " ax[cols].title.set_text(str(type(classifier).__name__))\n", " ax[cols].set(xlabel='Precision', ylabel='Recall')\n", " ax[cols].step(precision, recall)\n", " criteria = np.argmin(np.abs(thresholds - THRESHOLD))\n", " ax[cols].plot(precision[criteria], recall[criteria], 'o', c='r')\n", " ax[cols].grid(axis='both', linestyle='--', c='#cccccc')\n", " cols+=1\n", "\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 64, "id": "printable-eight", "metadata": {}, "outputs": [ { "data": { "text/plain": [ "
" ] }, "metadata": {}, "output_type": "display_data" }, { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.metrics import roc_curve, auc\n", "\n", "plt.figure()\n", "_, ax = plt.subplots(1, 3, figsize=(15, 4))\n", "cols = 0\n", "\n", "for classifier in lst:\n", " classifier_name = type(classifier).__name__\n", " # getting decision function prediction\n", " y_predict = get_y_predict(classifier, X_test)\n", "\n", " # calculating FPR and TPR\n", " fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", "\n", " # calculating the area under the curve\n", " roc_auc = auc(fpr, tpr)\n", " \n", " ax[cols].title.set_text(\"{0} (AUC={1:.2f})\".format(classifier_name, roc_auc))\n", " ax[cols].set(xlabel='False Positive Rate', ylabel='True Positive Rate')\n", " ax[cols].plot(fpr, tpr, c='k')\n", " ax[cols].plot([0, 1], [0, 1], c='k', linestyle='--')\n", " ax[cols].fill_between(fpr, tpr, hatch='\\\\', color='none', edgecolor='#cccccc')\n", " cols+=1\n", " \n", "plt.show()" ] }, { "cell_type": "code", "execution_count": 70, "id": "acceptable-adaptation", "metadata": {}, "outputs": [ { "data": { "image/png": "\n", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "from sklearn.metrics import roc_curve, auc\n", "\n", "plt.figure(figsize=(5, 5))\n", "\n", "for classifier in lst:\n", " classifier_name = type(classifier).__name__\n", " \n", " # calculating prediction using prediction functions\n", " y_predict = get_y_predict(classifier, X_test)\n", " \n", " # calculating the roc curves\n", " fpr, tpr, thresholds = roc_curve(y_test, y_predict)\n", " plt.plot(fpr, tpr, label=classifier_name)\n", " \n", "plt.plot([0, 1], [0, 1], c='green', linestyle='--')\n", "plt.title('ROC')\n", "plt.xlabel('False Positive Rate')\n", "plt.ylabel('True Positive Rate')\n", "plt.legend(loc=\"lower right\", fontsize=11)\n", "plt.show()" ] }, { "cell_type": "code", "execution_count": null, "id": "foreign-revolution", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.6" } }, "nbformat": 4, "nbformat_minor": 5 }