#!/usr/bin/env python
# emacs: -*- mode: python; py-indent-offset: 4; indent-tabs-mode: nil -*-
# vi: set ft=python sts=4 ts=4 sw=4 et:
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
#
#   See COPYING file distributed along with the PyMVPA package for the
#   copyright and license terms.
#
### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ### ##
"""Script to run most common analysis scenarios (e.g. cross-validation, searchlights)
"""

# Import minimal needed amount to get process going so we could
# QUICKLY parse and verify command line arguments
import glob
import mvpa2
from mvpa2.base import *
from mvpa2.misc.cmdline import *

import time
from os.path import exists, join as joinpath


def parse_cmdline():
    parser.usage = """
    %s [options] <NIfTI samples> <targets+blocks> <NIfTI mask> [<output_prefix>]

    where targets+blocks is a text file that lists the class label and the
    associated block of each data sample/volume as a tuple of two integer
    values (separated by a single space). -- one tuple per line.""" \
    % sys.argv[0]

    opts.add('preproc',
             [opt.targets_sa,
              opt.chunks_sa,
              opt.baseline_conditions,
              opt.zscore,
              opt.exclude_conditions,
              opt.include_conditions,
              opt.mean_group_sample,
              #opt.tr,
              #opt.detrend,
              ], "Preprocessing options")

    parser.option_groups = [opts.SVM, opts.KNN, opts.general, opts.preproc, opts.common]

    # Set a set of available classifiers for this example
    opt.clf.choices=['gnb', 'm1nn', 'knn', 'lin_nu_svmc', 'svm', 'lin_C_svmc', 'rbf_nu_svmc']
    opt.clf.default='gnb' # Fast

    # TODO: Need to be grouped nicely
    parser.add_options([
        opt.clf,

        ## Will deduce automagically
        ## Option("--attrfile-has-header",
        ##        action="store_true", dest="attrfile_has_header",
        ##        help="Provided attributes file carries header in the "
        ##        "first line.  All attributes will become .sa's"),

        Option("--analysis",
               type="choice", dest="analysis", default='preprocess',
               choices=['crossvalidation', 'searchlight', 'preprocess'],
               help="Type of analysis to perform on data"),

        # TODO
        #  - mask for searchlight ids if not full brain mask is desired to be analyzed

        Option("--dataset-summary",
               action="store_true", dest="print_dataset_summary",
               help="Print dataset summary after preprocessing"),

        Option("--log-output",
               action="store_true", dest="log_output",
               help="Log output into a file. Mention that only 'verbose' "
               "messages generated by the script are logged, not stderr "
               "or other Python output"),

        Option("--generic-searchlight",
               action="store_true", dest="use_generic_searchlight",
               help="For GNB and M1NN ad-hoc efficient searchlight "
               "implementations are used by default.  This option "
               "would enforce use of a generic searchlight"),

        Option("--cache-filename",
               action="store", dest="cache_filename",
               help="Filename to store generated dataset and reload it if "
               " present instead of loading from original volumes and"
               " preprocessing"),

        ])

    (options, files) = parser.parse_args()
    return files, options               # more logical order


def get_files(filename_pattern):
    return sorted(glob.glob(filename_pattern))


def validate_options(files, options):
    """Let's verify provided options and do some basic checks

    So we could fail earlier than later
    """

    # Let's spit out all errors at once -- it is annoying to redo
    bad_options_msgs = []

    if options.baseline_conditions:
        if len(options.baseline_conditions) > 1:
            bad_options_msgs.append(
                "--baseline-conditions must list only a single condition")

    if options.include_conditions and options.exclude_conditions:
        bad_options_msgs.append(
            "Specify either --exclude-conditions OR --include-conditions"
            " NOT BOTH")

    if not len(files) in [3, 4]:
        bad_options_msgs.append("Please provide 3 or 4 files in the command line")
    else:
        # Quickly check provided files for obvious problems,
        # e.g. different dimensionality
        try:
            import nibabel as nib
            nis = [nib.load(f).get_header()
                   for f in get_files(files[0]) + [files[2]]]
        except Exception, e:
            bad_options_msgs.append(
                "Failed to open input volumes. Error was: %s" % e)
        else:
            shapes = [ni.get_data_shape() for ni in nis]
            # First one should have at least the same # of dimensions
            #if shapes[1]
            # TODO

    if bad_options_msgs:
        sys.stderr.write("There were errors in options specification:"
                         + '\n E: '.join([''] + bad_options_msgs)
                         + '\n')
        raise SystemExit(1)

    # assure all "files" specs present
    #
    # Or may be we should just enforce suffix?  besides simple
    # classification there would be no use case where printing to the
    # screen would be the desired output format
    if len(files) < 4:
        files.append("")              # so output prefix was not specified


