Source code for straditize.evaluator

"""Evaluator class for the straditize algorithms"""
import os.path as osp
import shutil
import multiprocessing as mp
import warnings
from PIL import Image
from functools import partial
import pandas as pd
import numpy as np
from collections import OrderedDict
from straditize.common import docstrings
from straditize.straditizer import Straditizer
from psy_strat.stratplot import stratplot
from itertools import filterfalse


docstrings.get_sectionsf('stratplot')(stratplot)


[docs]def rmse(sim, ref): """Calculate the root mean squared error between simulation and reference Parameters ---------- sim: np.ndarray The simluated data ref: np.ndarray The reference data""" return np.sqrt(((sim - ref) ** 2).mean())
axislinestyle = {'left': '-', 'right': '-', 'bottom': '-', 'top': '-'}
[docs]class StraditizeEvaluator: """An evaluator for the straditize components""" @docstrings.dedent def __init__(self, data, *args, name='data', axislinestyle=axislinestyle, **kwargs): """ Parameters ---------- %(stratplot.parameters)s """ self.data = data self.name = name self.sp, self.groupers = stratplot(data, *args, **kwargs) self.labels = OrderedDict() if axislinestyle is not None: self.sp.update(axislinestyle=axislinestyle) self.use_bars = kwargs.get('use_bars')
[docs] @classmethod def from_polnet(cls, data, *args, **kwargs): data = data.drop_duplicates( ['age', 'original_varname']).sort_values( ['age', 'original_varname']) index = data.age.unique() columns = data.original_varname.unique() df = pd.DataFrame( [], columns=pd.Index(columns, name='original_varname'), index=pd.Index(index, name='age')) for key, row in data.iterrows(): df.loc[row.age, row.original_varname] = row.percentage df = df[df.columns[(df > 1.0).any(axis=0)]] return cls(df, *args, calculate_percentages=False, percentages=True, **kwargs)
[docs] def export(self, filepath, dpi=300, labels={}): self.dpi = dpi self.filepath = filepath for key, val in sorted(labels.items()): self.labels[key] = val self.sp.export(filepath + '.png', dpi=dpi) # now hide the axes self.sp.update(axiscolor={'left': 'w', 'right': 'w'}) self.sp.export(filepath + '-no-axes.png', dpi=dpi) self.sp.update(axiscolor={'left': 'k', 'right': 'k'})
_results = None _dpi = None _data = None @property def data(self): return self._data[[arr.name for arr in self.sp]] @data.setter def data(self, value): self._data = value @property def transformed_data(self): """The :attr:`data` in pixel coordinates""" columns = self.data.columns widths = pd.Series(np.diff(self.column_bounds).ravel(), columns) xwidths = pd.Series( np.diff([ax.get_xlim() for ax in self.sp.axes]).ravel(), columns) df = (self.data.fillna(0) * widths / xwidths).round() y_px = self.data_ylim - self.data_ylim[0] y_data = next(np.sort(ax.get_ylim()) for ax in self.sp.axes) diff_px = np.diff(y_px)[0] diff_data = np.diff(y_data)[0] slope = diff_px / diff_data intercept = y_px[0] - slope * y_data[0] df.index = np.round(intercept + slope * df.index).astype(int) return df @property def dpi(self): if self._dpi is None: raise ValueError("The image has not yet been exported!") return self._dpi @dpi.setter def dpi(self, value): self._dpi = value @property def results(self): if self._results is None: names = [self.name, 'ntaxa', 'nsamples'] + list(self.labels) levels = labels = [[]] * len(names) self._results = pd.DataFrame( [], columns=pd.MultiIndex(levels, labels, names=names), index=pd.Index([], name='metric')) column = self.results_column try: return self._results[column] except KeyError: self._results[column] = np.nan return self._results[column] @property def results_column(self): """The column name in :attr:`all_results`""" return (self.name, len(self.sp), len(self.data)) + tuple( self.labels.values()) @results.setter def results(self, value): column = self.results_column for key, val in value.items(): self._results.loc[key, column] = val @property def all_results(self): return self._results @property def width(self): fig = next(iter(self.sp.figs)) return fig.get_figwidth() * self.dpi @property def height(self): fig = next(iter(self.sp.figs)) return fig.get_figheight() * self.dpi @property def data_xlim(self): minx = min(ax.get_position().x0 for ax in self.sp.axes) maxx = max(ax.get_position().x1 for ax in self.sp.axes) width = self.width return np.round([minx * width, maxx * width]).astype(int) @property def summed_perc(self): summed_perc = np.sum([ax.get_xlim()[1] for ax in self.sp.axes]) return summed_perc @property def data_ylim(self): miny = min(ax.get_position().y0 for ax in self.sp.axes) maxy = max(ax.get_position().y1 for ax in self.sp.axes) height = self.height return height - np.round([maxy * height, miny * height]).astype(int) @property def column_starts(self): x0 = self.data_xlim[0] return -x0 + np.round(self.width * np.array( [ax.get_position().x0 for ax in self.sp.axes])).astype(int) @property def column_ends(self): x0 = self.data_xlim[0] return -x0 + np.round(self.width * np.array( [ax.get_position().x1 for ax in self.sp.axes])).astype(int) @property def column_bounds(self): return np.vstack([self.column_starts, self.column_ends]).T @property def full_df(self): # get the full_df in pixel coordinates df = self.transformed_data height = np.round(np.diff(self.data_ylim)).astype(int)[0] ncols = df.shape[1] interpolated = np.zeros((height, ncols), dtype=int) for i in np.arange(ncols): interpolated[:, i] = np.round(np.interp( np.arange(height), df.index.values, df.iloc[:, i].values, left=0, right=0)) return pd.DataFrame(interpolated, index=pd.Index(np.arange(height)), columns=np.arange(ncols)).fillna(0)
[docs] def set_xtranslation(self, stradi): stradi.data_reader.xaxis_px = np.array([0, 1]) dx = np.diff(self.sp[0].psy.ax.get_xlim()) stradi.data_reader.xaxis_data = np.r_[ 0, dx / np.diff(self.column_bounds[0])]
[docs] def init_stradi(self, datalim=True, columns=True, names=True, digitize=True, samples=True, axes=False): path = self.filepath if not axes: path += '-no-axes' image = Image.open(path + '.png') stradi = Straditizer(image) if datalim: stradi.data_xlim = self.data_xlim stradi.data_ylim = self.data_ylim stradi.init_reader('area' if not self.use_bars else 'bars') else: return stradi stradi.yaxis_px = np.array([0, np.diff(self.data_ylim)[0]]) stradi.yaxis_data = np.array(self.sp[0].psy.ax.get_ylim())[::-1] if columns: stradi.data_reader.column_starts = self.column_starts else: return stradi if names: stradi.colnames_reader.column_names = self.data.columns.tolist() self.set_xtranslation(stradi) if digitize: stradi.data_reader.digitize() stradi.data_reader._full_df.loc[:] = np.where( stradi.data_reader.full_df.values, self.full_df.values, 0) else: return stradi if samples: stradi.data_reader.sample_locs = self.transformed_data return stradi
[docs] def evaluate_column_starts(self, close=True, base='starts_'): stradi = self.init_stradi(columns=False, axes=True) starts = stradi.data_reader._get_column_starts() ref = self.column_starts results = self.results.copy() missing_cols = len(ref) - len(starts) results[base + 'missmatch'] = 100 * (len(starts) - len(ref)) / len(ref) results[base + 'missing'] = missing_cols if not missing_cols: diff = np.abs(starts - ref).sum() width = np.diff(self.data_xlim)[0] results[base + 'rmse'] = rmse(starts, ref) * self.summed_perc / width results[base + 'abs'] = diff * self.summed_perc / width # now remove the yaxes stradi.data_reader.recognize_yaxes(remove=True) starts = stradi.data_reader.column_starts diff = np.abs(starts - ref).sum() results[base + 'rmse_removey'] = rmse(starts, ref) * \ self.summed_perc / width results[base + 'abs_removey'] = \ diff * self.summed_perc / width self.results = results return stradi.close() if close else stradi
[docs] def evaluate_yaxes_removal(self, close=True): stradi = self.init_stradi(columns=False, axes=True) stradi_ref = self.init_stradi(digitize=False) # get the vertical axes that are the difference between the two images orig = stradi.data_reader.binary.copy() ref = (orig - stradi_ref.data_reader.binary).astype(bool) # remove the y-axes stradi.data_reader._get_column_starts() stradi.data_reader.recognize_yaxes(remove=True) sim = (orig - stradi.data_reader.binary).astype(bool) results = self.results ref_sum = ref.sum() # pixels that have wrongly considered as being part of the y-axes false_positive = sim & (~ref) # pixels that have wrongly not considered as being part of the y-axis false_negative = (~sim) & ref # pixels that have been identified correctly correct = sim & ref # calculate rmse results['yaxes_false_pos'] = 100 * false_positive.sum() / ref_sum results['yaxes_false_neg'] = 100 * false_negative.sum() / ref_sum results['yaxes_correct'] = 100 * correct.sum() / ref_sum self.results = results return (stradi.close(), stradi_ref.close()) if close else ( stradi, stradi_ref)
[docs] def evaluate_sample_accuracy(self, close=True, stradi=None, base='samples_'): stradi = stradi or self.init_stradi(digitize=False) stradi.data_reader.digitize() if stradi.data_reader.exaggerated_reader is not None: stradi.data_reader.digitize_exaggerated() results = self.results ref = self.data.fillna(0) full_df = stradi.full_df indexes = list(map(partial(full_df.index.get_loc, method='nearest'), ref.index)) sim = full_df.iloc[indexes].values ref = ref.values results[base + 'rmse'] = rmse(sim, ref) results[base + 'too_high'] = 100 * (sim > ref).sum() / ref.size results[base + 'too_low'] = 100 * (sim < ref).sum() / ref.size mask5p = ref.astype(bool) & (~np.isnan(ref)) & (ref <= 5) sim = sim[mask5p] ref = ref[mask5p] results[base + '5p_rmse'] = rmse(sim, ref) results[base + '5p_too_high'] = 100 * (sim > ref).sum() / ref.size results[base + '5p_too_low'] = 100 * (sim < ref).sum() / ref.size self.results = results return stradi.close() if close else stradi
[docs] def evaluate_sample_position(self, close=True, stradi=None, base='samples_'): stradi = stradi or self.init_stradi(samples=False) stradi.data_reader.add_samples( *stradi.data_reader.find_samples(max_len=8)) final = stradi.final_df ref = self.data nfound = len(final) nref = len(ref) results = self.results results[base + 'missmatch'] = 100 * abs(nfound - nref)/ nref results[base + 'missing'] = nref - nfound closest = list(map( partial(final.index.get_loc, method='nearest'), ref.index)) age_range = ref.index.max() - ref.index.min() # normalized rmse of the age results[base + 'nrmse_y'] = rmse( final.index[closest].values, ref.index.values) / age_range * 100 self.results = results return stradi.close() if close else stradi
[docs] def evaluate_full(self, close=True): stradi = self.evaluate_column_starts(False, 'full_starts_') self.set_xtranslation(stradi) if len(stradi.data_reader.column_starts) == len(self.column_starts): stradi = self.evaluate_sample_accuracy( False, stradi, 'full_samples_') return self.evaluate_sample_position( close, stradi, 'full_samples_')
[docs] def run(self): """Run all evaluations""" self.evaluate_column_starts() self.evaluate_yaxes_removal() self.evaluate_sample_accuracy() self.evaluate_sample_position() self.evaluate_full()
[docs] def close(self): import matplotlib.pyplot as plt self.sp.close(figs=True, data=True, ds=True) del self.sp plt.close('all')
[docs]class NoVerticalsEvaluator(StraditizeEvaluator): """An evaluator for an image without y-axis"""
[docs] def export(self, *args, **kwargs): super().export(*args, **kwargs) shutil.copyfile(self.filepath + '-no-axes.png', self.filepath + '.png')
[docs] def evaluate_column_starts(self, close=True, base='starts_'): stradi = self.init_stradi(columns=False, axes=True) starts = stradi.data_reader._get_column_starts() ref = self.column_starts results = self.results.copy() missing_cols = len(ref) - len(starts) results[base + 'missmatch'] = 100 * (len(starts) - len(ref)) / len(ref) results[base + 'missing'] = missing_cols if not missing_cols: diff = np.abs(starts - ref).sum() width = np.diff(self.data_xlim)[0] results[base + 'rmse'] = rmse(starts, ref) * self.summed_perc / width results[base + 'abs'] = diff * self.summed_perc / width self.results = results return stradi.close() if close else stradi
[docs] def evaluate_yaxes_removal(self, close=True): return
[docs]class ExaggerationsEvaluator(StraditizeEvaluator): """An evaluator with exaggerations""" def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) # draw red exaggerations self.sp.update(exag_factor=5, exag='areax', exag_color=[1, 0, 0, 1])
[docs] def init_stradi(self, *args, **kwargs): stradi = super().init_stradi(*args, **kwargs) if stradi.data_reader is not None and stradi.data_reader.columns: exag = stradi.data_reader.create_exaggerations_reader(5) im = np.asarray(stradi.data_reader.image) mask = (im[..., 0] > 200) & (im[..., 1:-1].sum(axis=-1) < 100) exag.mark_as_exaggerations(mask) return stradi
[docs]class BaselineScenario: """The baseline evaluation scenario for straditize with data from POLNET This class uses the default settings of the :class:`StraditizeEvaluator` and runs the analysis for a given dataset from POLNET.""" def __init__(self, output_dir='.'): self.output_dir = output_dir self.failed = [] self._all_results = [] self.results = None def __reduce__(self): return (self.__class__, (self.output_dir, ), {'failed': self.failed, 'results': self.results, '_all_results': [] } ) # do not distribute all results
[docs] def run(self, data, processes=None): self.failed.extend(data.e_.unique()) all_results = self._all_results grouped = data.groupby('e_') progress_args = (grouped.ngroups, 'Progress', 'Complete', 50) print_progressbar(0, *progress_args) with warnings.catch_warnings(): warnings.filterwarnings('ignore', 'Distinct samples merged from', UserWarning) warnings.filterwarnings('ignore', 'divide by zero encountered', RuntimeWarning) pool = mp.Pool(processes) for i, results in enumerate(pool.imap_unordered(self, grouped), 1): if np.ndim(results): all_results.append(results) self.failed.remove(int(results.name[0])) print_progressbar(i, *progress_args) pool.close() pool.join() pool.terminate() self.results = pd.concat(all_results, axis=1, sort=False).T self.results.index.names = self.index_names
index_names = ['e_', 'ntaxa', 'nsamples']
[docs] def init_evaluator(self, name, data, *args, **kwargs): """Initialize an evaluator for a given data set""" import matplotlib.pyplot as plt if 'fig' not in kwargs: kwargs['fig'] = plt.figure(figsize=(12, 6)) return StraditizeEvaluator.from_polnet( data, *args, name=str(name), **kwargs)
[docs] def export_evaluator(self, evaluator, *args, **kwargs): evaluator.export(osp.join(self.output_dir, evaluator.name), *args, **kwargs)
def __call__(self, data_tuple): key, group = data_tuple try: evaluator = self.init_evaluator(key, group) except Exception: return key else: try: self.export_evaluator(evaluator) evaluator.run() except Exception: evaluator.close() return key results = evaluator.results evaluator.close() return results
[docs]class DPI600Scenario(BaselineScenario): """Another evaluation scenario but with a resolution of 600 dpi"""
[docs] def export_evaluator(self, *args, **kwargs): kwargs['dpi'] = 600 return super().export_evaluator(*args, **kwargs)
[docs]class DPI150Scenario(BaselineScenario): """Another evaluation scenario but with a resolution of 150 dpi"""
[docs] def export_evaluator(self, *args, **kwargs): kwargs['dpi'] = 150 return super().export_evaluator(*args, **kwargs)
[docs]class BlackWhiteScenario(BaselineScenario): """An evaluation scenario with a binary (black and white) image"""
[docs] def init_evaluator(self, *args, **kwargs): evaluator = super().init_evaluator(*args, **kwargs) evaluator.sp.update(color='k') return evaluator
[docs]class NoVerticalsScenario(BaselineScenario): """An evaluation scenario without y-axes in the plot"""
[docs] def init_evaluator(self, name, data, *args, **kwargs): """Initialize an evaluator for a given data set""" import matplotlib.pyplot as plt if 'fig' not in kwargs: kwargs['fig'] = plt.figure(figsize=(12, 6)) return NoVerticalsEvaluator.from_polnet( data, *args, name=str(name), **kwargs)
[docs]class ExaggerationsScenario(BaselineScenario): """An evaluation scenario with an exaggerated plot of low percentages"""
[docs] def init_evaluator(self, name, data, *args, **kwargs): """Initialize an evaluator for a given data set""" import matplotlib.pyplot as plt if 'fig' not in kwargs: kwargs['fig'] = plt.figure(figsize=(12, 6)) return ExaggerationsEvaluator.from_polnet( data, *args, name=str(name), **kwargs)