"""
==============================================================
Compare DupleBalance with Decomposition-based McIL Methods
==============================================================

In this example, we compare the :class:`duplebalance.DupleBalanceClassifier` 
and other decomposition + binaryIL multi-class imbalanced learning methods.
"""

# %%
print(__doc__)

RANDOM_STATE = 42

# %% [markdown]
# Preparation
# -----------
# First, we will import necessary packages and generate an example
# multi-class imbalanced dataset.

from duplebalance import DupleBalanceClassifier
from duplebalance import DecompositionBasedClassifier
from duplebalance.base import sort_dict_by_key

import pandas as pd
from collections import Counter
import matplotlib.pyplot as plt

from sklearn.datasets import make_classification
from sklearn.model_selection import train_test_split
from sklearn.metrics import roc_auc_score

# %% [markdown]
# Make a 5-class imbalanced classification task

X, y = make_classification(n_classes=5, class_sep=1, # 3-class
    weights=[0.05, 0.05, 0.15, 0.25, 0.5], n_informative=10, n_redundant=1, flip_y=0,
    n_features=20, n_clusters_per_class=1, n_samples=2000, random_state=0)

X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=42)

origin_distr = sort_dict_by_key(Counter(y_train))
test_distr = sort_dict_by_key(Counter(y_test))
print('Original training dataset shape %s' % origin_distr)
print('Original test dataset shape %s' % test_distr)

# Initialize results list 
all_results = []
all_results_columns = ['Method', 'Score', '#Base Estimators', '#Training Samples']

# %% [markdown]
# Train a DupleBalance Classifier
# --------------------------------------------------

# Train a DupleBalanceClassifier

ensemble_init_kwargs = {
    'n_estimators': 5,
    'random_state': RANDOM_STATE,
}

clf = DupleBalanceClassifier(**ensemble_init_kwargs).fit(
    X_train, y_train,
    perturb_alpha='auto',
    sample_weight=None,
    eval_datasets={'test': (X_test, y_test)},
    train_verbose={
        'granularity': 1,
        'print_distribution': True,
        'print_metrics': True,
    },
)
y_pred_proba = clf.predict_proba(X_test)
score = roc_auc_score(y_test, y_pred_proba, **{'average': 'weighted', 'multi_class': 'ovo'})
print ("DupleBalance {} | Balanced AUROC: {:.3f} | #Training Samples: {:d}".format(
    ensemble_init_kwargs['n_estimators'], score, sum(clf.estimators_n_training_samples_)
    ))
all_results.append(
    ['DupleBalance', score, len(clf.estimators_), sum(clf.estimators_n_training_samples_)]
)

# %% [markdown]
# Train Decomposition + Binary IL Classifiers

# Train all decomposition + binary imbalanced learning McIL methods

ALL_DECOMP = ['ova', 'ovo', 'ecoc']
ALL_BINARY = ['clean', 'enn', 'oneside', 'tomeklink', 'smote', 'border', 'oups', 'ans', 'ccr', 'gazzah', 'smotersb', 'smotetomek']

for decomposition in ALL_DECOMP:
    for binary_il in ALL_BINARY:
        # print (f"Training decomposition {decomposition} + {binary_il} ...")
        clf = DecompositionBasedClassifier(
            binary_il=binary_il,
            decomposition=decomposition,
            random_state=RANDOM_STATE,
        ).fit(X_train, y_train)
        y_pred_proba = clf.predict_proba(X_test)
        score = roc_auc_score(y_test, y_pred_proba, **{'average': 'weighted', 'multi_class': 'ovo'})
        print ("Decomp: {:<5s} | BinaryIL: {:<10s} | Balanced AUROC: {:.3f} | #Training Samples: {:d}".format(
            decomposition, binary_il, score, sum(clf.estimators_n_training_samples_)
            ))
        all_results.append(
            [f'{decomposition}+{binary_il}', score, len(clf.estimators_), sum(clf.estimators_n_training_samples_)]
        )
    print ('\n')

# %% [markdown]
# Results Visualization
# --------------------------

import matplotlib.pyplot as plt
import seaborn as sns
# sns.set_context('talk')

def get_decomposition_mask(decomposition, all_results):
    mask = []
    for method in all_results['Method'].values:
        if method == 'DupleBalance':
            mask.append(True)
        elif method[:3] == decomposition:
            mask.append(True)
        elif method[:4] == decomposition:
            mask.append(True)
        else: mask.append(False)
    return mask

all_results = pd.DataFrame(all_results, columns=all_results_columns)


figure, axes = plt.subplots(1, 3, figsize=(12,6))

for decomposition, ax in zip(['ova', 'ovo', 'ecoc'], axes.flatten()):
    
    results_vis = all_results[get_decomposition_mask(decomposition, all_results)]
    sns.scatterplot(
        data=results_vis, 
        x='#Training Samples', y='Score', hue='Method', style='Method',
        s=300, ax=ax,
    )

    for position, spine in ax.spines.items():
        spine.set_color('black')
        spine.set_linewidth(2)

    ax.grid(color = 'black', linestyle='-.', alpha=0.3)
    ax.set_ylabel('AUROC (macro)')
    ax.legend(columnspacing=0.2,
              borderaxespad=0.2,
              handletextpad=0.2,
              labelspacing=0.2,
              handlelength=None,)
    ax.set_title(f"DupleBalance versus {decomposition.upper()}")

plt.tight_layout()
plt.show()