Source code for straditize.widgets.pattern_selection
"""A wdiget to select patterns in the image
The :class:`PatternSelectionWidget` is used by the
:class:`straditize.widget.selection_toolbar.SelectionToolbar` to select
patterns in the straditizer image
**Disclaimer**
Copyright (C) 2018-2019 Philipp S. Sommer
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>."""
from __future__ import division
import numpy as np
import datetime as dt
from itertools import product
from psyplot_gui.common import DockMixin
from psyplot_gui.compat.qtcompat import (
QWidget, Qt, with_qt5, QLabel, QHBoxLayout,
QVBoxLayout, QPushButton, QGridLayout)
from matplotlib.backends.backend_qt5agg import (
FigureCanvasQTAgg as FigureCanvas)
from matplotlib.widgets import RectangleSelector
from matplotlib.figure import Figure
if with_qt5:
from PyQt5.QtWidgets import (
QSizePolicy, QSlider, QGroupBox, QFormLayout, QProgressDialog)
else:
from PyQt4.QtGui import (
QSizePolicy, QSlider, QGroupBox, QFormLayout, QProgressDialog)
[docs]class EmbededMplCanvas(FigureCanvas):
"""Ultimately, this is a QWidget (as well as a FigureCanvasAgg, etc.)."""
def __init__(self, parent=None, *args, **kwargs):
fig = Figure(*args, **kwargs)
FigureCanvas.__init__(self, fig)
self.setParent(parent)
FigureCanvas.setSizePolicy(self,
QSizePolicy.Expanding,
QSizePolicy.Expanding)
FigureCanvas.updateGeometry(self)
[docs]class PatternSelectionWidget(QWidget, DockMixin):
"""A wdiget to select patterns in the image
This widget consist of an :class:`EmbededMplCanvas` to display the template
for the pattern and uses the :func:`skimage.feature.match_template`
function to identify it in the :attr:`arr`
See Also
--------
straditize.widget.selection_toolbar.SelectionToolbar.start_pattern_selection
"""
#: The template to look for in the :attr:`arr`
template = None
#: The selector to select the template in the original image
selector = None
#: The extents of the :attr:`template` in the original image
template_extents = None
#: The matplotlib artist of the :attr:`template` in the
#: :attr:`template_fig`
template_im = None
#: The :class:`EmbededMplCanvas` to display the :attr:`template`
template_fig = None
axes = None
#: A QSlider to set the threshold for the template correlation
sl_thresh = None
_corr_plot = None
key_press_cid = None
def __init__(self, arr, data_obj, remove_selection=False, *args, **kwargs):
"""
Parameters
----------
arr: np.ndarray of shape ``(Ny, Nx)``
The labeled selection array
data_obj: straditize.label_selection.LabelSelection
The data object whose image shall be selected
remove_selection: bool
If True, remove the selection on apply
"""
super(PatternSelectionWidget, self).__init__(*args, **kwargs)
self.arr = arr
self.data_obj = data_obj
self.remove_selection = remove_selection
self.template = None
# the figure to show the template
self.template_fig = EmbededMplCanvas()
# the button to select the template
self.btn_select_template = QPushButton('Select a template')
self.btn_select_template.setCheckable(True)
# the checkbox to allow fractions of the template
self.fraction_box = QGroupBox('Template fractions')
self.fraction_box.setCheckable(True)
self.fraction_box.setChecked(False)
self.fraction_box.setEnabled(False)
self.sl_fraction = QSlider(Qt.Horizontal)
self.lbl_fraction = QLabel('0.75')
self.sl_fraction.setValue(75)
# the slider to select the increments of the fractions
self.sl_increments = QSlider(Qt.Horizontal)
self.sl_increments.setValue(3)
self.sl_increments.setMinimum(1)
self.lbl_increments = QLabel('3')
# the button to perform the correlation
self.btn_correlate = QPushButton('Find template')
self.btn_correlate.setEnabled(False)
# the button to plot the correlation
self.btn_plot_corr = QPushButton('Plot correlation')
self.btn_plot_corr.setCheckable(True)
self.btn_plot_corr.setEnabled(False)
# slider for subselection
self.btn_select = QPushButton('Select pattern')
self.sl_thresh = QSlider(Qt.Horizontal)
self.lbl_thresh = QLabel('0.5')
self.btn_select.setCheckable(True)
self.btn_select.setEnabled(False)
self.sl_thresh.setValue(75)
self.sl_thresh.setVisible(False)
self.lbl_thresh.setVisible(False)
# cancel and close button
self.btn_cancel = QPushButton('Cancel')
self.btn_close = QPushButton('Apply')
self.btn_close.setEnabled(False)
vbox = QVBoxLayout()
vbox.addWidget(self.template_fig)
hbox = QHBoxLayout()
hbox.addStretch(0)
hbox.addWidget(self.btn_select_template)
vbox.addLayout(hbox)
fraction_layout = QGridLayout()
fraction_layout.addWidget(QLabel('Fraction'), 0, 0)
fraction_layout.addWidget(self.sl_fraction, 0, 1)
fraction_layout.addWidget(self.lbl_fraction, 0, 2)
fraction_layout.addWidget(QLabel('Increments'), 1, 0)
fraction_layout.addWidget(self.sl_increments, 1, 1)
fraction_layout.addWidget(self.lbl_increments, 1, 2)
self.fraction_box.setLayout(fraction_layout)
vbox.addWidget(self.fraction_box)
vbox.addWidget(self.btn_correlate)
vbox.addWidget(self.btn_plot_corr)
vbox.addWidget(self.btn_select)
thresh_box = QHBoxLayout()
thresh_box.addWidget(self.sl_thresh)
thresh_box.addWidget(self.lbl_thresh)
vbox.addLayout(thresh_box)
hbox = QHBoxLayout()
hbox.addWidget(self.btn_cancel)
hbox.addWidget(self.btn_close)
vbox.addLayout(hbox)
self.setLayout(vbox)
self.btn_select_template.clicked.connect(
self.toggle_template_selection)
self.sl_fraction.valueChanged.connect(
lambda i: self.lbl_fraction.setText(str(i / 100.)))
self.sl_increments.valueChanged.connect(
lambda i: self.lbl_increments.setText(str(i)))
self.btn_correlate.clicked.connect(self.start_correlation)
self.btn_plot_corr.clicked.connect(self.toggle_correlation_plot)
self.sl_thresh.valueChanged.connect(
lambda i: self.lbl_thresh.setText(str((i - 50) / 50.)))
self.sl_thresh.valueChanged.connect(self.modify_selection)
self.btn_select.clicked.connect(
self.toggle_selection)
self.btn_cancel.clicked.connect(self.cancel)
self.btn_close.clicked.connect(self.remove_plugin)
[docs] def toggle_template_selection(self):
"""Enable or disable the template selection"""
if (not self.btn_select_template.isChecked() and
self.selector is not None):
self.selector.set_active(False)
for a in self.selector.artists:
a.set_visible(False)
self.btn_select_template.setText('Select a template')
elif self.selector is not None and self.template_im is not None:
self.selector.set_active(True)
for a in self.selector.artists:
a.set_visible(True)
self.btn_select_template.setText('Apply')
else:
self.selector = RectangleSelector(
self.data_obj.ax, self.update_image, interactive=True)
if self.template_extents is not None:
self.selector.draw_shape(self.template_extents)
self.key_press_cid = self.data_obj.ax.figure.canvas.mpl_connect(
'key_press_event', self.update_image)
self.btn_select_template.setText('Cancel')
self.data_obj.draw_figure()
if self.template is not None:
self.fraction_box.setEnabled(True)
self.sl_increments.setMaximum(min(self.template.shape[:2]))
self.btn_correlate.setEnabled(True)
[docs] def update_image(self, *args, **kwargs):
"""Update the template image based on the :attr:`selector` extents"""
if self.template_im is not None:
self.template_im.remove()
del self.template_im
elif self.axes is None:
self.axes = self.template_fig.figure.add_subplot(111)
self.template_fig.figure.subplots_adjust(bottom=0.3)
if not self.selector.artists[0].get_visible():
self.template_extents = None
self.template = None
self.btn_select_template.setText('Cancel')
else:
self.template_extents = np.round(self.selector.extents).astype(int)
x, y = self.template_extents.reshape((2, 2))
if getattr(self.data_obj, 'extent', None) is not None:
extent = self.data_obj.extent
x -= int(min(extent[:2]))
y -= int(min(extent[2:]))
slx = slice(*sorted(x))
sly = slice(*sorted(y))
self.template = template = self.arr[sly, slx]
if template.ndim == 3:
self.template_im = self.axes.imshow(template)
else:
self.template_im = self.axes.imshow(template, cmap='binary')
self.btn_select_template.setText('Apply')
self.template_fig.draw()
[docs] def start_correlation(self):
"""Look for the correlations of template and source"""
if self.fraction_box.isChecked():
self._fraction = self.sl_fraction.value() / 100.
increments = self.sl_increments.value()
else:
self._fraction = 0
increments = 1
corr = self.correlate_template(
self.arr, self.template, self._fraction, increments)
if corr is not None:
self._correlation = corr
enable = self._correlation is not None
self.btn_plot_corr.setEnabled(enable)
self.btn_select.setEnabled(enable)
[docs] def toggle_selection(self):
"""Modifiy the selection (or not) based on the template correlation"""
obj = self.data_obj
if self.btn_select.isChecked():
self._orig_selection_arr = obj._selection_arr.copy()
self._selected_labels = obj.selected_labels
self._select_cmap = obj._select_cmap
self._select_norm = obj._select_norm
self.btn_select.setText('Reset')
self.btn_close.setEnabled(True)
obj.unselect_all_labels()
self.sl_thresh.setVisible(True)
self.lbl_thresh.setVisible(True)
self.modify_selection(self.sl_thresh.value())
else:
if obj._selection_arr is not None:
obj._selection_arr[:] = self._orig_selection_arr
obj._select_img.set_array(self._orig_selection_arr)
obj.select_labels(self._selected_labels)
obj._update_magni_img()
del self._orig_selection_arr, self._selected_labels
self.btn_select.setText('Select pattern')
self.btn_close.setEnabled(False)
self.sl_thresh.setVisible(False)
self.lbl_thresh.setVisible(False)
obj.draw_figure()
[docs] def modify_selection(self, i):
"""Modify the selection based on the correlation threshold
Parameters
----------
i: int
An integer between 0 and 100, the value of the :attr:`sl_thresh`
slider"""
if not self.btn_select.isChecked():
return
obj = self.data_obj
val = (i - 50.) / 50.
# select the values above 50
if not self.remove_selection:
# clear the selection
obj._selection_arr[:] = obj._orig_selection_arr.copy()
select_val = obj._selection_arr.max() + 1
obj._selection_arr[self._correlation >= val] = select_val
else:
obj._selection_arr[:] = self._orig_selection_arr.copy()
obj._selection_arr[self._correlation >= val] = -1
obj._select_img.set_array(obj._selection_arr)
obj._update_magni_img()
obj.draw_figure()
[docs] def correlate_template(self, arr, template, fraction=False, increment=1,
report=True):
"""Correlate a template with the `arr`
This method uses the :func:`skimage.feature.match_template` function
to find the given `template` in the source array `arr`.
Parameters
----------
arr: np.ndarray of shape ``(Ny,Nx)``
The labeled selection array (see :attr:`arr`), the source of the
given `template`
template: np.ndarray of shape ``(nx, ny)``
The template from ``arr`` that shall be searched
fraction: float
If not null, we will look through the given fraction of the
template to look for partial matches as well
increment: int
The increment of the loop with the `fraction`.
report: bool
If True and `fraction` is not null, a QProgressDialog is opened
to inform the user about the progress"""
from skimage.feature import match_template
mask = self.data_obj.selected_part
x = mask.any(axis=0)
if not x.any():
raise ValueError("No data selected!")
y = mask.any(axis=1)
xmin = x.argmax()
xmax = len(x) - x[::-1].argmax()
ymin = y.argmax()
ymax = len(y) - y[::-1].argmax()
if arr.ndim == 3:
mask = np.tile(mask[..., np.newaxis], (1, 1, arr.shape[-1]))
src = np.where(mask[ymin:ymax, xmin:xmax],
arr[ymin:ymax, xmin:xmax], 0)
sny, snx = src.shape
if not fraction:
corr = match_template(src, template)
full_shape = np.array(corr.shape)
else: # loop through the template to allow partial hatches
shp = np.array(template.shape, dtype=int)[:2]
ny, nx = shp
fshp = np.round(fraction * shp).astype(int)
fny, fnx = fshp
it = list(product(
range(0, fny, increment), range(0, fnx, increment)))
ntot = len(it)
full_shape = fshp - shp + src.shape
corr = np.zeros(full_shape, dtype=float)
if report:
txt = 'Searching template...'
dialog = QProgressDialog(txt, 'Cancel',
0, ntot)
dialog.setWindowModality(Qt.WindowModal)
t0 = dt.datetime.now()
for k, (i, j) in enumerate(it):
if report:
dialog.setValue(k)
if k and not k % 10:
passed = (dt.datetime.now() - t0).total_seconds()
dialog.setLabelText(
txt + ' %1.0f seconds remaning' % (
(passed * (ntot / k - 1.))))
if report and dialog.wasCanceled():
return
else:
y_end, x_start = fshp - (i, j) - 1
sly = slice(y_end, full_shape[0])
slx = slice(0, -x_start or full_shape[1])
corr[sly, slx] = np.maximum(
corr[sly, slx],
match_template(src, template[:-i or ny, j:]))
ret = np.zeros_like(arr, dtype=corr.dtype)
dny, dnx = src.shape - full_shape
for i, j in product(range(dny + 1), range(dnx + 1)):
ret[ymin + i:ymax - dny + i, xmin + j:xmax - dnx + j] = np.maximum(
ret[ymin + i:ymax - dny + i, xmin + j:xmax - dnx + j], corr)
return np.where(mask, ret, 0)
[docs] def toggle_correlation_plot(self):
"""Toggle the correlation plot between :attr:`template` and :attr:`arr`
"""
obj = self.data_obj
if self._corr_plot is None:
self._corr_plot = obj.ax.imshow(
self._correlation, extent=obj._select_img.get_extent(),
zorder=obj._select_img.zorder + 0.1)
self._corr_cbar = obj.ax.figure.colorbar(
self._corr_plot, orientation='vertical')
self._corr_cbar.set_label('Correlation')
else:
for a in [self._corr_cbar, self._corr_plot]:
try:
a.remove()
except ValueError:
pass
del self._corr_plot, self._corr_cbar
obj.draw_figure()
[docs] def to_dock(self, main, title=None, position=None, docktype='df', *args,
**kwargs):
if position is None:
position = main.dockWidgetArea(main.help_explorer.dock)
connect = self.dock is None
ret = super(PatternSelectionWidget, self).to_dock(
main, title, position, docktype=docktype, *args, **kwargs)
if connect:
self.dock.toggleViewAction().triggered.connect(self.maybe_tabify)
return ret
[docs] def maybe_tabify(self):
main = self.dock.parent()
if self.is_shown and main.dockWidgetArea(
main.help_explorer.dock) == main.dockWidgetArea(self.dock):
main.tabifyDockWidget(main.help_explorer.dock, self.dock)
[docs] def cancel(self):
if self.btn_select.isChecked():
self.btn_select.setChecked(False)
self.toggle_selection()
self.remove_plugin()
[docs] def remove_plugin(self):
from psyplot_gui.main import mainwindow
if self.selector is not None:
self.selector.disconnect_events()
for a in self.selector.artists:
try:
a.remove()
except ValueError:
pass
self.data_obj.draw_figure()
del self.selector
if self._corr_plot is not None:
self.toggle_correlation_plot()
if self.key_press_cid is not None:
self.data_obj.ax.figure.canvas.mpl_disconnect(self.key_press_cid)
try:
self.template_im.remove()
except (AttributeError, ValueError):
pass
try:
self.template_fig.delaxes(self.axes)
except (AttributeError, ValueError, RuntimeError):
pass
try:
self.template_fig.close()
except (AttributeError, RuntimeError):
pass
for attr in ['data_obj', 'arr', 'template', 'key_press_cid',
'_orig_selection_arr', '_selected_labels', '_select_cmap',
'_select_norm', '_correlation', 'axes', 'template_fig',
'template_im']:
try:
delattr(self, attr)
except AttributeError:
pass
try:
super().remove_plugin()
except RuntimeError: # already removed
pass
if self.dock is not None:
try:
mainwindow.dockwidgets.remove(self.dock)
except (ValueError, AttributeError):
pass
del self.dock