#!/usr/bin/env python3
# Line too long            - pylint: disable=C0301
# Invalid name             - pylint: disable=C0103
#
# Copyright (c) Greenplum Inc 2008. All Rights Reserved.
#
from gppylib.mainUtils import getProgramName

import copy
import datetime
import os
import random
import sys
import json
import shutil
import signal
import traceback
from collections import defaultdict
from time import strftime, sleep

try:
    import pg, pgdb

    from gppylib.commands.unix import *
    from gppylib.commands.gp import *
    from gppylib.gparray import GpArray, MODE_NOT_SYNC, STATUS_DOWN
    from gppylib.gpparseopts import OptParser, OptChecker
    from gppylib.gplog import *
    from gppylib.db import catalog
    from gppylib.db import dbconn
    from gppylib.userinput import *
    from gppylib.operations.startSegments import MIRROR_MODE_MIRRORLESS
    from gppylib.system import configurationInterface, configurationImplGpdb
    from gppylib.system.environment import GpCoordinatorEnvironment
    from pgdb import DatabaseError
    from gppylib.gpcatalog import COORDINATOR_ONLY_TABLES_MAPPED
    from gppylib.gpcatalog import COORDINATOR_ONLY_TABLES_NON_MAPPED
    from gppylib.operations.package import SyncPackages
    from gppylib.operations.utils import ParallelOperation
    from gppylib.parseutils import line_reader, check_values, canonicalize_address
    from gppylib.heapchecksum import HeapChecksum
    from gppylib.commands.pg import PgBaseBackup
    from gppylib.mainUtils import ExceptionNoStackTraceNeeded
    from gppylib.operations.update_pg_hba_on_segments import update_pg_hba_on_segments

except ImportError as e:
    sys.exit('ERROR: Cannot import modules.  Please check that you have sourced greenplum_path.sh.  Detail: ' + str(e))

# constants
MAX_PARALLEL_EXPANDS = 96
MAX_BATCH_SIZE = 128

SEGMENT_CONFIGURATION_BACKUP_FILE = "gpexpand.gp_segment_configuration"

DBNAME = 'postgres'

#global var
_gp_expand = None

description = ("""
Adds additional segments to a pre-existing CBDB Array.
""")

_help = ["""
The input file should be a plain text file with a line for each segment
to add with the format:

  <hostname>:<port>:<data_directory>:<dbid>:<content>:<definedprimary>
""",
         """
         If an input file is not specified, gpexpand will ask a series of questions
         and create one.
         """,
         ]

_TODO = ["""

Remaining TODO items:
====================
""",

         """* smarter heuristics on deciding which tables to reorder first. """,

         """* make sure system isn't in "readonly mode" during setup. """,

         """* need a startup validation where we check the status detail
             with gp_distribution_policy and make sure that our book
             keeping matches reality. we don't have a perfect transactional
             model since the tables can be in a different database from
             where the gpexpand schema is kept. """,

         """* currently requires that GPHOME and PYTHONPATH be set on all of the remote hosts of
              the system.  should get rid of this requirement. """
         ]

_usage = """[-f hosts_file]

gpexpand -i input_file [-B batch_size] [-t segment_tar_dir] [-S]

gpexpand [-d duration[hh][:mm[:ss]] | [-e 'YYYY-MM-DD hh:mm:ss']]
         [-a] [-n parallel_processes]

gpexpand -r

gpexpand -c

gpexpand -? | -h | --help | --verbose | -v"""

EXECNAME = os.path.split(__file__)[-1]


# ----------------------- Command line option parser ----------------------

def parseargs():
    parser = OptParser(option_class=OptChecker,
                       description=' '.join(description.split()),
                       version='%prog version $Revision$')
    parser.setHelp(_help)
    parser.set_usage('%prog ' + _usage)
    parser.remove_option('-h')

    parser.add_option('-c', '--clean', action='store_true',
                      help='remove the expansion schema.')
    parser.add_option('-r', '--rollback', action='store_true',
                      help='rollback failed expansion setup.')
    parser.add_option('-a', '--analyze', action='store_true',
                      help='Analyze the expanded table after redistribution.')
    parser.add_option('-d', '--duration', type='duration', metavar='[h][:m[:s]]',
                      help='duration from beginning to end.')
    parser.add_option('-e', '--end', type='datetime', metavar='datetime',
                      help="ending date and time in the format 'YYYY-MM-DD hh:mm:ss'.")
    parser.add_option('-i', '--input', dest="filename",
                      help="input expansion configuration file.", metavar="FILE")
    parser.add_option('-f', '--hosts-file', metavar='<hosts_file>',
                      help='file containing new host names used to generate input file')
    parser.add_option('-B', '--batch-size', type='int', default=16, metavar="<batch_size>",
                      help='Expansion configuration batch size. Valid values are 1-%d' % MAX_BATCH_SIZE)
    parser.add_option('-n', '--parallel', type="int", default=1, metavar="<parallel_processes>",
                      help='number of tables to expand at a time. Valid values are 1-%d.' % MAX_PARALLEL_EXPANDS)
    parser.add_option('-v', '--verbose', action='store_true',
                      help='debug output.')
    parser.add_option('-S', '--simple-progress', action='store_true',
                      help='show simple progress.')
    parser.add_option('-t', '--tardir', default='.', metavar="FILE",
                      help='Tar file directory.')
    parser.add_option('-h', '-?', '--help', action='help',
                      help='show this help message and exit.')
    parser.add_option('-s', '--silent', action='store_true',
                      help='Do not prompt for confirmation to proceed on warnings')
    parser.add_option('', '--hba-hostnames', action='store_true', default=False,
                      help='use hostnames instead of CIDR in pg_hba.conf')
    parser.add_option('--usage', action="briefhelp")

    parser.set_defaults(verbose=False, filters=[], slice=(None, None))

    # Parse the command line arguments
    (options, args) = parser.parse_args()
    return options, args, parser

def validate_options(options, args, parser):
    if len(args) > 0:
        logger.error('Unknown argument %s' % args[0])
        parser.exit()

    # -n sanity check
    if options.parallel > MAX_PARALLEL_EXPANDS or options.parallel < 1:
        logger.error('Invalid argument.  parallel value must be >= 1 and <= %d' % MAX_PARALLEL_EXPANDS)
        parser.print_help()
        parser.exit()

    proccount = os.environ.get('GP_MGMT_PROCESS_COUNT')
    if options.batch_size == 16 and proccount is not None:
        options.batch_size = int(proccount)

    if options.batch_size < 1 or options.batch_size > 128:
        logger.error('Invalid argument.  -B value must be >= 1 and <= %s' % MAX_BATCH_SIZE)
        parser.print_help()
        parser.exit()

    # OptParse can return date instead of datetime so we might need to convert
    if options.end and not isinstance(options.end, datetime.datetime):
        options.end = datetime.datetime.combine(options.end, datetime.time(0))

    if options.end and options.end < datetime.datetime.now():
        logger.error('End time occurs in the past')
        parser.print_help()
        parser.exit()

    if options.end and options.duration:
        logger.warn('Both end and duration options were given.')
        # Both a duration and an end time were given.
        if options.end > datetime.datetime.now() + options.duration:
            logger.warn('The duration argument will be used for the expansion end time.')
            options.end = datetime.datetime.now() + options.duration
        else:
            logger.warn('The end argument will be used for the expansion end time.')
    elif options.duration:
        options.end = datetime.datetime.now() + options.duration

    # -c and -r options are mutually exclusive
    if options.rollback and options.clean:
        rollbackOpt = "--rollback" if "--rollback" in sys.argv else "-r"
        cleanOpt = "--clean" if "--clean" in sys.argv else "-c"
        logger.error("%s and %s options cannot be specified together." % (rollbackOpt, cleanOpt))
        parser.exit()

    try:
        options.coordinator_data_directory = get_coordinatordatadir()
        options.gphome = get_gphome()
    except GpError as msg:
        logger.error(msg)
        parser.exit()

    if not os.path.exists(options.coordinator_data_directory):
        logger.error('Coordinator data directory %s does not exist.' % options.coordinator_data_directory)
        parser.exit()

    return options, args


# -------------------------------------------------------------------------
# process information functions
def create_pid_file(coordinator_data_directory):
    """Creates gpexpand pid file"""
    try:
        fp = open(coordinator_data_directory + '/gpexpand.pid', 'w')
        fp.write(str(os.getpid()))
    except IOError:
        raise
    finally:
        if fp: fp.close()


def remove_pid_file(coordinator_data_directory):
    """Removes gpexpand pid file"""
    try:
        os.unlink(coordinator_data_directory + '/gpexpand.pid')
    except:
        pass


def is_gpexpand_running(coordinator_data_directory):
    """Checks if there is another instance of gpexpand running"""
    is_running = False
    try:
        fp = open(coordinator_data_directory + '/gpexpand.pid', 'r')
        pid = int(fp.readline().strip())
        fp.close()
        is_running = check_pid(pid)
    except IOError:
        pass
    except Exception:
        raise

    return is_running


def gpexpand_status_file_exists(coordinator_data_directory):
    """Checks if gpexpand.pid exists"""
    return os.path.exists(coordinator_data_directory + '/gpexpand.status')


def is_cluster_up_and_balanced(dburl):
    count = -1
    sql = "select count(*) from gp_segment_configuration where status <> 'u' or preferred_role <> role;"
    try:
        with closing(dbconn.connect(dburl, encoding='UTF8')) as conn:
            count = dbconn.querySingleton(conn, sql)
    except Exception as e:
        raise Exception("failed to query cluster role check: %s" % str(e))

    return count == 0

# -------------------------------------------------------------------------
# expansion schema

undone_status = "NOT STARTED"
start_status = "IN PROGRESS"
done_status = "COMPLETED"
does_not_exist_status = 'NO LONGER EXISTS'

create_schema_sql = "CREATE SCHEMA gpexpand"
drop_schema_sql = "DROP SCHEMA IF EXISTS gpexpand CASCADE"

status_table_sql = """CREATE TABLE gpexpand.status
                        ( status text,
                          updated timestamp ) """

status_detail_table_sql = """CREATE TABLE gpexpand.status_detail
                        ( dbname text,
                          fq_name text,
                          table_oid oid,
                          root_partition_oid oid,
                          rank int,
                          external_writable bool,
                          status text,
                          expansion_started timestamp,
                          expansion_finished timestamp,
                          source_bytes numeric ) distributed by (table_oid)"""
# gpexpand views
progress_view_simple_sql = """CREATE VIEW gpexpand.expansion_progress AS
SELECT
    CASE status
        WHEN '%s' THEN 'Tables Expanded'
        WHEN '%s' THEN 'Tables Left'
    END AS Name,
    count(*)::text AS Value
FROM gpexpand.status_detail GROUP BY status""" % (done_status, undone_status)

progress_view_sql = """CREATE VIEW gpexpand.expansion_progress AS
SELECT
    CASE status
        WHEN '%s' THEN 'Tables Expanded'
        WHEN '%s' THEN 'Tables Left'
        WHEN '%s' THEN 'Tables In Progress'
    END AS Name,
    count(*)::text AS Value
FROM gpexpand.status_detail GROUP BY status

UNION

SELECT
    CASE status
        WHEN '%s' THEN 'Bytes Done'
        WHEN '%s' THEN 'Bytes Left'
        WHEN '%s' THEN 'Bytes In Progress'
    END AS Name,
    SUM(source_bytes)::text AS Value
FROM gpexpand.status_detail GROUP BY status

UNION

SELECT
    'Estimated Expansion Rate' AS Name,
    (SUM(source_bytes) / (1 + extract(epoch FROM (max(expansion_finished) - min(expansion_started)))) / 1024 / 1024)::text || ' MB/s' AS Value
FROM gpexpand.status_detail
WHERE status = '%s'
AND
expansion_started > (SELECT updated FROM gpexpand.status WHERE status = '%s' ORDER BY updated DESC LIMIT 1)

UNION

SELECT
'Estimated Time to Completion' AS Name,
CAST((SUM(source_bytes) / (
SELECT 1 + SUM(source_bytes) / (1 + (extract(epoch FROM (max(expansion_finished) - min(expansion_started)))))
FROM gpexpand.status_detail
WHERE status = '%s'
AND
expansion_started > (SELECT updated FROM gpexpand.status WHERE status = '%s' ORDER BY
updated DESC LIMIT 1)))::text || ' seconds' as interval)::text AS Value
FROM gpexpand.status_detail
WHERE status = '%s'
  OR status = '%s'""" % (done_status, undone_status, start_status,
                         done_status, undone_status, start_status,
                         done_status,
                         'EXPANSION STARTED',
                         done_status,
                         'EXPANSION STARTED',
                         start_status, undone_status)

