import json
import numpy as np
from pandas import DataFrame
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from sklearn.linear_model import LinearRegression
from sklearn.metrics import mean_squared_error
from sklearn.ensemble import GradientBoostingRegressor
from IPython import display

# Read in JSON file with all MTG card data
with open('AllCards.json', 'r', encoding='utf8') as read_file:
    data = json.load(read_file)
print(len(data), 'cards read.')


# is_number(s) is a helper function that tests whether or not a JSON string entry parses as a float.
def is_number(s):
    try:
        float(s)
        return True
    except ValueError:
        return False


# Conservatively collect names of creatures with "Flying" only or no rules text and output a CSV file with
# card name, power, toughness, 0/1 for flying/non-flying respectively, converted mana cost (CMC)
creatures = []  # list of creature names with numeric power and toughness and lacking transform layout
flying_creatures = []  # subset of creature names with rule text indicating "Flying" only
non_flying_creatures = []  # subset of creature names with blank rule text
for card_name in data:
    card_data = data[card_name]
    card_types = [] if 'types' not in card_data else card_data['types']
    if 'Creature' in card_types \
            and 'power' in card_data and is_number(card_data['power']) \
            and 'toughness' in card_data and is_number(card_data['toughness'])\
            and 'transform' not in card_data['layout']:  # Note: prominent outliers were all transform back sides
        creatures.append(card_name)
        card_text = '' if 'text' not in card_data else card_data['text']
        isFlying = False
        if card_text == 'Flying':
            isFlying = True
            flying_creatures.append(card_name)
        elif card_text == '':
            non_flying_creatures.append(card_name)
        else:
            continue
        print("\"{}\", {}, {}, {}, {}".format(card_name, card_data['power'], card_data['toughness'],
              1 if isFlying else 0, card_data['convertedManaCost']))
print(len(creatures), 'creatures found.')
print(len(flying_creatures), 'flying creatures filtered.')
print(len(non_flying_creatures), 'non-flying creatures filtered.')
print()

# Create arrays of data, create pandas DataFrames from those arrays, and summarize the data.
flying_data = ([float(data[card_name]['power']), float(data[card_name]['toughness']), 1,
                data[card_name]['convertedManaCost']] for card_name in flying_creatures)
non_flying_data = ([float(data[card_name]['power']), float(data[card_name]['toughness']), 0,
                    data[card_name]['convertedManaCost']] for card_name in non_flying_creatures)
all_data = list(flying_data) + list(non_flying_data)
full_df = DataFrame(all_data, columns=['power', 'toughness', 'flying', 'cmc'])
full_df = full_df.sample(frac=1, random_state=0).reset_index(drop=True)  # shuffle DataFrame
print(full_df.describe(include='all'))


# jitter_arr - given a numeric array, return a jittered numeric array for jittering plot data
def jitter_arr(arr):
    stdev = .02 * (max(arr) - min(arr))
    return arr + np.random.randn(len(arr)) * stdev


# scatterplot jittered data with blue non-flying and red flying
minval = 0
maxval = 10
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
ax.set_xlim((minval, maxval))
ax.set_ylim((minval, maxval))
for f, c, m, a in [(0, 'r', 'o', .2), (1, 'b', '^', .5)]:
    xs = list(full_df.loc[full_df['flying'] == f]['power'])
    ys = list(full_df.loc[full_df['flying'] == f]['toughness'])
    zs = list(full_df.loc[full_df['flying'] == f]['cmc'])
    ax.scatter(jitter_arr(xs), jitter_arr(ys), jitter_arr(zs), c=c, marker=m, alpha=a)
ax.set_xlabel('Power')
ax.set_ylabel('Toughness')
ax.set_zlabel('CMC')
plt.show()

# Perform a least-squares linear regression and print the linear model computed.
# https://scikit-learn.org/stable/modules/linear_model.html#ordinary-least-squares
X = full_df[['power', 'toughness', 'flying']].values
y = full_df['cmc'].values
reg = LinearRegression().fit(X, y)
print('Linear regression model:\ncmc = {} * power\n    + {} * toughness\n    + {} * flying\n    + {}'
      .format(reg.coef_[0], reg.coef_[1], reg.coef_[2], reg.intercept_))
print('R-squared value: ', reg.score(X, y))
print()