def run_analysis(files, options):
    # Finally import the rest of the suite
    verbose(1, "Importing PyMVPA v. %s suite" % mvpa2.__version__)

    import numpy as np
    import mvpa2.suite as mv

    verbose(1, "Loading data")
    verbose(3, "Files:  %s" % '\n     '.join([''] +files))
    verbose(2, "Analysis options: %s" % options)

    # data filename
    dfile = files[0]
    # text file with targets and block definitions (chunks)
    cfile = files[1]
    # mask volume filename
    mfile = files[2]
    # output prefix (if was defined)
    ofile = files[3]

    if options.cache_filename and exists(options.cache_filename):
        verbose(2, "Reloading data from %s" % options.cache_filename)
        data = mv.h5load(options.cache_filename)
    else:
        # read conditions into an array (assumed to be two columns of integers)
        # TODO: We need some generic helper to read conditions stored in some
        #       common formats
        verbose(2, "Reading conditions from file %s" % cfile)
        if len(open(cfile).readline().split()) > 2:
            verbose(3, "Detected more than 3 columns, assuming present header")
            attrs = mv.ColumnData(cfile, header=True)
        else:
            attrs = mv.SampleAttributes(cfile, literallabels=True)

        # dfile could be a glob pattern
        dfiles = get_files(dfile) # sorted(glob.glob(dfile))
        verbose(2, "Loading %d volume files %s" % (len(dfiles), ', '.join(dfiles)))
        data = mv.fmri_dataset(dfiles,
                               targets=attrs.targets,
                               chunks=attrs.chunks,
                               mask=mfile)

        # Assign possible additional attributes
        if len(attrs) > 2:
            for k, v in attrs.iteritems():
                if not k in data.sa:
                    data.sa[k] = v

        verbose(1, "Preprocessing")

        # First z-score since baseline condition might be removed
        # later on in exclude/include_conditions handling
        if options.zscore:
            verbose(2, "Zscoring data samples")
            if len(options.baseline_conditions):
                # it should have only 1 as checked above
                param_est = options.baseline_conditions[0]
            else:
                param_est = None

            # TODO: verify if enforcing dtype here is needed/desired
            mv.zscore(data, chunks_attr=options.chunks_sa,
                      param_est=param_est,
                      dtype='float32')

        # Exclude or keep only some conditions if requested
        # TODO: this is like we used to have 'select()' -- so we just
        #       need to move this functionality back and replace here
        #       with 1,2 lines
        if options.exclude_conditions or options.include_conditions:
            # resultant mask matching specified conditions for
            # inclusion or exclusion
            bmask = np.ones(len(data), dtype=bool)
            include = False
            if options.include_conditions:
                conditions_ = options.include_conditions
                include = True

            # well we had a check that not both specified but never
            # hurts to double check.
            #
            # TODO: think may be we should allow both with first
            # treating includes and then excludes, e.g. include
            # face/house but exclude chunks > 5
            if options.exclude_conditions:
                if include:
                    raise RuntimeError("Somehow our cmdline args check failed "
                                       "and we got both exclude and include "
                                       "conditions here where it shouldn't "
                                       "happen")
                conditions_ = options.exclude_conditions

            # we should get an overlap between those specified by
            # multiple .sa's
            # So per each sa we collect union and then intersect it
            for sa, values in conditions_:
                bmask_sa = np.zeros(len(data), dtype=bool)
                sa_type = data.sa[sa].value.dtype.type
                for v in values:
                    # converting data types of v here to match the one
                    # as stored in ds.sa[sa] since from cmdline everything is
                    # a string
                    bmask_v = data.sa[sa].value == sa_type(v)
                    if not np.sum(bmask_v):
                        verbose(1, "Listed value %r was not found in .sa.%s"
                                % (v, sa))
                        continue
                    bmask_sa |= bmask_v
                bmask &= bmask_sa

            if include:
                verbose(2, "Including %d samples matching the include_conditions=%s"
                        % (np.sum(bmask), options.include_conditions))
                data = data[bmask]
            else:
                verbose(2, "Excluding %d samples matching the exclude_conditions=%s"
                        % (np.sum(bmask), options.exclude_conditions))
                data = data[~bmask]


        if options.mean_group_sample:
            mgs_args = (options.chunks_sa, options.targets_sa)
            verbose(2, "Computing mean sample per each %s/%s" % mgs_args)
            # TODO: verify if enforcing dtype here is needed/desired
            data = data.get_mapped(mv.mean_group_sample(mgs_args))

        if options.cache_filename:
            verbose(2, "Storing dataset into %s" % options.cache_filename)
            mv.h5save(options.cache_filename, data)

        pass

    if verbose.level > 2 or options.print_dataset_summary:
        verbose(1, data.summary())

    if options.analysis == 'preprocess':
        # We are done
        return

    verbose(1, "Creating analysis pipeline")

    # TODO: may be we just need --clf-expr which gets evaluated
    #       and thus would be the most flexible and may be even
    #       simplest way to specify arbitrary classifier

    # TODO: cmdline for other types
    partitioner = mv.NFoldPartitioner(cvtype=options.crossfolddegree)

    # Choice of the learner
    if options.clf == 'm1nn':
        # TODO: options for parameters
        clf = mv.kNN(1) # M1NN()
    elif options.clf == 'gnb':
        # TODO: options for parameters
        clf = mv.GNB()
    elif options.clf == 'knn':
        clf = mv.kNN(k=options.knearestdegree)
    elif options.clf == 'lin_nu_svmc':
        clf = mv.LinearNuSVMC(nu=options.svm_nu)
    elif options.clf in ['lin_C_svmc', 'svm']:
        clf = mv.LinearCSVMC(C=options.svm_C)
    elif options.clf == 'rbf_nu_svmc':
        clf = mv.RbfNuSVMC(nu=options.svm_nu)
    else:
        raise ValueError, 'Unknown classifier type: %s' % `options.clf`

    if options.targets_sa != 'targets':
        verbose(3, "Assigning space=%r for %s to operate on"
                % (options.targets_sa, clf))
        clf.space = options.targets_sa

    verbose(2, "Using '%s' classifier" % options.clf)

    verbose(3, "Assigning a measure to be CrossValidation")
    # compute N-1 cross-validation with the selected classifier in
    # each sphere
    cv = mv.CrossValidation(clf, partitioner)

    #
    # Define the final measure to estimate
    #
    if options.analysis == 'searchlight':

        # Keyword arguments for the Searchlights
        slkwargs = dict(
            #TODO roi_ids=center_ids,
            #TODO nproc=8,
            #TODO errorfx=mean_mismatch_error
            #TODO enable_ca=['null_t'] etc
            )

        if not options.use_generic_searchlight \
               and options.clf in ['gnb', 'm1nn']:
            # There might be an ad-hoc fast one

            # Using explicit if instead of more concise dictionary
            # lookup (commented out) since M1NNSearchlight might not
            # be yet available
            #SearchlightClass = \
            #    {'gnb': mv.GNBSearchlight,
            #     'm1nn': mv.M1NNSearchlight}[options.clf]
            if options.clf == 'gnb':
                SearchlightClass =  mv.GNBSearchlight
            else:
                # must tbe m1nn since above have only 2 choices
                SearchlightClass = mv.M1NNSearchlight

            # They require partitioner to be specified
            slargs = (clf, partitioner)
            #TODO slkwargs.update({})
        else:
            SearchlightClass = mv.Searchlight
            # We just need to provide the measure, in this case cv
            slargs = (cv,)
            #TODO slkwargs.update({})
            pass

        # Add the query engine
        # TODO: handle ER datasets
        slargs = slargs + (
            mv.IndexQueryEngine(voxel_indices=mv.Sphere(options.radius)),)

        verbose(3, "Generating a Searchlight instance")
        # contruct searchlight with 5mm radius
        # this assumes that the spatial pixdim values in the source NIfTI file
        # are specified in mm
        measure = SearchlightClass(*slargs, **slkwargs)

    elif options.analysis == 'crossvalidation':
        measure = cv

    else:
        raise ValueError("Unknown type of analysis %s" % options.analysis)

    #
    # Finally run
    #
    verbose(1, "Estimating the measure on loaded data")
    verbose(2, "Measure: %r" % measure)

    t0 = time.time()
    results = measure(data)
    results.a['proccessing_time'] = time.time() - t0

    if not ofile is None:
        verbose(1, "Storing results using prefix %r" % ofile)
        print np.histogram(results.samples)

        # We could "always" save to HDF5, which also would assure
        # that we have output directory
        mv.h5save(ofile + 'results.hdf5', results, mkdir=True)

        # TODO: some might not be mapped to Nifti that easily
        # map the result vector back into a nifti image
        rimg = mv.map2nifti(data, results)
        # save to file
        rimg.to_filename(ofile + 'results.nii.gz')

    else:
        verbose(1, "Results:\n%r" % results)



def main(files=None, options=None):
    """ Wrapped into a function call for easy profiling later on
    """

    # If files or options were not specified
    if files is None or options is None:
        files, options = parse_cmdline()
    elif files or options:
        raise ValueError("Please specify both files and options, "
                         "or none of them, so they are parsed from "
                         "the command line")

    validate_options(files, options)

    # If output prefix was specified -- lets place verbose output into
    # the logfile as well
    if options.log_output:
        logfile = files[3] + '%s.log' \
                  % time.strftime('%Y%m%d:%H%M',
                                  time.localtime(time.time()))
        verbose(2, "Logging to %s" % logfile)
        verbose.handlers = [ sys.stdout, open(logfile, 'a') ]

    run_analysis(files, options)



if __name__ == "__main__":
    main()