# -------------------------------------------------------------------------
class InvalidStatusError(Exception): pass


class ValidationError(Exception): pass


# -------------------------------------------------------------------------
class GpExpandStatus():
    """Class that manages gpexpand status file.

    The status file is placed in the coordinator data directory on both the coordinator and
    the standby coordinator.  it's used to keep track of where we are in the progression.
    """

    def __init__(self, logger, coordinator_data_directory, coordinator_mirror=None):
        self.logger = logger

        self._status_values = {'UNINITIALIZED': 1,
                               'EXPANSION_PREPARE_STARTED': 2,
                               'BUILD_SEGMENT_TEMPLATE_STARTED': 3,
                               'BUILD_SEGMENT_TEMPLATE_DONE': 4,
                               'BUILD_SEGMENTS_STARTED': 5,
                               'BUILD_SEGMENTS_DONE': 6,
                               'UPDATE_CATALOG_STARTED': 7,
                               'UPDATE_CATALOG_DONE': 8,
                               'SETUP_EXPANSION_SCHEMA_STARTED': 9,
                               'SETUP_EXPANSION_SCHEMA_DONE': 10,
                               'PREPARE_EXPANSION_SCHEMA_STARTED': 11,
                               'PREPARE_EXPANSION_SCHEMA_DONE': 12,
                               'EXPANSION_PREPARE_DONE': 13
                               }
        self._status = []
        self._status_info = []
        self._coordinator_data_directory = coordinator_data_directory
        self._coordinator_mirror = coordinator_mirror
        self._status_filename = coordinator_data_directory + '/gpexpand.status'
        if coordinator_mirror:
            self._status_standby_filename = coordinator_mirror.getSegmentDataDirectory() \
                                            + '/gpexpand.status'
            self._segment_configuration_standby_filename = coordinator_mirror.getSegmentDataDirectory() \
                                            + '/' + SEGMENT_CONFIGURATION_BACKUP_FILE
        self._fp = None
        self._temp_dir = None
        self._input_filename = None
        self._gp_segment_configuration_backup = None

        if os.path.exists(self._status_filename):
            self._read_status_file()

    def _read_status_file(self):
        """Reads in an existing gpexpand status file"""
        self.logger.debug("Trying to read in a pre-existing gpexpand status file")
        try:
            self._fp = open(self._status_filename, 'a+')
            self._fp.seek(0)

            for line in self._fp:
                (status, status_info) = line.rstrip().split(':')
                if status == 'BUILD_SEGMENT_TEMPLATE_STARTED':
                    self._temp_dir = status_info
                elif status == 'BUILD_SEGMENTS_STARTED':
                    self._seg_tarfile = status_info
                elif status == 'BUILD_SEGMENTS_DONE':
                    self._number_new_segments = status_info
                elif status == 'EXPANSION_PREPARE_STARTED':
                    self._input_filename = status_info
                elif status == 'UPDATE_CATALOG_STARTED':
                    self._gp_segment_configuration_backup = status_info

                self._status.append(status)
                self._status_info.append(status_info)
        except IOError:
            raise

        if self._status[-1] not in self._status_values:
            raise InvalidStatusError('Invalid status file.  Unknown status %s' % self._status)

    def create_status_file(self):
        """Creates a new gpexpand status file"""
        try:
            self._fp = open(self._status_filename, 'w')
            self._fp.write('UNINITIALIZED:None\n')
            self._fp.flush()
            os.fsync(self._fp)
            self._status.append('UNINITIALIZED')
            self._status_info.append('None')
        except IOError:
            raise

        if self._coordinator_mirror:
            self._sync_status_file()

    def _sync_status_file(self):
        """Syncs the gpexpand status file with the coordinator mirror"""
        cpCmd = Rsync('gpexpand copying status file to coordinator mirror',
                    srcFile=self._status_filename,
                    dstFile=self._status_standby_filename,
                    dstHost=self._coordinator_mirror.getSegmentHostName())
        cpCmd.run(validateAfter=True)

    def set_status(self, status, status_info=None, force=False):
        """Sets the current status.  gpexpand status must be set in
           proper order.  Any out of order status result in an
           InvalidStatusError exception. But if force is True, setting
           status out of order is allowded"""
        self.logger.debug("Transitioning from %s to %s" % (self._status[-1], status))

        if not self._fp:
            raise InvalidStatusError('The status file is invalid and cannot be written to')
        if status not in self._status_values:
            raise InvalidStatusError('%s is an invalid gpexpand status' % status)
        # Only allow state transitions forward or backward 1
        # If force is True, allow write any status
        if self._status and \
                        self._status_values[status] != self._status_values[self._status[-1]] + 1 and \
                        not force:
            raise InvalidStatusError('Invalid status transition from %s to %s' % (self._status[-1], status))
        self._fp.write('%s:%s\n' % (status, status_info))
        self._fp.flush()
        os.fsync(self._fp)
        self._status.append(status)
        self._status_info.append(status_info)
        if self._coordinator_mirror:
            self._sync_status_file()

    def get_current_status(self):
        """Gets the current status that has been written to the gpexpand
           status file"""
        if (len(self._status) > 0 and len(self._status_info) > 0):
            return (self._status[-1], self._status_info[-1])
        else:
            return (None, None)

    def get_status_history(self):
        """Gets the full status history"""
        return list(zip(self._status, self._status_info))

    def remove_status_file(self):
        """Closes and removes the gpexand status file"""
        if self._fp:
            self._fp.close()
            self._fp = None
        if os.path.exists(self._status_filename):
            os.unlink(self._status_filename)
        if self._coordinator_mirror:
            RemoveFile.remote('gpexpand coordinator mirror status file cleanup',
                              self._coordinator_mirror.getSegmentHostName(),
                              self._status_standby_filename)

    def remove_segment_configuration_backup_file(self):
        """ Remove the segment configuration backup file """
        self.logger.debug("Removing segment configuration backup file")
        if self._gp_segment_configuration_backup != None and os.path.exists(
                self._gp_segment_configuration_backup) == True:
            os.unlink(self._gp_segment_configuration_backup)
        if self._coordinator_mirror:
            RemoveFile.remote('gpexpand coordinator mirror segment configuration backup file cleanup',
                              self._coordinator_mirror.getSegmentHostName(),
                              self._segment_configuration_standby_filename)

    def sync_segment_configuration_backup_file(self):
        """ Sync the segment configuration backup file to standby """
        if self._coordinator_mirror:
            self.logger.debug("Sync segment configuration backup file")
            cpCmd = Rsync('gpexpand copying segment configuration backup file to coordinator mirror',
                        srcFile=self._gp_segment_configuration_backup,
                        dstFile=self._segment_configuration_standby_filename,
                        dstHost=self._coordinator_mirror.getSegmentHostName())
            cpCmd.run(validateAfter=True)

    def get_temp_dir(self):
        """Gets temp dir that was used during template creation"""
        return self._temp_dir

    def get_input_filename(self):
        """Gets input file that was used by expansion setup"""
        return self._input_filename

    def get_seg_tarfile(self):
        """Gets tar file that was used during template creation"""
        return self._seg_tarfile

    def get_number_new_segments(self):
        """ Gets the number of new segments added """
        return self._number_new_segments

    def get_gp_segment_configuration_backup(self):
        """Gets the filename of the gp_segment_configuration backup file
        created during expansion setup"""
        return self._gp_segment_configuration_backup

    def set_gp_segment_configuration_backup(self, filename):
        """Sets the filename of the gp_segment_configuration backup file"""
        self._gp_segment_configuration_backup = filename

    def can_rollback(self, status):
        """Return if it can rollback under current status"""
        if int(self._status_values[status]) >= int(self._status_values['UPDATE_CATALOG_DONE']):
            return False
        return True

    def rewind(self, status, status_info=None):
        """
        Rewind to a particular status.
        """
        self.logger.debug("Rewind the status to %s" % status)

        if not self._fp:
            self._fp = open(self._status_filename, 'a+')
        self.set_status(status, status_info, True)


# -------------------------------------------------------------------------

class ExpansionError(Exception): pass


class SegmentTemplateError(Exception): pass


