#!/usr/bin/env python
"""
In bayesloop, each new data study is handled by an instance of a ``Study``-class. In this way, all data, the inference
results and the appropriate post-processing routines are stored in one object that can be accessed conveniently or
stored in a file. Apart from the basic ``Study`` class, there exist a number of specialized classes that extend the
basic fit method, for example to infer the full distribution of hyper-parameters or to apply model selection to on-line
data streams.
"""

from __future__ import division, print_function
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from scipy.optimize import minimize
from scipy.misc import factorial
from scipy.misc import logsumexp
import sympy.abc as abc
from sympy import Symbol
from sympy import lambdify
from sympy.stats import density
import sympy.stats
from copy import copy, deepcopy
from collections import OrderedDict, Iterable
from inspect import getargspec
from tqdm import tqdm, tqdm_notebook
from .helper import *
from .preprocessing import *
from .transitionModels import CombinedTransitionModel
from .transitionModels import SerialTransitionModel
from .exceptions import *


class Study(object):
    """
    Fits with fixed hyper-parameters and hyper-parameter optimization. This class implements a
    forward-backward-algorithm for analyzing time series data using hierarchical models. For efficient computation,
    all parameter distributions are discretized on a parameter grid.
    """
    def __init__(self):
        self.observationModel = None
        self.transitionModel = None

        self.gridSize = []
        self.boundaries = []
        self.marginalGrid = []
        self.grid = []
        self.latticeConstant = []

        self.rawData = np.array([])
        self.formattedData = np.array([])
        self.rawTimestamps = None
        self.formattedTimestamps = None

        self.posteriorSequence = []
        self.posteriorMeanValues = []
        self.logEvidence = 0
        self.localEvidence = []

        self.selectedHyperParameters = []

        print('+ Created new study.')

    def loadExampleData(self):
        """
        Loads UK coal mining disaster data.
        """
        self.rawData = np.array([5, 4, 1, 0, 4, 3, 4, 0, 6, 3, 3, 4, 0, 2, 6, 3, 3, 5, 4, 5, 3, 1, 4,
                                 4, 1, 5, 5, 3, 4, 2, 5, 2, 2, 3, 4, 2, 1, 3, 2, 2, 1, 1, 1, 1, 3, 0,
                                 0, 1, 0, 1, 1, 0, 0, 3, 1, 0, 3, 2, 2, 0, 1, 1, 1, 0, 1, 0, 1, 0, 0,
                                 0, 2, 1, 0, 0, 0, 1, 1, 0, 2, 3, 3, 1, 1, 2, 1, 1, 1, 1, 2, 3, 3, 0,
                                 0, 0, 1, 4, 0, 0, 0, 1, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0])

        self.rawTimestamps = np.arange(1852, 1962)

        print('+ Successfully imported example data.')

    def loadData(self, array, timestamps=None):
        """
        Loads Numpy array as data.

        Args:
            array(ndarray): Numpy array containing time series data
            timestamps(ndarray): Array of timestamps (same length as data array)
        """
        self.rawData = array
        if timestamps is not None:  # load custom timestamps
            if len(timestamps) == len(array):
                self.rawTimestamps = np.array(timestamps)
            else:
                print('! WARNING: Number of timestamps does not match number of data points. Omitting timestamps.')
        else:  # set default timestamps (integer range)
            self.rawTimestamps = np.arange(len(self.rawData))
        print('+ Successfully imported array.')

    def setObservationModel(self, L, silent=False):
        """
        Sets observation model (likelihood function) for analysis and creates parameter grid for inference routine.

        Args:
            L: Observation model class (see observationModels.py)
            silent(bool): If set to True, no output is generated by this method.
        """
        self.observationModel = L

        # prepare parameter grid
        self.marginalGrid = []
        self.gridSize = []
        self.boundaries = []
        self.latticeConstant = []
        for v, n in zip(self.observationModel.parameterValues, self.observationModel.parameterNames):
            if v is None:  # if user has not specified parameter values, we try to estimate them
                try:
                    v = self.observationModel.estimateParameterValues(n, self.rawData)
                    print('+ Estimated parameter interval for "{}": [{}, {}] ({} values).'
                          .format(n, v[0], v[-1], len(v)))
                except:
                    raise ConfigurationError('Could not estimate parameter values for "{}".'.format(n))

            v = np.array(v, dtype=np.float)  # inference algorithm needs floats

            self.marginalGrid.append(v)
            self.gridSize.append(len(v))
            self.boundaries.append([v[0], v[-1]])

            # check if parameter values are equally spaced
            if np.any(np.abs(np.diff(np.diff(v))) > 10 ** -10):
                print('! WARNING: Supplied parameter values for "{}" are not equally spaced. Assuming categorical '
                      'parameter.'.format(n))
                self.latticeConstant.append(1.)
            else:  # equally spaced (regular grid)
                self.latticeConstant.append(np.abs(v[0] - v[1]))

        # create grid
        self.grid = [m for m in np.meshgrid(*self.marginalGrid, indexing='ij')]

        # if observation model is updated, transition model must know the new lattice constant
        if self.transitionModel is not None:
            self.transitionModel.latticeConstant = self.latticeConstant

        if not silent:
            print('+ Observation model: {}. Parameter(s): {}'.format(L, L.parameterNames))

    def _computePrior(self, silent=False):
        """
        Computes discrete prior probabilities (densities) for the parameters of the observation model. The custom prior
        distribution may be passed as a Numpy array that has tha same shape as the parameter grid, as a(lambda)
        function or as a (list of) SymPy random variable(s).

        Args:
            silent(bool): If set to True, no output is generated by this method.

        Returns:
            ndarray: Prior probability (density) values with the same size as the parameter grid
        """
        prior = self.observationModel.prior

        # check whether correctly shaped numpy array is provided
        if isinstance(prior, np.ndarray):
            if np.all(prior.shape == self.grid[0].shape):
                if not silent:
                    print('    + Set prior (numpy array).')
                    return prior
            else:
                raise ConfigurationError('Prior array does not match parameter grid size.')

        # check whether function is provided
        if hasattr(prior, '__call__'):
            if not silent:
                print('    + Set prior (function): {}'.format(prior.__name__))
            return prior(*self.grid)

        # check whether single random variable is provided
        if type(prior) is sympy.stats.rv.RandomSymbol:
            prior = [prior]

        # check if list/tuple is provided
        if isinstance(prior, (list, tuple)) and not isinstance(prior, str):
            if len(prior) != len(self.observationModel.parameterNames):
                raise ConfigurationError('Observation model contains {} parameters, but {} priors were provided.'
                                         .format(len(self.observationModel.parameterNames), len(prior)))

            pdf = 1
            x = [abc.x]*len(prior)
            for i, rv in enumerate(prior):
                if type(rv) is not sympy.stats.rv.RandomSymbol:
                    raise ConfigurationError('Only lambda functions or SymPy random variables can be used as a prior.')
                if len(list(rv._sorted_args[0].distribution.free_symbols)) > 0:
                    raise ConfigurationError('Prior distribution must not contain free parameters.')

                # multiply total pdf with density for current parameter
                pdf = pdf*density(rv)(x[i])

            # set density as lambda function
            if not silent:
                print('    + Set prior (sympy): {}'.format(pdf))
            return lambdify(x, pdf, modules=['numpy', {'factorial': factorial}])(*self.grid)

    def setTransitionModel(self, T, silent=False):
        """
        Set transition model which describes the parameter dynamics.

        Args:
            T: Transition model class (see transitionModels.py)
            silent(bool): If true, no output is printed by this method
        """
        # check if model is a break-point and raise error if so
        if str(T) == 'Break-point':
            raise ConfigurationError('The "BreakPoint" transition model can only be used with the '
                                     '"SerialTransitionModel" class.')

        self.transitionModel = T
        self.transitionModel.study = self
        self.transitionModel.latticeConstant = self.latticeConstant
        if not silent:
            print('+ Transition model: {}. Hyper-Parameter(s): {}'
                  .format(T, self._unpackAllHyperParameters(values=False)))

    def fit(self, forwardOnly=False, evidenceOnly=False, silent=False):
        """
        Computes the sequence of posterior distributions and evidence for each time step. Evidence is also computed for
        the complete data set.

        Args:
            forwardOnly(bool): If set to True, the fitting process is terminated after the forward pass. The resulting
                posterior distributions are so-called "filtering distributions" which - at each time step -
                only incorporate the information of past data points. This option thus emulates an online
                analysis.
            evidenceOnly(bool): If set to True, only forward pass is run and evidence is calculated. In contrast to the
                forwardOnly option, no posterior mean values are computed and no posterior distributions are stored.
            silent(bool): If set to True, no output is generated by the fitting method.
        """
        self._checkConsistency()

        if not silent:
            print('+ Started new fit:')

        self.formattedData = movingWindow(self.rawData, self.observationModel.segmentLength)
        self.formattedTimestamps = self.rawTimestamps[self.observationModel.segmentLength-1:]
        if not silent:
            print('    + Formatted data.')

        # initialize array for posterior distributions
        if not evidenceOnly:
            self.posteriorSequence = np.empty([len(self.formattedData)]+self.gridSize)

        # initialize array for computed evidence (marginal likelihood)
        self.logEvidence = 0
        self.localEvidence = np.empty(len(self.formattedData))

        # set prior distribution for forward-pass
        if self.observationModel.prior is not None:
            alpha = self._computePrior(silent=silent)
        else:
            alpha = np.ones(self.gridSize)  # flat prior

        # normalize prior (necessary in case an improper prior is used)
        alpha /= np.sum(alpha)
        alpha /= np.prod(self.latticeConstant)

        # forward pass
        for i in np.arange(0, len(self.formattedData)):

            # compute likelihood
            likelihood = self.observationModel.processedPdf(self.grid, self.formattedData[i])

            # update alpha based on likelihood
            alpha *= likelihood

            # normalization constant of alpha is used to compute evidence
            norm = np.sum(alpha)
            self.logEvidence += np.log(norm)
            self.localEvidence[i] = norm*np.prod(self.latticeConstant)  # integration yields evidence, not only sum

            # normalize alpha (for numerical stability)
            if norm > 0.:
                alpha /= norm
            else:
                # if all probability values are zero, normalization is not possible
                print('    ! WARNING: Forward pass distribution contains only zeros, check parameter boundaries!')
                print('      Stopping inference process. Setting model evidence to zero.')
                self.logEvidence = -np.inf
                return

            # alphas are stored as preliminary posterior distributions
            if not evidenceOnly:
                self.posteriorSequence[i] = alpha

            # compute alpha for next iteration
            alpha = self.transitionModel.computeForwardPrior(alpha, self.formattedTimestamps[i])

        self.logEvidence += np.log(np.prod(self.latticeConstant))  # integration yields evidence, not only sum
        if not silent:
            print('    + Finished forward pass.')
            print('    + Log10-evidence: {:.5f}'.format(self.logEvidence / np.log(10)))

        if not (forwardOnly or evidenceOnly):
            # set prior distribution for forward-pass
            if self.observationModel.prior is not None:
                beta = self._computePrior(silent=True)
            else:
                beta = np.ones(self.gridSize)  # flat prior

            # normalize prior (necessary in case an improper prior is used)
            beta /= np.sum(beta)

            # backward pass
            for i in np.arange(0, len(self.formattedData))[::-1]:
                # posterior ~ alpha*beta
                self.posteriorSequence[i] *= beta  # alpha*beta

                # normalize posterior wrt the parameters
                norm = np.sum(self.posteriorSequence[i])
                if norm > 0.:
                    self.posteriorSequence[i] /= np.sum(self.posteriorSequence[i])
                else:
                    # if all posterior probabilities are zero, normalization is not possible
                    print('    ! WARNING: Posterior distribution contains only zeros, check parameter boundaries!')
                    print('      Stopping inference process. Setting model evidence to zero.')
                    self.logEvidence = -np.inf
                    return

                # re-compute likelihood
                likelihood = self.observationModel.processedPdf(self.grid, self.formattedData[i])

                # compute local evidence
                try:
                    self.localEvidence[i] = 1./(np.sum(self.posteriorSequence[i]/likelihood) *
                                                np.prod(self.latticeConstant))  # integration, not only sum
                except:  # in case division by zero happens
                    self.localEvidence[i] = np.nan

                # compute beta for next iteration
                beta = self.transitionModel.computeBackwardPrior(beta*likelihood, self.formattedTimestamps[i])

                # normalize beta (for numerical stability)
                beta /= np.sum(beta)

            if not silent:
                print('    + Finished backward pass.')

        # posterior mean values do not need to be computed for evidence
        if evidenceOnly:
            self.posteriorMeanValues = []
        else:
            self.posteriorMeanValues = np.empty([len(self.grid), len(self.posteriorSequence)])

            for i in range(len(self.grid)):
                self.posteriorMeanValues[i] = np.array([np.sum(p*self.grid[i]) for p in self.posteriorSequence])

            if not silent:
                print('    + Computed mean parameter values.')

    def optimize(self, parameterList=[], **kwargs):
        """
        Uses the COBYLA minimization algorithm from SciPy to perform a maximization of the log-evidence with respect
        to all hyper-parameters (the parameters of the transition model) of a time seris model. The starting values
        are the values set by the user when defining the transition model.

        For the optimization, only the log-evidence is computed and no parameter distributions are stored. When a local
        maximum is found, the parameter distribution is computed based on the optimal values for the hyper-parameters.

        Args:
            parameterList(list): List of hyper-parameter names to optimize. For nested transition models with multiple,
                identical hyper-parameter names, the sub-model index can be provided. By default, all hyper-parameters
                are optimized.
            **kwargs - All other keyword parameters are passed to the 'minimize' routine of scipy.optimize.
        """
        # set list of parameters to optimize
        if isinstance(parameterList, str):  # in case only a single parameter name is provided as a string
            self.selectedHyperParameters = [parameterList]
        else:
            self.selectedHyperParameters = parameterList

        print('+ Starting optimization...')
        self._checkConsistency()

        if self.selectedHyperParameters:
            print('  --> Parameter(s) to optimize:', self.selectedHyperParameters)
        else:
            print('  --> All model parameters are optimized (except change/break-points).')

            # load all hyper-parameter names (but remove break- and change-points)
            allHyperParameters = list(flatten(self._unpackHyperParameters(self.transitionModel)))
            points = list(flatten(self._unpackChangepointNames(self.transitionModel))) + \
                     list(flatten(self._unpackBreakpointNames(self.transitionModel)))
            self.selectedHyperParameters = [x for x in allHyperParameters if x not in points]

        # create parameter list to set start values for optimization
        x0 = self._unpackSelectedHyperParameters()

        # check if valid parameter names were entered
        if len(x0) == 0:
            # reset list of parameters to optimize, so that unpacking and setting hyper-parameters works as expected
            self.selectedHyperParameters = []
            raise ConfigurationError('No parameters to optimize. Check parameter names.')

        # perform optimization (maximization of log-evidence)
        result = minimize(self._optimizationStep, x0, method='COBYLA', **kwargs)

        print('+ Finished optimization.')

        # set optimal hyperparameters in transition model
        self._setSelectedHyperParameters(result.x)

        # run analysis with optimal parameter values
        self.fit()

        # reset list of parameters to optimize, so that unpacking and setting hyper-parameters works as expected
        self.selectedHyperParameters = []

    def _optimizationStep(self, x):
        """
        Wrapper for the fit method to use it in conjunction with scipy.optimize.minimize.

        Args:
            x(list): unpacked list of current hyper-parameter values
        """
        # set new hyperparameters in transition model
        self._setSelectedHyperParameters(x)

        # compute log-evidence
        self.fit(evidenceOnly=True, silent=True)

        print('    + Log10-evidence: {:.5f}'.format(self.logEvidence / np.log(10)), '- Parameter values:', x)

        # return negative log-evidence (is minimized to maximize evidence)
        return -self.logEvidence

    def _unpackHyperParameters(self, transitionModel, values=False):
        """
        Returns list of all hyper-parameters (names or values), nested as the transition model.

        Args:
            transitionModel: An instance of a transition model
            values: By default, parameter names are returned; if set to True, parameter values are returned

        Returns:
            list: hyper-parameters (names or values)
        """
        paramList = []
        # recursion step for sub-models
        if hasattr(transitionModel, 'models'):
            for m in transitionModel.models:
                paramList.append(self._unpackHyperParameters(m, values=values))

        # extend hyper-parameter based on current (sub-)model
        if hasattr(transitionModel, 'hyperParameterNames'):
            if values:
                paramList.extend(transitionModel.hyperParameterValues)
            else:
                paramList.extend(transitionModel.hyperParameterNames)

        return paramList

    def _unpackAllHyperParameters(self, values=True):
        """
        Returns a flattened list of all hyper-parameter values of the current transition model.

        Returns:
            list: all hyper-parameter values of the current transition model
        """
        return list(flatten(self._unpackHyperParameters(self.transitionModel, values=values)))

    def _unpackSelectedHyperParameters(self):
        """
        The parameters of a transition model can be split between several sub-models (using CombinedTransitionModel or
        SerialTransitionModel) and can be lists of values (multiple standard deviations in GaussianRandomWalk). This
        function unpacks the hyper-parameters, resulting in a single list of values that can be fed to the optimization
        step routine. Note that only the hyper-parameters that are noted (by name) in the attribute
        selectedHyperParameters are regarded.

        Returns:
            list: currently selected hyper-parameter values if successful, 0 otherwise
        """
        # if no hyper-parameters are selected, choose all
        if not self.selectedHyperParameters:
            return self._unpackAllHyperParameters()

        # if self.selectedHyperParameters is not empty
        nameTree = self._unpackHyperParameters(self.transitionModel)
        valueTree = self._unpackHyperParameters(self.transitionModel, values=True)
        output = []

        # loop over selected hyper-parameters
        for name in self.selectedHyperParameters:
            iFound = recursiveIndex(nameTree, name)  # choose first hit
            if len(iFound) == 0:
                raise ConfigurationError('Could not find any hyper-parameter named {}.'.format(name))

            value = valueTree[:]
            for i in iFound:
                value = value[i]

            output.append(value)

            # remove occurrence from nameTree (if name is listed twice, use second occurrence...)
            assignNestedItem(nameTree, iFound, ' ')

        # return selected values of hyper-parameters
        return output

    def _setAllHyperParameters(self, x):
        """
        Sets all current hyper-parameters, based on a flattened list of parameter values.

        Args:
            x(list): list of values (e.g. from _unpackSelectedHyperParameters)
        """
        paramList = list(x[:])  # make copy of parameter list

        nameTree = self._unpackHyperParameters(self.transitionModel)
        namesFlat = list(flatten(self._unpackHyperParameters(self.transitionModel)))

        for name in namesFlat:
            index = recursiveIndex(nameTree, name)

            # get correct sub-model
            model = self.transitionModel
            for i in index[:-1]:
                model = model.models[i]

            model.hyperParameterValues[model.hyperParameterNames.index(name)] = paramList[0]
            paramList.pop(0)

            # remove occurrence from nameTree (if name is listed twice, use second occurrence...)
            assignNestedItem(nameTree, index, ' ')

    def _setSelectedHyperParameters(self, x):
        """
        The parameters of a transition model can be split between several sub-models (using CombinedTransitionModel or
        SerialTransitionModel) and can be lists of values (multiple standard deviations in GaussianRandomWalk). This
        function takes a list of values and sets the corresponding variables in the transition model instance. Note that
        only the hyper-parameters that are noted (by name) in the attribute selectedHyperParameters are regarded.

        Args:
            x(list): list of values (e.g. from _unpackSelectedHyperParameters)

        Returns:
            int: 1, if successful, 0 otherwise
        """
        # if no hyper-parameters are selected, choose all
        if not self.selectedHyperParameters:
            self._setAllHyperParameters(x)
            return 1

        paramList = list(x[:])  # make copy of parameter list
        nameTree = self._unpackHyperParameters(self.transitionModel)

        # loop over selected hyper-parameters
        for name in self.selectedHyperParameters:
            iFound = recursiveIndex(nameTree, name)  # choose first hit
            if len(iFound) == 0:
                raise ConfigurationError('Could not find any hyper-parameter named {}.'.format(name))

            # get correct sub-model
            model = self.transitionModel
            for i in iFound[:-1]:
                model = model.models[i]

            model.hyperParameterValues[model.hyperParameterNames.index(name)] = paramList[0]
            paramList.pop(0)

            # remove occurrence from nameTree (if name is listed twice, use second occurrence...)
            assignNestedItem(nameTree, iFound, ' ')
        return 1

    def _unpackChangepointNames(self, transitionModel):
        """
        Returns list of all hyper-parameter names that are associated with change-points, nested like the transition
        model.

        Returns:
            list: all hyper-parameter names that are associated with change-points
        """
        paramList = []
        # recursion step for sub-models
        if hasattr(transitionModel, 'models'):
            for m in transitionModel.models:
                paramList.append(self._unpackChangepointNames(m))

        # extend hyper-parameter based on current (sub-)model
        if hasattr(transitionModel, 'hyperParameterNames'):
            if str(transitionModel) == 'Change-point':
                paramList.extend(transitionModel.hyperParameterNames)

        return paramList

    def _unpackBreakpointNames(self, transitionModel):
        """
        Returns list of all hyper-parameter names that are associated with break-points, nested like the transition
        model.

        Returns:
            list: all hyper-parameter names that are associated with break-points
        """
        paramList = []
        # recursion step for sub-models
        if hasattr(transitionModel, 'models'):
            for m in transitionModel.models:
                paramList.append(self._unpackBreakpointNames(m))

        # extend hyper-parameter based on current (sub-)model
        if hasattr(transitionModel, 'hyperParameterNames'):
            if str(transitionModel) == 'Serial transition model':
                paramList.extend(transitionModel.hyperParameterNames)

        return paramList

    def _getHyperParameterIndex(self, transitionModel, name):
        """
        Helper function that returns the index at which a hyper-parameter is found in the flattened list of
        hyper-parameter names.

        Args:
            transitionModel: transition model instance in which to search
            name(str): Name of a hyper-parameter. If the name occurs multiple times, the index of the submodel can be
                supplied (starting at 1 for the first submodel).

        Returns:
            int: index of the hyper-parameter
        """
        # no index provided: choose first occurrence and determine axis of hyper-parameter on grid of
        # hyper-parameter values
        hpn = list(flatten(self._unpackHyperParameters(transitionModel, values=False)))
        if name in hpn:
            paramIndex = hpn.index(name)
        else:
            raise PostProcessingError('Could not find any hyper-parameter with name: {}.'.format(name))

        return paramIndex

    def getHyperParameterValue(self, name):
        """
        Returns the currently set value of a hyper-parameter.

        Args:
            name(str): Hyper-parameter name.

        Returns:
            float: current value of the specified hyper-parameter.
        """
        flatHyperParameterValues = self._unpackAllHyperParameters(values=True)
        value = flatHyperParameterValues[self._getHyperParameterIndex(self.transitionModel, name)]
        return value

    def getParameterMeanValues(self, name):
        """
        Returns posterior mean values for a parameter of the observation model.

        Args:
            name(str): Name of the parameter to display

        Returns:
            ndarray: array of posterior mean values for the selected parameter
        """
        # get parameter index
        paramIndex = -1
        for i, n in enumerate(self.observationModel.parameterNames):
            if n == name:
                paramIndex = i

        # check if match was found
        if paramIndex == -1:
            raise PostProcessingError('Wrong parameter name. Available options: {0}'
                                      .format(self.observationModel.parameterNames))

        return self.posteriorMeanValues[paramIndex]

    def getParameterDistribution(self, t, name, plot=False, **kwargs):
        """
        Compute the marginal parameter distribution at a given time step.

        Args:
            t: Time step/stamp for which the parameter distribution is evaluated
            name(str): Name of the parameter to display
            plot(bool): If True, a plot of the distribution is created
            **kwargs: All further keyword-arguments are passed to the plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the parameter values, the second one the corresponding
                probability (density) values
        """
        if self.posteriorSequence == []:
            raise PostProcessingError('Cannot plot posterior sequence as it has not yet been computed. '
                                      'Run complete fit.')

        # check if supplied time stamp exists
        if t not in self.formattedTimestamps:
            raise PostProcessingError('Supplied time ({}) does not exist in data or is out of range.'.format(t))
        timeIndex = list(self.formattedTimestamps).index(t)  # to select corresponding posterior distribution

        # get parameter index
        paramIndex = -1
        for i, n in enumerate(self.observationModel.parameterNames):
            if n == name:
                paramIndex = i

        # check if match was found
        if paramIndex == -1:
            raise PostProcessingError('Wrong parameter name. Available options: {0}'
                                      .format(self.observationModel.parameterNames))

        axesToMarginalize = list(range(len(self.observationModel.parameterNames)))
        try:
            axesToMarginalize.remove(paramIndex)
        except ValueError:
            raise PostProcessingError('Wrong parameter index. Available indices: {}'.format(axesToMarginalize))

        x = self.marginalGrid[paramIndex]
        marginalDistribution = np.squeeze(np.apply_over_axes(np.sum, self.posteriorSequence[timeIndex],
                                                             axesToMarginalize))

        if plot:
            plt.fill_between(x, 0, marginalDistribution, **kwargs)

            plt.xlabel(self.observationModel.parameterNames[paramIndex])

            # in case an integer step size for hyper-parameter values is chosen, probability is displayed
            # (probability density otherwise)
            if self.latticeConstant[paramIndex] == 1.:
                plt.ylabel('probability')
            else:
                plt.ylabel('probability density')

        return x, marginalDistribution

    def getParameterDistributions(self, name, plot=False, **kwargs):
        """
        Computes the time series of marginal posterior distributions with respect to a given model parameter.

        Args:
            name(str): Name of the parameter to display
            plot(bool): If True, a plot of the series of distributions is created (density map)
            **kwargs: All further keyword-arguments are passed to the plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the parameter values, the second one the sequence of
            corresponding posterior distributions.
        """
        if self.posteriorSequence == []:
            raise PostProcessingError('Cannot plot posterior sequence as it has not yet been computed. '
                                      'Run complete fit.')

        dt = self.formattedTimestamps[1:] - self.formattedTimestamps[:-1]
        if not np.all(dt == dt[0]):
            print('! WARNING: Time stamps are not equally spaced. This may result in false plotting of parameter '
                  'distributions.')

        # get parameter index
        paramIndex = -1
        for i, n in enumerate(self.observationModel.parameterNames):
            if n == name:
                paramIndex = i

        # check if match was found
        if paramIndex == -1:
            raise PostProcessingError('Wrong parameter name. Available options: {0}'
                                      .format(self.observationModel.parameterNames))

        axesToMarginalize = list(range(1, len(self.observationModel.parameterNames) + 1))  # axis 0 is time
        try:
            axesToMarginalize.remove(paramIndex + 1)
        except ValueError:
            raise PostProcessingError('Wrong parameter index. Available indices: {}'
                                      .format(np.array(axesToMarginalize) - 1))

        x = self.marginalGrid[paramIndex]
        marginalPosteriorSequence = np.squeeze(np.apply_over_axes(np.sum, self.posteriorSequence, axesToMarginalize))

        if plot:
            if 'c' in kwargs:
                cmap = createColormap(kwargs['c'])
            elif 'color' in kwargs:
                cmap = createColormap(kwargs['color'])
            else:
                cmap = createColormap('b')

            plt.imshow(marginalPosteriorSequence.T,
                       origin=0,
                       cmap=cmap,
                       extent=[self.formattedTimestamps[0], self.formattedTimestamps[-1]] + self.boundaries[paramIndex],
                       aspect='auto')

        return x, marginalPosteriorSequence

    def plotParameterEvolution(self, name, color='b', gamma=0.5, **kwargs):
        """
        Extended plot method to display a series of marginal posterior distributions corresponding to a single model
        parameter. In contrast to getMarginalParameterDistributions(), this method includes the removal of plotting
        artefacts, gamma correction as well as an overlay of the posterior mean values.

        Args:
            name(str): name of the parameter to display
            color: color from which a light colormap is created
            gamma(float): exponent for gamma correction of the displayed marginal distribution; default: 0.5
            kwargs: all further keyword-arguments are passed to the plot of the posterior mean values
        """
        if self.posteriorSequence == []:
            raise PostProcessingError('Cannot plot posterior sequence as it has not yet been computed. '
                                      'Run complete fit.')

        dt = self.formattedTimestamps[1:] - self.formattedTimestamps[:-1]
        if not np.all(dt == dt[0]):
            print('! WARNING: Time stamps are not equally spaced. This may result in false plotting of parameter '
                  'distributions.')

        # get parameter index
        paramIndex = -1
        for i, n in enumerate(self.observationModel.parameterNames):
            if n == name:
                paramIndex = i

        # check if match was found
        if paramIndex == -1:
            raise PostProcessingError('Wrong parameter name. Available options: {0}'
                                      .format(self.observationModel.parameterNames))

        axesToMarginalize = list(range(1, len(self.observationModel.parameterNames) + 1))  # axis 0 is time
        try:
            axesToMarginalize.remove(paramIndex + 1)
        except ValueError:
            raise PostProcessingError('Wrong parameter index to plot. Available indices: {}'
                                      .format(np.array(axesToMarginalize)-1))
        marginalPosteriorSequence = np.squeeze(np.apply_over_axes(np.sum, self.posteriorSequence, axesToMarginalize))

        # clean up very small probability values, as they may create image artefacts
        pmax = np.amax(marginalPosteriorSequence)
        marginalPosteriorSequence[marginalPosteriorSequence < pmax*(10**-20)] = 0

        plt.imshow(marginalPosteriorSequence.T**gamma,
                   origin=0,
                   cmap=createColormap(color),
                   extent=[self.formattedTimestamps[0], self.formattedTimestamps[-1]] + self.boundaries[paramIndex],
                   aspect='auto')

        # set default color of plot to black
        if ('c' not in kwargs) and ('color' not in kwargs):
            kwargs['c'] = 'k'

        # set default linewidth to 1.5
        if ('lw' not in kwargs) and ('linewidth' not in kwargs):
            kwargs['lw'] = 1.5

        plt.plot(self.formattedTimestamps, self.posteriorMeanValues[paramIndex], **kwargs)

        plt.ylim(self.boundaries[paramIndex])
        plt.ylabel(self.observationModel.parameterNames[paramIndex])
        plt.xlabel('time step')

    def _checkConsistency(self):
        """
        This method is called at the very beginning of analysis methods to ensure that all necessary elements of the
        model are set correctly. If problem with user input is detected, an exception will be raised.
        """
        if len(self.rawData) == 0:
            raise ConfigurationError('No data loaded.')
        if not self.observationModel:
            raise ConfigurationError('No observation model chosen.')
        if not self.transitionModel:
            raise ConfigurationError('No transition model chosen.')

        # check for duplicate hyper-parameter names
        flatNames = self._unpackAllHyperParameters(values=False)
        u, i = np.unique(flatNames, return_inverse=True)
        duplicates = u[np.bincount(i) > 1]
        if len(duplicates) > 0:
            raise ConfigurationError('Detected duplicate hyper-parameter names: {}.'.format(duplicates))


