#!/usr/bin/env python3
# Copyright (c) 2014 Pivotal inc.


import datetime
import subprocess
import os
import platform
import pwd
import sys
import re
from contextlib import closing
from optparse import OptionParser
import pgdb
from gppylib.utils import formatInsertValuesList, Escape

gpsd_version = '%prog 1.0'

sysnslist = "('pg_toast', 'pg_bitmapindex', 'pg_temp_1', 'pg_catalog', 'information_schema')"
# turn off optimizer to fall back to planner and speed up statistic queries
# unset search path due to CVE-2018-1058
pgoptions = '-c optimizer=off -c gp_role=utility -c search_path='

def ResultIter(cursor, arraysize=1000):
    'An iterator that uses fetchmany to keep memory usage down'
    while True:
        results = cursor.fetchmany(arraysize)
        if not results:
            break
        for result in results:
            yield result


def getVersion(envOpts):
    cmd = subprocess.Popen('psql --pset footer -Atqc "select version()" template1', shell=True, stdout=subprocess.PIPE, stderr=subprocess.PIPE, env=envOpts)
    if cmd.wait() != 0:
        sys.stderr.write('\nError while trying to find GPDB version.\n\n' + cmd.communicate()[1].decode('ascii') + '\n\n')
        sys.exit(1)
    return cmd.communicate()[0].decode('ascii')

def get_num_segments(cursor):
    query = "select count(*) from gp_segment_configuration where role='p' and content >=0;"
    try:
        cursor.execute(query)
    except pgdb.DatabaseError as e:
        sys.stderr.write('\nError while trying to retrieve number of segments.\n\n' + str(e) + '\n\n')
        sys.exit(1)
    vals = cursor.fetchone()
    return vals[0]

