File: //lib/python3/dist-packages/sqlparse/engine/grouping.py
#
# Copyright (C) 2009-2020 the sqlparse authors and contributors
# <see AUTHORS file>
#
# This module is part of python-sqlparse and is released under
# the BSD License: https://opensource.org/licenses/BSD-3-Clause
from sqlparse import sql
from sqlparse import tokens as T
from sqlparse.utils import recurse, imt
T_NUMERICAL = (T.Number, T.Number.Integer, T.Number.Float)
T_STRING = (T.String, T.String.Single, T.String.Symbol)
T_NAME = (T.Name, T.Name.Placeholder)
def _group_matching(tlist, cls):
    """Groups Tokens that have beginning and end."""
    opens = []
    tidx_offset = 0
    for idx, token in enumerate(list(tlist)):
        tidx = idx - tidx_offset
        if token.is_whitespace:
            # ~50% of tokens will be whitespace. Will checking early
            # for them avoid 3 comparisons, but then add 1 more comparison
            # for the other ~50% of tokens...
            continue
        if token.is_group and not isinstance(token, cls):
            # Check inside previously grouped (i.e. parenthesis) if group
            # of different type is inside (i.e., case). though ideally  should
            # should check for all open/close tokens at once to avoid recursion
            _group_matching(token, cls)
            continue
        if token.match(*cls.M_OPEN):
            opens.append(tidx)
        elif token.match(*cls.M_CLOSE):
            try:
                open_idx = opens.pop()
            except IndexError:
                # this indicates invalid sql and unbalanced tokens.
                # instead of break, continue in case other "valid" groups exist
                continue
            close_idx = tidx
            tlist.group_tokens(cls, open_idx, close_idx)
            tidx_offset += close_idx - open_idx
def group_brackets(tlist):
    _group_matching(tlist, sql.SquareBrackets)
def group_parenthesis(tlist):
    _group_matching(tlist, sql.Parenthesis)
def group_case(tlist):
    _group_matching(tlist, sql.Case)
def group_if(tlist):
    _group_matching(tlist, sql.If)
def group_for(tlist):
    _group_matching(tlist, sql.For)
def group_begin(tlist):
    _group_matching(tlist, sql.Begin)
def group_typecasts(tlist):
    def match(token):
        return token.match(T.Punctuation, '::')
    def valid(token):
        return token is not None
    def post(tlist, pidx, tidx, nidx):
        return pidx, nidx
    valid_prev = valid_next = valid
    _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
def group_tzcasts(tlist):
    def match(token):
        return token.ttype == T.Keyword.TZCast
    def valid(token):
        return token is not None
    def post(tlist, pidx, tidx, nidx):
        return pidx, nidx
    _group(tlist, sql.Identifier, match, valid, valid, post)
def group_typed_literal(tlist):
    # definitely not complete, see e.g.:
    # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literal-syntax
    # https://docs.microsoft.com/en-us/sql/odbc/reference/appendixes/interval-literals
    # https://www.postgresql.org/docs/9.1/datatype-datetime.html
    # https://www.postgresql.org/docs/9.1/functions-datetime.html
    def match(token):
        return imt(token, m=sql.TypedLiteral.M_OPEN)
    def match_to_extend(token):
        return isinstance(token, sql.TypedLiteral)
    def valid_prev(token):
        return token is not None
    def valid_next(token):
        return token is not None and token.match(*sql.TypedLiteral.M_CLOSE)
    def valid_final(token):
        return token is not None and token.match(*sql.TypedLiteral.M_EXTEND)
    def post(tlist, pidx, tidx, nidx):
        return tidx, nidx
    _group(tlist, sql.TypedLiteral, match, valid_prev, valid_next,
           post, extend=False)
    _group(tlist, sql.TypedLiteral, match_to_extend, valid_prev, valid_final,
           post, extend=True)
def group_period(tlist):
    def match(token):
        return token.match(T.Punctuation, '.')
    def valid_prev(token):
        sqlcls = sql.SquareBrackets, sql.Identifier
        ttypes = T.Name, T.String.Symbol
        return imt(token, i=sqlcls, t=ttypes)
    def valid_next(token):
        # issue261, allow invalid next token
        return True
    def post(tlist, pidx, tidx, nidx):
        # next_ validation is being performed here. issue261
        sqlcls = sql.SquareBrackets, sql.Function
        ttypes = T.Name, T.String.Symbol, T.Wildcard
        next_ = tlist[nidx] if nidx is not None else None
        valid_next = imt(next_, i=sqlcls, t=ttypes)
        return (pidx, nidx) if valid_next else (pidx, tidx)
    _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
