#!/usr/bin/env python3
#
# Copyright 2021 VMware, Inc.
# SPDX-License-Identifier: Apache-2.0
#
#------------------------------------------------------------------------
# The gpmemreport utility interprets the gpmemwatcher output.
#
# It will identify postgres processes in the following states:
# - postgres processes in <defunct> state (zombies)
# - postgres processes in sleep state - uninterruptible sleep (usually IO)
# - postgres processes with greater than 50% system CPU usage
# - postgres processes with greater than 5% system memory usage
#
# Segment summaries will include:
#
#   Total (MB)
#   Session (MB)
#   Background (MB)
#   # Active
#   # Idle
#
# Reference this Knowledge Base article for usage information
# https://community.pivotal.io/s/article/how-to-interpret-the-memwatcher-output-memreport?language=en_US
#------------------------------------------------------------------------

import gzip
import itertools
import os
import sys
from optparse import OptionParser
import re
import string
from collections import defaultdict
from datetime import datetime
from gppylib.commands import gp

timestamp_string = '>>>%y:%m:%d:%H:%M:%S<<<\n'
datetime_format = '%Y-%m-%d %H:%M:%S'
gpverstr = " ".join(gp.GpVersion.local(None, os.getenv('GPHOME', None)).split()[3:])

class ProcessListing(object):
    all_procs = defaultdict()

    def __init__(self, ps_dictionary):
        self.pid = ps_dictionary['PID']
        self.ppid = ps_dictionary['PPID']
        self.res_size = ps_dictionary['RSS']
        self.command = ps_dictionary['COMMAND']
        self._psdict = ps_dictionary
        ProcessListing.all_procs[self.pid] = self

    @classmethod
    def process(cls, ps_dictionary):
        cmd = ps_dictionary['COMMAND']
        if re.match('^\S*postgres\S*\s', cmd) is not None:
            if cmd.startswith('postgres:'):
                if re.search(r'con\d+', cmd):
                    return GPDBQueryProcess(ps_dictionary)
                else:
                    return GPDBBackgroundProcess(ps_dictionary)
            elif '-D' in cmd and '-p' in cmd:
                return PostmasterProcessListing(ps_dictionary)
        return ProcessListing(ps_dictionary)

    @classmethod
    def reset(cls):
        cls.all_procs = dict()
        for subcls in cls.__subclasses__():
            subcls.reset()

    @classmethod
    def report(cls, rpt_file):
        print('NODE SUMMARY', file=rpt_file)
        print('============\n', file=rpt_file)
        rows = cls._reportRows()
        for row in rows:
            fmt = row[0]
            for data in row[1]:
                print(fmt.format(*data), file=rpt_file)

    @classmethod
    def _reportRows(cls):
        report = [('{0:<50}{1:<}', [('Total Number of Processes:', len(list(cls.all_procs.values())))])]
        report.append(('{0:<50}{1:<10.4f}', [
            ('Total Memory used (MB)', sum(float(p.res_size) for p in list(cls.all_procs.values())) / 1024)]))

        # Now let's see if we have any zobmie or uninterruptible sleep
        trouble = [x for x in list(cls.all_procs.values()) if x.isZombie()]
        if trouble:
            rows = []
            report.append(('\n\n{0}\n', [('Zombie processes\n----------------',)]))
            fmt = '{0:<10}{1:<}'
            rows.append(['PID', 'COMMAND'])
            rows = rows + [(x.pid, x.command.rstrip()) for x in trouble]
            report.append((fmt, rows))

        trouble = [x for x in list(cls.all_procs.values()) if x.isIOSleep()]
        if trouble:
            rows = []
            report.append(('\n\n{0}\n', [('I/O sleep processes\n-------------------',)]))
            fmt = '{0:<10}{1:>10}  {2:<10} {3:<}'
            rows.append(['PID', 'TIME', 'WCHAN', 'COMMAND'])
            rows = rows + [(x.pid, x.time(), x.wchan(), x.command.rstrip()) for x in trouble]
            report.append((fmt, rows))

        # Now for high CPU processes (high being > 50%)
        trouble = [x for x in list(cls.all_procs.values()) if x.pcpu() >= 50]
        if trouble:
            report.append(('\n\n{0}\n', [('Processes with CPU > 50%\n------------------------------',)]))
            rows = []
            fmt = '{0:<10}{1:<6}{2:<}'
            rows.append(['PID', '%CPU', 'COMMAND'])
            rows = rows + [(x.pid, x.pcpu(), x.command.rstrip()) for x in trouble]
            report.append((fmt, rows))

        # and finally where %MEM > 5%
        trouble = [x for x in list(cls.all_procs.values()) if x.pmem() >= 5]
        if trouble:
            report.append(('\n\n{0}\n', [('Processes with %MEM > 5%\n--------------------',)]))
            rows = []
            fmt = '{0:<10}{1:<6}{2:<}'
            rows.append(['PID', '%%MEM', 'COMMAND'])
            rows = rows + [(x.pid, x.pmem(), x.command.rstrip()) for x in trouble]
            report.append((fmt, rows))
        return report

    def time(self):
        return self._psdict['TIME']

    def wchan(self):
        return self._psdict['WCHAN']

    def pmem(self):
        return float(self._psdict['%MEM'])

    def pcpu(self):
        return float(self._psdict['%CPU'])

    def isZombie(self):
        state = self._psdict['STAT']
        return 'Z' in state

    def isIOSleep(self):
        return 'D' in self._psdict['STAT']

    def procType(self):
        return "OS"