# -------------------------------------------------------------------------
class SegmentTemplate:
    """Class for creating, distributing and deploying new segments to an
    existing CBDB array"""

    def __init__(self, logger, statusLogger, pool,
                 gparray, coordinatorDataDirectory,
                 dburl, conn, tempDir, batch_size, is_hba_hostnames,
                 segTarDir='.', schemaTarFile='gpexpand_schema.tar'):
        self.logger = logger
        self.statusLogger = statusLogger
        self.pool = pool
        self.gparray = gparray
        self.tempDir = tempDir
        self.batch_size = batch_size
        self.dburl = dburl
        self.conn = conn
        self.coordinatorDataDirectory = coordinatorDataDirectory
        self.schema_tar_file = schemaTarFile
        self.maxDbId = self.gparray.get_max_dbid()
        self.segTarDir = segTarDir
        self.segTarFile = os.path.join(segTarDir, self.schema_tar_file)
        self.isHbaHostnames = is_hba_hostnames

        hosts = []
        for seg in self.gparray.getExpansionSegDbList():
            hosts.append(seg.getSegmentHostName())
        self.hosts = SegmentTemplate.consolidate_hosts(pool, hosts)
        logger.debug('Hosts: %s' % self.hosts)

    @staticmethod
    def consolidate_hosts(pool, hosts):
        tmpHosts = {}
        consolidatedHosts = []

        for host in hosts:
            tmpHosts[host] = 0

        for host in list(tmpHosts.keys()):
            hostnameCmd = Hostname('gpexpand associating hostnames with segments', ctxt=REMOTE, remoteHost=host)
            pool.addCommand(hostnameCmd)

        pool.join()

        finished_cmds = pool.getCompletedItems()

        for cmd in finished_cmds:
            if not cmd.was_successful():
                raise SegmentTemplateError(cmd.get_results())
            if cmd.get_hostname() not in consolidatedHosts:
                logger.debug('Adding %s to host list' % cmd.get_hostname())
                consolidatedHosts.append(cmd.get_hostname())

        return consolidatedHosts

    def build_segment_template(self, newTableSpaceInfo=None):
        """Builds segment template tar file"""
        self.statusLogger.set_status('BUILD_SEGMENT_TEMPLATE_STARTED', self.tempDir)
        # build segment template should consider tablespace files
        self._create_template(newTableSpaceInfo)
        self._fixup_template()
        self._tar_template()
        self.statusLogger.set_status('BUILD_SEGMENT_TEMPLATE_DONE')

    def build_new_segments(self):
        """Deploys the template tar file and configures the new segments"""
        self.statusLogger.set_status('BUILD_SEGMENTS_STARTED', self.segTarFile)
        self._distribute_template()
        # FIXME: truncate the qd only tables' underlying files instead of delete the tuples
        self._configure_new_segments()
        numNewSegments = len(self.gparray.getExpansionSegDbList())
        self.statusLogger.set_status('BUILD_SEGMENTS_DONE', numNewSegments)

    def _create_template(self, newTableSpaceInfo=None):
        """Creates the schema template that is used by new segments"""
        self.logger.info('Creating segment template')

        self._select_src_segment()

        self.oldSegCount = self.gparray.get_segment_count()

        self.conn.close()

        # pg_basebackup in Cloudberry need a parameter 'target_gp_dbid'
        # it uses it to create tablespace-related data dirs in corresponding
        # tablespace locations. Here we are using pg_basebackup to create
        # a template for gpexpand so  we can just provide a dummyDBID
        # to make sure the name is different from any dbids in the system.
        dummyDBID = self._gen_DummyDBID()

        try:
            coordinatorSeg = self.gparray.coordinator
            cmd = PgBaseBackup(target_datadir=self.tempDir,
                               source_host=coordinatorSeg.getSegmentAddress(),
                               source_port=str(coordinatorSeg.getSegmentPort()),
                               recovery_mode=False,
                               target_gp_dbid=dummyDBID)
            cmd.run(validateAfter=True)
        except Exception as msg:
            raise SegmentTemplateError(msg)

        # new segments' tablespace info is loaded from tablespace
        # input files if provided. If newTableSpaceInfo is None,
        # then there are no user-created tablespaces in the system,
        # no need to consider tablespace problems in gpexpand.
        if newTableSpaceInfo:
            self._handle_tablespace_template(dummyDBID, newTableSpaceInfo)

    def _handle_tablespace_template(self, dummyDBID, newTableSpaceInfo):
        """
        If there are user-created tablespaces in Cloudberry cluster, we
        have to pack them into the template. The logic here contains two
        steps:
          1. move the tablespace files generated by pg_basebackup into the
             directory `tempDir/pg_tblspc/dumps`.
          2. save the restore-paths of tablespace files on newsegments in
             a json file `tempDir/pg_tblspc/newTableSpaceInfo.json`

        newTableSpaceInfo is a python dict, its type spec is:
          newTableSpaceInfo :: {
                                  "names" : [ name::string ],
                                  "oids"  : [ oid::string ],
                                  "details" : {
                                                   dbid::string : [ location::string ]
                                              }
                               }
          newTableSpaceInfo[names] and newTableSpaceInfo[oids] are tablespace infos
          that are in the same order.
          newTableSpaceInfo[dbid] is a list of locations in the same order of oids.
        """
        coordinator_tblspc_dir = self.gparray.coordinator.getSegmentTableSpaceDirectory()
        tbcspc_oids = os.listdir(coordinator_tblspc_dir)

        # tablespace_template_dir is the path we store tablespace files generated
        # by pg_basebackup. Its directory structure is:
        #   tablespace_template_dir
        #     |__ oid1
        #     |      |__ tablespace_file1_db_dumpdbid
        #     |
        #     |__ oid2
        #     |      |__ tablespace_file2_db_dumpdbid
        #     ...
        tablespace_template_dir = os.path.join(self.tempDir,
                                               "pg_tblspc",
                                               "dumps")
        os.mkdir(tablespace_template_dir)

        for tbcspc_oid in tbcspc_oids:
            symlink_path = os.path.join(coordinator_tblspc_dir, tbcspc_oid)
            target_path = os.readlink(symlink_path)
            os.mkdir(os.path.join(tablespace_template_dir, tbcspc_oid))
            # the target name for copytree does not impact anything
            shutil.copytree(target_path,
                        os.path.join(tablespace_template_dir, tbcspc_oid, str(dummyDBID)))
            shutil.rmtree(os.path.join(os.path.dirname(target_path),
                                       str(dummyDBID)))

        with open(os.path.join(self.tempDir,
                               "pg_tblspc",
                               "newTableSpaceInfo.json"), "w") as f:
            json.dump(newTableSpaceInfo, f)

    def _gen_DummyDBID(self):
        """gen a random int that surely beyond the possible dbid range"""
        return random.randint(40960, 81920)

    def _select_src_segment(self):
        """Gets a segment to use as a source for pg_hba.conf
        and postgresql.conf files"""
        segPair = self.gparray.segmentPairs[0]
        if segPair.primaryDB.valid:
            self.srcSegHostname = segPair.primaryDB.getSegmentHostName()
            self.srcSegDataDir = segPair.primaryDB.getSegmentDataDirectory()
        elif segPair.mirrorDB and segPair.mirrorDB.valid:
            self.srcSegHostname = segPair.mirrorDB.getSegmentHostName()
            self.srcSegDataDir = segPair.mirrorDB.getSegmentDataDirectory()
        else:
            raise SegmentTemplateError("no valid segdb for content=0 to use as a template")

    def _distribute_template(self):
        """Distributes the template tar file to the new segments and expands it"""
        self.logger.info('Distributing template tar file to new hosts')

        self._distribute_tarfile()

    def _distribute_tarfile(self):
        """Distributes template tar file to hosts"""
        for host in self.hosts:
            logger.debug('Copying tar file to %s' % host)
            cpCmd = Rsync(name='gpexpand distribute tar file to new hosts',
                        srcFile=self.schema_tar_file,
                        dstFile=self.segTarDir,
                        dstHost=host)
            self.pool.addCommand(cpCmd)

        self.pool.join()
        self.pool.check_results()

    def _start_new_primary_segments(self):
        newSegments = self.gparray.getExpansionSegDbList()
        for seg in newSegments:
            if seg.isSegmentMirror():
                continue
            """ Start all the new segments in utilty mode. """
            segStartCmd = SegmentStart(
                name="Starting new segment dbid %s on host %s." % (str(seg.getSegmentDbId()), seg.getSegmentHostName())
                , gpdb=seg
                , numContentsInCluster=0  # Starting seg on it's own.
                , era=None
                , mirrormode=MIRROR_MODE_MIRRORLESS
                , utilityMode=True
                , specialMode='convertCoordinatorDataDirToSegment'
                , ctxt=REMOTE
                , remoteHost=seg.getSegmentHostName()
                , pg_ctl_wait=True
                , timeout=SEGMENT_TIMEOUT_DEFAULT)
            self.pool.addCommand(segStartCmd)
        self.pool.join()
        self.pool.check_results()

    def _stop_new_primary_segments(self):
        newSegments = self.gparray.getExpansionSegDbList()
        for seg in newSegments:
            if seg.isSegmentMirror() == True:
                continue
            segStopCmd = SegmentStop(
                name="Stopping new segment dbid %s on host %s." % (str(seg.getSegmentDbId), seg.getSegmentHostName())
                , dataDir=seg.getSegmentDataDirectory()
                , mode='smart'
                , nowait=False
                , ctxt=REMOTE
                , remoteHost=seg.getSegmentHostName()
            )
            self.pool.addCommand(segStopCmd)
        self.pool.join()
        self.pool.check_results()

    def _configure_new_segments(self):
        """Configures new segments.  This includes modifying the postgresql.conf file
        and setting up the gp_id table"""

        self.logger.info('Configuring new segments (primary)')
        new_segment_info = ConfigureNewSegment.buildSegmentInfoForNewSegment(self.gparray.getExpansionSegDbList(),
                                                                             primaryMirror='primary')
        self.logger.info(new_segment_info)
        for host in iter(new_segment_info):
            segCfgCmd = ConfigureNewSegment(name='gpexpand configure new segments',
                                            confinfo=new_segment_info[host],
                                            logdir=get_logger_dir(),
                                            tarFile=self.segTarFile, newSegments=True,
                                            verbose=gplog.logging_is_verbose(), batchSize=self.batch_size,
                                            ctxt=REMOTE, remoteHost=host)
            self.pool.addCommand(segCfgCmd)

        self.pool.join()
        self.pool.check_results()

        # Map primary and mirror hostnames to expanded segment's contentid
        expanded_host_content = defaultdict(lambda: [])
        for seg in self.gparray.getExpansionSegDbList():
            expanded_host_content[seg.getSegmentContentId()].append(seg.getSegmentHostName())

        # Now update the primary's hba config with the corresponding primary
        # and mirror information
        update_pg_hba_on_segments(self.gparray, self.isHbaHostnames, self.batch_size, expanded_host_content)

        self._start_new_primary_segments()
        self._stop_new_primary_segments()

    def _fixup_template(self):
        """Copies postgresql.conf and pg_hba.conf files from a valid segment on the system.
        Then modifies the template copy of pg_hba.conf"""

        self.logger.info('Copying postgresql.conf from existing segment into template')

        localHostname = self.gparray.coordinator.getSegmentHostName()
        cmdName = 'gpexpand copying postgresql.conf to %s:%s/postgresql.conf' \
                  % (self.srcSegHostname, self.srcSegDataDir)
        cpCmd = Rsync(name=cmdName, srcFile=self.srcSegDataDir + '/postgresql.conf',
            dstFile=self.tempDir, dstHost=localHostname, ctxt=REMOTE,
            remoteHost=self.srcSegHostname)
        cpCmd.run(validateAfter=True)

        self.logger.info('Copying pg_hba.conf from existing segment into template')
        cmdName = 'gpexpand copy pg_hba.conf to %s:%s/pg_hba.conf' \
                  % (self.srcSegHostname, self.srcSegDataDir)
        cpCmd = Rsync(name=cmdName, srcFile=self.srcSegDataDir + '/pg_hba.conf',
                    dstFile=self.tempDir, dstHost=localHostname,ctxt=REMOTE,
                    remoteHost=self.srcSegHostname)
        cpCmd.run(validateAfter=True)

    def _tar_template(self):
        """Tars up the template files"""
        self.logger.info('Creating schema tar file')
        tarCmd = CreateTar('gpexpand tar segment template', self.tempDir, self.schema_tar_file)
        tarCmd.run(validateAfter=True)

    @staticmethod
    def cleanup_build_segment_template(tarFile, tempDir):
        """Reverts the work done by build_segment_template.  Deletes the temp
        directory and local tar file"""
        rmCmd = RemoveDirectory('gpexpand remove temp dir: %s' % tempDir, tempDir)
        rmCmd.run(validateAfter=True)
        rmCmd = RemoveFile('gpexpand remove segment template file', tarFile)
        rmCmd.run(validateAfter=True)

    @staticmethod
    def cleanup_build_new_segments(pool, tarFile, gparray, hosts=None, removeDataDirs=False):
        """Cleans up the work done by build_new_segments.  Deletes remote tar files and
        and removes remote data directories"""

        if not hosts:
            hosts = []
            for seg in gparray.getExpansionSegDbList():
                hosts.append(seg.getSegmentHostName())

        # Remove template tar file
        for host in hosts:
            rmCmd = RemoveFile('gpexpand remove segment template file on host: %s' % host,
                               tarFile, ctxt=REMOTE, remoteHost=host)
            pool.addCommand(rmCmd)

        if removeDataDirs:
            for seg in gparray.getExpansionSegDbList():
                hostname = seg.getSegmentHostName()
                datadir = seg.getSegmentDataDirectory()
                rmCmd = RemoveDirectory('gpexpand remove new segment data directory: %s:%s' % (hostname, datadir),
                                        datadir, ctxt=REMOTE, remoteHost=hostname)
                pool.addCommand(rmCmd)
        pool.join()
        pool.check_results()

    def cleanup(self):
        """Cleans up temporary files from the local system and new segment hosts"""

        self.logger.info('Cleaning up temporary template files')
        SegmentTemplate.cleanup_build_segment_template(self.schema_tar_file, self.tempDir)
        SegmentTemplate.cleanup_build_new_segments(self.pool, self.segTarFile, self.gparray, self.hosts)


# ------------------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------------------
class NewSegmentInput:
    def __init__(self, hostname, address, port, datadir, dbid, contentId, role):
        self.hostname = hostname
        self.address = address
        self.port = port
        self.datadir = datadir
        self.dbid = dbid
        self.contentId = contentId
        self.role = role