# Plot flying/non-flying linear surfaces against scatterplotted data.
# Replot scatterplot.
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for f, c, m, a in [(0, 'r', 'o', .2), (1, 'b', '^', .5)]:
    xs = list(full_df.loc[full_df['flying'] == f]['power'])
    ys = list(full_df.loc[full_df['flying'] == f]['toughness'])
    zs = list(full_df.loc[full_df['flying'] == f]['cmc'])
    ax.scatter(jitter_arr(xs), jitter_arr(ys), jitter_arr(zs), c=c, marker=m, alpha=a)
ax.set_xlabel('Power')
ax.set_ylabel('Toughness')
ax.set_zlabel('CMC')

# Make flying predictions and plot surface.
x_vals = np.arange(minval, maxval + 1, 1)
y_vals = np.arange(minval, maxval + 1, 1)
Xs, Ys = np.meshgrid(x_vals, y_vals)
Zs = reg.coef_[0]*Xs + reg.coef_[1]*Ys + reg.coef_[2]
ax.plot_surface(Xs, Ys, Zs, color='b', alpha=.25, linewidth=0, antialiased=False)

# Make non-flying predictions and plot surface.
Zs = reg.coef_[0]*Xs + reg.coef_[1]*Ys
ax.plot_surface(Xs, Ys, Zs, color='r', alpha=.25, linewidth=0, antialiased=False)

plt.show()  # We see against these planes that the data is nonlinear.

# Create a better nonlinear model using gradient-boosted regression trees.
# https://scikit-learn.org/stable/modules/ensemble.html#regression
train_set_size = int(2 * X.shape[0] / 3)
X_train, X_test = X[:train_set_size, :], X[train_set_size:, :]
y_train, y_test = y[:train_set_size], y[train_set_size:]
est = GradientBoostingRegressor(learning_rate=.1, n_estimators=40, verbose=False,
                                random_state=0, subsample=.7,
                                max_depth=3, loss='ls').fit(X_train, y_train)
print('Gradient-boosted regression tree mean squared error:', mean_squared_error(y_test, est.predict(X_test)))
# Tuning advice here:
# https://www.analyticsvidhya.com/blog/2016/02/complete-guide-parameter-tuning-gradient-boosting-gbm-python/

# Plot GBRT model prediction surfaces.
# Replot scatterplot.
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')
for f, c, m, a in [(0, 'r', 'o', .2), (1, 'b', '^', .5)]:
    xs = list(full_df.loc[full_df['flying'] == f]['power'])
    ys = list(full_df.loc[full_df['flying'] == f]['toughness'])
    zs = list(full_df.loc[full_df['flying'] == f]['cmc'])
    ax.scatter(jitter_arr(xs), jitter_arr(ys), jitter_arr(zs), c=c, marker=m, alpha=a)
ax.set_xlabel('Power')
ax.set_ylabel('Toughness')
ax.set_zlabel('CMC')

# Make flying predictions and plot surface.
Zs = est.predict(np.c_[Xs.ravel(), Ys.ravel(), np.ones(Xs.shape).ravel()])
flying_Zs = Zs.reshape(Xs.shape)
surf = ax.plot_surface(Xs, Ys, flying_Zs, color='b', alpha=.25, linewidth=0, antialiased=False)

# Make non-flying predictions and plot surface.
Zs = est.predict(np.c_[Xs.ravel(), Ys.ravel(), np.zeros(Xs.shape).ravel()])
non_flying_Zs = Zs.reshape(Xs.shape)
ax.plot_surface(Xs, Ys, non_flying_Zs, color='r', alpha=.25, linewidth=0, antialiased=False)
plt.show()

# Provide a tabular summary of the gradient-boosted regression tree flying premium.
premium_Zs = flying_Zs - non_flying_Zs
flying_premium_df = DataFrame(premium_Zs, columns=y_vals, index=x_vals)
flying_premium_df = flying_premium_df.round(1)
print('Flying premium given power (row) and toughness (column):')
display.display(flying_premium_df)

# Given the negative values, it's likely that outlier predictions where data points are sparse are not meaningful.
# Therefore, we compute a sample-weighted mean of premiums:
num_samples = X.shape[0]
average_flying_cost = 0
for i in range(num_samples):
    average_flying_cost += (est.predict([[X[i][0], X[i][1], 1]]) - est.predict([[X[i][0], X[i][1], 0]]))[0]
average_flying_cost /= num_samples
print('Average sample-weighted cost of flying (CMC):', round(average_flying_cost, 3))