class PostmasterProcessListing(ProcessListing):
    segments = dict()
    pending_procs = defaultdict(list)

    def __init__(self, ps_dictionary):
        super(PostmasterProcessListing, self).__init__(ps_dictionary)
        self.memTotal = self.res_size
        self.bgProcs = []
        self.queries = defaultdict(list)
        match = re.search(r'(^.*\s-p\s)(\d{4,6})(.*$)', ps_dictionary['COMMAND'])
        self.segID = int(match.group(2))

        PostmasterProcessListing.segments[self.segID] = self

    @classmethod
    def reset(cls):
        cls.segments = dict()
        cls.pending_procs = defaultdict(list)

    @classmethod
    def addChildProcess(cls, procListing):
        if procListing.ppid in cls.segments:
            cls.segments[procListing.ppid]._addChildProcess(procListing)
        else:
            cls.pending_procs[procListing.ppid].append(procListing)

    @classmethod
    def report(cls, rpt_file):
        GPDBQueryProcess.linkWithSegmentProcsess()
        GPDBBackgroundProcess.linkWithSegmentProcsess()

        print('\n\nSEGMENT SUMMARY', file=rpt_file)
        print('===============\n', file=rpt_file)

        segs = sorted(list(PostmasterProcessListing.segments.values()), key=(lambda x: x.segID))

        report_rows = [x._reportRow() for x in segs]
        print('{0:<15}{1:>15}{2:>15}{3:>20}{4:>15}{5:>15}'.format('Segment', 'Total (MB)', 'Session (MB)',
                                                                  'Background (MB)', '# Active', '# Idle'),
              file=rpt_file)

        fmt = '{0:<15}{1:>15.2f}{2:>15.2f}{3:>20.2f}{4:>15}{5:>15}'
        for r in report_rows:
            print(fmt.format(*r), file=rpt_file)

        query_values = {}
        for (index, x) in enumerate(segs):
            for pair in x._getQueryReport():
                if not pair[0] in query_values:
                    query_values[pair[0]] = [0.0] * len(segs)
                query_values[pair[0]][index] += pair[1]

        print('\n\n', file=rpt_file)
        fmt = '{0:20}'
        for x in range(1, len(segs) + 1):
            fmt += '{%d:>10}' % (x)
        print(fmt.format('', *['seg' + str(x.segID) for x in segs]), file=rpt_file)
        print('-' * (30 + (10 * len(segs))), file=rpt_file)
        fmt = '{0:20}'
        for x in range(1, len(segs) + 2):
            fmt += '{%d:10.2f}' % (x)
        for sess, totals in list(query_values.items()):
            val = [sess]
            val.extend(totals)
            val.append(sum(totals))
            print(fmt.format(*val), file=rpt_file)

    def _getQueryReport(self):
        return [(sess, sum(float(p.res_size) for p in procs) / 1024) for sess, procs in list(self.queries.items())]

    def _addChildProcess(self, procListing):
        if procListing.procType() == 'QUERY':
            self.addQueryProcess(procListing)
        elif procListing.procType() == 'BG':
            self.addBGProcess(procListing)
        else:
            print(procListing.procType())

    def addQueryProcess(self, qProc):
        key = qProc.sessionKey()
        self.queries[key].append(qProc)

    def addBGProcess(self, bProc):
        self.bgProcs.append(bProc)

    def _reportRow(self):
        all_queries = list(itertools.chain(*list(self.queries.values())))
        sess_mem = sum(float(q.res_size) for q in all_queries) / 1024
        bg_mem = float(sum(float(q.res_size) for q in self.bgProcs) / 1024)
        total = float(self.res_size) / 1024 + sess_mem + bg_mem
        num_idle = len([x for x in all_queries if x.isIdle()])
        num_active = sum(not x.isIdle() for x in all_queries)

        return (self.segID, total, sess_mem, bg_mem, num_active, num_idle)