def group_as(tlist):
    def match(token):
        return token.is_keyword and token.normalized == 'AS'
    def valid_prev(token):
        return token.normalized == 'NULL' or not token.is_keyword
    def valid_next(token):
        ttypes = T.DML, T.DDL, T.CTE
        return not imt(token, t=ttypes) and token is not None
    def post(tlist, pidx, tidx, nidx):
        return pidx, nidx
    _group(tlist, sql.Identifier, match, valid_prev, valid_next, post)
def group_assignment(tlist):
    def match(token):
        return token.match(T.Assignment, ':=')
    def valid(token):
        return token is not None and token.ttype not in (T.Keyword)
    def post(tlist, pidx, tidx, nidx):
        m_semicolon = T.Punctuation, ';'
        snidx, _ = tlist.token_next_by(m=m_semicolon, idx=nidx)
        nidx = snidx or nidx
        return pidx, nidx
    valid_prev = valid_next = valid
    _group(tlist, sql.Assignment, match, valid_prev, valid_next, post)
def group_comparison(tlist):
    sqlcls = (sql.Parenthesis, sql.Function, sql.Identifier,
              sql.Operation, sql.TypedLiteral)
    ttypes = T_NUMERICAL + T_STRING + T_NAME
    def match(token):
        return token.ttype == T.Operator.Comparison
    def valid(token):
        if imt(token, t=ttypes, i=sqlcls):
            return True
        elif token and token.is_keyword and token.normalized == 'NULL':
            return True
        else:
            return False
    def post(tlist, pidx, tidx, nidx):
        return pidx, nidx
    valid_prev = valid_next = valid
    _group(tlist, sql.Comparison, match,
           valid_prev, valid_next, post, extend=False)
@recurse(sql.Identifier)
def group_identifier(tlist):
    ttypes = (T.String.Symbol, T.Name)
    tidx, token = tlist.token_next_by(t=ttypes)
    while token:
        tlist.group_tokens(sql.Identifier, tidx, tidx)
        tidx, token = tlist.token_next_by(t=ttypes, idx=tidx)
def group_arrays(tlist):
    sqlcls = sql.SquareBrackets, sql.Identifier, sql.Function
    ttypes = T.Name, T.String.Symbol
    def match(token):
        return isinstance(token, sql.SquareBrackets)
    def valid_prev(token):
        return imt(token, i=sqlcls, t=ttypes)
    def valid_next(token):
        return True
    def post(tlist, pidx, tidx, nidx):
        return pidx, tidx
    _group(tlist, sql.Identifier, match,
           valid_prev, valid_next, post, extend=True, recurse=False)
def group_operator(tlist):
    ttypes = T_NUMERICAL + T_STRING + T_NAME
    sqlcls = (sql.SquareBrackets, sql.Parenthesis, sql.Function,
              sql.Identifier, sql.Operation, sql.TypedLiteral)
    def match(token):
        return imt(token, t=(T.Operator, T.Wildcard))
    def valid(token):
        return imt(token, i=sqlcls, t=ttypes) \
            or (token and token.match(
                T.Keyword,
                ('CURRENT_DATE', 'CURRENT_TIME', 'CURRENT_TIMESTAMP')))
    def post(tlist, pidx, tidx, nidx):
        tlist[tidx].ttype = T.Operator
        return pidx, nidx
    valid_prev = valid_next = valid
    _group(tlist, sql.Operation, match,
           valid_prev, valid_next, post, extend=False)
def group_identifier_list(tlist):
    m_role = T.Keyword, ('null', 'role')
    sqlcls = (sql.Function, sql.Case, sql.Identifier, sql.Comparison,
              sql.IdentifierList, sql.Operation)
    ttypes = (T_NUMERICAL + T_STRING + T_NAME
              + (T.Keyword, T.Comment, T.Wildcard))
    def match(token):
        return token.match(T.Punctuation, ',')
    def valid(token):
        return imt(token, i=sqlcls, m=m_role, t=ttypes)
    def post(tlist, pidx, tidx, nidx):
        return pidx, nidx
    valid_prev = valid_next = valid
    _group(tlist, sql.IdentifierList, match,
           valid_prev, valid_next, post, extend=True)
@recurse(sql.Comment)
def group_comments(tlist):
    tidx, token = tlist.token_next_by(t=T.Comment)
    while token:
        eidx, end = tlist.token_not_matching(
            lambda tk: imt(tk, t=T.Comment) or tk.is_whitespace, idx=tidx)
        if end is not None:
            eidx, end = tlist.token_prev(eidx, skip_ws=False)
            tlist.group_tokens(sql.Comment, tidx, eidx)
        tidx, token = tlist.token_next_by(t=T.Comment, idx=tidx)