class HyperStudy(Study):
    """
    Infers hyper-parameter distributions. This class serves as an extension to the basic Study class and allows to
    compute the distribution of hyper-parameters of a given transition model. For further information, see the
    documentation of the fit-method of this class.
    """
    def __init__(self):
        super(HyperStudy, self).__init__()

        self.hyperGrid = []
        self.hyperGridValues = []
        self.hyperGridConstant = []
        self.flatHyperParameters = []
        self.flatHyperParameterNames = []
        self.flatHyperPriors = []
        self.flatHyperPriorValues = []
        self.hyperParameterDistribution = None
        self.averagePosteriorSequence = None
        self.logEvidenceList = []
        self.localEvidenceList = []

        print('  --> Hyper-study')

    def _createHyperGrid(self, silent=False):
        """
        Creates an array of hyper-parameter values that are fitted. Also determines grid constants for proper
        normalisation and computes prior probability (density) values

        Args:
            silent(bool): If true, no output is produced by this method
        """
        # extract flat list of hyper-parameter names and values
        self.flatHyperParameters = self._unpackAllHyperParameters()
        self.flatHyperParameterNames = self._unpackAllHyperParameters(values=False)
        self.flatHyperPriors = self._unpackAllHyperPriors()

        # look if change/break-points have value 'all' and assign array of all time-stamps
        for i, v in enumerate(self.flatHyperParameters):
            if isinstance(v, str) and v == 'all':
                self.flatHyperParameters[i] = self.formattedTimestamps[:-1]

        # create hyper-parameter grid
        temp = np.meshgrid(*self.flatHyperParameters, indexing='ij')
        if len(self.flatHyperParameterNames) > 0:
            self.hyperGridValues = np.array([t.ravel() for t in temp]).T
        else:
            self.hyperGridValues = np.array([])

        # find lattice constants for equally spaced hyper-parameter values
        self.hyperGridConstant = []
        for values in self.flatHyperParameters:
            if isinstance(values, Iterable) and len(values) > 1:
                a = np.array(values)
                d = a[1:] - a[:-1]
                dd = d[1:] - d[:-1]
                if np.all(np.abs(dd) < 10 ** -10):  # for equally spaced values, set difference as grid-constant
                    self.hyperGridConstant.append(np.abs(d[0]))
                else:  # for irregularly spaced values (e.g. categorical), set grid-constant to 1
                    self.hyperGridConstant.append(1)
            else:  # for single value, set grid-constant to 1
                self.hyperGridConstant.append(1)
        self.hyperGridConstant = np.array(self.hyperGridConstant)

        # evaluate hyper-prior values
        priorValuesList = []
        priorNamesList = []
        for prior, values, gridConst, name in zip(self.flatHyperPriors, self.flatHyperParameters,
                                                  self.hyperGridConstant, self.flatHyperParameterNames):
            if prior is None:
                priorValues = np.ones_like(values)
                priorNamesList.append('flat')
            elif hasattr(prior, '__call__'):
                try:
                    priorValues = [prior(value) for value in values]
                except:
                    raise ConfigurationError('Failed to set hyper-prior for "{}" from function "{}".'
                                             .format(name, prior.__name__))
                priorNamesList.append(prior.__name__)
            elif isinstance(prior, Iterable):
                if len(prior) != len(values):
                    raise ConfigurationError('Failed to set hyper-prior for "{}" from list/array.'.format(name))
                priorValues = prior
                priorNamesList.append('list/array')
            else:  # SymPy RV
                if len(list(prior._sorted_args[0].distribution.free_symbols)) > 0:
                    raise ConfigurationError('Hyper-prior for "{}" must not contain free parameters.'.format(name))

                # get symbolic representation of probability density
                x = abc.x
                symDensity = density(prior)(x)

                # get density as lambda function
                pdf = lambdify([x], symDensity, modules=['numpy', {'factorial': factorial}])

                # evaluate density
                priorValues = pdf(values)
                priorNamesList.append(str(symDensity))

            priorValuesList.append(priorValues)

        # create hyper-prior grid
        if len(self.flatHyperParameterNames) > 0:
            temp = np.meshgrid(*priorValuesList, indexing='ij')
            self.flatHyperPriorValues = np.array([t.ravel() for t in temp]).T
            self.flatHyperPriorValues = np.prod(self.flatHyperPriorValues, axis=1)  # multiply probs for all hyper-parameters
            self.flatHyperPriorValues = self.flatHyperPriorValues/np.sum(self.flatHyperPriorValues)  # renormalization
            if not silent:
                print('+ Set hyper-prior(s): {}'.format(priorNamesList))
        else:
            # we need a dummy value for transition models without hyper-parameters
            self.flatHyperPriorValues = np.array([1])

    def fit(self, forwardOnly=False, evidenceOnly=False, silent=False, nJobs=1, referenceLogEvidence=None,
            customHyperGrid=False):
        """
        This method over-rides the according method of the Study-class. It runs the algorithm for equally spaced hyper-
        parameter values as defined by the variable 'hyperGrid'. The posterior sequence represents the average
        model of all analyses. Posterior mean values are computed from this average model.

        Args:
            forwardOnly(bool): If set to True, the fitting process is terminated after the forward pass. The resulting
                posterior distributions are so-called "filtering distributions" which - at each time step -
                only incorporate the information of past data points. This option thus emulates an online
                analysis.
            evidenceOnly(bool): If set to True, only forward pass is run and evidence is calculated. In contrast to the
                forwardOnly option, no posterior mean values are computed and no posterior distributions are stored.
            silent(bool): If set to true, reduced output is created by this method.
            nJobs(int): Number of processes to employ. Multiprocessing is based on the 'pathos' module.
            referenceLogEvidence(float): Reference value to increase numerical stability when computing average
                posterior sequence. Ideally, this value represents the mean value of all log-evidence values. As an
                approximation, the default behavior sets it to the log-evidence of the first set of hyper-parameter
                values.
            customHyperGrid(bool): If set to true, the method "_createHyperGrid" is not called before starting the fit.
                This is used by the class "ChangepointStudy", which employs a custom version of "_createHyperGrid".
        """
        print('+ Started new fit.')

        # create hyper-parameter grid
        if not customHyperGrid:
            self._createHyperGrid()

        self.formattedData = movingWindow(self.rawData, self.observationModel.segmentLength)
        self._checkConsistency()

        if not evidenceOnly:
            self.averagePosteriorSequence = np.zeros([len(self.formattedData)]+self.gridSize)

        self.logEvidenceList = []
        self.localEvidenceList = []

        print('    + {} analyses to run.'.format(len(self.hyperGridValues)))

        # check if multiprocessing is available
        if nJobs > 1:
            try:
                from pathos.multiprocessing import ProcessPool
            except ImportError:
                raise ImportError('No module named pathos.multiprocessing. This module represents an optional '
                                  'dependency of bayesloop and is therefore not installed alongside bayesloop.')

        # prepare parallel execution if necessary
        if nJobs > 1:
            # compute reference log-evidence value for numerical stability when computing average posterior sequence
            if referenceLogEvidence is None:
                self._setSelectedHyperParameters(self.hyperGridValues[0])
                Study.fit(self, forwardOnly=forwardOnly, evidenceOnly=evidenceOnly, silent=True)
                referenceLogEvidence = self.logEvidence

            print('    + Creating {} processes.'.format(nJobs))
            pool = ProcessPool(nodes=nJobs)

            # use parallelFit method to create copies of this HyperStudy instance with only partial hyper-grid values
            subStudies = pool.map(self._parallelFit,
                                  range(nJobs),
                                  [nJobs]*nJobs,
                                  [forwardOnly]*nJobs,
                                  [evidenceOnly]*nJobs,
                                  [silent]*nJobs,
                                  [referenceLogEvidence]*nJobs)

            # prevent memory pile-up in main process
            pool.close()
            pool.join()
            pool.terminate()
            pool.restart()

            # merge all sub-studies
            for S in subStudies:
                self.logEvidenceList += S.logEvidenceList
                self.localEvidenceList += S.localEvidenceList
                if not evidenceOnly:
                    self.averagePosteriorSequence += S.averagePosteriorSequence
        # single process fit
        else:
            # show progressbar if silent=False
            if not silent:
                # first assume jupyter notebook and tray to use tqdm-widget, if it fails, use normal tqdm-progressbar
                try:
                    enum = tqdm_notebook(enumerate(self.hyperGridValues), total=len(self.hyperGridValues))
                except:
                    enum = tqdm(enumerate(self.hyperGridValues), total=len(self.hyperGridValues))
            else:
                enum = enumerate(self.hyperGridValues)

            for i, hyperParamValues in enum:
                self._setSelectedHyperParameters(hyperParamValues)

                # call fit method from parent class
                Study.fit(self, forwardOnly=forwardOnly, evidenceOnly=evidenceOnly, silent=True)

                self.logEvidenceList.append(self.logEvidence)
                self.localEvidenceList.append(self.localEvidence)

                # compute reference log-evidence value for numerical stability when computing average posterior sequence
                if i == 0 and referenceLogEvidence is None:
                    referenceLogEvidence = self.logEvidence

                if (not evidenceOnly) and np.isfinite(self.logEvidence):
                    # note: averagePosteriorSequence has no proper normalization
                    self.averagePosteriorSequence += self.posteriorSequence *\
                                                     np.exp(self.logEvidence - referenceLogEvidence) *\
                                                     self.flatHyperPriorValues[i]

            # remove progressbar correctly
            if not silent:
                enum.close()

        if not evidenceOnly:
            # compute average posterior distribution
            normalization = np.array([np.sum(posterior) for posterior in self.averagePosteriorSequence])
            for i in range(len(self.grid)):
                normalization = normalization[:, None]  # add axis; needs to match averagePosteriorSequence
            self.averagePosteriorSequence /= normalization

            # set self.posteriorSequence to average posterior sequence for plotting reasons
            self.posteriorSequence = self.averagePosteriorSequence

            if not silent:
                print('    + Computed average posterior sequence')

        # compute log-evidence of average model
        self.logEvidence = logsumexp(np.array(self.logEvidenceList) + np.log(self.flatHyperPriorValues))
        print('    + Log10-evidence of average model: {:.5f}'.format(self.logEvidence / np.log(10)))

        # compute hyper-parameter distribution
        logHyperParameterDistribution = self.logEvidenceList + np.log(self.flatHyperPriorValues)
        # ignore evidence values of -inf when computing mean value for scaling
        scaledLogHyperParameterDistribution = logHyperParameterDistribution - \
                                              np.mean(np.ma.masked_invalid(logHyperParameterDistribution))
        self.hyperParameterDistribution = np.exp(scaledLogHyperParameterDistribution)
        self.hyperParameterDistribution /= np.sum(self.hyperParameterDistribution)
        self.hyperParameterDistribution /= np.prod(self.hyperGridConstant)  # probability density

        if not silent:
            print('    + Computed hyper-parameter distribution')

        # compute local evidence of average model
        self.localEvidence = np.sum((np.array(self.localEvidenceList).T*self.flatHyperPriorValues).T, axis=0)

        if not silent:
            print('    + Computed local evidence of average model')

        # compute posterior mean values
        if not evidenceOnly:
            self.posteriorMeanValues = np.empty([len(self.grid), len(self.posteriorSequence)])
            for i in range(len(self.grid)):
                self.posteriorMeanValues[i] = np.array([np.sum(p*self.grid[i]) for p in self.posteriorSequence])

            if not silent:
                print('    + Computed mean parameter values.')

        # clear localEvidenceList (to keep file size small for stored studies)
        self.localEvidenceList = []

        print('+ Finished fit.')

    def _parallelFit(self, idx, nJobs, forwardOnly, evidenceOnly, silent, referenceLogEvidence):
        """
        This method is called by the fit method of the HyperStudy class. It creates a copy of the current class
        instance and performs a fit based on a subset of the specified hyper-parameter grid. The method thus allows
        to distribute a HyperStudy fit among multiple processes for multiprocessing.

        Args:
            idx(int): Index from 0 to (nJobs-1), indicating which part of the hyper-grid values are to be analyzed.
            nJobs(int): Number of processes to employ. Multiprocessing is based on the 'pathos' module.
            forwardOnly(bool): If set to True, the fitting process is terminated after the forward pass. The resulting
                posterior distributions are so-called "filtering distributions" which - at each time step -
                only incorporate the information of past data points. This option thus emulates an online
                analysis.
            evidenceOnly(bool): If set to True, only forward pass is run and evidence is calculated. In contrast to the
                forwardOnly option, no posterior mean values are computed and no posterior distributions are stored.
            silent(bool): If set to True, no output is generated by the fitting method.
            referenceLogEvidence(float): Reference value to increase numerical stability when computing average
                posterior sequence. Ideally, this value represents the mean value of all log-evidence values.

        Returns:
            HyperStudy instance
        """
        S = copy(self)
        S.hyperGridValues = np.array_split(S.hyperGridValues, nJobs)[idx]
        S.flatHyperPriorValues = np.array_split(S.flatHyperPriorValues, nJobs)[idx]

        # show progressbar for last process if silent=False
        if not silent and idx == nJobs-1:
            # first assume jupyter notebook and tray to use tqdm-widget, if it fails, use normal tqdm-progressbar
            try:
                enum = tqdm_notebook(enumerate(S.hyperGridValues), total=len(S.hyperGridValues))
            except:
                enum = tqdm(enumerate(S.hyperGridValues), total=len(S.hyperGridValues))
        else:
            enum = enumerate(S.hyperGridValues)

        for i, hyperParamValues in enum:
            S._setSelectedHyperParameters(hyperParamValues)

            # call fit method from parent class
            Study.fit(S, forwardOnly=forwardOnly, evidenceOnly=evidenceOnly, silent=True)

            S.logEvidenceList.append(S.logEvidence)
            S.localEvidenceList.append(S.localEvidence)
            if (not evidenceOnly) and np.isfinite(S.logEvidence):
                S.averagePosteriorSequence += S.posteriorSequence *\
                                              np.exp(S.logEvidence - referenceLogEvidence) * \
                                              S.flatHyperPriorValues[i]

        # remove progressbar correctly
        if not silent and idx == nJobs-1:
            enum.close()

        return S

    # optimization methods are inherited from Study class, but cannot be used in this case
    def optimize(self, *args, **kwargs):
        raise NotImplementedError('HyperStudy object has no optimizing method.')

    def _unpackHyperPriors(self, transitionModel):
        """
        Returns list of all hyper-priors, nested as the transition model.

        Args:
            transitionModel: An instance of a transition model

        Returns:
            list: hyper-priors
        """
        priorList = []
        # recursion step for sub-models
        if hasattr(transitionModel, 'models'):
            for m in transitionModel.models:
                priorList.append(self._unpackHyperPriors(m))

        # append prior
        if hasattr(transitionModel, 'prior'):
            # only take prior if transition model has at least one hyper-parameter or is a break-point
            # otherwise, the number of hyper-priors does not match the number of hyper-parameters
            if (hasattr(transitionModel, 'hyperParameterNames') and len(transitionModel.hyperParameterNames) > 0) or\
                    (str(transitionModel) == 'Break-point'):
                priorList.append(transitionModel.prior)
        return priorList

    def _unpackAllHyperPriors(self):
        """
        Returns a flattened list of all hyper-priors of the current transition model.

        Returns:
            list: all hyper-priors of the current transition model
        """
        return list(flatten(self._unpackHyperPriors(self.transitionModel)))

    def getHyperParameterDistribution(self, name, plot=False, **kwargs):
        """
        Computes marginal hyper-parameter distribution of a single hyper-parameter in a HyperStudy fit.

        Args:
            name(str): Name of the hyper-parameter to display
                (first model hyper-parameter)
            plot(bool): If True, a bar chart of the distribution is created
            **kwargs: All further keyword-arguments are passed to the bar-plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the hyper-parameter values, the second one the
                corresponding probability (density) values
        """
        paramIndex = self._getHyperParameterIndex(self.transitionModel, name)

        axesToMarginalize = list(range(len(self.flatHyperParameterNames)))
        axesToMarginalize.remove(paramIndex)

        # reshape hyper-parameter distribution for easy marginalizing
        hyperGridSteps = []
        for x in self.flatHyperParameters:
            if isinstance(x, Iterable):
                hyperGridSteps.append(len(x))
            else:
                hyperGridSteps.append(1)

        distribution = self.hyperParameterDistribution.reshape(hyperGridSteps, order='C')
        marginalDistribution = np.squeeze(np.apply_over_axes(np.sum, distribution, axesToMarginalize))

        # marginal distribution is not created by sum, but by the integral
        integrationFactor = np.prod([self.hyperGridConstant[axis] for axis in axesToMarginalize])
        marginalDistribution *= integrationFactor

        x = self.flatHyperParameters[paramIndex]
        if plot:
            # check if categorical
            if np.any(np.abs(np.diff(np.diff(x))) > 10 ** -10):
                plt.bar(np.arange(len(x)), marginalDistribution, align='center', width=1., **kwargs)
                plt.xticks(np.arange(len(x)), x)
                plt.ylabel('probability')
            # regular spacing
            else:
                plt.bar(x, marginalDistribution, align='center', width=self.hyperGridConstant[paramIndex], **kwargs)
                if self.hyperGridConstant[paramIndex] == 1.:
                    plt.ylabel('probability')
                else:
                    plt.ylabel('probability density')

            plt.xlabel(self.flatHyperParameterNames[paramIndex])

        return x, marginalDistribution

    def getJointHyperParameterDistribution(self, names, plot=False, figure=None, subplot=111, **kwargs):
        """
        Computes the joint distribution of two hyper-parameters of a HyperStudy and optionally creates a 3D bar chart.
        Note that the 3D plot can only be included in an existing plot by passing a figure object and subplot
        specification.

        Args:
            names(list): List of two hyper-parameter names to display
            plot(bool): If True, a 3D-bar chart of the distribution is created
            figure: In case the plot is supposed to be part of an existing figure, it can be passed to the method. By
                default, a new figure is created.
            subplot: Characterization of subplot alignment, as in matplotlib. Default: 111
            **kwargs: all further keyword-arguments are passed to the bar3d-plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray, ndarray: The first and second array contains the hyper-parameter values, the
                third one the corresponding probability (density) values
        """
        # check if list with two elements is provided
        if not isinstance(names, Iterable):
            raise PostProcessingError('A list of exactly two hyper-parameters has to be provided.')
        elif not len(names) == 2:
            raise PostProcessingError('A list of exactly two hyper-parameters has to be provided.')

        paramIndices = [self._getHyperParameterIndex(self.transitionModel, n) for n in names]

        # check if parameter indices are in ascending order (so axes are labeled correctly)
        if not paramIndices[0] < paramIndices[1]:
            print('! WARNING: Switching hyper-parameter order for plotting.')
            paramIndices = paramIndices[::-1]

        axesToMarginalize = list(range(len(self.flatHyperParameterNames)))
        for p in paramIndices:
            axesToMarginalize.remove(p)

        # reshape hyper-parameter distribution for easy marginalizing
        hyperGridSteps = []
        for x in self.flatHyperParameters:
            if isinstance(x, Iterable):
                hyperGridSteps.append(len(x))
            else:
                hyperGridSteps.append(1)

        distribution = self.hyperParameterDistribution.reshape(hyperGridSteps, order='C')
        marginalDistribution = np.squeeze(np.apply_over_axes(np.sum, distribution, axesToMarginalize))

        # marginal distribution is not created by sum, but by the integral
        integrationFactor = np.prod([self.hyperGridConstant[axis] for axis in axesToMarginalize])
        marginalDistribution *= integrationFactor

        x, y = [self.flatHyperParameters[i] for i in paramIndices]
        if np.any(np.abs(np.diff(np.diff(x))) > 10 ** -10):
            x2 = np.tile(np.arange(len(x)), (len(y), 1)).T
        else:
            x2 = np.tile(x, (len(y), 1)).T

        if np.any(np.abs(np.diff(np.diff(y))) > 10 ** -10):
            y2 = np.tile(np.arange(len(y)), (len(x), 1))
        else:
            y2 = np.tile(y, (len(x), 1))

        z = marginalDistribution

        if plot:
            # allow to add plot to predefined figure
            if figure is None:
                fig = plt.figure()
            else:
                fig = figure
            ax = fig.add_subplot(subplot, projection='3d')

            ax.bar3d(x2.flatten() - self.hyperGridConstant[paramIndices[0]]/2.,
                     y2.flatten() - self.hyperGridConstant[paramIndices[1]]/2.,
                     z.flatten()*0.,
                     self.hyperGridConstant[paramIndices[0]],
                     self.hyperGridConstant[paramIndices[1]],
                     z.flatten(),
                     zsort='max',
                     **kwargs
                     )

            # check for categorical hyper-parameter values
            if np.any(np.abs(np.diff(np.diff(x))) > 10 ** -10):
                plt.xticks(np.arange(len(x)), x)
            if np.any(np.abs(np.diff(np.diff(y))) > 10 ** -10):
                plt.yticks(np.arange(len(y)), y)

            ax.set_xlabel(self.flatHyperParameterNames[paramIndices[0]])
            ax.set_ylabel(self.flatHyperParameterNames[paramIndices[1]])

            # in case an integer step size for hyper-parameter values is chosen, probability is displayed
            # (probability density otherwise)
            if self.hyperGridConstant[paramIndices[0]]*self.hyperGridConstant[paramIndices[1]] == 1.:
                ax.set_zlabel('probability')
            else:
                ax.set_zlabel('probability density')

        return x, y, marginalDistribution