# ------------------------------------------------------------------------------------------------------
# ------------------------------------------------------------------------------------------------------
class gpexpand:
    def __init__(self, logger, gparray, dburl, options, parallel=1):
        self.pastThePointOfNoReturn = False
        self.logger = logger
        self.dburl = dburl
        self.options = options
        self.numworkers = parallel
        self.gparray = gparray
        self.conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8', allowSystemTableMods=True)
        self.old_segments = self.gparray.getSegDbList()

        datadir = self.gparray.coordinator.getSegmentDataDirectory()
        self.statusLogger = GpExpandStatus(logger=logger,
                                           coordinator_data_directory=datadir,
                                           coordinator_mirror=self.gparray.standbyCoordinator)

        # Adjust batch size if it's too high given the number of segments
        seg_count = len(self.old_segments)
        if self.options.batch_size > seg_count:
            self.options.batch_size = seg_count
        self.pool = WorkerPool(numWorkers=self.options.batch_size)

        self.tempDir = self.statusLogger.get_temp_dir()
        if not self.tempDir:
            self.tempDir = createTempDirectoryName(self.options.coordinator_data_directory, "gpexpand")
        self.queue = None
        self.segTemplate = None

    @staticmethod
    def prepare_gpdb_state(logger, dburl, options):
        """ Gets GPDB in the appropriate state for an expansion.
        This state will depend on if this is a new expansion setup,
        a continuation of a previous expansion or a rollback """
        # Get the database in the expected state for the expansion/rollback
        # If gpexpand status file exists ,the last run of gpexpand didn't finish properly
        status_file_exists = os.path.exists(options.coordinator_data_directory + '/gpexpand.status')
        gpexpand_db_status = None

        if not status_file_exists:
            logger.info('Querying gpexpand schema for current expansion state')
            try:
                gpexpand_db_status = gpexpand.get_status_from_db(dburl, options)
            except Exception as e:
                raise Exception('Error while trying to query the gpexpand schema: %s' % e)
            logger.debug('Expansion status returned is %s' % gpexpand_db_status)

        return gpexpand_db_status

    @staticmethod
    def get_status_from_db(dburl, options):
        """Gets gpexpand status from the gpexpand schema"""
        status_conn = None
        gpexpand_db_status = None
        if get_local_db_mode(options.coordinator_data_directory) == 'NORMAL':
            try:
                status_conn = dbconn.connect(dburl, encoding='UTF8')
                # Get the last status entry
                cursor = dbconn.query(status_conn, 'SELECT status FROM gpexpand.status ORDER BY updated DESC LIMIT 1')
                if cursor.rowcount == 1:
                    gpexpand_db_status = cursor.fetchone()[0]

            except Exception:
                # expansion schema doesn't exists or there was a connection failure.
                pass
            finally:
                if status_conn: status_conn.close()

        # make sure gpexpand schema doesn't exist since it wasn't in DB provided
        if not gpexpand_db_status:
            """
            MPP-14145 - If there's no discernible status, the schema must not exist.

            The checks in get_status_from_db claim to look for existence of the 'gpexpand' schema, but more accurately they're
            checking for non-emptiness of the gpexpand.status table. If the table were empty, but the schema did exist, gpexpand would presume
            a new expansion was taking place and it would try to CREATE SCHEMA later, which would fail. So, here, if this is the case, we error out.

            Note: -c/--clean will not necessarily work either, as it too has assumptions about the non-emptiness of the gpexpand schema.
            """
            conn = dbconn.connect(dburl, encoding='UTF8', utility=True)
            try:
                count = dbconn.querySingleton(conn,
                                                   "SELECT count(n.nspname) FROM pg_catalog.pg_namespace n WHERE n.nspname = 'gpexpand'")
                if count > 0:
                    raise ExpansionError(
                        "Existing expansion state could not be determined, but a gpexpand schema already exists. Cannot proceed.")
            finally:
                conn.close()

        return gpexpand_db_status

    def validate_max_connections(self):
        try:
            conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8')
            max_connections = int(catalog.getSessionGUC(conn, 'max_connections'))
        except DatabaseError as ex:
            if self.options.verbose:
                logger.exception(ex)
            logger.error('Failed to check max_connections GUC')
            raise ex
        finally:
            conn.close()

        if max_connections < self.options.parallel * 2 + 1:
            self.logger.error('max_connections is too small to expand %d tables at' % self.options.parallel)
            self.logger.error('a time.  This will lead to connection errors.  Either')
            self.logger.error('reduce the value for -n passed to gpexpand or raise')
            self.logger.error('max_connections in postgresql.conf')
            return False

        return True

    def rollback(self, dburl):
        """Rolls back and expansion setup that didn't successfully complete"""
        status_history = self.statusLogger.get_status_history()
        if not status_history:
            raise ExpansionError('No status history to rollback.')

        if (status_history[-1])[0] == 'EXPANSION_PREPARE_DONE':
            raise ExpansionError('Expansion preparation complete.  Nothing to rollback')

        for status in reversed(status_history):
            if not self.statusLogger.can_rollback(status[0]):
                raise ExpansionError('Catalog has been changed, the cluster can not rollback.')

            elif status[0] == 'BUILD_SEGMENT_TEMPLATE_STARTED':
                self.logger.info('Rolling back segment template build')
                SegmentTemplate.cleanup_build_segment_template('gpexpand_schema.tar', status[1])

            elif status[0] == 'BUILD_SEGMENTS_STARTED':
                self.logger.info('Rolling back building of new segments')
                newSegList = self.read_input_files(self.statusLogger.get_input_filename())
                self.addNewSegments(newSegList)
                SegmentTemplate.cleanup_build_new_segments(self.pool,
                                                           self.statusLogger.get_seg_tarfile(),
                                                           self.gparray, removeDataDirs=True)

            elif status[0] == 'UPDATE_CATALOG_STARTED':
                self.logger.info('Rolling back coordinator update')
                self.restore_coordinator()
                self.gparray = GpArray.initFromCatalog(dburl, utility=True)

            else:
                self.logger.debug('Skipping %s' % status[0])

        self.conn.close()

        self.statusLogger.remove_status_file()
        self.statusLogger.remove_segment_configuration_backup_file()

    def get_state(self):
        """Returns expansion state from status logger"""
        return self.statusLogger.get_current_status()[0]

    def generate_inputfile(self):
        """Writes a gpexpand input file based on expansion segments
        added to gparray by the gpexpand interview"""
        outputfile = 'gpexpand_inputfile_' + strftime("%Y%m%d_%H%M%S")
        outfile = open(outputfile, 'w')

        logger.info("Generating input file...")

        for db in self.gparray.getExpansionSegDbList():
            tempStr = "%s|%s|%d|%s|%d|%d|%s" % (canonicalize_address(db.getSegmentHostName())
                                                , canonicalize_address(db.getSegmentAddress())
                                                , db.getSegmentPort()
                                                , db.getSegmentDataDirectory()
                                                , db.getSegmentDbId()
                                                , db.getSegmentContentId()
                                                , db.getSegmentPreferredRole()
                                                )
            outfile.write(tempStr + "\n")

        outfile.close()

        return outputfile

    def generate_tablespace_inputfile(self, filename):
        coordinator_tblspc_dir = self.gparray.coordinator.getSegmentTableSpaceDirectory()
        tblspc_oids = os.listdir(coordinator_tblspc_dir)
        if not tblspc_oids:
            return None

        tblspc_oid_names = self.get_tablespace_oid_names()
        tblspc_info = {}

        for oid in tblspc_oids:
            if oid not in tblspc_oid_names:
                continue
            location = os.path.dirname(os.readlink(os.path.join(coordinator_tblspc_dir,
                                                                oid)))
            tblspc_info[oid] = {"location": location,
                                "name": tblspc_oid_names[int(oid)]}

        with open(filename, 'w') as f:
            names = "|".join([tblspc_info[oid]["name"]
                              for oid in tblspc_oids])
            oids = "|".join(tblspc_oids)
            headline = "tableSpaceNameOrders={names}".format(names=names)
            secline = "tableSpaceOidOrders={oids}".format(oids=oids)
            print(headline, file=f)
            print(secline, file=f)

            for db in self.gparray.getExpansionSegDbList():
                if db.isSegmentPrimary():
                    tmpStr = "{dbid}|{locations}"
                    locations = "|".join([tblspc_info[oid]["location"]
                                          for oid in tblspc_oids])
                    print(tmpStr.format(dbid=db.getSegmentDbId(),
                                             locations=locations), file=f)

        return filename

    def get_tablespace_oid_names(self):
        if self.conn is None:
            self.conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8', allowSystemTableMods=True)
        sql = "select oid, spcname from pg_tablespace"
        cursor = dbconn.query(self.conn, sql)
        return dict(cursor.fetchall())

    def addNewSegments(self, inputFileEntryList):
        for seg in inputFileEntryList:
            self.gparray.addExpansionSeg(content=int(seg.contentId)
                                         , preferred_role=seg.role
                                         , dbid=int(seg.dbid)
                                         , role=seg.role
                                         , hostname=seg.hostname.strip()
                                         , address=seg.address.strip()
                                         , port=int(seg.port)
                                         , datadir=os.path.abspath(seg.datadir.strip())
                                         )
        try:
            self.gparray.validateExpansionSegs()
        except Exception as e:
            raise ExpansionError('Invalid input file: %s' % e)

    def _getParsedRow(self, lineno, line):
        parts = line.split('|')
        if len(parts) != 7:
            raise ExceptionNoStackTraceNeeded("expected 7 parts, obtained %d" % len(parts))
        hostname, address, port, datadir, dbid, contentId, role = parts
        check_values(lineno, address=address, port=port, datadir=datadir, content=contentId,
                     hostname=hostname, dbid=dbid, role=role)
        return NewSegmentInput(hostname=hostname
                                        , port=port
                                        , address=address
                                        , datadir=datadir
                                        , dbid=dbid
                                        , contentId=contentId
                                        , role=role
                                        )

    def read_input_files(self, inputFilename=None):
        """Reads and validates line format of the input file passed
        in on the command line via the -i arg"""

        retValue = []

        if not self.options.filename and not inputFilename:
            raise ExpansionError('Missing input file')

        if self.options.filename:
            inputFilename = self.options.filename
        f = None

        try:
            f = open(inputFilename, 'r')
            for lineno, line in line_reader(f):
                try:
                    retValue.append(self._getParsedRow(lineno, line))
                except ValueError:
                    raise ExpansionError('Missing or invalid value on line %d of file %s.' % (lineno, inputFilename))
                except Exception as e:
                    raise ExpansionError('Invalid input file on line %d of file %s: %s' % (lineno, inputFilename, str(e)))
        except IOError:
            raise ExpansionError('Input file %s not found' % inputFilename)
        finally:
            if f is not None:
                f.close()

        return retValue

    def read_tablespace_file(self):
        """
        If there are user-created tablespaces in Cloudberry cluster,
        it returns a python dict, otherwise returns None.
        The python dict is like:
        newTableSpaceInfo is a python dict, its type spec is:
          newTableSpaceInfo :: {
                                  "names" : [ name::string ],
                                  "oids"  : [ oid::string ],
                                  "details" : {
                                                   dbid::string : [ location::string ]
                                              }
                               }
          newTableSpaceInfo[names] and newTableSpaceInfo[oids] are tablespace infos
          that are in the same order.
          newTableSpaceInfo[dbid] is a list of locations in the same order of oids.
        """
        coordinator_tblspc_dir = self.gparray.coordinator.getSegmentTableSpaceDirectory()
        if not os.listdir(coordinator_tblspc_dir):
            return None
        
        tblspc_oids = os.listdir(coordinator_tblspc_dir)
        tblspc_oid_names = self.get_tablespace_oid_names()
        flag = False
        for oid in tblspc_oids:
            if oid in tblspc_oid_names:
                flag = True
        if not flag:
            return None

        if not self.options.filename:
            raise ExpansionError('Missing tablespace input file')

        tablespace_inputfile = self.options.filename + ".ts"

        """
        Check if the tablespace input file exists or not
        In cases where the user manually creates an input file, the file
        will not be present. In such cases create the file and exit giving the
        user a chance to review it and re-run gpexpand.
        """
        if not os.path.exists(tablespace_inputfile):
            self.generate_tablespace_inputfile(tablespace_inputfile)
            self.logger.warning("Could not locate tablespace input configuration file '{0}'. A new tablespace input configuration file is written " \
                                "to '{0}'. Please review the file and re-run with: gpexpand -i {1}".format(tablespace_inputfile, self.options.filename))

            logger.info("Exiting...")
            sys.exit(1)

        new_tblspc_info = {}

        with open(tablespace_inputfile) as f:
            headline = f.readline().strip()
            secline = f.readline().strip()
            tblspc_names = headline.split('=')[1].strip().split('|')
            tblspc_oids = secline.split('=')[1].strip().split('|')
            new_tblspc_info["names"] = tblspc_names
            new_tblspc_info["oids"] = tblspc_oids
            details = {}
            for line in f:
                l = line.strip().split('|')
                dbid = l[0]
                locations = l[1:]
                details[dbid] = locations
            new_tblspc_info["details"] = details

        return new_tblspc_info

    def lock_catalog(self):
        self.conn_catalog_lock = dbconn.connect(self.dburl, utility=True, encoding='UTF8')
        self.logger.info('Locking catalog')
        dbconn.execSQL(self.conn_catalog_lock, "BEGIN", autocommit=False)
        # FIXME: is CHECKPOINT inside BEGIN the one wanted by us?
        dbconn.execSQL(self.conn_catalog_lock, "select gp_expand_lock_catalog()", autocommit=False)
        dbconn.execSQL(self.conn_catalog_lock, "CHECKPOINT", autocommit=False)
        self.logger.info('Locked catalog')

    def unlock_catalog(self):
        self.logger.info('Unlocking catalog')
        dbconn.execSQL(self.conn_catalog_lock, "COMMIT")
        self.conn_catalog_lock.close()
        self.conn_catalog_lock = None
        self.logger.info('Unlocked catalog')

    def add_segments(self, newTableSpaceInfo):
        """Starts the process of adding the new segments to the array"""
        self.segTemplate = SegmentTemplate(logger=self.logger,
                                           statusLogger=self.statusLogger,
                                           pool=self.pool,
                                           gparray=self.gparray,
                                           coordinatorDataDirectory=self.options.coordinator_data_directory,
                                           dburl=self.dburl,
                                           conn=self.conn,
                                           tempDir=self.tempDir,
                                           segTarDir=self.options.tardir,
                                           batch_size=self.options.batch_size,
                                           is_hba_hostnames=self.options.hba_hostnames)
        try:
            self.segTemplate.build_segment_template(newTableSpaceInfo)
            self.segTemplate.build_new_segments()
        except SegmentTemplateError as msg:
            raise ExpansionError(msg)

    def update_original_segments(self):
        """Updates the gp_id catalog table of existing hosts"""

        # Update the gp_id of original segments
        self.newPrimaryCount = 0;
        for seg in self.gparray.getExpansionSegDbList():
            if seg.isSegmentPrimary(False):
                self.newPrimaryCount += 1

        self.newPrimaryCount += self.gparray.get_primary_count()

        if self.segTemplate:
            self.segTemplate.cleanup()

        # FIXME: update postmaster.opts

    def update_catalog(self):
        """
        Starts the database, calls updateSystemConfig() to setup
        the catalog tables and get the actual dbid and content id
        for the new segments.
        """
        self.statusLogger.set_gp_segment_configuration_backup(
            self.options.coordinator_data_directory + '/' + SEGMENT_CONFIGURATION_BACKUP_FILE)
        self.gparray.dumpToFile(self.statusLogger.get_gp_segment_configuration_backup())
        self.statusLogger.set_status('UPDATE_CATALOG_STARTED', self.statusLogger.get_gp_segment_configuration_backup())
        self.statusLogger.sync_segment_configuration_backup_file()

        # Mark expansion segment primaries not in sync
        for seg in self.gparray.getExpansionSegDbList():
            if seg.isSegmentMirror() == True:
                continue
            if self.gparray.get_mirroring_enabled() == True:
                seg.setSegmentMode(MODE_NOT_SYNC)

        # Set expansion segment mirror state = down
        for seg in self.gparray.getExpansionSegDbList():
            if seg.isSegmentPrimary() == True:
                continue
            seg.setSegmentStatus(STATUS_DOWN)

        # Update the catalog
        configurationInterface.getConfigurationProvider().updateSystemConfig(
            self.gparray,
            "%s: segment config for resync" % getProgramName(),
            dbIdToForceMirrorRemoveAdd={},
            useUtilityMode=True,
            allowPrimary=True
        )

        # The content IDs may have changed, so we must make sure the array is in proper order.
        self.gparray.reOrderExpansionSegs()

        # Issue checkpoint due to forced shutdown below
        self.conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8')
        dbconn.execSQL(self.conn, "CHECKPOINT")
        self.conn.close()

        # increase expand version 
        self.conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8')
        dbconn.execSQL(self.conn, "select gp_expand_bump_version()")
        self.conn.close()

        self.statusLogger.set_status('UPDATE_CATALOG_DONE')
        self.pastThePointOfNoReturn = True;

    # --------------------------------------------------------------------------
    def cleanup_new_segments(self):
        """
        This method is called after all new segments have been configured.
        """

        self.logger.info('Cleaning up databases in new segments.')
        newSegments = self.gparray.getExpansionSegDbList()

        """ Get a list of databases. """
        conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8')
        databases = catalog.getDatabaseList(conn)
        conn.close()

        """
        Connect to each database in each segment and do some cleanup of tables that have stuff in them as a result of copying the segment from the coordinator.
        Note, this functionality used to be in segcopy and was therefore done just once to the original copy of the coordinator.
        """
        for seg in newSegments:
            if seg.isSegmentMirror() == True:
                continue
            """ Start all the new segments in utilty mode. """
            segStartCmd = SegmentStart(
                name="Starting new segment dbid %s on host %s." % (str(seg.getSegmentDbId()), seg.getSegmentHostName())
                , gpdb=seg
                , numContentsInCluster=self.newPrimaryCount  # Starting seg on it's own.
                , era=None
                , mirrormode=MIRROR_MODE_MIRRORLESS
                , utilityMode=True
                , ctxt=REMOTE
                , remoteHost=seg.getSegmentHostName()
                , pg_ctl_wait=True
                , timeout=SEGMENT_TIMEOUT_DEFAULT)
            self.pool.addCommand(segStartCmd)
        self.pool.join()
        self.pool.check_results()


        for i in range(1,12):
            flag = True
            for segment in newSegments:
                if seg.isSegmentMirror() == True:
                    continue

                cmd = Command('pg_isready for segment',
                            "pg_isready -q -h %s -p %d -d %s" % (segment.getSegmentHostName(), segment.getSegmentPort(), segment.getSegmentDataDirectory()))
                cmd.run()
                rc = cmd.get_return_code()
                if rc != 0:
                    flag &= False
            if flag: 
                break
            time.sleep(10)
            self.logger.info("Waiting for segment ready last for %s second" % (i*10))


        """
        Build the list of DELETE and TRUNCATE statements based on the subset of
        COORDINATOR_ONLY_TABLES defined in gpcatalog.py. Mapped tables cannot
        be TRUNCATEd, and we will have DELETE from them.
        """
        delete_statements = ["delete from pg_catalog.%s" % tab for tab in COORDINATOR_ONLY_TABLES_MAPPED]
        truncate_statements = ["truncate pg_catalog.%s" % tab for tab in COORDINATOR_ONLY_TABLES_NON_MAPPED]
        statements = delete_statements + truncate_statements
        """
          Connect to each database in the new segments, and clean up the catalog tables.
        """
        for seg in newSegments:
            if seg.isSegmentMirror() == True:
                continue
            for database in databases:
                if database[0] == 'template0':
                    continue
                dburl = dbconn.DbURL(hostname=seg.getSegmentHostName()
                                     , port=seg.getSegmentPort()
                                     , dbname=database[0]
                                     )
                name = "gpexpand execute segment cleanup commands. seg dbid = %s, command = %s" % (
                    seg.getSegmentDbId(), str(statements))
                execSQLCmd = ExecuteSQLStatementsCommand(name=name
                                                         , url=dburl
                                                         , sqlCommandList=statements
                                                         )
                self.pool.addCommand(execSQLCmd)

        self.pool.join()
        ### need to fix self.pool.check_results(). Call getCompletedItems to clear the queue for now.
        self.pool.check_results()
        self.pool.getCompletedItems()

    # --------------------------------------------------------------------------
    def restore_coordinator(self):
        """Restores the gp_segment_configuration catalog table for rollback"""
        backupFile = self.statusLogger.get_gp_segment_configuration_backup()

        if not os.path.exists(backupFile):
            raise ExpansionError('gp_segment_configuration backup file %s does not exist' % backupFile)

        # Create a new gpArray from the backup file
        array = GpArray.initFromFile(backupFile)

        originalDbIds = ""
        originalDbIdsList = []
        first = True
        for seg in array.getDbList():
            originalDbIdsList.append(int(seg.getSegmentDbId()))
            if first == False:
                originalDbIds += ", "
            first = False
            originalDbIds += str(seg.getSegmentDbId())

        if len(originalDbIds) > 0:
            # Update the catalog with the contents of the backup
            restore_conn = None
            try:
                restore_conn = dbconn.connect(self.dburl, utility=True, encoding='UTF8', allowSystemTableMods=True)

                # Get a list of all the expand primary segments
                sqlStr = "select dbid from pg_catalog.gp_segment_configuration where dbid not in (%s) and role = 'p'" % str(
                    originalDbIds)
                curs = dbconn.query(restore_conn, sqlStr)
                deleteDbIdList = []
                rows = curs.fetchall()
                for row in rows:
                    deleteDbIdList.append(int(row[0]))

                # Get a list of all the expand mirror segments
                sqlStr = "select content from pg_catalog.gp_segment_configuration where dbid not in (%s) and role = 'm'" % str(
                    originalDbIds)
                curs = dbconn.query(restore_conn, sqlStr)
                deleteContentList = []
                rows = curs.fetchall()
                for row in rows:
                    deleteContentList.append(int(row[0]))

                #
                # The following is a sanity check to make sure we don't do something bad here.
                #
                if len(originalDbIdsList) < 2:
                    self.logger.error(
                        "The original DB DIS list is to small to be correct: %s " % str(len(originalDbIdsList)))
                    raise Exception("Unable to complete rollback")

                totalToDelete = len(deleteDbIdList) + len(deleteContentList)
                if int(totalToDelete) > int(self.statusLogger.get_number_new_segments()):
                    self.logger.error(
                        "There was a discrepancy between the number of expand segments to rollback (%s), and the expected number of segment to rollback (%s)" \
                        % (str(totalToDelete), str(self.statusLogger.get_number_new_segments())))
                    self.logger.error("  Expanded primary segment dbids = %s", str(deleteDbIdList))
                    self.logger.error("  Expansion mirror content ids   = %s", str(deleteContentList))
                    raise Exception("Unable to complete rollback")

                for content in deleteContentList:
                    sqlStr = "select * from gp_remove_segment_mirror(%s::smallint)" % str(content)
                    dbconn.execSQL(restore_conn, sqlStr)

                for dbid in deleteDbIdList:
                    sqlStr = "select * from gp_remove_segment(%s::smallint)" % str(dbid)
                    dbconn.execSQL(restore_conn, sqlStr)

                restore_conn.commit()
            except Exception as e:
                raise Exception("Unable to restore coordinator. Exception: " + str(e))
            finally:
                if restore_conn != None:
                    restore_conn.close()

    def sync_new_mirrors(self):
        """ This method will execute gprecoverseg so that all new segments sync with their mirrors."""
        if self.gparray.get_mirroring_enabled():
            self.logger.info('Starting new mirror segment synchronization')
            cmd = GpRecoverSeg(name="gpexpand syncing mirrors", options="-a -F")
            cmd.run(validateAfter=True)

    def start_prepare(self):
        """Inserts into gpexpand.status that expansion preparation has started."""
        if self.options.filename:
            self.statusLogger.create_status_file()
            self.statusLogger.set_status('EXPANSION_PREPARE_STARTED', os.path.abspath(self.options.filename))

    def finalize_prepare(self):
        """Removes the gpexpand status file and segment configuration backup file"""
        self.statusLogger.remove_status_file()
        self.statusLogger.remove_segment_configuration_backup_file()

    def setup_schema(self):
        """Used to setup the gpexpand schema"""
        self.statusLogger.set_status('SETUP_EXPANSION_SCHEMA_STARTED')
        self.logger.info('Creating expansion schema')
        self.conn = dbconn.connect(self.dburl, encoding='UTF8')
        dbconn.execSQL(self.conn, create_schema_sql)
        dbconn.execSQL(self.conn, status_table_sql)
        dbconn.execSQL(self.conn, status_detail_table_sql)

        # views
        if not self.options.simple_progress:
            dbconn.execSQL(self.conn, progress_view_sql)
        else:
            dbconn.execSQL(self.conn, progress_view_simple_sql)

        self.statusLogger.set_status('SETUP_EXPANSION_SCHEMA_DONE')

    def prepare_schema(self):
        """Prepares the gpexpand schema"""
        self.statusLogger.set_status('PREPARE_EXPANSION_SCHEMA_STARTED')

        if not self.conn:
            self.conn = dbconn.connect(self.dburl, encoding='UTF8', allowSystemTableMods=True)
            self.gparray = GpArray.initFromCatalog(self.dburl)

        nowStr = datetime.datetime.now()
        statusSQL = "INSERT INTO gpexpand.status VALUES ( 'SETUP', '%s' ) " % (nowStr)

        dbconn.execSQL(self.conn, statusSQL)

        db_list = catalog.getDatabaseList(self.conn)

        for db in db_list:
            dbname = db[0]
            if dbname == 'template0':
                continue
            self.logger.info('Populating gpexpand.status_detail with data from database %s' % (
                dbname))
            self._populate_regular_tables(dbname)
            self._populate_partitioned_tables(dbname)

        nowStr = datetime.datetime.now()
        statusSQL = "INSERT INTO gpexpand.status VALUES ( 'SETUP DONE', '%s' ) " % (nowStr)
        dbconn.execSQL(self.conn, statusSQL)

        self.conn.close()

        self.statusLogger.set_status('PREPARE_EXPANSION_SCHEMA_DONE')
        self.statusLogger.set_status('EXPANSION_PREPARE_DONE')

        # At this point, no rollback is possible and the system
        # including new segments has been started once before so finalize
        self.finalize_prepare()

    def _populate_regular_tables(self, dbname):
        src_bytes_str = "0" if self.options.simple_progress else "pg_relation_size(quote_ident(n.nspname) || '.' || quote_ident(c.relname))"
        sql = """SELECT
    current_database(),
    quote_ident(n.nspname) || '.' || quote_ident(c.relname) as fq_name,
    c.oid as tableoid,
    NULL as root_partition_oid,
    2 as rank,
    pe.writable is not null as external_writable,
    '%s' as undone_status,
    NULL as expansion_started,
    NULL as expansion_finished,
    %s as source_bytes
FROM
    pg_class c
    JOIN pg_namespace n ON (c.relnamespace=n.oid)
    JOIN pg_catalog.gp_distribution_policy p on (c.oid = p.localoid)
    LEFT JOIN pg_partitioned_table pp on (c.oid=pp.partrelid)
    LEFT JOIN pg_exttable pe on (c.oid=pe.reloid and pe.writable)
WHERE
    pp.partrelid is NULL
    AND NOT c.relispartition
    AND n.nspname != 'gpexpand'
    AND n.nspname != 'pg_bitmapindex'
    AND c.relpersistence != 't'
                  """ % (undone_status, src_bytes_str)
        self.logger.debug(sql)
        table_conn = self.connect_database(dbname)

        try:
            data_file = os.path.abspath('./status_detail.dat')
            self.logger.debug('status_detail data file: %s' % data_file)
            copySQL = """COPY (%s) TO '%s'""" % (sql, data_file)

            self.logger.debug(copySQL)
            dbconn.execSQL(table_conn, copySQL)
            table_conn.close()
        except Exception as e:
            raise ExpansionError(e)

        try:
            copySQL = """COPY gpexpand.status_detail FROM '%s'""" % (data_file)

            self.logger.debug(copySQL)
            dbconn.execSQL(self.conn, copySQL)
        except Exception as e:
            raise ExpansionError(e)
        finally:
            os.unlink(data_file)

    def _populate_partitioned_tables(self, dbname):
        """
        The policy of leaves can be different of the policy of root, But it must
        follow below rules:
        If a partitioned table is Hash distributed, then all its leaf partitions
        must also be Hash partitioned on the same distribution key, with the
        same 'numsegments', or randomly distributed with the same 'numsegments'.
        If a partitioned table is Randomly distributed, then all the leaves must
        be randomly distributed as well.

        population of status_detail for partitioned tables, leaf partition can
        has different policy with root partition, we need to expand leaf
        partitions separately in parallel.

        Step1:
           BEGIN;
           Lock all root/interior/leaf partitions
           Change all numsegments of root/interior/leaf partitions to size of cluster;
           Change all leaf partition to random distributed;
           COMMIT;
        Step2:
           Change all leaf partition's policy back to parent's policy with set distributed
           with(REORGANIZE=true)
        """
        table_conn = self.connect_database(dbname)

        cursor = dbconn.query(table_conn, """
            SELECT partrelid::regclass AS relname
            FROM pg_partitioned_table, pg_class
            WHERE partrelid = pg_class.oid AND relispartition = FALSE;
        """)
        for row in cursor:
            prepare_cmd = """
                ALTER TABLE %s EXPAND PARTITION PREPARE;
            """ % (row.relname)
            self.logger.debug(prepare_cmd)
            dbconn.execSQL(table_conn, prepare_cmd, autocommit=True)

        src_bytes_str = "0" if self.options.simple_progress else "pg_relation_size(c.oid)"
        get_status_detail_cmd = """
             SELECT
                current_database(),
                quote_ident(n.nspname) || '.' || quote_ident(c.relname) as fq_name,
                c.oid as tableoid,
                d.oid as root_partition_oid,
                2 as rank,
                false as external_writable,
                '%s' as undone_status,
                NULL as expansion_started,
                NULL as expansion_finished,
                %s as source_bytes
            FROM
                pg_inherits a,
                pg_partitioned_table b,
                pg_class c, 
                pg_class d,
                pg_namespace n
            WHERE
                a.inhparent=b.partrelid and
                a.inhrelid = c.oid and
                a.inhparent = d.oid and
                c.relnamespace = n.oid and
                c.relkind != 'p' and
                c.relkind != 'f'
        """ % (undone_status, src_bytes_str)
        self.logger.debug(get_status_detail_cmd)

        try:
            data_file = os.path.abspath('./status_detail.dat')
            self.logger.debug('status_detail data file: %s' % data_file)
            copySQL = """COPY (%s) TO '%s'""" % (get_status_detail_cmd, data_file)

            self.logger.debug(copySQL)
            dbconn.execSQL(table_conn, copySQL)
            table_conn.commit()
            table_conn.close()
        except Exception as e:
            raise ExpansionError(e)

        try:
            copySQL = """COPY gpexpand.status_detail FROM '%s'""" % (data_file)

            self.logger.debug(copySQL)
            dbconn.execSQL(self.conn, copySQL)
        except Exception as e:
            raise ExpansionError(e)
        finally:
            os.unlink(data_file)

    def perform_expansion(self):
        """Performs the actual table re-organizations"""
        expansionStart = datetime.datetime.now()

        # setup a threadpool
        self.queue = WorkerPool(numWorkers=self.numworkers)

        # go through and reset any "IN PROGRESS" tables
        self.conn = dbconn.connect(self.dburl, encoding='UTF8')
        sql = "INSERT INTO gpexpand.status VALUES ( 'EXPANSION STARTED', '%s' ) " % (
            expansionStart)
        dbconn.execSQL(self.conn, sql)

        sql = """UPDATE gpexpand.status_detail set status = '%s' WHERE status = '%s' """ % (undone_status, start_status)
        dbconn.execSQL(self.conn, sql)

        # read schema and queue up commands
        sql = "SELECT * FROM gpexpand.status_detail WHERE status = 'NOT STARTED' ORDER BY rank"
        cursor = dbconn.query(self.conn, sql)

        for row in cursor:
            self.logger.debug(row)
            name = "name"
            tbl = ExpandTable(options=self.options, row=row)
            cmd = ExpandCommand(name=name, status_url=self.dburl, table=tbl, options=self.options)
            self.queue.addCommand(cmd)

        table_expand_error = False

        stopTime = None
        stoppedEarly = False
        if self.options.end:
            stopTime = self.options.end

        # wait till done.
        while not self.queue.isDone():
            logger.debug(
                "woke up.  queue: %d finished %d  " % (self.queue.assigned, self.queue.completed_queue.qsize()))
            if stopTime and datetime.datetime.now() >= stopTime:
                stoppedEarly = True
                break
            time.sleep(5)

        expansionStopped = datetime.datetime.now()

        self.pool.haltWork()
        self.pool.joinWorkers()
        self.queue.haltWork()
        self.queue.joinWorkers()

        # Doing this after the halt and join workers guarantees that no new completed items can be added
        # while we're doing a check
        for expandCommand in self.queue.getCompletedItems():
            if expandCommand.table_expand_error:
                table_expand_error = True
                break

        if stoppedEarly:
            logger.info('End time reached.  Stopping expansion.')
            sql = "INSERT INTO gpexpand.status VALUES ( 'EXPANSION STOPPED', '%s' ) " % (
                expansionStopped)
            dbconn.execSQL(self.conn, sql)
            logger.info('You can resume expansion by running gpexpand again')
        elif table_expand_error:
            logger.warn('**************************************************')
            logger.warn('One or more tables failed to expand successfully.')
            logger.warn('Please check the log file, correct the problem and')
            logger.warn('run gpexpand again to finish the expansion process')
            logger.warn('**************************************************')
            # We'll try to update the status, but if the errors were caused by
            # going into read only mode, this will fail.  That's ok though as
            # gpexpand will resume next run
            try:
                sql = "INSERT INTO gpexpand.status VALUES ( 'EXPANSION STOPPED', '%s' ) " % (
                    expansionStopped)
                dbconn.execSQL(self.conn, sql)
            except:
                pass
        else:
            sql = "INSERT INTO gpexpand.status VALUES ( 'EXPANSION COMPLETE', '%s' ) " % (
                expansionStopped)
            dbconn.execSQL(self.conn, sql)
            logger.info("EXPANSION COMPLETED SUCCESSFULLY")

    def shutdown(self):
        """used if the script is closed abrubtly"""
        logger.info('Shutting down gpexpand...')
        if self.pool:
            self.pool.haltWork()
            self.pool.joinWorkers()

        if self.queue:
            self.queue.haltWork()
            self.queue.joinWorkers()

        try:
            expansionStopped = datetime.datetime.now()
            sql = "INSERT INTO gpexpand.status VALUES ( 'EXPANSION STOPPED', '%s' ) " % (
                expansionStopped)
            dbconn.execSQL(self.conn, sql)
            self.conn.close()
        except pgdb.OperationalError:
            pass
        except Exception:
            # schema doesn't exist.  Cancel or error during setup
            pass

    def halt_work(self):
        if self.pool:
            self.pool.haltWork()
            self.pool.joinWorkers()

        if self.queue:
            self.queue.haltWork()
            self.queue.joinWorkers()

    def cleanup_schema(self, gpexpand_db_status):
        """Removes the gpexpand schema"""
        # drop schema
        if gpexpand_db_status != 'EXPANSION COMPLETE':
            c = dbconn.connect(self.dburl, encoding='UTF8')
            self.logger.warn('Expansion has not yet completed.  Removing the expansion')
            self.logger.warn('schema now will leave the following tables unexpanded:')
            unexpanded_tables_sql = "SELECT fq_name FROM gpexpand.status_detail WHERE status = 'NOT STARTED' ORDER BY rank"

            cursor = dbconn.query(c, unexpanded_tables_sql)
            unexpanded_tables_text = ''.join("\t%s\n" % row[0] for row in cursor)

            c.close()

            self.logger.warn(unexpanded_tables_text)
            self.logger.warn('These tables will have to be expanded manually by setting')
            self.logger.warn('the distribution policy using the ALTER TABLE command.')
            if not ask_yesno('', "Are you sure you want to drop the expansion schema?", 'N'):
                logger.info("User Aborted. Exiting...")
                sys.exit(0)

        # See if user wants to dump the status_detail table to file
        c = dbconn.connect(self.dburl, encoding='UTF8')
        if ask_yesno('', "Do you want to dump the gpexpand.status_detail table to file?", 'Y'):
            self.logger.info(
                "Dumping gpexpand.status_detail to %s/gpexpand.status_detail" % self.options.coordinator_data_directory)
            copy_gpexpand_status_detail_sql = "COPY gpexpand.status_detail TO '%s/gpexpand.status_detail'" % self.options.coordinator_data_directory
            dbconn.execSQL(c, copy_gpexpand_status_detail_sql)

        self.logger.info("Removing gpexpand schema")
        dbconn.execSQL(c, drop_schema_sql)
        c.commit()
        c.close()

    def connect_database(self, dbname):
        test_url = copy.deepcopy(self.dburl)
        test_url.pgdb = dbname
        c = dbconn.connect(test_url, encoding='UTF8', allowSystemTableMods=True)
        return c

    def sync_packages(self):
        """
        The design decision here is to squash any exceptions resulting from the
        synchronization of packages. We should *not* disturb the user's attempts to expand.
        """
        try:
            logger.info('Syncing Apache Cloudberry extensions')
            new_segment_list = self.gparray.getExpansionSegDbList()
            new_host_set = set([h.getSegmentHostName() for h in new_segment_list])
            operations = [SyncPackages(host) for host in new_host_set]
            ParallelOperation(operations, self.numworkers).run()
            # introspect outcomes
            for operation in operations:
                operation.get_ret()
        except Exception:
            logger.exception('Syncing of Apache Cloudberry extensions has failed.')
            logger.warning('Please run gppkg --clean after successful expansion.')

    def validate_heap_checksums(self):
        num_workers = min(len(self.gparray.get_hostlist()), MAX_PARALLEL_EXPANDS)
        heap_checksum_util = HeapChecksum(gparray=self.gparray, num_workers=num_workers, logger=self.logger)
        successes, failures = heap_checksum_util.get_segments_checksum_settings()
        if len(successes) == 0:
            logger.fatal("No segments responded to ssh query for heap checksum. Not expanding the cluster.")
            return 1

        consistent, inconsistent, coordinator_heap_checksum = heap_checksum_util.check_segment_consistency(successes)

        inconsistent_segment_msgs = []
        for segment in inconsistent:
            inconsistent_segment_msgs.append("dbid: %s "
                                             "checksum set to %s differs from coordinator checksum set to %s" %
                                             (segment.getSegmentDbId(), segment.heap_checksum,
                                              coordinator_heap_checksum))

        if not heap_checksum_util.are_segments_consistent(consistent, inconsistent):
            self.logger.fatal("Cluster heap checksum setting differences reported")
            self.logger.fatal("Heap checksum settings on %d of %d segment instances do not match coordinator <<<<<<<<"
                              % (len(inconsistent_segment_msgs), len(self.gparray.segmentPairs)))
            self.logger.fatal("Review %s for details" % get_logfile())
            log_to_file_only("Failed checksum consistency validation:", logging.WARN)
            self.logger.fatal("gpexpand error: Cluster will not be modified as checksum settings are not consistent "
                              "across the cluster.")

            for msg in inconsistent_segment_msgs:
                log_to_file_only(msg, logging.WARN)
                raise Exception("Segments have heap_checksum set inconsistently to coordinator")
        else:
            self.logger.info("Heap checksum setting consistent across cluster")