class GPDBQueryProcess(ProcessListing):
    _myprocs = []

    def __init__(self, ps_dictionary):
        if 'COMMAND' not in ps_dictionary:
            print(ps_dictionary)
        super(GPDBQueryProcess, self).__init__(ps_dictionary)
        session_pattern = re.compile(r'con\d+')
        self.session = session_pattern.search(ps_dictionary['COMMAND']).group()
        self.segment = None
        cmd_val = re.search(r'cmd\d+', ps_dictionary['COMMAND'])
        if cmd_val:
            self.pgcmd = cmd_val.group()
        elif 'Distributed' in ps_dictionary['COMMAND']:
            self.pgcmd = 'dtx'
        else:
            self.pgcmd = 'idle'
        GPDBSession.addQuery(self)
        GPDBQueryProcess._myprocs.append(self)

    @classmethod
    def reset(cls):
        cls._myprocs = []

    @classmethod
    def linkWithSegmentProcsess(cls):
        for p in cls._myprocs:
            try:
                proc = p
                while not isinstance(proc, PostmasterProcessListing):
                    proc = PostmasterProcessListing.all_procs[proc.ppid]
                proc.addQueryProcess(p)
                p.segment = proc.segID
            except KeyError :
                print(
                    "Unable to find postmaster proces for line:\n\tPID: {0}\n\tPPID: {1}\n\tCMD: {2} ".format(
                        p.pid, p.ppid, p.command))
                continue

    def sessionKey(self):
        return ':'.join([self.session, self.pgcmd])

    def procType(self):
        return 'QUERY'

    def isIdle(self):
        return self.pgcmd == 'idle'


class GPDBBackgroundProcess(ProcessListing):
    _myprocs = []

    def __init__(self, ps_dictionary):
        super(GPDBBackgroundProcess, self).__init__(ps_dictionary)
        GPDBBackgroundProcess._myprocs.append(self)

    @classmethod
    def reset(cls):
        cls._myprocs = []

    @classmethod
    def linkWithSegmentProcsess(cls):
        for p in cls._myprocs:
            try:
                proc = p
                while not isinstance(proc, PostmasterProcessListing):
                    proc = ProcessListing.all_procs[proc.ppid]
                proc.addBGProcess(p)
            except KeyError :
                print(
                    "Unable to find postmaster proces for line:\n\tPID: {0}\n\tPPID: {1}\n\tCMD: {2}".format(
                        p.pid, p.ppid, p.command))
                continue

    def procType(self):
        return 'BG'


