package com.palacesun.engine.wrapper;

import java.util.HashSet;
import java.util.Iterator;
import java.util.LinkedHashSet;
import java.util.Set;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class MysqlSchemaExtractor {
    private static final Pattern FROM_PATTERN = Pattern.compile("\\bFROM\\b", 2);
    private static final Pattern JOIN_PATTERN = Pattern.compile("\\bJOIN\\s+([^\\s,;)+]+)", 2);
    private static final Pattern INSERT_PATTERN = Pattern.compile("\\bINSERT\\s+(?:\\w+\\s+)*INTO\\s+([^\\s\\(]+)", 2);
    private static final Pattern UPDATE_PATTERN = Pattern.compile("\\bUPDATE\\s+([^\\s]+)", 2);
    private static final Pattern DELETE_PATTERN = Pattern.compile("\\bDELETE\\s+(?:\\w+\\s+)*FROM\\s+([^\\s]+)", 2);
    private static final Pattern CREATE_TABLE_PATTERN = Pattern.compile("\\bCREATE\\s+TABLE\\s+(?:IF\\s+NOT\\s+EXISTS\\s+)?([^\\s\\(]+)", 2);
    private static final Pattern ALTER_TABLE_PATTERN = Pattern.compile("\\bALTER\\s+TABLE\\s+([^\\s]+)", 2);
    private static final Pattern DROP_TABLE_PATTERN = Pattern.compile("\\bDROP\\s+TABLE\\s+(?:IF\\s+EXISTS\\s+)?([^\\s]+)", 2);
    private static final Pattern SELECT_PATTERN = Pattern.compile("\\bSELECT\\b", 2);
    private static final Pattern UNION_PATTERN = Pattern.compile("\\bUNION\\b", 2);
    private static final Pattern ALIAS_PATTERN = Pattern.compile("^(.*?)\\s+AS\\s+\\w+$", 2);

    public MysqlSchemaExtractor() {
    }

    public static Set<String> extractSchemasFromSql(String sql) {
        String processedSql = preprocessSql(sql);
        Set<String> tableRefs = new LinkedHashSet();
        if (isSelectStatement(processedSql) || containsUnion(processedSql)) {
            extractFromSelect(processedSql, tableRefs);
        }

        if (isInsertStatement(processedSql)) {
            extractFromInsert(processedSql, tableRefs);
        }

        if (isUpdateStatement(processedSql)) {
            extractFromUpdate(processedSql, tableRefs);
        }

        if (isDeleteStatement(processedSql)) {
            extractFromDelete(processedSql, tableRefs);
        }

        if (isCreateTableStatement(processedSql)) {
            extractFromCreateTable(processedSql, tableRefs);
        }

        if (isAlterTableStatement(processedSql)) {
            extractFromAlterTable(processedSql, tableRefs);
        }

        if (isDropTableStatement(processedSql)) {
            extractFromDropTable(processedSql, tableRefs);
        }

        return extractSchemasFromTableRefs(tableRefs);
    }

    private static String preprocessSql(String sql) {
        String result = sql;
        result = result.replaceAll("--.*", " ");
        result = result.replaceAll("#.*", " ");
        result = result.replaceAll("/\\*.*?\\*/", " ");
        result = result.replaceAll("\\s+", " ");
        return result.trim();
    }

    private static boolean isSelectStatement(String sql) {
        return SELECT_PATTERN.matcher(sql).find();
    }

    private static boolean containsUnion(String sql) {
        return UNION_PATTERN.matcher(sql).find();
    }

    private static boolean isInsertStatement(String sql) {
        return Pattern.compile("^\\s*INSERT\\s", 2).matcher(sql).find();
    }

    private static boolean isUpdateStatement(String sql) {
        return Pattern.compile("^\\s*UPDATE\\s", 2).matcher(sql).find();
    }

    private static boolean isDeleteStatement(String sql) {
        return Pattern.compile("^\\s*DELETE\\s", 2).matcher(sql).find();
    }

    private static boolean isCreateTableStatement(String sql) {
        return Pattern.compile("^\\s*CREATE\\s+TABLE\\s", 2).matcher(sql).find();
    }

    private static boolean isAlterTableStatement(String sql) {
        return Pattern.compile("^\\s*ALTER\\s+TABLE\\s", 2).matcher(sql).find();
    }

    private static boolean isDropTableStatement(String sql) {
        return Pattern.compile("^\\s*DROP\\s+TABLE\\s", 2).matcher(sql).find();
    }

    private static void extractFromSelect(String sql, Set<String> tableRefs) {
        extractFromClauseSimple(sql, tableRefs);
        Matcher joinMatcher = JOIN_PATTERN.matcher(sql);

        while(joinMatcher.find()) {
            String tableRef = joinMatcher.group(1).trim();
            if (!tableRef.isEmpty()) {
                tableRefs.add(tableRef);
            }
        }

        extractSubqueriesSimple(sql, tableRefs);
    }

    private static void extractFromClauseSimple(String sql, Set<String> tableRefs) {
        Matcher fromMatcher = FROM_PATTERN.matcher(sql);

        while(fromMatcher.find()) {
            int fromStart = fromMatcher.end();
            String afterFrom = sql.substring(fromStart);
            int endPos = findEndOfFromClause(afterFrom);
            String fromClause = afterFrom.substring(0, endPos).trim();
            extractTablesFromClauseSimple(fromClause, tableRefs);
        }

    }

    private static int findEndOfFromClause(String sqlFragment) {
        String[] endKeywords = new String[]{"WHERE", "GROUP", "HAVING", "ORDER", "LIMIT", "UNION", ")"};
        int minPos = sqlFragment.length();
        String[] arr$ = endKeywords;
        int len$ = arr$.length;

        for(int i$ = 0; i$ < len$; ++i$) {
            String keyword = arr$[i$];
            int pos = findKeywordPosition(sqlFragment, keyword);
            if (pos >= 0 && pos < minPos) {
                minPos = pos;
            }
        }

        return minPos;
    }

    private static int findKeywordPosition(String text, String keyword) {
        String lowerText = text.toLowerCase();
        String lowerKeyword = keyword.toLowerCase();
        return lowerText.indexOf(lowerKeyword);
    }

    private static void extractTablesFromClauseSimple(String clause, Set<String> tableRefs) {
        String[] parts = clause.split(",");
        String[] arr$ = parts;
        int len$ = arr$.length;

        for(int i$ = 0; i$ < len$; ++i$) {
            String part = arr$[i$];
            String tableRef = part.trim();
            tableRef = removeAliasSimple(tableRef);
            if (!tableRef.isEmpty()) {
                tableRefs.add(tableRef);
            }
        }

    }

    private static String removeAliasSimple(String tableRef) {
        int spacePos = tableRef.indexOf(32);
        if (spacePos > 0) {
            return tableRef.substring(0, spacePos).trim();
        } else {
            String lowerTableRef = tableRef.toLowerCase();
            int asPos = lowerTableRef.indexOf(" as ");
            return asPos > 0 ? tableRef.substring(0, asPos).trim() : tableRef;
        }
    }

    private static void extractSubqueriesSimple(String sql, Set<String> tableRefs) {
        int closeParen;
        for(int startPos = 0; startPos < sql.length(); startPos = closeParen + 1) {
            int openParen = sql.indexOf(40, startPos);
            if (openParen == -1) {
                break;
            }

            closeParen = findMatchingParen(sql, openParen);
            if (closeParen == -1) {
                break;
            }

            String subquery = sql.substring(openParen + 1, closeParen).trim();
            if (isSelectStatement(subquery)) {
                extractFromClauseSimple(subquery, tableRefs);
                Matcher joinMatcher = JOIN_PATTERN.matcher(subquery);

                while(joinMatcher.find()) {
                    String tableRef = joinMatcher.group(1).trim();
                    if (!tableRef.isEmpty()) {
                        tableRefs.add(tableRef);
                    }
                }
            }
        }

    }

    private static int findMatchingParen(String sql, int openParen) {
        int depth = 1;

        for(int i = openParen + 1; i < sql.length(); ++i) {
            char c = sql.charAt(i);
            if (c == '(') {
                ++depth;
            } else if (c == ')') {
                --depth;
                if (depth == 0) {
                    return i;
                }
            }
        }

        return -1;
    }

    private static void extractFromInsert(String sql, Set<String> tableRefs) {
        Matcher insertMatcher = INSERT_PATTERN.matcher(sql);
        if (insertMatcher.find()) {
            tableRefs.add(insertMatcher.group(1));
        }

        if (SELECT_PATTERN.matcher(sql).find()) {
            int selectStart = findSelectStart(sql);
            if (selectStart >= 0) {
                extractFromSelect(sql.substring(selectStart), tableRefs);
            }
        }

    }

    private static void extractFromUpdate(String sql, Set<String> tableRefs) {
        Matcher updateMatcher = UPDATE_PATTERN.matcher(sql);
        if (updateMatcher.find()) {
            tableRefs.add(updateMatcher.group(1));
        }

        Matcher joinMatcher = JOIN_PATTERN.matcher(sql);

        while(joinMatcher.find()) {
            tableRefs.add(joinMatcher.group(1));
        }

    }

    private static void extractFromDelete(String sql, Set<String> tableRefs) {
        Matcher deleteFromMatcher = DELETE_PATTERN.matcher(sql);
        if (deleteFromMatcher.find()) {
            tableRefs.add(deleteFromMatcher.group(1));
        }

        Matcher joinMatcher = JOIN_PATTERN.matcher(sql);

        while(joinMatcher.find()) {
            tableRefs.add(joinMatcher.group(1));
        }

    }

    private static void extractFromCreateTable(String sql, Set<String> tableRefs) {
        Matcher createMatcher = CREATE_TABLE_PATTERN.matcher(sql);
        if (createMatcher.find()) {
            tableRefs.add(createMatcher.group(1));
        }

    }

    private static void extractFromAlterTable(String sql, Set<String> tableRefs) {
        Matcher alterMatcher = ALTER_TABLE_PATTERN.matcher(sql);
        if (alterMatcher.find()) {
            tableRefs.add(alterMatcher.group(1));
        }

    }

    private static void extractFromDropTable(String sql, Set<String> tableRefs) {
        Matcher dropMatcher = DROP_TABLE_PATTERN.matcher(sql);
        if (dropMatcher.find()) {
            tableRefs.add(dropMatcher.group(1));
        }

    }

    private static int findSelectStart(String sql) {
        Matcher selectMatcher = SELECT_PATTERN.matcher(sql);
        return selectMatcher.find() ? selectMatcher.start() : -1;
    }

    private static Set<String> extractSchemasFromTableRefs(Set<String> tableRefs) {
        Set<String> schemas = new HashSet();
        Iterator i$ = tableRefs.iterator();

        while(i$.hasNext()) {
            String ref = (String)i$.next();
            if (ref.contains(".")) {
                String[] parts = ref.split("\\.", 2);
                String schema = parts[0].trim();
                schemas.add(schema);
            }
        }

        return schemas;
    }

    public static String handleSchema(String sqlAll) {
        Set<String> schemasFromSql = extractSchemasFromSql(sqlAll);
        Iterator i$ = schemasFromSql.iterator();

        String schemaFromSql;
        while(i$.hasNext()) {
            schemaFromSql = (String)i$.next();
            if (schemaFromSql.startsWith("`") && schemaFromSql.endsWith("`")) {
                String normalSchema = schemaFromSql.replaceAll("^`+|`+$", "");
                sqlAll = sqlAll.replaceAll(schemaFromSql, normalSchema);
                schemasFromSql.add(normalSchema);
            }
        }

        i$ = schemasFromSql.iterator();

        while(true) {
            do {
                if (!i$.hasNext()) {
                    return sqlAll;
                }

                schemaFromSql = (String)i$.next();
            } while(schemaFromSql.startsWith("`") && schemaFromSql.endsWith("`"));

            sqlAll = sqlAll.replaceAll(schemaFromSql, "`" + schemaFromSql + "`");
        }
    }
}