# -----------------------------------------------
class ExpandTable():
    def __init__(self, options, row=None):
        self.options = options
        if row is not None:
            (self.dbname, self.fq_name, self.table_oid,
             self.root_partition_oid,
             self.rank, self.external_writable, self.status,
             self.expansion_started, self.expansion_finished,
             self.source_bytes) = row

    def add_table(self, conn):
        insertSQL = """INSERT INTO gpexpand.status_detail
                            VALUES ('%s','%s',%s,
                                    '%d',%d,'%s','%s','%s','%s',%d)
                    """ % (self.dbname.replace("'", "''"), self.fq_name.replace("'", "''"), self.table_oid,
                           self.root_partition_oid,
                           self.rank, self.external_writable, self.status,
                           self.expansion_started, self.expansion_finished,
                           self.source_bytes)
        logger.info('Added table %s.%s' % (self.dbname, self.fq_name))
        logger.debug(insertSQL)
        dbconn.execSQL(conn, insertSQL)

    def mark_started(self, status_conn, table_conn, start_time, cancel_flag):
        if cancel_flag:
            return
        sql = "SELECT pg_relation_size(%s)" % (self.table_oid)
        row = dbconn.queryRow(table_conn, sql)
        src_bytes = int(row[0])
        logger.debug(" Table: %s has %d bytes" % (self.fq_name, src_bytes))

        sql = """UPDATE gpexpand.status_detail
                  SET status = '%s', expansion_started='%s',
                      source_bytes = %d
                  WHERE dbname = '%s'
                        AND table_oid = %s """ % (start_status, start_time,
                                                  src_bytes, self.dbname.replace("'", "''"),
                                                  self.table_oid)

        logger.debug("Mark Started: " + sql)
        dbconn.execSQL(status_conn, sql)

    def reset_started(self, status_conn):
        sql = """UPDATE gpexpand.status_detail
                 SET status = '%s', expansion_started=NULL, expansion_finished=NULL
                 WHERE dbname = '%s'
                 AND table_oid = %s """ % (undone_status,
                                           self.dbname.replace("'", "''"), self.table_oid)

        logger.debug('Resetting detailed_status: %s' % sql)
        dbconn.execSQL(status_conn, sql)

    def expand(self, table_conn, cancel_flag):
        # expand leaf partitions separately in parallel
        # FIXME: alter table on external table does not throw
        #        a warning, but it will throw error in 6X
        #        do we still need using alter external table?
        if self.root_partition_oid is not None:
            get_dist_cmd = """
                select pg_get_table_distributedby(%d) as distribution_policy;
            """ % (self.root_partition_oid)
            res = dbconn.queryRow(table_conn, get_dist_cmd)
            sql = "ALTER TABLE %s SET WITH (REORGANIZE=true) %s" % (self.fq_name, res.distribution_policy)
        else:
            # FIXME: Can "ONLY" be allowed in "EXPAND TABLE"?
            sql = 'ALTER TABLE %s EXPAND TABLE' % self.fq_name

        logger.info('Expanding %s.%s' % (self.dbname, self.fq_name))
        logger.debug("Expand SQL: %s" % sql)

        # check is atomic in python
        if not cancel_flag:
            dbconn.execSQL(table_conn, sql)
            # the ALTER TABLE command requires a commit to execute
            table_conn.commit()
            if self.options.analyze:
                sql = 'ANALYZE %s' % (self.fq_name)
                logger.info('Analyzing %s' % (self.fq_name))
                dbconn.execSQL(table_conn, sql)

            return True

        # I can only get here if the cancel flag is True
        return False

    def mark_finished(self, status_conn, start_time, finish_time):
        sql = """UPDATE gpexpand.status_detail
                  SET status = '%s', expansion_started='%s', expansion_finished='%s'
                  WHERE dbname = '%s'
                  AND table_oid = %s """ % (done_status, start_time, finish_time,
                                            self.dbname.replace("'", "''"), self.table_oid)
        logger.debug(sql)
        dbconn.execSQL(status_conn, sql)

    def mark_does_not_exist(self, status_conn, finish_time):
        sql = """UPDATE gpexpand.status_detail
                  SET status = '%s', expansion_finished='%s'
                  WHERE dbname = '%s'
                  AND table_oid = %s """ % (does_not_exist_status, finish_time,
                                            self.dbname.replace("'", "''"), self.table_oid)
        logger.debug(sql)
        dbconn.execSQL(status_conn, sql)


