#!/usr/bin/env python
#coding=utf-8

import numpy as np
import matplotlib.pyplot as plt

## EXERCISE 1
data = np.recfromcsv('../data/olympics.csv')

## EXERCISE 2
total_2012 = data['gold_2012']+data['silver_2012']+data['bronze_2012']
total_2012 = total_2012[~np.isnan(total_2012)]

plt.figure()
plt.hist(total_2012)
plt.xlabel('number of medals')
plt.ylabel('number of countries')

plt.figure()
bins = np.arange(0, total_2012.max()+1)
plt.hist(total_2012, bins)
plt.xlabel('number of medals')
plt.ylabel('number of countries')

plt.figure()
ax=None
all_axes = []
for i in range(0,9):
    yr = str(1976+i*4) #olympics are every 4 years
    total = data['gold_'+yr]+data['silver_'+yr]+data['bronze_'+yr]
    ax = plt.subplot(3,3,i+1, sharey=ax)
    plt.setp(ax.get_xticklabels(), visible=False)
    plt.setp(ax.get_yticklabels(), visible=False)
    plt.hist(total, bins)
    plt.text(0.7, 0.8, yr, transform=ax.transAxes)
    plt.xlim(0, bins.max())
    all_axes.append(ax)

#add labels only to one subplot
ax = all_axes[6]
plt.setp(ax.get_xticklabels(), visible=True)
plt.setp(ax.get_yticklabels(), visible=True)
ax.set_xlabel('N/o medals')
ax.set_ylabel('N/o countries')

plt.savefig('images/fig2.png')

## EXERCISE 3
plt.figure()

country_list = ['United States', 'Germany', 'France'] 

# select countries
idx = np.array([np.where(data['country']==name)[0][0] 
                for name in country_list])
sel = data[idx] 

width = 1./(len(country_list)+1)
bins = np.arange(len(country_list))

# bar plot
plt.bar(bins, sel['gold_2012'], width, color='gold', label='gold')
plt.bar(bins+width, sel['silver_2012'], width, color='silver',
        label='silver')
plt.bar(bins+2*width, sel['bronze_2012'], width,color='brown',
        label='bronze')
plt.xticks(bins+3*width/2, country_list)
plt.ylabel('Number of medals')
plt.legend()
plt.savefig('images/fig3b.png')

# stacked bar plot
plt.figure()
b1 = plt.bar(bins, sel['bronze_2012'], bottom=0, 
        color='brown', label='bronze')
b2 = plt.bar(bins, sel['silver_2012'], bottom=sel['bronze_2012'],
        color='silver', label='silver')
b3 = plt.bar(bins, sel['gold_2012'], bottom=sel['bronze_2012']+sel['silver_2012'],
        color='gold', label='gold')
plt.xticks(bins+0.4, country_list)
plt.ylabel('Number of medals')
plt.legend()
plt.savefig('images/fig3c.png')
  
## EXERCISE 4

plt.figure()
# Calculate total number of medals
#  First, remove country names from the array
cols = list(sel.dtype.names)
cols.remove('country')
medals_data = data[cols]
#  Then, convert record arrray to ndarray
medals = medals_data.view(np.float64).reshape(len(data),-1)
total_medals = medals.reshape(len(data),-1, 3).sum(2)

# Extract years to label the plot
years = map(lambda x: x[5:],cols[::3])

plt.plot(total_medals[idx,:].T)
plt.xlim(0, len(years))
plt.xticks(np.arange(len(years)), years, rotation=30, size=10)
plt.legend(country_list)
plt.xlabel('Summer Olympics')
plt.ylabel('Total medals')
plt.savefig('images/fig4a.png')

# Sparkline plot
plt.figure(figsize=(2.5,8))
plt.axes([0.5, 0.02, 0.5,0.96], frameon=False)
total_medals= np.ma.masked_invalid(total_medals) 
spark = total_medals.copy()
spark = spark-spark.min(1)[:,None]
spark = spark/spark.max(1)[:, None]
spark = (spark-0.5)*0.8

n_countries = 30 #number of countries to show
offset = np.arange(1, n_countries+1)

# sort countries wrt n/o medals in 2012
aux = total_medals.copy()
aux[np.isnan(aux)]=0
i = np.argsort(aux[:,-1])
i = i[-n_countries:]

plt.plot((spark[i,:]+offset[:, None])[:n_countries,:].T, 'k')
plt.yticks(offset, data['country'][i], size=10)
plt.xlim(0,spark.shape[1])
plt.ylim(-0.5, n_countries+0.5)
plt.xticks([])
plt.savefig('images/fig4b.png')

## EXERCISE 5
plt.close('all')
# read country statistics
country_stats = np.recfromcsv('../data/country_aligned.csv')

total_2012 = data['gold_2012']+data['silver_2012']+data['bronze_2012']
total_2012 = np.ma.masked_invalid(total_2012)
pop = np.ma.masked_invalid(country_stats['population'])
gdp = np.ma.masked_invalid(country_stats['gdp'])
ath = np.ma.masked_invalid(country_stats['athletes_2012'])
country = data['country']

#this is a helper function to remove missing values
# from bunch of arrays in one step
from matplotlib import cbook
ath, gdp, pop, total_2012, country = cbook.delete_masked_points(ath, gdp, pop,
                                                      total_2012,
                                                       country)
fig = plt.figure()
fig.subplots_adjust(wspace=0.3)
fig.subplots_adjust(hspace=0.3)
plt.subplot(221)
plt.plot(pop, total_2012, '.')
plt.xlabel('population')
plt.ylabel('medals in 2012')
plt.subplot(222)
plt.plot(gdp, total_2012, '.')
plt.xticks(np.arange(0, 120000, 20000))
plt.xlabel('GDP per capita ($US)')
plt.ylabel('medals in 2012')
plt.subplot(223)
plt.plot(ath, total_2012, '.')
plt.xlabel('athletes')
plt.ylabel('medals in 2012')

plt.savefig('images/fig5a.png')

# scatter plot
plt.figure()
pts = plt.scatter(ath, total_2012, gdp*1./gdp.max()*300.)
plt.xlabel('Number of athletes')
plt.ylabel('Number of medals')
plt.savefig('images/fig5b.png')
# add country names 
i = np.argsort(total_2012)[-10:]
for x,y,l in zip(ath[i], total_2012[i], country[i]):
    plt.annotate(l, xy = (x, y))
plt.savefig('images/fig5c.png')


# interactive plot
plt.figure()
pts = plt.scatter(ath, total_2012, gdp*1./gdp.max()*300.)
plt.xlabel('Number of athletes')
plt.ylabel('Number of medals')

import sys
sys.path.append('../code')
from data_cursor import DataCursor

DataCursor([pts], country,
           template="{lab}\nMedals:{y}\nAthletes:{x}")
plt.show()