class GPDBSession(object):
    sessions = dict()

    def __init__(self, sess_id):
        self.sessionID = sess_id
        self.queryProcesses = defaultdict(list)

    @classmethod
    def reset(cls):
        cls.sessions = dict()

    @classmethod
    def addQuery(cls, queryProcess):
        key = queryProcess.sessionKey()
        if key in GPDBSession.sessions:
            sess = GPDBSession.sessions[key]
        else:
            sess = GPDBSession(key)
            cls.sessions[key] = sess
        sess._adddQueryProcess(queryProcess)

    @classmethod
    def report(cls):
        print("\n\nGPDB SESSION REPORT")
        print(list(cls.sessions.keys()))
        for k in sorted(cls.sessions.keys()):
            cls.sessions[k]._report()

    def _adddQueryProcess(self, queryProcess):
        self.queryProcesses[queryProcess.segment].append(queryProcess)

    def _report(self):
        total_mem = 0
        seg_mem = defaultdict(int)
        for seg, proc in self.queryProcesses.items():
            total_mem += sum(int(p.res_size) for p in proc)
            seg_mem[seg] += sum(int(p.res_size) for p in proc)

        print("Memory usage for " + self.sessionID)
        print("    Total Memory: " + str(total_mem))
        for key in list(seg_mem.keys()):
            print("        seg %s: %d" % (key, seg_mem[key]))


def parsePSOutput(psLines):
    header = psLines[0].split()
    result = []
    num_columns = len(header)
    for line in psLines[1:]:
        vals = line.split(None, num_columns - 1)
        result.append(dict(list(zip(header, vals))))
    return result


def getNextSample(file, start, end):
    line = file.readline().decode('utf-8')
    if not line:
        return None
    while line == '\n':
        line = file.readline().decode('utf-8')
    print(line)
    if not line.startswith('>>>'):
        raise Exception("Processing file failed - did not find timestamp where expected")
    timestamp = datetime.strptime(line, timestamp_string)

    process_this_listing = (start is None or timestamp >= start) and (end is None or timestamp <= end)

    header = file.readline().split()
    header = [x.decode('utf-8') for x in header]
    n_cols = len(header)
    rows = []
    for line in file:
        line = str(line, 'utf-8')
        if line == '\n':
            break
        if process_this_listing:
            rows.append(dict(list(zip(header, line.rstrip().split(None, n_cols - 1)))))
    return (timestamp, rows)


def writeFile(directory, timestamp, entries):
    report_name = os.path.join(directory, timestamp.strftime('%Y%m%d-%H%M%S'))
    for e in entries:
        ProcessListing.process(e)
    with open(report_name, 'w') as f:
        print('{0:20}{1}\n'.format('Polling Timestamp:', timestamp.isoformat()), file=f)
        ProcessListing.report(f)
        PostmasterProcessListing.report(f)
    ProcessListing.reset()


def parseCmdLine():
    usage = 'usage: %prog GZIP_FILE'
    parser = OptionParser(usage, version='%prog version ' + gpverstr)
    parser.add_option('-s', '--start', dest='start_time', help='Start of reporting period', metavar='START')
    parser.add_option('-e', '--end', dest='end_time', help='End of reporting period', metavar='END')
    (option, args) = parser.parse_args()
    if len(args) != 1:
        parser.error("Must speciy 1 file name")
    return (option, args)


if __name__ == '__main__':
    (options, args) = parseCmdLine()
    start = datetime.strptime(options.start_time, datetime_format) if options.start_time else None
    end = datetime.strptime(options.end_time, datetime_format) if options.end_time else None
    f = gzip.open(args[0], 'r')
    try:
        sample = getNextSample(f, start, end)
        while sample:
            if sample[1]:  # We have a non-empty collection of rows
                writeFile('.', *sample)
            sample = getNextSample(f, start, end)
    finally:
        f.close()

