Source code for skbold.utils.crossval_splitter

from __future__ import division, print_function, absolute_import
from builtins import range
import pandas as pd
import numpy as np
import scipy.stats as stat
import os.path as op

try:
    import matplotlib.pyplot as plt
except ImportError:
    print('Matplotlib not installed; cannot plot!')


[docs]class CrossvalSplitter(object): def __init__(self, data, train_size, vars, cb_between_splits=False, binarize=None, include=None, exclude=None, interactions=True, sep='\t', index_col=0, ignore=None, iterations=1000): if isinstance(data, (str, unicode)): data = pd.read_csv(data, sep=sep, index_col=index_col) data['cv_group'] = np.nan if include is not None: data = data.loc[include] for var in vars.keys(): # ignore values, such as 9999 data.loc[data[var] == ignore, var] = np.nan if exclude is not None: for key, value in exclude.items(): data = data[data[key] != value] self.data = data if 0 < train_size < 1: # percentage train_size = np.round(data.shape[0] * train_size) test_size = data.shape[0] - train_size self.train_size = train_size self.test_size = test_size self.cb_between_splits = cb_between_splits self.vars = vars self.interactions = interactions self.exclude = exclude self.ignore = ignore self.iterations = iterations self.best_all_samples = None self.best_train_set = None self.best_test_set = None self.best_min_p_val = 0
[docs] def split(self, verbose=False): full_size = self.train_size + self.test_size for i in range(self.iterations): p_this_iter = [] data = self.data all_idx = data.index # take two random samples: full_sample = np.random.choice(all_idx, size=full_size, replace=False) train_idx = full_sample[:self.train_size] end = (self.train_size + self.test_size) test_idx = full_sample[self.train_size:end] data.loc[train_idx, 'cv_group'] = 'train' data.loc[test_idx, 'cv_group'] = 'test' data = data.loc[full_sample] # only take the sampled data # make sure everything is goin' all right assert(len(train_idx) == self.train_size) assert(len(test_idx) == self.test_size) assert(sum(np.in1d(train_idx, test_idx) == 0)) ps = self._counterbalance(data.loc[train_idx]) p_this_iter.extend(ps) if self.cb_between_splits: ps = self._counterbalance(data.loc[test_idx]) p_this_iter.extend(ps) if min(p_this_iter) > self.best_min_p_val: self.best_min_p_val = min(p_this_iter) self.best_all_samples = full_sample self.best_train_set = train_idx self.best_test_set = test_idx if verbose: print('Iteration %d, best min p-value found: %.3f...' % (i, self.best_min_p_val)) self.data = self.data.loc[self.best_all_samples] self.data.loc[self.best_train_set, 'cv_group'] = 'train' self.data.loc[self.best_test_set, 'cv_group'] = 'test' return self.best_train_set, self.best_test_set
def _counterbalance(self, data): p_this_set = [] for var, values in self.vars.items(): categorical = False is isinstance(values, (str, unicode)) if categorical: chisq, p = self._test_categorical(data, var, values) p_this_set.append(p) if self.interactions: ps = self._test_categorical_interaction(data) p_this_set.extend(ps) return p_this_set def _test_continuous(self, s1, s2): t, p = stat.ttest_ind(s1, s2, nan_policy='omit') return t, p def _test_categorical_interaction(self, data): p_ints = [] for i, (var, values) in enumerate(self.vars.items()): if i == 0: cvar, cvalues = var, values else: s1 = data[data[var].isin(values)][var] s2 = data[data[cvar].isin(cvalues)][cvar] sint = s1 * s2 count = sint.value_counts() chisq, p = stat.chisquare(count.tolist()) p_ints.append(p) return p_ints def _test_categorical(self, data, var, values): count = data[data[var].isin(values)][var].value_counts() chisq, p = stat.chisquare(count.tolist()) return chisq, p
[docs] def save(self, out_dir, save_plots=True): if self.best_min_p_val == 0: IOError('split not yet run, nothing to save!') self.data = self.data.sort_index() self.data.to_csv(op.join(out_dir, 'split.tsv'), sep='\t') if save_plots: self.plot_results(out_dir)
[docs] def plot_results(self, out_dir): train_idx = self.best_train_set test_idx = self.best_test_set data = self.data for ii, (var, values) in enumerate(self.vars.items()): plt.figure(ii) if isinstance(values, list): fig, (ax1, ax2) = plt.subplots(1, 2) labels = values count = data.groupby([var, 'cv_group']).size() count = count.unstack(level=0)[labels] train_vals = count.loc['train'].values test_vals = count.loc['test'].values ax1.pie(train_vals, labels=labels, autopct='%1.1f%%', shadow=True, startangle=90) ax1.set_title('%s train group' % var) ax2.pie(test_vals, labels=labels, autopct='%1.1f%%', shadow=True, startangle=90) ax2.set_title('%s test group' % var) fn = op.join(out_dir, var + '.png') fig.savefig(fn) else: train_vals = data.loc[train_idx, var].values test_vals = data.loc[test_idx, var].values train_vals = train_vals[~np.isnan(train_vals)] test_vals = test_vals[~np.isnan(test_vals)] plt.hist(train_vals, alpha=0.5, label='Train') plt.hist(test_vals, alpha=0.5, label='Test') plt.legend(loc='upper right') plt.title(var) fn = op.join(out_dir, var + '.png') plt.savefig(fn)