# -----------------------------------------------
class ExecuteSQLStatementsCommand(SQLCommand):
    """
    This class will execute a list of SQL statements.
    """

    def __init__(self, name, url, sqlCommandList):
        self.name = name
        self.url = url
        self.sqlCommandList = sqlCommandList
        self.conn = None
        self.error = None

        SQLCommand.__init__(self, name)

    def run(self, validateAfter=False):
        statement = None

        self.results = CommandResult(rc=0
                                     , stdout=b""
                                     , stderr=b""
                                     , completed=True
                                     , halt=False
                                     )

        try:
            self.conn = dbconn.connect(self.url, utility=True, encoding='UTF8', allowSystemTableMods=True)
            for statement in self.sqlCommandList:
                dbconn.execSQL(self.conn, statement)
        except Exception as e:
            # traceback.print_exc()
            logger.error("Exception in ExecuteSQLStatements. URL = %s" % str(self.url))
            logger.error("  Statement = %s" % str(statement))
            logger.error("  Exception = %s" % str(e))
            self.error = str(e)
            self.results = CommandResult(rc=1
                                         , stdout=b""
                                         , stderr=str(e).encode()
                                         , completed=False
                                         , halt=True
                                         )
        finally:
            if self.conn != None:
                self.conn.close()

    def set_results(self, results):
        raise ExecutionError("TODO:  must implement", None)

    def get_results(self):
        return self.results

    def was_successful(self):
        if self.error != None:
            return False
        else:
            return True

    def validate(self, expected_rc=0):
        raise ExecutionError("TODO:  must implement", None)