def dumpSchema(connectionInfo, envOpts):
    dump_cmd = 'pg_dump -h {host} -p {port} -U {user} -s -x -O {database}'.format(**connectionInfo)
    p = subprocess.Popen(dump_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() != 0:
        sys.stderr.write('\nError while dumping schema.\n\n' + p.communicate()[1].decode('ascii') + '\n\n')
        sys.exit(1)


def dumpGlobals(connectionInfo, envOpts):
    dump_cmd = 'pg_dumpall -h {host} -p {port} -U {user} -l {database} -g'.format(**connectionInfo)
    p = subprocess.Popen(dump_cmd, shell=True, stderr=subprocess.PIPE, env=envOpts)
    if p.wait() != 0:
        sys.stderr.write('\nError while dumping globals.\n\n' + p.communicate()[1].decode('ascii') + '\n\n')
        sys.exit(1)


def dumpTupleCount(cur):
    stmt = "SELECT pgc.relname, pgn.nspname, pgc.relpages, pgc.reltuples, pgc.relallvisible FROM pg_class pgc, pg_namespace pgn WHERE pgc.relnamespace = pgn.oid and pgn.nspname NOT IN " + \
        sysnslist
    setStmt = '-- Table: {1}\n' \
        'UPDATE pg_class\nSET\n' \
        '{0}\n' \
        'WHERE relname = \'{1}\' AND relnamespace = ' \
        '(SELECT oid FROM pg_namespace WHERE nspname = \'{2}\');\n\n'

    cur.execute(stmt)
    columns = [x[0] for x in cur.description]
    types = ['int', 'real', 'int']
    for vals in ResultIter(cur):
        print(setStmt.format(',\n'.join(['\t%s = %s::%s' %
                                            t for t in zip(columns[2:], vals[2:], types)]), vals[0], vals[1]))


def dumpStats(cur, inclHLL):
    query = 'SELECT pgc.relname, pgn.nspname, pga.attname, pgtn.nspname, pgt.typname, pgs.* ' \
        'FROM pg_class pgc, pg_statistic pgs, pg_namespace pgn, pg_attribute pga, pg_type pgt, pg_namespace pgtn ' \
        'WHERE pgc.relnamespace = pgn.oid and pgn.nspname NOT IN ' + \
        sysnslist + \
        ' and pgc.oid = pgs.starelid ' \
        'and pga.attrelid = pgc.oid ' \
        'and pga.attnum = pgs.staattnum ' \
        'and pga.atttypid = pgt.oid ' \
        'and pgt.typnamespace = pgtn.oid ' \
        'ORDER BY pgc.relname, pgs.staattnum'
    pstring = '--\n' \
        '-- Table: {0}, Attribute: {1}\n' \
        '--\n' \
        'INSERT INTO pg_statistic VALUES (\n' \
        '{2});\n\n'

    cur.execute(query)

    for vals in ResultIter(cur):
        starelid = "'%s.%s'::regclass" % (Escape(vals[1]), Escape(vals[0]))
        rowVals = formatInsertValuesList(vals, starelid, inclHLL)
        print(pstring.format(vals[0], vals[2], ',\n'.join(rowVals)))


def parseCmdLine():
    p = OptionParser(usage='Usage: %prog [options] <database>', version=gpsd_version, conflict_handler="resolve", epilog="WARNING: This tool collects statistics about your data, including most common values, which requires some data elements to be included in the output file. Please review output file to ensure it is within corporate policy to transport the output file.")
    p.add_option('-?', '--help', action='help', help='Show this help message and exit')
    p.add_option('-h', '--host', action='store',
                 dest='host', help='Specify a remote host')
    p.add_option('-p', '--port', action='store',
                 dest='port', help='Specify a port other than 5432')
    p.add_option('-U', '--user', action='store', dest='user',
                 help='Connect as someone other than current user')
    p.add_option('-s', '--stats-only', action='store_false', dest='dumpSchema',
                 default=True, help='Just dump the stats, do not do a schema dump')
    p.add_option('-l', '--hll', action='store_true', dest='dumpHLL',
                 default=False, help='Include HLL stats')
    return p

def main():
    parser = parseCmdLine()
    options, args = parser.parse_args()
    if len(args) != 1:
        parser.error("No database specified")
        exit(1)

    # OK - now let's setup all the arguments & options
    envOpts = os.environ
    db = args[0]
    host = options.host or platform.node()
    user = options.user or ('PGUSER' in envOpts and envOpts['PGUSER']) or pwd.getpwuid(os.geteuid())[0]
    port = options.port or ('PGPORT' in envOpts and envOpts['PGPORT']) or '5432'
    inclSchema = options.dumpSchema
    inclHLL = options.dumpHLL

    envOpts['PGOPTIONS'] = pgoptions

    version = getVersion(envOpts)


    timestamp = datetime.datetime.today()

    connectionInfo = {
    'user': user,
    'host': host,
    'port': port,
    'database': db,
    'options': pgoptions
    }
    num_segments = 0
    with closing(pgdb.connect(**connectionInfo)) as connection:
        with closing(connection.cursor()) as cursor:
            num_segments = get_num_segments(cursor)
    sys.stdout.writelines(['\n-- Greenplum database Statistics Dump',
                           '\n-- Copyright (C) 2007 - 2014 Pivotal'
                           '\n-- Database: ' + db,
                           '\n-- Num Segments: ' + str(num_segments),
                           '\n-- Date:     ' + timestamp.date().isoformat(),
                           '\n-- Time:     ' + timestamp.time().isoformat(),
                           '\n-- CmdLine:  ' + ' '.join(sys.argv),
                           '\n-- Version:  ' + version + '\n\n'])

    sys.stdout.flush()
    if inclSchema:
        dumpGlobals(connectionInfo, envOpts)
        # Now be sure that when we load the rest we are doing it in the right
        # database
        print('\\connect ' + db + '\n')
        sys.stdout.flush()
        dumpSchema(connectionInfo, envOpts)
    else:
        print('\\connect ' + db + '\n')
        sys.stdout.flush()

    # Now we have to explicitly allow editing of these pg_class &
    # pg_statistic tables
    sys.stdout.writelines(['\n-- ',
                           '\n-- Allow system table modifications',
                           '\n-- ',
                           '\nset allow_system_table_mods=true;\n\n'])
    # turn off optimizer when loading stats. Orca adds a bit of overhead, but it's significant when small insrt queries take 1 vs .1ms
    sys.stdout.writelines('set optimizer to off;\n\n')
    sys.stdout.flush()

    try:
        with closing(pgdb.connect(**connectionInfo)) as connection:
            with closing(connection.cursor()) as cursor:
                dumpTupleCount(cursor)
                dumpStats(cursor, inclHLL)
        # Upon success, leave a warning message about data collected
        sys.stderr.writelines(['WARNING: This tool collects statistics about your data, including most common values, '
                               'which requires some data elements to be included in the output file.\n',
                               'Please review output file to ensure it is within corporate policy to transport the output file.\n'])

    except pgdb.DatabaseError as err:  # catch *all* exceptions
        sys.stderr.write('Error while dumping statistics:\n')
        sys.stderr.write(str(err))
        sys.exit(1)


if __name__ == "__main__":
        main()
