PK ! &, dlc_gui/__init__.pyimport sys if hasattr(sys, "_called_from_test"): print("In pytest, skipping __init__.py") else: from .gui import show # pragma: no cover __all__ = ["show"] # pragma: no cover PK ! IZ dlc_gui/__main__.pyimport argparse import dlc_gui parser = argparse.ArgumentParser() parser.add_argument("config", help="Abs path to config.yaml", nargs="?") args = parser.parse_args() config = args.config dlc_gui.show(config) PK ! T=} } dlc_gui/data.py""" This module handles the data creation and handling of dlc_gui. It creates the main pandas DataFrame, color palette, and `frames_dict`. `frames_dict` is a dictionary of frame_name (str): frame_path (path object) pairs. """ # TODO make sure each exception encloses paths with quotes from pathlib import Path from typing import List, Tuple, Union import numpy as np import pandas as pd import dlc_gui.util class DataModel: """ Create useful data structures such as the main pandas DataFrame used by DeepLabCut, the frames_dict which keeps allows translation between abs and rel paths, and the color palette. """ def __init__(self, config_path): # Initialize without a given directory of frames or a h5 file # Define attributes as empty or None, because the rest of the code # expects their existence self.config_path = config_path self.config_dict = dlc_gui.util.read_config_file(self.config_path) self.scorer = self.config_dict["scorer"] self.bodyparts = self.config_dict["bodyparts"] # Make sure the project path is valid and exists try: self.project_path = Path(self.config_dict["project_path"]).resolve() if not self.project_path.is_dir(): raise FileNotFoundError( "'project_path' ({0}) in config.yaml does not exist.".format( self.config_dict["project_path"] ) ) except TypeError as e: raise TypeError( "'project_path' in config.yaml has an invalid type of {0}".format( type(self.config_dict["project_path"]) ) ) from e self.colors, self.colors_opposite, self.colors_opaque = self.color_palette( len(self.bodyparts) ) # Define variables to be used in gui.py for QFileDialog default dirs self.labeled_data_path = self.project_path / "labeled-data" self.save_path_hdf = self.labeled_data_path / "CollectedData_{}.h5".format( self.scorer ) self.save_path_pkl = self.labeled_data_path / "CollectedData_{}.pkl".format( self.scorer ) # TODO find replacement for pd.concat to avoid defining as None self.data_frame = None self.frames_dict = {} def init_from_dir(self, dir: Union[str, Path]) -> int: # Defines self.frames_dict and self.data_frame based on a dir # Exit codes: # 0 - Success # 1 - No images found in directory # 2 - Dir not in project path # 3 - Invalid path string given (e.g. None) try: dir = Path(dir) except TypeError: return 3 if self.project_path not in dir.parents: return 2 self.frames_paths = sorted(Path(dir).glob("*.png")) if not self.frames_paths: return 1 self.frames_names = [ str(Path(frame_path).relative_to(self.project_path)) for frame_path in self.frames_paths ] self.frames_dict = dict(zip(self.frames_names, self.frames_paths)) init_nan = np.empty((len(self.frames_paths), 2)) init_nan[:] = np.nan for bodypart in self.bodyparts: index = pd.MultiIndex.from_product( [[self.scorer], [bodypart], ["x", "y"]], names=["scorer", "bodyparts", "coords"], ) frame = pd.DataFrame(init_nan, columns=index, index=self.frames_names) self.data_frame = pd.concat([self.data_frame, frame], axis=1) return 0 def init_from_file(self, file: Union[str, Path]) -> int: # Defines self.frames_dict and self.data_frame based on a h5 file # Due to the inconsistencies between ``to_csv``, ``from_csv``, # ``read_csv``, etc., ONLY '.h5' files will be accepted. # https://github.com/pandas-dev/pandas/issues/13262 # TODO Proper extension checking # Exit codes: # 0 - Success # 1 - Invalid file given # 2 - Malformed h5 file # 3 - Could not find the DataFrame index from the file try: file = Path(file) except TypeError: return 1 if not file.is_file(): return 1 if file.suffix in (".hdf", ".h5"): try: self.data_frame = pd.read_hdf(file, "df_with_missing") except KeyError: return 2 elif file.suffix in (".pkl", ".pickle"): self.data_frame = pd.read_pickle(file) try: self.frames_names = sorted(self.data_frame.index.tolist()) except AttributeError: return 3 self.frames_paths = [Path(self.project_path, _) for _ in self.frames_names] # TODO avoid copy pasted code self.frames_dict = dict(zip(self.frames_names, self.frames_paths)) return 0 def color_palette(self, number: int) -> Tuple[List, List, List]: # Create a list of QColors and their opposites equal in length to the # number of bodyparts # TODO set alpha from config hues = np.linspace(0, 1, number, endpoint=False) colors = [(h, 1, 1, 0.5) for h in hues] colors_opposite = [(abs(0.5 - h), 1, 1, 0.5) for h in hues] colors_opaque = [(h, 1, 1, 1) for h in hues] return colors, colors_opposite, colors_opaque def add_coords_to_dataframe(self, frame, bodypart, coords): if all(coord is None for coord in coords): coords = (np.nan, np.nan) try: self.data_frame.loc[frame, self.scorer][bodypart, "x"] = coords[0] self.data_frame.loc[frame, self.scorer][bodypart, "y"] = coords[1] except KeyError as e: raise KeyError( "The scorer of the config.yaml does not match this .h5 file." ) from e def get_coords_from_dataframe(self, frame, bodypart): x = self.data_frame.loc[frame, self.scorer][bodypart, "x"] y = self.data_frame.loc[frame, self.scorer][bodypart, "y"] if np.isnan(x): x = None if np.isnan(y): y = None return (x, y) def save_as_pkl(self, path): if path: pd.to_pickle(self.data_frame, path) def save_as_hdf(self, path): if path: self.data_frame.to_hdf(path, "df_with_missing", format="table", mode="w") PK ! .tU tU dlc_gui/gui.py""" This module handles all the GUI aspects of dlc_gui. This module creates a main window containing the main widget containing subset widgets. """ # TODO add feature to specify project_path import sys from typing import Union import webbrowser from PySide2.QtCore import QEvent, QRect, QRectF, Qt from PySide2.QtGui import QBrush, QColor, QCursor, QKeySequence, QPen, QPixmap from PySide2.QtWidgets import ( QAction, QApplication, QCheckBox, QDesktopWidget, QFileDialog, QGraphicsEllipseItem, QGraphicsScene, QGraphicsView, QGridLayout, QLabel, QListWidget, QMainWindow, QShortcut, QSlider, QSplitter, QToolTip, QVBoxLayout, QWidget, ) import dlc_gui.data import dlc_gui.util class GraphicsScene(QGraphicsScene): def __init__(self, parent): super(GraphicsScene, self).__init__(parent) def load_image(self, image: str) -> None: # Load frame png into scene self.frame_image = QPixmap() self.frame_image.load(image) self.addPixmap(self.frame_image) class GraphicsView(QGraphicsView): def __init__(self, parent): super(GraphicsView, self).__init__(parent) self.scene = GraphicsScene(self) self.setScene(self.scene) self.fitInView(self.scene.sceneRect(), aspectRadioMode=Qt.KeepAspectRatio) self.viewport().setCursor(Qt.CrossCursor) # keep track of the current scale value to prevent zooming too far out self.current_scale = 1.0 def zoom(self, pos: tuple, scale: float, anchor: str = "NoAnchor") -> None: if anchor == "NoAnchor": self.setTransformationAnchor(QGraphicsView.NoAnchor) self.setResizeAnchor(QGraphicsView.NoAnchor) old_pos = self.mapToScene(*pos) # Prevent zooming out beyond 0.3 or in beyond 33 if (self.current_scale > 0.3 or scale > 1) and ( self.current_scale < 33 or scale < 1 ): self.scale(scale, scale) self.current_scale *= scale new_pos = self.mapToScene(*pos) translate_delta = (new_pos - old_pos).toTuple() self.translate(*translate_delta) # Scroll wheel to zoom in and out def wheelEvent(self, event): scale_factor = 1.25 if event.delta() > 0: self.zoom(event.pos().toTuple(), scale_factor) else: self.zoom(event.pos().toTuple(), 1 / scale_factor) # Toggle on dragging when middle mouse is pressed def mousePressEvent(self, event): if ( event.button() == Qt.MiddleButton or QApplication.keyboardModifiers() == Qt.ShiftModifier ): self.__og_pos = event.pos() self.viewport().setCursor(Qt.ClosedHandCursor) else: return super(GraphicsView, self).mousePressEvent(event) # TODO find a legit way to engage the ScrollHandDrag mode # rather than simply using it to change the cursor look # Translate scene using scrollbars while middle button is held def mouseMoveEvent(self, event): if ( event.button() == Qt.MiddleButton or QApplication.keyboardModifiers() == Qt.ShiftModifier ): offset = self.__og_pos - event.pos() self.__og_pos = event.pos() self.verticalScrollBar().setValue( self.verticalScrollBar().value() + offset.y() ) self.horizontalScrollBar().setValue( self.horizontalScrollBar().value() + offset.x() ) else: super(GraphicsView, self).mouseMoveEvent(event) # Toggle off dragging when middle mouse is pressed def mouseReleaseEvent(self, event): if ( event.button() == Qt.MiddleButton or QApplication.keyboardModifiers() == Qt.ShiftModifier ): self.viewport().setCursor(Qt.CrossCursor) else: return super(GraphicsView, self).mouseReleaseEvent(event) class MainWidget(QWidget): """ Create the main user interface and controls, connected to DataModel """ def __init__(self, parent, config_path): super(MainWidget, self).__init__(parent) # Create the main widgets self.graphics_view = GraphicsView(self) self.bodyparts_view = QListWidget() self.dot_label_lines_state = QCheckBox("Show dot labels") self.dot_size_slider = QSlider(Qt.Horizontal) self.dot_size_label = QLabel(parent=self.dot_size_slider) self.frames_view = QListWidget() # Setup the data_model, get config values, and setup widgets based on data_model self.config_path = config_path data_model = dlc_gui.data.DataModel(self.config_path) self.init_from_data_model(data_model) # Setup the checkbox for dot labels visibility self.dot_label_lines_state.stateChanged.connect(lambda x: self.update_scene()) # Set up the dot_size_slider dot_size_from_config = self.data_model.config_dict["dotsize"] self.dot_size_slider.setMinimum(2) self.dot_size_slider.setMaximum(100) self.dot_size_slider.setValue(dot_size_from_config) self.dot_size_slider.setTickPosition(QSlider.TicksBothSides) self.dot_size_slider.setTickInterval(10) self.dot_size_label.setText( "Label dot size: {} (from config.yaml)".format(dot_size_from_config) ) self.dot_size_slider.valueChanged.connect( lambda: self.dot_size_label.setText( "Label dot size: {}".format(self.dot_size_slider.value()) ) ) self.dot_size_slider.valueChanged.connect(lambda: self.update_scene()) # Add a widget to add a layout containing the bodyparts and the slider labeling_widget = QWidget() labeling_layout = QVBoxLayout() labeling_layout.addWidget(self.bodyparts_view) labeling_layout.addWidget(self.dot_label_lines_state) labeling_layout.addWidget(self.dot_size_label) labeling_layout.addWidget(self.dot_size_slider) labeling_widget.setLayout(labeling_layout) # Set the main layout of Widget main_layout = QGridLayout() splitter = QSplitter() splitter.addWidget(self.frames_view) splitter.addWidget(self.graphics_view) splitter.addWidget(labeling_widget) splitter.setStretchFactor(1, 1) main_layout.addWidget(splitter) self.setLayout(main_layout) # Set up events self.frames_view.currentItemChanged.connect(lambda x: self.update_scene()) self.graphics_view.scene.installEventFilter(self) shortcut_next_bodypart = QShortcut(QKeySequence("d"), self) shortcut_prev_bodypart = QShortcut(QKeySequence("a"), self) shortcut_next_frame = QShortcut(QKeySequence("s"), self) shortcut_prev_frame = QShortcut(QKeySequence("w"), self) # setattr is used as an assignment function, because lambdas cannot assign # prior to python 3.8 shortcut_next_bodypart.activated.connect(lambda: self.switch_bodypart("next")) shortcut_prev_bodypart.activated.connect(lambda: self.switch_bodypart("prev")) shortcut_next_frame.activated.connect( lambda: setattr(self, "current_frame_row", self.current_frame_row + 1) ) shortcut_prev_frame.activated.connect( lambda: setattr(self, "current_frame_row", self.current_frame_row - 1) ) shortcut_toggle_dot_label_lines_state = QShortcut(QKeySequence("f"), self) shortcut_toggle_dot_label_lines_state.activated.connect( lambda: self.dot_label_lines_state.toggle() ) self.save_file_dialog = QFileDialog() self.save_file_dialog.setFileMode(QFileDialog.AnyFile) self.save_file_dialog.setAcceptMode(QFileDialog.AcceptSave) def init_from_data_model_from_file(self, dataframe_file) -> None: data_model = dlc_gui.data.DataModel(self.config_path) exit_code = data_model.init_from_file(dataframe_file) if exit_code == 1: self.send_status("Invalid file.", 5) elif exit_code == 2 or exit_code == 3: self.send_status( "Error: {} is not structured correctly.".format(dataframe_file), 5 ) elif exit_code == 0: self.init_from_data_model(data_model) def init_from_data_model_from_dir(self, dir: str) -> None: data_model = dlc_gui.data.DataModel(self.config_path) exit_code = data_model.init_from_dir(dir) if exit_code == 1: self.send_status("Error: No frames (*.png) found in {}".format(dir), 5) elif exit_code == 2: self.send_status( 'Error: {} is not within the project path "{}"'.format( dir, self.data_model.config_dict["project_path"] ), 5, ) elif exit_code == 0: self.init_from_data_model(data_model) def init_from_data_model(self, data_model) -> None: self.data_model = data_model self.bodyparts = self.data_model.bodyparts self.project_path = self.data_model.project_path # Populate the frames and bodyparts lists if self.data_model.frames_dict: self.frames_dict = self.data_model.frames_dict self.frames_view.clear() for frame in self.frames_dict.keys(): self.frames_view.addItem(frame) self.bodyparts_view.clear() for bodypart in self.bodyparts: self.bodyparts_view.addItem(bodypart) # Convert tuples to QColors self.data_model.colors = [ QColor.fromHsvF(*color) for color in self.data_model.colors ] self.data_model.colors_opposite = [ QColor.fromHsvF(*color) for color in self.data_model.colors_opposite ] self.data_model.colors_opaque = [ QColor.fromHsvF(*color) for color in self.data_model.colors_opaque ] # Set initial selections for both listwidgets, load the first frame, # and add color icons frames_view_items = self.frames_view.findItems("*", Qt.MatchWildcard) if frames_view_items: self.frames_view.setCurrentItem(frames_view_items[0]) self.graphics_view.scene.load_image(str(list(self.frames_dict.values())[0])) self.bodyparts_view_items = self.bodyparts_view.findItems("*", Qt.MatchWildcard) if self.bodyparts_view_items: self.bodyparts_view.setCurrentItem(self.bodyparts_view_items[0]) for item, color in zip( self.bodyparts_view_items, self.data_model.colors_opaque ): pixmap = QPixmap(100, 100) pixmap.fill(color) item.setIcon(pixmap) self.update_scene() def eventFilter(self, obj, event): """ Implement left and right mouse clicking functionalities """ # Check if the click is within the QGraphicsView if ( obj is self.graphics_view.scene and event.type() == QEvent.Type.GraphicsSceneMousePress ): scene_pos = event.scenePos() coords = (scene_pos.x(), scene_pos.y()) frame = self.current_frame_text bodypart = self.current_bodypart_text if frame and bodypart: if event.buttons() == Qt.LeftButton: self.data_model.add_coords_to_dataframe(frame, bodypart, coords) elif event.buttons() == Qt.RightButton: self.data_model.add_coords_to_dataframe( frame, bodypart, (None, None) ) self.update_scene() return super(MainWidget, self).eventFilter(obj, event) def switch_bodypart(self, prev_or_next: str) -> None: if prev_or_next == "prev": setattr(self, "current_bodypart_row", self.current_bodypart_row - 1) elif prev_or_next == "next": setattr(self, "current_bodypart_row", self.current_bodypart_row + 1) tool_tip = QToolTip() tool_tip.showText( QCursor.pos(), self.current_bodypart_text, self.graphics_view, QRect(), 300 ) @property def current_bodypart_row(self) -> int: return self.bodyparts_view.currentRow() @current_bodypart_row.setter def current_bodypart_row(self, row) -> None: self.bodyparts_view.setCurrentRow(row) @property def current_frame_row(self) -> int: return self.frames_view.currentRow() @current_frame_row.setter def current_frame_row(self, row) -> None: self.frames_view.setCurrentRow(row) @property def current_bodypart_text(self) -> Union[str, None]: try: return self.bodyparts_view.currentItem().text() except AttributeError: return None @property def current_frame_text(self) -> Union[str, None]: try: return self.frames_view.currentItem().text() except AttributeError: return None def send_status(self, msg, timeout) -> None: self.parent().status_bar.showMessage(msg, timeout * 1000) def save_as_pkl(self) -> None: self.save_file_dialog.selectFile(str(self.data_model.save_path_pkl)) self.save_file_dialog.setNameFilter("(*.pkl *.pickle)") self.save_file_dialog.setDefaultSuffix(".pkl") if self.save_file_dialog.exec(): save_path = self.save_file_dialog.selectedFiles()[0] self.data_model.save_as_pkl(save_path) def save_as_hdf(self) -> None: self.save_file_dialog.selectFile(str(self.data_model.save_path_hdf)) self.save_file_dialog.setNameFilter("(*.h5 *.hdf)") self.save_file_dialog.setDefaultSuffix(".h5") if self.save_file_dialog.exec(): save_path = self.save_file_dialog.selectedFiles()[0] self.data_model.save_as_hdf(save_path) # Updating scene is in MainWidget and not GraphicsScene because it needs to know # current frame and current bodypart, both properties of MainWidget def update_scene(self) -> None: def add_dots_to_scene( coords: tuple, size: float, brush_color: QColor, pen_color: QColor, tooltip: str, ) -> None: # Adds dots to the scene x, y = coords dot_rect = QRectF(x - size / 2, y - size / 2, size, size) dot_brush = QBrush(Qt.SolidPattern) dot_brush.setColor(brush_color) dot_pen = QPen(dot_brush, size / 40) dot_pen.setColor(pen_color) dot_ellipse = QGraphicsEllipseItem(dot_rect) dot_ellipse.setPen(dot_pen) dot_ellipse.setBrush(dot_brush) dot_ellipse.setToolTip(tooltip) self.graphics_view.scene.addItem(dot_ellipse) def add_dot_labels_to_scene( text: str, coords: tuple, size: float, fg_color: QColor, bg_color: QColor ) -> None: x, y = coords x_offset = size * 2 y_offset = size * 2 label = self.graphics_view.scene.addText(text) label.setPos( x + label.boundingRect().width() / 2 + x_offset, y - label.boundingRect().height() / 2 - y_offset, ) label.setHtml( """