# -----------------------------------------------
class ExpandCommand(SQLCommand):
    def __init__(self, name, status_url, table, options):
        self.status_url = status_url
        self.table = table
        self.options = options
        self.cmdStr = "Expand %s.%s" % (table.dbname, table.fq_name)
        self.table_url = copy.deepcopy(status_url)
        self.table_url.pgdb = table.dbname
        self.table_expand_error = False

        SQLCommand.__init__(self, name)

    def run(self, validateAfter=False):
        # connect.
        status_conn = None
        table_conn = None
        table_exp_success = False

        try:
            status_conn = dbconn.connect(self.status_url, encoding='UTF8')
            table_conn = dbconn.connect(self.table_url, encoding='UTF8')
        except DatabaseError as ex:
            if self.options.verbose:
                logger.exception(ex)
            logger.error(ex.__str__().strip())
            if status_conn: status_conn.close()
            if table_conn: table_conn.close()
            self.table_expand_error = True
            return

        # validate table hasn't been dropped
        start_time = None
        try:
            sql = """select * from pg_class c where c.oid = %d """ % (self.table.table_oid)

            cursor = dbconn.query(table_conn, sql)

            if cursor.rowcount == 0:
                logger.info('%s no longer exists in database %s' % (self.table.fq_name,
                                                                       self.table.dbname))

                self.table.mark_does_not_exist(status_conn, datetime.datetime.now())
                status_conn.close()
                table_conn.close()
                return
            else:
                # Set conn for  cancel
                self.cancel_conn = table_conn
                start_time = datetime.datetime.now()
                if not self.options.simple_progress:
                    self.table.mark_started(status_conn, table_conn, start_time, self.cancel_flag)

                table_exp_success = self.table.expand(table_conn, self.cancel_flag)

        except Exception as ex:
            if ex.__str__().find('canceling statement due to user request') == -1 and not self.cancel_flag:
                self.table_expand_error = True
                if self.options.verbose:
                    logger.exception(ex)
                logger.error('Table %s.%s failed to expand: %s' % (self.table.dbname,
                                                                   self.table.fq_name,
                                                                   ex.__str__().strip()))
            else:
                logger.info('ALTER TABLE of %s.%s canceled' % (
                    self.table.dbname, self.table.fq_name))

        if table_exp_success:
            end_time = datetime.datetime.now()
            # update metadata
            logger.info(
                "Finished expanding %s.%s" % (self.table.dbname, self.table.fq_name))
            self.table.mark_finished(status_conn, start_time, end_time)
        elif not self.options.simple_progress:
            logger.info("Resetting status_detail for %s.%s" % (
                self.table.dbname, self.table.fq_name))
            self.table.reset_started(status_conn)

        # disconnect
        status_conn.close()
        table_conn.close()

    def set_results(self, results):
        raise ExecutionError("TODO:  must implement", None)

    def get_results(self):
        raise ExecutionError("TODO:  must implement", None)

    def was_successful(self):
        raise ExecutionError("TODO:  must implement", None)

    def validate(self, expected_rc=0):
        raise ExecutionError("TODO:  must implement", None)


# ------------------------------- UI Help --------------------------------
def read_hosts_file(hosts_file):
    new_hosts = []
    try:
        f = open(hosts_file, 'r')
        try:
            for l in f:
                if l.strip().startswith('#') or l.strip() == '':
                    continue

                new_hosts.append(l.strip())

        finally:
            f.close()
    except IOError:
        raise ExpansionError('Hosts file %s not found' % hosts_file)

    return new_hosts


