"""
==============================================================
Compare DupleBalance with Ad-hoc McIL Methods
==============================================================

In this example, we compare the :class:`duplebalance.DupleBalanceClassifier` 
and other ad-hoc multi-class imbalanced learning methods.
"""

# %%
print(__doc__)

RANDOM_STATE = 42

# %% [markdown]
# Preparation
# -----------
# First, we will import necessary packages and generate an example
# multi-class imbalanced dataset.

from duplebalance import DupleBalanceClassifier
from duplebalance import AdhocMultiClassifier
from duplebalance.base import sort_dict_by_key

import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

# %% [markdown]
# Make a 5-class imbalanced classification task

X, y = make_classification(n_classes=5, class_sep=1, # 3-class
    weights=[0.05, 0.05, 0.15, 0.25, 0.5], n_informative=10, n_redundant=1, flip_y=0,
    n_features=20, n_clusters_per_class=1, n_samples=2000, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

origin_distr = sort_dict_by_key(Counter(y_train))
test_distr = sort_dict_by_key(Counter(y_test))
print('Original training dataset shape %s' % origin_distr)
print('Original test dataset shape %s' % test_distr)

# Initialize results list 
all_results = []
all_results_columns = ['Method', 'Score', '#Base Estimators', '#Training Samples']

# %% [markdown]
# Train a DupleBalance Classifier
# --------------------------------------------------
# Train a DupleBalanceClassifier

n_estimators_list = [1, 5, 10, 25, 50]

ensemble_init_kwargs = {
    'random_state': RANDOM_STATE,
}

for n_estimators in n_estimators_list:
    clf = DupleBalanceClassifier(
        n_estimators=n_estimators,
        **ensemble_init_kwargs
    ).fit(
        X_train, y_train,
        perturb_alpha=0.7,
        sample_weight=None,
        eval_datasets={'test': (X_test, y_test)},
        train_verbose=False,
    )
    y_pred_proba = clf.predict_proba(X_test)
    score = roc_auc_score(y_test, y_pred_proba, **{'average': 'weighted', 'multi_class': 'ovo'})
    print ("DupleBalance {:<2d} | Balanced AUROC: {:.3f} | #Training Samples: {:d}".format(
        n_estimators, score, sum(clf.estimators_n_training_samples_)
        ))
    all_results.append(
        ['DupleBalance', score, len(clf.estimators_), sum(clf.estimators_n_training_samples_)]
    )

# %% [markdown]
# Train Ad-hoc McIL Classifiers

# Train all ad-hoc McIL methods

ALL_ADHOC_METHOD = ['mdoboost', 'mdobagging', 'soupboost', 'soupbagging', 'mrrbagging', 'adacost', 'asymboost']

for n_estimators in n_estimators_list:
    for method in ALL_ADHOC_METHOD:
        clf = AdhocMultiClassifier(
            method=method,
            n_estimators=n_estimators,
            **ensemble_init_kwargs
        ).fit(X_train, y_train)
        y_pred_proba = clf.predict_proba(X_test)
        score = roc_auc_score(y_test, y_pred_proba, **{'average': 'weighted', 'multi_class': 'ovo'})
        print ("Ad-hoc method: {:<15s} {:<2d} | Balanced AUROC: {:.3f} | #Training Samples: {:d}".format(
            method, n_estimators, score, sum(clf.estimators_n_training_samples_)
            ))
        all_results.append(
            [method, score, len(clf.estimators_), sum(clf.estimators_n_training_samples_)]
        )
    print ('\n')

# %% [markdown]
# Results Visualization
# --------------------------

import matplotlib.pyplot as plt
import seaborn as sns
sns.set_context('talk')

results_vis = pd.DataFrame(all_results, columns=all_results_columns)

fig = plt.figure(figsize=(10,6))
ax = sns.lineplot(
    data=results_vis, 
    x='#Training Samples', y='Score', hue='Method', style='Method',
    markers=True, err_style='bars', linewidth=4, markersize=12, alpha=0.9
)

for position, spine in ax.spines.items():
    spine.set_color('black')
    spine.set_linewidth(2)

ax.grid(color = 'black', linestyle='-.', alpha=0.3)
ax.set_ylabel('AUROC (macro)')
ax.legend(columnspacing=0.2,
          borderaxespad=0.2,
          handletextpad=0.2,
          labelspacing=0.2,
          handlelength=None,)
ax.set_title(f"DupleBalance versus Ad-hoc Methods")
plt.show()