@recurse(sql.Where)
def group_where(tlist):
    tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN)
    while token:
        eidx, end = tlist.token_next_by(m=sql.Where.M_CLOSE, idx=tidx)
        if end is None:
            end = tlist._groupable_tokens[-1]
        else:
            end = tlist.tokens[eidx - 1]
        # TODO: convert this to eidx instead of end token.
        # i think above values are len(tlist) and eidx-1
        eidx = tlist.token_index(end)
        tlist.group_tokens(sql.Where, tidx, eidx)
        tidx, token = tlist.token_next_by(m=sql.Where.M_OPEN, idx=tidx)
@recurse()
def group_aliased(tlist):
    I_ALIAS = (sql.Parenthesis, sql.Function, sql.Case, sql.Identifier,
               sql.Operation, sql.Comparison)
    tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number)
    while token:
        nidx, next_ = tlist.token_next(tidx)
        if isinstance(next_, sql.Identifier):
            tlist.group_tokens(sql.Identifier, tidx, nidx, extend=True)
        tidx, token = tlist.token_next_by(i=I_ALIAS, t=T.Number, idx=tidx)
@recurse(sql.Function)
def group_functions(tlist):
    has_create = False
    has_table = False
    for tmp_token in tlist.tokens:
        if tmp_token.value == 'CREATE':
            has_create = True
        if tmp_token.value == 'TABLE':
            has_table = True
    if has_create and has_table:
        return
    tidx, token = tlist.token_next_by(t=T.Name)
    while token:
        nidx, next_ = tlist.token_next(tidx)
        if isinstance(next_, sql.Parenthesis):
            tlist.group_tokens(sql.Function, tidx, nidx)
        tidx, token = tlist.token_next_by(t=T.Name, idx=tidx)
def group_order(tlist):
    """Group together Identifier and Asc/Desc token"""
    tidx, token = tlist.token_next_by(t=T.Keyword.Order)
    while token:
        pidx, prev_ = tlist.token_prev(tidx)
        if imt(prev_, i=sql.Identifier, t=T.Number):
            tlist.group_tokens(sql.Identifier, pidx, tidx)
            tidx = pidx
        tidx, token = tlist.token_next_by(t=T.Keyword.Order, idx=tidx)
@recurse()
def align_comments(tlist):
    tidx, token = tlist.token_next_by(i=sql.Comment)
    while token:
        pidx, prev_ = tlist.token_prev(tidx)
        if isinstance(prev_, sql.TokenList):
            tlist.group_tokens(sql.TokenList, pidx, tidx, extend=True)
            tidx = pidx
        tidx, token = tlist.token_next_by(i=sql.Comment, idx=tidx)
def group_values(tlist):
    tidx, token = tlist.token_next_by(m=(T.Keyword, 'VALUES'))
    start_idx = tidx
    end_idx = -1
    while token:
        if isinstance(token, sql.Parenthesis):
            end_idx = tidx
        tidx, token = tlist.token_next(tidx)
    if end_idx != -1:
        tlist.group_tokens(sql.Values, start_idx, end_idx, extend=True)
def group(stmt):
    for func in [
        group_comments,
        # _group_matching
        group_brackets,
        group_parenthesis,
        group_case,
        group_if,
        group_for,
        group_begin,
        group_functions,
        group_where,
        group_period,
        group_arrays,
        group_identifier,
        group_order,
        group_typecasts,
        group_tzcasts,
        group_typed_literal,
        group_operator,
        group_comparison,
        group_as,
        group_aliased,
        group_assignment,
        align_comments,
        group_identifier_list,
        group_values,
    ]:
        func(stmt)
    return stmt
def _group(tlist, cls, match,
           valid_prev=lambda t: True,
           valid_next=lambda t: True,
           post=None,
           extend=True,
           recurse=True
           ):
    """Groups together tokens that are joined by a middle token. i.e. x < y"""
    tidx_offset = 0
    pidx, prev_ = None, None
    for idx, token in enumerate(list(tlist)):
        tidx = idx - tidx_offset
        if tidx < 0:  # tidx shouldn't get negative
            continue
        if token.is_whitespace:
            continue
        if recurse and token.is_group and not isinstance(token, cls):
            _group(token, cls, match, valid_prev, valid_next, post, extend)
        if match(token):
            nidx, next_ = tlist.token_next(tidx)
            if prev_ and valid_prev(prev_) and valid_next(next_):
                from_idx, to_idx = post(tlist, pidx, tidx, nidx)
                grp = tlist.group_tokens(cls, from_idx, to_idx, extend=extend)
                tidx_offset += to_idx - from_idx
                pidx, prev_ = from_idx, grp
                continue
        pidx, prev_ = tidx, token