class ChangepointStudy(HyperStudy):
    """
    Infers change-points and structural breaks. This class builds on the HyperStudy-class and the change-point
    transition model to perform a series of analyses with varying change point times. It subsequently computes the
    average model from all possible change points and creates a probability distribution of change point times. It
    supports any number of change-points and arbitarily combined models.
    """
    def __init__(self):
        super(ChangepointStudy, self).__init__()

        # store all possible combinations of change-points (even the ones that are assigned a probability of zero),
        # to reconstruct change-point distribution after analysis
        self.allHyperGridValues = []
        self.allHyperPriorValues = []
        self.mask = []  # mask to select valid change-point combinations

        self.userDefinedGrid = False  # needed to ensure that user-defined hyper-grid is not overwritten by fit-method
        self.hyperGridBackup = []  # needed to reconstruct hyperGrid attribute in the case of break-point model
        print('  --> Change-point analysis')

    def fit(self, forwardOnly=False, evidenceOnly=False, silent=False, nJobs=1):
        """
        This method over-rides the corresponding method of the HyperStudy-class. It runs the algorithm for all possible
        combinations of change-points (and possible scans a range of values for other hyper-parameters). The posterior
        sequence represents the average model of all analyses. Posterior mean values are computed from this average
        model.

        Args:
            forwardOnly(bool): If set to True, the fitting process is terminated after the forward pass. The resulting
                posterior distributions are so-called "filtering distributions" which - at each time step -
                only incorporate the information of past data points. This option thus emulates an online
                analysis.
            evidenceOnly(bool): If set to True, only forward pass is run and evidence is calculated. In contrast to the
                forwardOnly option, no posterior mean values are computed and no posterior distributions are stored.
            silent(bool): If set to True, reduced output is generated by the fitting method.
            nJobs(int): Number of processes to employ. Multiprocessing is based on the 'pathos' module.
        """
        # format data/timestamps once, so number of data segments is known
        self.formattedData = movingWindow(self.rawData, self.observationModel.segmentLength)
        self.formattedTimestamps = self.rawTimestamps[self.observationModel.segmentLength - 1:]

        # nested serial transition models are not supported, as the correct order is not determined correctly
        if len(list(flatten(self._unpackSerialTransitionModels(self.transitionModel)))) > 1:
            raise NotImplementedError('Multiple instances of SerialTransition models are currently not supported by '
                                      'ChangepointStudy.')

        # determine names of change/break-points
        changepoints = list(flatten(self._unpackChangepointNames(self.transitionModel)))
        breakpoints = list(flatten(self._unpackBreakpointNames(self.transitionModel)))

        # both types are not allowed at the moment, as the correct order is not determined correctly
        if len(changepoints) > 0 and len(breakpoints) > 0:
            raise NotImplementedError('Detected both change-points (Changepoint transition model) and break-points '
                                      '(SerialTransitionModel). Currently, only one type is supported in a single '
                                      'transition model.')

        # at least one change/break-point should be present
        if len(changepoints) == 0 and len(breakpoints) == 0:
            raise ConfigurationError('No change-points or break-points detected in transition model. Check transition '
                                     'model.')

        self.flatHyperParameters = self._unpackAllHyperParameters()
        self.flatHyperParameterNames = self._unpackAllHyperParameters(values=False)

        # create hyperGrid in the case of change-points
        if len(changepoints) > 0:
            print('+ Detected {} change-point(s) in transition model: {}'.format(len(changepoints), changepoints))
            points = changepoints
        else:
            print('+ Detected {} break-point(s) in transition model: {}'.format(len(breakpoints), breakpoints))
            points = breakpoints

        # first create standard hyper-grid
        self._createHyperGrid()
        self.allHyperGridValues = self.hyperGridValues[:]
        self.allHyperPriorValues = self.flatHyperPriorValues[:]

        # extract hyper-grid values that belong to changepoints
        pointMask = np.sum([np.array(self.flatHyperParameterNames) == p for p in points], axis=0).astype(np.bool)
        maskedHyperGridValues = self.allHyperGridValues[:, pointMask]

        # only accept if change-point values are ordered (and not equal)
        self.mask = np.array(
            [all(x[i] < x[i + 1] for i in range(len(points) - 1)) for x in maskedHyperGridValues],
            dtype=bool)
        self.hyperGridValues = self.allHyperGridValues[self.mask]
        self.flatHyperPriorValues = self.allHyperPriorValues[self.mask]/np.sum(self.allHyperPriorValues[self.mask])

        # call fit method of hyper-study
        HyperStudy.fit(self,
                       forwardOnly=forwardOnly,
                       evidenceOnly=evidenceOnly,
                       silent=silent,
                       nJobs=nJobs,
                       customHyperGrid=True)

        # for proper plotting, hyperGridValues must include all possible combinations of hyper-parameter values. We
        # therefore have to include invalid combinations and assign the probability zero to them.
        temp = np.zeros(len(self.allHyperGridValues))
        temp[self.mask] = self.hyperParameterDistribution
        self.hyperParameterDistribution = temp

        temp = np.zeros(len(self.allHyperPriorValues))
        temp[self.mask] = self.flatHyperPriorValues
        self.flatHyperPriorValues = temp

    def _unpackSerialTransitionModels(self, transitionModel):
        """
        Returns list of all occurrences of serial transition models in the transition model, nested like the transition
        model.

        Returns:
            list: all serial transition models
        """
        modelList = []
        # recursion step for sub-models
        if hasattr(transitionModel, 'models'):
            for m in transitionModel.models:
                modelList.append(self._unpackSerialTransitionModels(m))

        # extend hyper-parameter based on current (sub-)model
        if hasattr(transitionModel, 'hyperParameterNames'):
            if str(transitionModel) == 'Serial transition model':
                modelList.append(transitionModel)

        return modelList

    def getDurationDistribution(self, names, plot=False, **kwargs):
        """
        Computes the distribution of the number of time steps between two change/break-points. This distribution of
        duration is created from the joint distribution of the two specified change/break-points.

        Args:
            names(list): List of two parameter names of change/break-points to display
                (first and second model parameter)
            plot(bool): If True, a bar chart of the distribution is created
            **kwargs: All further keyword-arguments are passed to the bar-plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the number of time steps, the second one the corresponding
                probability values.
        """
        # check if list with two elements is provided
        if not isinstance(names, Iterable):
            raise PostProcessingError('A list of exactly two hyper-parameters has to be provided.')
        elif not len(names) == 2:
            raise PostProcessingError('A list of exactly two hyper-parameters has to be provided.')
        paramIndices = [self._getHyperParameterIndex(self.transitionModel, n) for n in names]

        # check if parameter indices are in ascending order (so axes are labeled correctly)
        if not paramIndices[0] < paramIndices[1]:
            print('! WARNING: Switching hyper-parameter order for plotting.')
            paramIndices = paramIndices[::-1]

        values = self.hyperGridValues[:, paramIndices].T
        duration = np.unique(values[1] - values[0])  # get all possible differences between time points
        durationDistribution = np.zeros(len(duration))  # initialize array for distribution

        # loop over all hyper-grid points and collect probabilities for different durations
        for i, values in enumerate(self.allHyperGridValues[:, paramIndices]):
            if values[1] > values[0]:
                # get matching index in duration (rounding needed because of finite precision)
                idx = np.where(duration.round(10) == (values[1]-values[0]).round(10))[0][0]
                durationDistribution[idx] += self.hyperParameterDistribution[i]

        # properly normalize duration distribution
        durationDistribution /= np.sum(durationDistribution)

        if plot:
            plt.bar(duration, durationDistribution, align='center', width=duration[0], **kwargs)

            plt.xlabel('duration between {} and {} (in time steps)'
                       .format(self.flatHyperParameterNames[paramIndices[0]],
                               self.flatHyperParameterNames[paramIndices[1]]))
            plt.ylabel('probability')

        return duration, durationDistribution