def interview_setup(gparray, options):
    help = """
System Expansion is used to add segments to an existing CBDB array.
gpexpand did not detect a System Expansion that is in progress.

Before initiating a System Expansion, you need to provision and burn-in
the new hardware.  Please be sure to run gpcheckperf to make sure the
new hardware is working properly.

Please refer to the Admin Guide for more information."""

    if not ask_yesno(help, "Would you like to initiate a new System Expansion", 'N'):
        logger.info("User Aborted. Exiting...")
        sys.exit(0)

    help = """
This utility can handle some expansion scenarios by asking a few questions.
More complex expansions can be done by providing an input file with
the --input <file>.  Please see the docs for the format of this file. """

    standard, message = gparray.isStandardArray()
    if standard == False:
        help = help + """

       The current system appears to be non-standard.
       """
        help = help + message
        help = help + """
       gpexpand may not be able to symmetrically distribute the new segments appropriately.
       It is recommended that you specify your own input file with appropriate values."""
        if not ask_yesno(help, "Are you sure you want to continue with this gpexpand session?", 'N'):
            logger.info("User Aborted. Exiting...")
            sys.exit(0)

    help = help + """

We'll now ask you a few questions to try and build this file for you.
You'll have the opportunity to save this file and inspect it/modify it
before continuing by re-running this utility and providing the input file. """

    def datadir_validator(input_value, *args):
        if not input_value or input_value.find(' ') != -1 or input_value == '':
            return None
        else:
            return input_value

    if options.hosts_file:
        new_hosts = read_hosts_file(options.hosts_file)
    else:
        new_hosts = ask_list(None,
                             "\nEnter a comma separated list of new hosts you want\n" \
                             "to add to your array.  Do not include interface hostnames.\n" \
                             "**Enter a blank line to only add segments to existing hosts**", [])
        new_hosts = [host.strip() for host in new_hosts]

    num_new_hosts = len(new_hosts)

    mirror_type = 'none'

    if gparray.get_mirroring_enabled():
        if num_new_hosts < 2:
            raise ExpansionError('You must be adding two or more hosts when expanding a system with mirroring enabled.')
        mirror_type = ask_string(
            "\nYou must now specify a mirroring strategy for the new hosts.  Spread mirroring places\n" \
            "a given hosts mirrored segments each on a separate host.  You must be \n" \
            "adding more hosts than the number of segments per host to use this. \n" \
            "Grouped mirroring places all of a given hosts segments on a single \n" \
            "mirrored host.  You must be adding at least 2 hosts in order to use this.\n\n",
            "What type of mirroring strategy would you like?",
            'grouped', ['spread', 'grouped'])

    try:
        gparray.addExpansionHosts(new_hosts, mirror_type)
        gparray.validateExpansionSegs()
    except Exception as ex:
        num_new_hosts = 0
        if ex.__str__() == 'No new hosts to add':
            print()
            print('** No hostnames were given that do not already exist in the **')
            print('** array. Additional segments will be added existing hosts. **')
        else:
            raise

    help = """
    By default, new hosts are configured with the same number of primary
    segments as existing hosts.  Optionally, you can increase the number
    of segments per host.

    For example, if existing hosts have two primary segments, entering a value
    of 2 will initialize two additional segments on existing hosts, and four
    segments on new hosts.  In addition, mirror segments will be added for
    these new primary segments if mirroring is enabled.
    """
    num_new_datadirs = ask_int(help, "How many new primary segments per host do you want to add?", None, 0, 0, 128)

    if num_new_datadirs > 0:
        new_datadirs = []
        new_mirrordirs = []

        for i in range(1, num_new_datadirs + 1):
            new_datadir = ask_input(None, 'Enter new primary data directory %d' % i, '',
                                    '/data/gpdb_p%d' % i, datadir_validator, None)
            new_datadirs.append(new_datadir.strip())

        if len(new_datadirs) != num_new_datadirs:
            raise ExpansionError(
                'The number of data directories entered does not match the number of primary segments added')

        if gparray.get_mirroring_enabled():
            for i in range(1, num_new_datadirs + 1):
                new_mirrordir = ask_input(None, 'Enter new mirror data directory %d' % i, '',
                                          '/data/gpdb_m%d' % i, datadir_validator, None)
                new_mirrordirs.append(new_mirrordir.strip())

            if len(new_mirrordirs) != num_new_datadirs:
                raise ExpansionError(
                    'The number of new mirror data directories entered does not match the number of segments added')

        gparray.addExpansionDatadirs(datadirs=new_datadirs
                                     , mirrordirs=new_mirrordirs
                                     , mirror_type=mirror_type
                                     )
        try:
            gparray.validateExpansionSegs()
        except Exception as ex:
            if ex.__str__().find('Port') == 0:
                raise ExpansionError(
                    'Current primary and mirror ports are contiguous.  The input file for gpexpand will need to be created manually.')
    elif num_new_hosts == 0:
        raise ExpansionError('No new hosts or segments were entered.')

    print("\nGenerating configuration file...\n")

    outfile = _gp_expand.generate_inputfile()

    outFileStr = """\nInput configuration file was written to '%s'.""" % (outfile)
    print(outFileStr)

    outfile_ts = _gp_expand.generate_tablespace_inputfile(outfile + ".ts")
    if outfile_ts:
        outFileTsStr = """Tablespace Input configuration file was written to '%s'.""" % (outfile_ts)
        print(outFileTsStr)

    print()

    print("""Please review the file and make sure that it is correct then re-run
with: gpexpand -i %s
                """ % (outfile))


def sig_handler(sig, arg):
    if _gp_expand is not None:
        _gp_expand.shutdown()

    signal.signal(signal.SIGTERM, signal.SIG_DFL)
    signal.signal(signal.SIGHUP, signal.SIG_DFL)

    # raise sig
    os.kill(os.getpid(), sig)


# --------------------------------------------------------------------------
# Main
# --------------------------------------------------------------------------
def main(options, args, parser):
    global _gp_expand

    remove_pid = True
    try:
        # setup signal handlers so we can clean up correctly
        signal.signal(signal.SIGTERM, sig_handler)
        signal.signal(signal.SIGHUP, sig_handler)

        logger = get_default_logger()
        setup_tool_logging(EXECNAME, getLocalHostname(), getUserName())

        options, args = validate_options(options, args, parser)

        if options.verbose:
            enable_verbose_logging()

        if is_gpexpand_running(options.coordinator_data_directory):
            logger.error('gpexpand is already running.  Only one instance')
            logger.error('of gpexpand is allowed at a time.')
            remove_pid = False
            sys.exit(1)
        else:
            create_pid_file(options.coordinator_data_directory)

        # prepare provider for updateSystemConfig
        gpEnv = GpCoordinatorEnvironment(options.coordinator_data_directory, True)
        configurationInterface.registerConfigurationProvider(
            configurationImplGpdb.GpConfigurationProviderUsingGpdbCatalog())
        configurationInterface.getConfigurationProvider().initializeProvider(gpEnv.getCoordinatorPort())

        dburl = dbconn.DbURL(dbname=DBNAME, port=gpEnv.getCoordinatorPort())

        gpexpand_db_status = gpexpand.prepare_gpdb_state(logger, dburl, options)

        # Get array configuration
        try:
            gparray = GpArray.initFromCatalog(dburl, utility=True)
        except DatabaseError as ex:
            logger.error('Failed to connect to database.  Make sure the')
            logger.error('Cloudberry instance you wish to expand is running')
            logger.error('and that your environment is correct, then rerun')
            logger.error('gpexpand ' + ' '.join(sys.argv[1:]))
            sys.exit(1)

        _gp_expand = gpexpand(logger, gparray, dburl, options, parallel=options.parallel)

        gpexpand_file_status = None
        if not gpexpand_db_status:
            gpexpand_file_status = _gp_expand.get_state()

        if options.clean and gpexpand_db_status is not None:
            _gp_expand.cleanup_schema(gpexpand_db_status)
            logger.info('Cleanup Finished.  exiting...')
            sys.exit(0)

        if options.rollback:
            try:
                if gpexpand_db_status:
                    logger.error('A previous expansion is either in progress or has')
                    logger.error('completed.  Since the setup portion of the expansion')
                    logger.error('has finished successfully there is nothing to rollback.')
                    sys.exit(1)
                if gpexpand_file_status is None:
                    logger.error('There is no partially completed setup to rollback.')
                    sys.exit(1)
                _gp_expand.rollback(dburl)
                logger.info('Rollback complete.')
                sys.exit(0)
            except ExpansionError as e:
                logger.error(e)
                sys.exit(1)

        # check if the cluster is in good health
        if gpexpand_file_status is None and gpexpand_db_status is None and not is_cluster_up_and_balanced(dburl):
            logger.warning('One or more segments are either down or not in preferred role.')

        if gpexpand_db_status == 'SETUP DONE' or gpexpand_db_status == 'EXPANSION STOPPED':
            if not _gp_expand.validate_max_connections():
                raise ValidationError()
            _gp_expand.perform_expansion()
        elif gpexpand_db_status == 'EXPANSION STARTED':
            logger.info('It appears the last run of gpexpand did not exit cleanly.')
            logger.info('Resuming the expansion process...')
            if not _gp_expand.validate_max_connections():
                raise ValidationError()
            _gp_expand.perform_expansion()
        elif gpexpand_db_status == 'EXPANSION COMPLETE':
            logger.info('Expansion has already completed.')
            logger.info('If you want to expand again, run gpexpand -c to remove')
            logger.info('the gpexpand schema and begin a new expansion')
        elif gpexpand_db_status is None and gpexpand_file_status is None and options.filename:
            _gp_expand.validate_heap_checksums()
            newSegList = _gp_expand.read_input_files()
            _gp_expand.addNewSegments(newSegList)
            newTableSpaceInfo = _gp_expand.read_tablespace_file()
            _gp_expand.sync_packages()
            _gp_expand.start_prepare()
            _gp_expand.lock_catalog()
            _gp_expand.add_segments(newTableSpaceInfo)
            _gp_expand.update_original_segments()
            _gp_expand.cleanup_new_segments()
            _gp_expand.update_catalog()
            _gp_expand.unlock_catalog()
            _gp_expand.setup_schema()
            _gp_expand.prepare_schema()
            _gp_expand.sync_new_mirrors()
            logger.info('************************************************')
            logger.info('Initialization of the system expansion complete.')
            logger.info('To begin table expansion onto the new segments')
            logger.info('rerun gpexpand')
            logger.info('************************************************')
        elif gpexpand_file_status is not None and not _gp_expand.statusLogger.can_rollback(gpexpand_file_status):
            """
            gpexpand cannot rollback if new segments are online and catalog lock has been released in phase1
            So if gpexpand fails after releasing catalog lock in phase1, it must retry the failing work
            """
            logger.warn('The last gpexpand setup did not complete successfully.')
            logger.warn('But you can not rollback to the original state, for new segments have been online.')
            logger.warn('So retry the failing work again in gpexpand setup.')

            """Clean the schema first"""
            conn = dbconn.connect(_gp_expand.dburl, encoding='UTF8')
            dbconn.execSQL(conn, drop_schema_sql)
            conn.close()

            """
            Reset status to UPDATA_CATALOG_DONE
            So following work can continue
            """
            _gp_expand.statusLogger.rewind('UPDATE_CATALOG_DONE', 'Reset status forcedly to retry the following work')

            """Redo the following work"""
            _gp_expand.setup_schema()
            _gp_expand.prepare_schema()
            _gp_expand.sync_new_mirrors()
            logger.info('************************************************')
            logger.info('Initialization of the system expansion complete.')
            logger.info('To begin table expansion onto the new segments')
            logger.info('rerun gpexpand')
            logger.info('************************************************')
        elif options.filename is None and gpexpand_file_status is None:
            interview_setup(gparray, options)
        else:
            logger.error('The last gpexpand setup did not complete successfully.')
            logger.error('Please run gpexpand -r to rollback to the original state.')

        logger.info("Exiting...")
        sys.exit(0)

    except ValidationError:
        logger.info('Bringing Apache Cloudberry back online...')
        if _gp_expand is not None:
            _gp_expand.shutdown()
        sys.exit()
    except Exception as e:
        if options and options.verbose:
            logger.exception("gpexpand failed. exiting...")
        else:
            logger.error("gpexpand failed: %s \n\nExiting..." % e)
        if _gp_expand is not None and _gp_expand.pastThePointOfNoReturn == True:
            logger.error(
                'gpexpand is past the point of rollback. Any remaining issues must be addressed outside of gpexpand.')
        if _gp_expand is not None:
            if not (gpexpand_db_status is None and _gp_expand.get_state() is None):
                if _gp_expand.pastThePointOfNoReturn == False:
                    logger.error('Please run \'gpexpand -r\' to rollback to the original state.')
            _gp_expand.shutdown()
        sys.exit(3)
    except KeyboardInterrupt:
        # Disable SIGINT while we shutdown.
        signal.signal(signal.SIGINT, signal.SIG_IGN)

        if _gp_expand is not None:
            _gp_expand.shutdown()

        # Re-enabled SIGINT
        signal.signal(signal.SIGINT, signal.default_int_handler)

        sys.exit('\nUser Interrupted')


    finally:
        try:
            if remove_pid and options:
                remove_pid_file(options.coordinator_data_directory)
        except NameError:
            pass

        if _gp_expand is not None:
            _gp_expand.halt_work()

if __name__ == '__main__':
    options, args, parser = parseargs()
    main(options, args, parser)
