#!/usr/bin/env python3
#
# Copyright 2021-2023, Julian Catchen <jcatchen@illinois.edu>
#
# This file is part of Stacks.
#
# Stacks is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Stacks is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Stacks.  If not, see <http://www.gnu.org/licenses/>.
#

import argparse
import os, sys
import gzip

#
# Global configuration variables.
#
in_path  = ""
out_path = ""
section  = ""
pretty   = False

def parse_command_line():
    global in_path
    global out_path
    global section
    global pretty

    usage = '''
    %(prog)s logfile [section]
    %(prog)s [--pretty] [--out-path path] logfile [section]
    cat logfile | %(prog)s [--pretty] [--section section]'''

    desc  = '''Export a paricular section of a Stacks log or distribs file. If you \
    supply a log path alone, %(prog)s will print the available sections to output. \
    The log file can also be supplied via stdin.'''

    p = argparse.ArgumentParser(description=desc, usage=usage)

    #
    # Add options.
    #
    p.add_argument("-p", "--pretty", action="store_true",
                   help="Output data as a table with columns lined up.")
    p.add_argument("-o", "--out-path", type=str, metavar='path',
                   help="Path to output file.")
    p.add_argument("-s", "--section", type=str, metavar='section', 
                   help="Name of section to output from the log file.")
    
    #
    # Parse the command line
    #
    args, posargs = p.parse_known_args()

    if args.pretty != None:
        pretty = args.pretty
    if args.section != None:
        section = args.section
    if args.out_path != None:
        out_path = args.out_path

    if len(posargs) > 0:
        in_path = posargs[0]
    if len(posargs) > 1 and args.section == None:
        section = posargs[1]

    #
    # Test input file path.
    #
    if len(in_path) > 0 and os.path.exists(in_path) == False:
        print("Input log file path does not exist ('{}')".format(in_path), file=sys.stderr)
        p.print_help()
        sys.exit()


def scan_log_file(fh):
    sections = []

    for line in fh:
        if line[0:6] == "BEGIN ":
            section = line[6:]
            section = section.strip("\n")
            sections.append(section)
    fh.seek(0)

    return sections


def find_section(sections, section):
    section_hash = {}
    for s in sections:
        section_hash[s] = 0

    if section in section_hash:
        return section

    slen = len(section)

    for s in section_hash:
        if s[0:slen] == section:
            section_hash[s] += 1

    sorted_sections = sorted(section_hash.items(), key=lambda x:x[1], reverse=True)

    s = ""
    if sorted_sections[0][1] == 0:
        print("Section '{}' was not found.".format(section), file=sys.stderr)
        
    elif sorted_sections[0][1] == 1 and sorted_sections[1][1] == 0:
        s = sorted_sections[0][0]
        
    elif sorted_sections[0][1] > 0 and sorted_sections[1][1] > 0:
        print("Section '{}' is ambiguous and could refer to more than one section of the log:".format(section), file=sys.stderr)
        for i in range(len(sorted_sections)):
            if sorted_sections[i][1] > 0:
                print("    " + sorted_sections[i][0], file=sys.stderr)
            else:
                break

    return s


def pretty_print(stream):
    colcnts = []
    for line in stream:
        if line[0] == '#':
            continue

        cols = line.split('\t')

        while len(colcnts) < len(cols):
            colcnts.append(0)

        for i in range(len(cols)):
            if colcnts[i] < len(cols[i]):
                colcnts[i] = len(cols[i])

    s = ""
    for line in stream:
        if line[0] == '#':
            s += line + "\n"
            continue
        
        cols = line.split('\t')
        for i in range(len(cols)):
            s += str(cols[i]).ljust(colcnts[i]) + "  "
        s += "\n"

    return s


def output_section(in_fh, out_fh, sections, section):

    start_output = False
    s = []

    for line in in_fh:
        line = line.strip("\n")

        if start_output == False and line == "BEGIN " + section:
            start_output = True
            continue
        
        if start_output == True:
            if line != "END " + section:
                s.append(line)
            else:
                if pretty == True:
                    s = pretty_print(s)
                    out_fh.write(s)
                else:
                    out_fh.write("\n".join(s) + "\n")
                return


def stream_section(in_fh, out_fh, section):

    start_output = False
    s = []

    for line in in_fh:
        line = line.strip("\n")

        if len(section) == 0 and line[0:6] == "BEGIN ":
            print(line[6:])
            continue
        
        if start_output == False and line == "BEGIN " + section:
            start_output = True
            continue
        
        if start_output == True:
            if line != "END " + section:
                s.append(line)
            else:
                if pretty == True:
                    s = pretty_print(s)
                    out_fh.write(s)
                else:
                    out_fh.write("\n".join(s) + "\n")
                return

    if len(section) > 0 and start_output == False:
        print("Section '{}' was not found.".format(section), file=sys.stderr)
        return


def main():
    parse_command_line()

    #
    # Open input log file, if no file path is given, assume stdin.
    #
    in_fh  = None
    out_fh = None

    if len(in_path) > 0:
        if in_path[-3:] == ".gz":
            in_fh = gzip.open(in_path, 'rb')
        else:
            in_fh = open(in_path)
    else:
        in_fh = sys.stdin

    #
    # Set output file handle
    #
    if len(out_path) > 0:
        out_fh = open(out_path)
    else:
        out_fh = sys.stdout

    if in_fh == sys.stdin:
        stream_section(in_fh, out_fh, section)
    else:
        #
        # Scan the log file for section headings
        #
        sections = scan_log_file(in_fh)

        if len(sections) == 0:
            print("No printable sections in file, '{}'.".format(in_path))
            return

        if len(section) == 0:
            for s in sections:
                out_fh.write(s + "\n")
            return

        #
        # If a section was supplied, but it is ambiguous, try to match it.
        #
        found_section = find_section(sections, section)

        if len(found_section) == 0:
            return
    
        output_section(in_fh, out_fh, sections, found_section)


#                                                                              #
#------------------------------------------------------------------------------#
#                                                                              #
if __name__ == "__main__":
    main()