class OnlineStudy(HyperStudy):
    """
    Enables model selection for online data streams. This class builds on the Study-class and features a step-method
    to include new data points in the study as they arrive from a data stream. This online-analysis is performed in an
    forward-only way, resulting in filtering-distributions only. In contrast to a normal study, however, one can add
    multiple transition models to account for different types of parameter dynamics (similar to a Hyper study). The
    Online study then computes the probability distribution over all transition models for each new data point,
    enabling real-time model selection.

    Args:
        storeHistory(bool): If true, posterior distributions and their mean values, as well as hyper-posterior
            distributions are stored for all time steps.
    """
    def __init__(self, storeHistory=False):
        super(OnlineStudy, self).__init__()

        self.transitionModels = None
        self.transitionModelNames = None
        self.tmCount = None
        self.tmCounts = None
        self.hyperParameterValues = None
        self.allFlatHyperParameterValues = None
        self.hyperParameterNames = None
        self.hyperGridConstants = None

        self.alpha = None
        self.beta = None
        self.normi = None
        self.hyperPrior = None
        self.hyperPriorValues = None
        self.transitionModelPrior = None

        self.parameterPosterior = None
        self.transitionModelPosterior = None
        self.marginalizedPosterior = None

        self.hyperParameterDistribution = None
        self.transitionModelDistribution = None
        self.localHyperEvidence = None

        self.storeHistory = storeHistory
        self.posteriorMeanValues = []
        self.posteriorSequence = []
        self.hyperParameterSequence = []
        self.transitionModelSequence = []
        print('  --> Online study')

        self.debug = []

    def addTransitionModel(self, name, transitionModel):
        """
        Adds a transition model to the list of transition models that are fitted in each time step. Note that a list of
        hyper-parameter values can be supplied.

        Args:
            name(str): a custom name for this transition model to identify it in post-processing methods
            transitionModel: instance of a transition model class.

        Example:
            Here, 'S' denotes the OnlineStudy instance. In the first example, we assume a Poisson observation model and
            add a Gaussian random walk with varying standard deviation to the rate parameter 'lambda':

                S.setObservationModel(bl.om.Poisson('lambda', bl.oint(0, 6, 1000)))
                S.addTransitionModel(bl.tm.GaussianRandomWalk('sigma', [0, 0.1, 0.2, 0.3], target='lambda'))
        """
        if self.transitionModels is None:
            self.transitionModels = []
            self.transitionModelNames = []
            self.hyperParameterValues = []
            self.allFlatHyperParameterValues = []
            self.hyperParameterNames = []
            self.hyperGridConstants = []
            self.hyperPrior = []
            self.hyperPriorValues = []

        # extract hyper-parameter values and names
        self.setTransitionModel(transitionModel, silent=True)
        self._createHyperGrid(silent=True)

        self.transitionModels.append(transitionModel)
        self.transitionModelNames.append(name)
        self.hyperParameterValues.append(self.hyperGridValues[:])
        self.allFlatHyperParameterValues.append(self.flatHyperParameters)
        self.hyperParameterNames.append(self.flatHyperParameterNames[:])
        self.hyperGridConstants.append(self.hyperGridConstant[:])
        self.hyperPrior.append(self.flatHyperPriors[:])

        # different normalization routine than in hyper-study
        self.hyperPriorValues.append(self.flatHyperPriorValues[:]/np.prod(self.hyperGridConstant))

        # count individual transition models
        self.tmCounts = []
        for hpv in self.hyperParameterValues:
            if len(hpv) > 0:
                self.tmCounts.append(len(hpv))
            else:
                self.tmCounts.append(1)
        self.tmCount = np.sum(self.tmCounts)

        if len(self.hyperGridValues) > 0:
            print('+ Added transition model: {} ({} combination(s) of the following hyper-parameters: {})'
                  .format(name, len(self.hyperGridValues), self.hyperParameterNames[-1]))
        else:
            print('+ Added transition model: {} (no hyper-parameters)'.format(name))

    def setTransitionModelPrior(self, transitionModelPrior, silent=False):
        """
        Sets prior probabilities for transition models added to the online study instance.

        Args:
            transitionModelPrior: List/Array of probabilities, one for each transition model. If the list does not sum
                to one, it will be re-normalised.
            silent: If true, no output is generated by this method.
        """
        if not (isinstance(transitionModelPrior, Iterable) and len(transitionModelPrior) == len(self.transitionModels)):
            raise ConfigurationError('Length of transition model prior ({}) does not fit number of transition models '
                                     '({})'.format(len(transitionModelPrior), len(self.transitionModels)))

        self.transitionModelPrior = np.array(transitionModelPrior)

        if not np.sum(transitionModelPrior) == 1.:
            print('+ WARNING: Transition model prior does not sum up to one. Will re-normalize.')
            self.transitionModelPrior /= np.sum(self.transitionModelPrior)

        if not silent:
            print('+ Set custom transition model prior.')

    def adoptHyperParameterDistribution(self):
        """
        Will set the current hyper-parameter distribution as the new hyper-parameter prior, if a distribution has
        already been computed. Is usually called after the 'step' method.
        """
        if self.hyperParameterDistribution is not None:
            self.hyperPriorValues = deepcopy(self.hyperParameterDistribution)

    def adoptTransitionModelDistribution(self):
        """
        Will set the current transition model distribution as the new transition model prior, if a distribution has
        already been computed. Is usually called after the 'step' method.
        """
        if self.transitionModelDistribution is not None:
            self.transitionModelPrior = deepcopy(self.transitionModelDistribution)

    def step(self, dataPoint):
        """
        Update the current parameter distribution by adding a new data point to the data set.

        Args:
            dataPoint(float, int, ndarray): Float, int, or 1D-array of those (for multidimensional data).
        """
        # at least one transition model has to be set or added
        if (self.tmCount is None) and (self.transitionModel is None):
            raise ConfigurationError('No transition model set or added.')

        # if one only sets a transition model, but does not use addTransitionModel, we add it here
        if (self.tmCount is None) and (self.transitionModel is not None):
            self.addTransitionModel('transition model', self.transitionModel)

        if not isinstance(dataPoint, list):
            dataPoint = [dataPoint]

        if len(self.rawData) == 0:
            # to check the model consistency the first time that 'step' is called
            self.rawData = np.array(dataPoint)
            Study._checkConsistency(self)

            self.rawTimestamps = np.array([0])
            self.formattedTimestamps = []
        else:
            self.rawData = np.append(self.rawData, np.array(dataPoint), axis=0)
            self.rawTimestamps = np.append(self.rawTimestamps, self.rawTimestamps[-1]+1)

        # only proceed if at least one data segment can be created
        if len(self.rawData) < self.observationModel.segmentLength:
            print('    + Not enough data points to start analysis. Will wait for more data.')
            return

        self.formattedTimestamps.append(self.rawTimestamps[-1])

        # initialize hyper-prior as flat
        if self.hyperPrior is None:
            self.hyperPrior = 'flat hyper-prior'
            self.hyperPriorValues = [np.ones(tmc) / (tmc * np.prod(hgc))
                                     for tmc, hgc in zip(self.tmCounts, self.hyperGridConstants)]
            print('    + Initialized flat hyper-prior.')

        # initialize transition model prior as flat
        if self.transitionModelPrior is None:
            self.transitionModelPrior = np.ones(len(self.transitionModels))/len(self.transitionModels)
            print('    + Initialized flat transition mode prior.')

        # initialize alpha with prior distribution
        if self.alpha is None:
            if self.observationModel.prior is not None:
                if isinstance(self.observationModel.prior, np.ndarray):
                    self.alpha = self.observationModel.prior
                else:  # prior is set as a function
                    self.alpha = self.observationModel.prior(*self.grid)
            else:
                self.alpha = np.ones(self.gridSize)  # flat prior

            # normalize prior (necessary in case an improper prior is used)
            self.alpha /= np.sum(self.alpha)
            print('    + Initialized prior.')

        # initialize normi as an array of ones
        if self.normi is None:
            self.normi = [np.ones(tmc) for tmc in self.tmCounts]
            print('    + Initialized normalization factors.')

        # initialize parameter posterior
        if self.parameterPosterior is None:
            self.parameterPosterior = [np.zeros([tmc] + self.gridSize) for tmc in self.tmCounts]

        # initialize hyper-posterior
        if self.hyperParameterDistribution is None:
            self.hyperParameterDistribution = [np.zeros(tmc) for tmc in self.tmCounts]

        # initialize transition model evidence
        if self.localHyperEvidence is None:
            self.localHyperEvidence = np.zeros(len(self.transitionModels))

        # initialize transition model distribution (normalized version of transition model evidence array)
        if self.transitionModelDistribution is None:
            self.transitionModelDistribution = np.zeros(len(self.transitionModels))

        # initialize transition model posterior (needs to be re-initialized each time step)
        self.transitionModelPosterior = np.zeros([len(self.transitionModels)] + self.gridSize)

        # initialize marginalized posterior
        if self.marginalizedPosterior is None:
            self.marginalizedPosterior = np.zeros(self.gridSize)

        # select data segment
        dataSegment = self.rawData[-self.observationModel.segmentLength:]

        # compute current likelihood only once
        likelihood = self.observationModel.processedPdf(self.grid, dataSegment)

        # loop over all hypotheses/transition models
        for i, (tm, hpv) in enumerate(zip(self.transitionModels, self.hyperParameterValues)):
            self.setTransitionModel(tm, silent=True)  # set current transition model

            if len(hpv) == 0:
                hpv = [None]

            # loop over all hyper-parameter values to fit
            for j, x in enumerate(hpv):
                # set current hyper-parameter values
                if x is not None:
                    self._setAllHyperParameters(x)

                # compute alpha_i
                if np.sum(self.marginalizedPosterior) == 0.:  # first time step, so use predefined prior
                    alphai = self.alpha*likelihood
                else:  # in all other time step transform "old" alpha/posterior
                    alphai = self.transitionModel.computeForwardPrior(self.alpha, len(self.formattedData)-1)*likelihood
                ni = np.sum(alphai)

                # hyper-post. values are not normalized at this point: hyper-like. * hyper-prior
                self.hyperParameterDistribution[i][j] = (ni/self.normi[i][j])*self.hyperPriorValues[i][j]

                # store parameter posterior
                self.parameterPosterior[i][j] = alphai/ni

                # update normalization constant
                self.normi[i][j] = ni

            # compute hyper-evidence to properly normalize hyper-parameter distribution
            self.localHyperEvidence[i] = np.sum(self.hyperParameterDistribution[i] *
                                                np.prod(self.hyperGridConstants[i]))

            # normalize hyper-parameter distribution of current transition model
            self.hyperParameterDistribution[i] /= self.localHyperEvidence[i]

            # compute parameter posterior, marginalized over current hyper-parameter values of current transition model
            hpd = deepcopy(self.hyperParameterDistribution[i])
            while len(self.parameterPosterior[i].shape) > len(hpd.shape):
                hpd = np.expand_dims(hpd, axis=-1)
            self.transitionModelPosterior[i] = np.sum(self.parameterPosterior[i] *
                                                      hpd *
                                                      np.prod(self.hyperGridConstants[i]), axis=0)

        # compute distribution of transition models; normalizing constant of this distribution represents the local
        # evidence of current data point, marginalizing over all transition models
        self.transitionModelDistribution = self.localHyperEvidence * self.transitionModelPrior
        self.localEvidence = np.sum(self.transitionModelDistribution)
        self.transitionModelDistribution /= self.localEvidence

        # normalize marginalized posterior
        tmd = deepcopy(self.transitionModelDistribution)
        while len(self.transitionModelPosterior.shape) > len(tmd.shape):
            tmd = np.expand_dims(tmd, axis=-1)
        self.marginalizedPosterior = np.sum(self.transitionModelPosterior * tmd, axis=0)

        # compute new alpha
        self.alpha = self.marginalizedPosterior

        # store results for future plotting
        if self.storeHistory:
            self.posteriorMeanValues.append(np.array([np.sum(self.marginalizedPosterior*g) for g in self.grid]))
            self.posteriorSequence.append(self.marginalizedPosterior.copy())
            self.hyperParameterSequence.append(deepcopy(self.hyperParameterDistribution))
            self.transitionModelSequence.append(deepcopy(self.transitionModelDistribution))

            # optimization methods are inherited from Study class, but cannot be used in this case

    def fit(self, *args, **kwargs):
        raise NotImplementedError('OnlineStudy object has no "fit" method. Use "step" instead.')

    def getParameterDistribution(self, t, name, plot=False, **kwargs):
        """
        Compute the marginal parameter distribution at a given time step. Only available if Online Study is created
        with flag 'storeHistory=True'.

        Args:
            t(int, float): Time step/stamp for which the parameter distribution is evaluated
            name(str): Name of the parameter to display
            plot(bool): If True, a plot of the distribution is created
            **kwargs: All further keyword-arguments are passed to the plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the parameter values, the second one the corresponding
                probability (density) values
        """
        if not self.storeHistory:
            raise PostProcessingError('To get past parameter distributions, Online Study must be called with flag'
                                      '"storeHistory=True". Use "getCurrentParameterDistribution" instead.')

        # plotting function of Study class can only handle arrays, not lists
        self.formattedTimestamps = np.array(self.formattedTimestamps)
        self.posteriorSequence = np.array(self.posteriorSequence)

        Study.getParameterDistribution(self, t, name, plot=plot, **kwargs)

        # re-transform arrays to lists, so online study may continue to append values
        self.formattedTimestamps = list(self.formattedTimestamps)
        self.posteriorSequence = list(self.posteriorSequence)

    def getCurrentParameterDistribution(self, name, plot=False, **kwargs):
        """
        Compute the current marginal parameter distribution.

        Args:
            name(str): Name of the parameter to display
            plot(bool): If True, a plot of the distribution is created
            **kwargs: All further keyword-arguments are passed to the plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the parameter values, the second one the corresponding
                probability (density) values
        """
        # get parameter index
        paramIndex = -1
        for i, n in enumerate(self.observationModel.parameterNames):
            if n == name:
                paramIndex = i

        # check if match was found
        if paramIndex == -1:
            raise PostProcessingError('Wrong parameter name. Available options: {0}'
                                      .format(self.observationModel.parameterNames))

        axesToMarginalize = list(range(len(self.observationModel.parameterNames)))
        try:
            axesToMarginalize.remove(paramIndex)
        except ValueError:
            raise PostProcessingError('Wrong parameter index. Available indices: {}'.format(axesToMarginalize))

        x = self.marginalGrid[paramIndex]
        marginalDistribution = np.squeeze(np.apply_over_axes(np.sum, self.marginalizedPosterior, axesToMarginalize))

        if plot:
            plt.fill_between(x, 0, marginalDistribution, **kwargs)

            plt.xlabel(self.observationModel.parameterNames[paramIndex])

            # in case an integer step size for hyper-parameter values is chosen, probability is displayed
            # (probability density otherwise)
            if self.latticeConstant[paramIndex] == 1.:
                plt.ylabel('probability')
            else:
                plt.ylabel('probability density')

        return x, marginalDistribution

    def getParameterDistributions(self, name, plot=False, **kwargs):
        """
        Computes the time series of marginal posterior distributions with respect to a given model parameter. Only
        available if Online Study is created with flag 'storeHistory=True'.

        Args:
            name(str): Name of the parameter to display
            plot(bool): If True, a plot of the series of distributions is created (density map)
            **kwargs: All further keyword-arguments are passed to the plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the parameter values, the second one the sequence of
                corresponding posterior distributions.
        """
        if not self.storeHistory:
            raise PostProcessingError('To get past parameter distributions, Online Study must be called with flag'
                                      '"storeHistory=True". Use "getCurrentParameterDistribution" instead.')

        # plotting function of Study class can only handle arrays, not lists
        self.formattedTimestamps = np.array(self.formattedTimestamps)
        self.posteriorSequence = np.array(self.posteriorSequence)

        Study.getParameterDistributions(self, name, plot=plot, **kwargs)

        # re-transform arrays to lists, so online study may continue to append values
        self.formattedTimestamps = list(self.formattedTimestamps)
        self.posteriorSequence = list(self.posteriorSequence)

    def plotParameterEvolution(self, name, color='b', gamma=0.5, **kwargs):
        """
        Plots a series of marginal posterior distributions corresponding to a single model parameter, together with the
        posterior mean values. Only available if Online Study is created with flag 'storeHistory=True'.

        Args:
            name(str): Name of the parameter to display
            color: color from which a light colormap is created
            gamma(float): exponent for gamma correction of the displayed marginal distribution; default: 0.5
            kwargs: all further keyword-arguments are passed to the plot of the posterior mean values
        """
        if not self.storeHistory:
            raise PostProcessingError('To plot past parameter distributions, Online Study must be called with flag'
                                      '"storeHistory=True". Use "getCurrentParameterDistribution" instead.')

        # plotting function of Study class can only handle arrays, not lists
        self.formattedTimestamps = np.array(self.formattedTimestamps)
        self.posteriorMeanValues = np.array(self.posteriorMeanValues).T
        self.posteriorSequence = np.array(self.posteriorSequence)

        Study.plotParameterEvolution(self, name, color=color, gamma=gamma, **kwargs)

        # re-transform arrays to lists, so online study may continue to append values
        self.formattedTimestamps = list(self.formattedTimestamps)
        self.posteriorMeanValues = list(self.posteriorMeanValues.T)
        self.posteriorSequence = list(self.posteriorSequence)

    def getCurrentTransitionModelDistribution(self):
        """
        Returns the current probabilities for each transition model defined in the Online Study.

        Returns:
            ndarray: Normalized array of transition model probabilities.
        """
        return self.transitionModelDistribution

    def getCurrentTransitionModelProbability(self, transitionModel):
        """
        Returns the current posterior probability for a specified transition model.

        Args:
            transitionModel(str): Name of the transition model

        Returns:
            float: Posterior probability value for the specified transition model
        """
        transitionModelIndex = self.transitionModelNames.index(transitionModel)
        return self.transitionModelDistribution[transitionModelIndex]

    def getTransitionModelDistributions(self):
        """
        The transition model distribution contains posterior probability values for all transition models included in
        the online study. This distribution is available for all time steps analyzed. Only available if Online Study
        is created with flag 'storeHistory=True'.

        Returns:
            ndarray: Array containing the posterior probability values for all transition models included in the online
                study for all time steps analyzed
        """
        if not self.storeHistory:
            raise PostProcessingError('To get past transition model distributions, Online Study must be called with '
                                      'flag "storeHistory=True". Use "getCurrentTransitionModelDistribution" instead.')

        return np.array(self.transitionModelSequence)

    def getTransitionModelProbabilities(self, transitionModel):
        """
        Returns posterior probability values for a specified transition model. This distribution is available for all
        time steps analyzed. Only available if Online Study is created with flag 'storeHistory=True'.

        Returns:
            ndarray: Array containing the posterior probability values for the specified transition model for all time
                steps analyzed
            transitionModel(str): Name of the transition model
        """
        if not self.storeHistory:
            raise PostProcessingError('To get past transition model distributions, Online Study must be called with '
                                      'flag "storeHistory=True". Use "getCurrentTransitionModelDistribution" instead.')

        transitionModelIndex = self.transitionModelNames.index(transitionModel)
        return np.array(self.transitionModelSequence)[:, transitionModelIndex]

    def getCurrentParameterMeanValue(self, name):
        """
        Returns the posterior mean value for a given parameter of the observation model.

        Args:
            name(str): Name of the parameter

        Returns:
            float: posterior mean value
        """
        # get parameter index
        paramIndex = -1
        for i, n in enumerate(self.observationModel.parameterNames):
            if n == name:
                paramIndex = i

        # check if match was found
        if paramIndex == -1:
            raise PostProcessingError('Wrong parameter name. Available options: {0}'
                                      .format(self.observationModel.parameterNames))

        mean = np.sum(self.marginalizedPosterior*self.grid[paramIndex])
        return mean

    def getHyperParameterMeanValue(self, t, name, transitionModel=0):
        """
        Computes the mean value of the joint hyper-parameter distribution for a given hyper-parameter and transition
        model at a given time step. Only available if Online Study is created with flag 'storeHistory=True'.

        Args:
            t(int): Time step at which to compute distribution
            name(str): name of hyper-parameter
            transitionModel(int, str): Index or name of the transition model that contains the hyper-parameter;
                default: 0 (first transition model)

        Returns:
            ndarray: Array containing the mean values of all hyper-parameters of the given transition model
        """
        # find index of transition model
        if isinstance(transitionModel, str):
            transitionModelIndex = self.transitionModelNames.index(transitionModel)
        elif isinstance(transitionModel, int):
            transitionModelIndex = transitionModel
        else:
            raise PostProcessingError('Transition model must be specified by either index (int) or name (str).')

        # access hyper-parameter distribution
        try:
            hyperParameterDistribution = self.hyperParameterSequence[t]
        except IndexError:
            raise PostProcessingError('No hyper-parameter distribution found for t={}. Choose 0 <= t <= {}.'
                                      .format(t, len(self.formattedTimestamps) - 1))

        try:
            hyperParameterDistribution = hyperParameterDistribution[transitionModelIndex][:, None]
            hyperParameterValues = self.hyperParameterValues[transitionModelIndex]
            hyperGridConstants = self.hyperGridConstants[transitionModelIndex]
        except IndexError:
            raise PostProcessingError('Transition model with index {} does not exist. Options: 0-{}.'
                                      .format(transitionModelIndex, len(self.transitionModels) - 1))

        # determine index of hyper-parameter
        hpIndex = self._getHyperParameterIndex(self.transitionModels[transitionModelIndex], name)

        # compute mean value
        mean = np.sum(hyperParameterValues*hyperParameterDistribution*np.prod(hyperGridConstants), axis=0)
        return mean[hpIndex]

    def getHyperParameterMeanValues(self, name, transitionModel=0):
        """
        Computes the sequence of mean value of the joint hyper-parameter distribution for a given transition model for
        all time steps.

        Args:
            name(str, None): name of hyper-parameter (if name=None, mean values of ALL hyper-parameters are returned)
            transitionModel(int, str): Index or name of the transition model that contains the hyper-parameter;
                default: 0 (first transition model)

        Returns:
            ndarray: Array containing the sequences of mean values of the given transition model
        """
        # find index of transition model
        if isinstance(transitionModel, str):
            transitionModelIndex = self.transitionModelNames.index(transitionModel)
        elif isinstance(transitionModel, int):
            transitionModelIndex = transitionModel
        else:
            raise PostProcessingError('Transition model must be specified by either index (int) or name (str).')

        # access hyper-parameter distributions
        try:
            hyperParameterSequence = np.array([hp[transitionModelIndex].tolist()
                                               for hp in self.hyperParameterSequence])[:, :, None]
            hyperParameterValues = self.hyperParameterValues[transitionModelIndex]
            hyperGridConstants = self.hyperGridConstants[transitionModelIndex]
        except IndexError:
            raise PostProcessingError('Transition model with index {} does not exist. Options: 0-{}.'
                                      .format(transitionModelIndex, len(self.transitionModels) - 1))
        # compute mean value
        mean = np.sum(hyperParameterSequence * hyperParameterValues * np.prod(hyperGridConstants), axis=1).T
        if name is not None:
            # determine index of hyper-parameter
            hpIndex = self._getHyperParameterIndex(self.transitionModels[transitionModelIndex], name)
            return mean[hpIndex]
        else:
            return mean

    def getHyperParameterDistribution(self, t, name, transitionModel=0, plot=False, **kwargs):
        """
        Computes marginal hyper-parameter distribution of a single hyper-parameter at a specific time step in an
        OnlineStudy fit.

        Args:
            t(int): Time step at which to compute distribution
            name(str): hyper-parameter name
            transitionModel(int, str): Index or name of the transition model that contains the hyper-parameter;
                default: 0 (first transition model)
            plot(bool): If True, a bar chart of the distribution is created
            **kwargs: All further keyword-arguments are passed to the bar-plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the hyper-parameter values, the second one the
                corresponding probability (density) values
        """
        # find index of transition model
        if isinstance(transitionModel, str):
            transitionModelIndex = self.transitionModelNames.index(transitionModel)
        elif isinstance(transitionModel, int):
            transitionModelIndex = transitionModel
        else:
            raise PostProcessingError('Transition model must be specified by either index (int) or name (str).')

        # access hyper-parameter distribution
        try:
            hyperParameterDistribution = self.hyperParameterSequence[t]
        except IndexError:
            raise PostProcessingError('No hyper-parameter distribution found for t={}. Choose 0 <= t <= {}.'
                                      .format(t, len(self.formattedTimestamps)-1))

        try:
            hyperParameterDistribution = hyperParameterDistribution[transitionModelIndex]
        except IndexError:
            raise PostProcessingError('Transition model with index {} does not exist. Options: 0-{}.'
                                      .format(transitionModelIndex, len(self.transitionModels)-1))

        # get hyper-parameter index
        paramIndex = self._getHyperParameterIndex(self.transitionModels[transitionModelIndex], name)

        axesToMarginalize = list(range(len(self.hyperParameterNames[transitionModelIndex])))
        axesToMarginalize.remove(paramIndex)

        # reshape hyper-parameter grid for easy marginalization
        hyperGridSteps = [len(x) for x in self.allFlatHyperParameterValues[transitionModelIndex]]
        distribution = hyperParameterDistribution.reshape(hyperGridSteps, order='C')
        marginalDistribution = np.squeeze(np.apply_over_axes(np.sum, distribution, axesToMarginalize))

        # marginal distribution is not created by sum, but by the integral
        integrationFactor = np.prod([self.hyperGridConstants[transitionModelIndex][axis] for axis in axesToMarginalize])
        marginalDistribution *= integrationFactor

        x = self.allFlatHyperParameterValues[transitionModelIndex][paramIndex]
        if plot:
            # check if categorical
            if np.any(np.abs(np.diff(np.diff(x))) > 10 ** -10):
                plt.bar(np.arange(len(x)), marginalDistribution, align='center', width=1., **kwargs)
                plt.xticks(np.arange(len(x)), x)
                plt.ylabel('probability')
            # regular spacing
            else:
                plt.bar(x, marginalDistribution, align='center',
                        width=self.hyperGridConstants[transitionModelIndex][paramIndex],
                        **kwargs)
                if self.hyperGridConstants[transitionModelIndex][paramIndex] == 1.:
                    plt.ylabel('probability')
                else:
                    plt.ylabel('probability density')

            plt.xlabel(self.hyperParameterNames[transitionModelIndex][paramIndex])

        return x, marginalDistribution

    def getCurrentHyperParameterDistribution(self, name, transitionModel=0, plot=False, **kwargs):
        """
        Computes marginal hyper-parameter distribution of a single hyper-parameter at a specific time step in an
        OnlineStudy fit.

        Args:
            name(str): hyper-parameter name
            transitionModel(int, str): Index or name of the transition model that contains the hyper-parameter;
                default: 0 (first transition model)
            plot(bool): If True, a bar chart of the distribution is created
            **kwargs: All further keyword-arguments are passed to the bar-plot (see matplotlib documentation)

        Returns:
            ndarray, ndarray: The first array contains the hyper-parameter values, the second one the
                corresponding probability (density) values
        """
        # find index of transition model
        if isinstance(transitionModel, str):
            transitionModelIndex = self.transitionModelNames.index(transitionModel)
        elif isinstance(transitionModel, int):
            transitionModelIndex = transitionModel
        else:
            raise PostProcessingError('Transition model must be specified by either index (int) or name (str).')

        try:
            hyperParameterDistribution = self.hyperParameterDistribution[transitionModelIndex]
        except IndexError:
            raise PostProcessingError('Transition model with index {} does not exist. Options: 0-{}.'
                                      .format(transitionModelIndex, len(self.transitionModels) - 1))

        # get hyper-parameter index
        paramIndex = self._getHyperParameterIndex(self.transitionModels[transitionModelIndex], name)

        axesToMarginalize = list(range(len(self.hyperParameterNames[transitionModelIndex])))
        axesToMarginalize.remove(paramIndex)

        # reshape hyper-parameter grid for easy marginalization
        hyperGridSteps = [len(x) for x in self.allFlatHyperParameterValues[transitionModelIndex]]
        distribution = hyperParameterDistribution.reshape(hyperGridSteps, order='C')
        marginalDistribution = np.squeeze(np.apply_over_axes(np.sum, distribution, axesToMarginalize))

        # marginal distribution is not created by sum, but by the integral
        integrationFactor = np.prod([self.hyperGridConstants[transitionModelIndex][axis] for axis in axesToMarginalize])
        marginalDistribution *= integrationFactor

        x = self.allFlatHyperParameterValues[transitionModelIndex][paramIndex]
        if plot:
            # check if categorical
            if np.any(np.abs(np.diff(np.diff(x))) > 10 ** -10):
                plt.bar(np.arange(len(x)), marginalDistribution, align='center', width=1., **kwargs)
                plt.xticks(np.arange(len(x)), x)
                plt.ylabel('probability')
            # regular spacing
            else:
                plt.bar(x, marginalDistribution, align='center',
                        width=self.hyperGridConstants[transitionModelIndex][paramIndex],
                        **kwargs)
                if self.hyperGridConstants[transitionModelIndex][paramIndex] == 1.:
                    plt.ylabel('probability')
                else:
                    plt.ylabel('probability density')

            plt.xlabel(self.hyperParameterNames[transitionModelIndex][paramIndex])

        return x, marginalDistribution

    def getHyperParameterDistributions(self, name, transitionModel=0):
        """
        Computes marginal hyper-parameter distributions of a single hyper-parameter for all time steps in an OnlineStudy
        fit.

        Args:
            name(str): hyper-parameter name
            transitionModel(int, str): Index or name of the transition model that contains the hyper-parameter;
                default: 0 (first transition model)

        Returns:
            ndarray, ndarray: The first array contains the hyper-parameter values, the second one the
                corresponding probability (density) values (first axis is time).
        """
        # find index of transition model
        if isinstance(transitionModel, str):
            transitionModelIndex = self.transitionModelNames.index(transitionModel)
        elif isinstance(transitionModel, int):
            transitionModelIndex = transitionModel
        else:
            raise PostProcessingError('Transition model must be specified by either index (int) or name (str).')

        # access hyper-parameter distributions
        try:
            hyperParameterSequence = np.array(self.hyperParameterSequence)[:, transitionModelIndex]
        except IndexError:
            raise PostProcessingError('Transition model with index {} does not exist. Options: 0-{}.'
                                      .format(transitionModelIndex, len(self.transitionModels) - 1))

        paramIndex = self._getHyperParameterIndex(self.transitionModels[transitionModelIndex], name)

        # marginalize the hyper-posterior probabilities
        hpv = np.array(self.hyperParameterValues[transitionModelIndex])
        paramHpv = hpv[:, paramIndex]
        uniqueValues = np.sort(np.unique(paramHpv))

        marginalDistribution = []
        for value in uniqueValues:
            marginalDistribution.append([])
            indices = np.where(paramHpv == value)
            for hp in hyperParameterSequence:
                probabilities = hp[indices]
                marginalDistribution[-1].append(np.sum(probabilities))
        marginalDistribution = np.array(marginalDistribution).T

        # renormalize marginal probability density
        temp = list(self.hyperGridConstants[transitionModelIndex])
        del temp[paramIndex]
        marginalDistribution *= np.prod(temp)

        return uniqueValues, marginalDistribution

    def plotHyperParameterEvolution(self, name, transitionModel=0, color='b', gamma=0.5, **kwargs):
        """
        Plot method to display a series of marginal posterior distributions corresponding to a single model parameter.
        This method includes the removal of plotting artefacts, gamma correction as well as an overlay of the posterior
        mean values.

        Args:
            name(str): hyper-parameter name
            transitionModel(int, str): Index or name of the transition model that contains the hyper-parameter;
                default: 0 (first transition model)
            color: color from which a light colormap is created
            gamma(float): exponent for gamma correction of the displayed marginal distribution; default: 0.5
            kwargs: all further keyword-arguments are passed to the plot of the posterior mean values
        """
        # find index of transition model
        if isinstance(transitionModel, str):
            transitionModelIndex = self.transitionModelNames.index(transitionModel)
        elif isinstance(transitionModel, int):
            transitionModelIndex = transitionModel
        else:
            raise PostProcessingError('Transition model must be specified by either index (int) or name (str).')

        # get sequence of hyper-parameter distributions
        uniqueValues, marginalDistribution = self.getHyperParameterDistributions(name, transitionModelIndex)
        paramIndex = self._getHyperParameterIndex(self.transitionModels[transitionModelIndex], name)

        # compute hyper-posterior mean values
        meanValues = self.getHyperParameterMeanValues(None, transitionModelIndex)[paramIndex]

        # clean up very small probability values, as they may create image artefacts
        pmax = np.amax(marginalDistribution)
        marginalDistribution[marginalDistribution < pmax * (10 ** -20)] = 0

        plt.imshow(marginalDistribution.T ** gamma,
                   origin=0,
                   cmap=createColormap(color),
                   extent=[self.formattedTimestamps[0], self.formattedTimestamps[-1]] + [uniqueValues[0], uniqueValues[-1]],
                   aspect='auto')

        # set default color of plot to black
        if ('c' not in kwargs) and ('color' not in kwargs):
            kwargs['c'] = 'k'

        # set default linewidth to 1.5
        if ('lw' not in kwargs) and ('linewidth' not in kwargs):
            kwargs['lw'] = 1.5

        plt.plot(self.formattedTimestamps, meanValues, **kwargs)

        plt.ylim(uniqueValues[0], uniqueValues[-1])
        plt.ylabel(self.hyperParameterNames[transitionModelIndex][paramIndex])
        plt.xlabel('time step')
