package util.sqlparse.visitor.gauss.visitor;

import com.chenyang.druid.sql.ast.SQLExpr;
import com.chenyang.druid.sql.ast.SQLLimit;
import com.chenyang.druid.sql.ast.SQLObject;
import com.chenyang.druid.sql.ast.SQLStatement;
import com.chenyang.druid.sql.ast.expr.SQLBinaryOpExpr;
import com.chenyang.druid.sql.ast.expr.SQLBinaryOperator;
import com.chenyang.druid.sql.ast.expr.SQLCharExpr;
import com.chenyang.druid.sql.ast.expr.SQLIdentifierExpr;
import com.chenyang.druid.sql.ast.expr.SQLIntegerExpr;
import com.chenyang.druid.sql.ast.expr.SQLPropertyExpr;
import com.chenyang.druid.sql.ast.statement.SQLDeleteStatement;
import com.chenyang.druid.sql.ast.statement.SQLExprTableSource;
import com.chenyang.druid.sql.ast.statement.SQLInsertStatement;
import com.chenyang.druid.sql.ast.statement.SQLJoinTableSource;
import com.chenyang.druid.sql.ast.statement.SQLReplaceStatement;
import com.chenyang.druid.sql.ast.statement.SQLSelect;
import com.chenyang.druid.sql.ast.statement.SQLSelectQuery;
import com.chenyang.druid.sql.ast.statement.SQLSelectQueryBlock;
import com.chenyang.druid.sql.ast.statement.SQLSelectStatement;
import com.chenyang.druid.sql.ast.statement.SQLTableSource;
import com.chenyang.druid.sql.ast.statement.SQLUnionQuery;
import com.chenyang.druid.sql.ast.statement.SQLUpdateStatement;
import com.chenyang.druid.sql.dialect.gauss.ast.expr.tablesource.GaussExprTableSource;
import com.chenyang.druid.sql.dialect.gauss.ast.stmt.GaussSelectQueryBlock;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.regex.Pattern;
import util.sqlparse.visitor.common.bean.SQLResult;
import util.sqlparse.visitor.common.bean.StatementType;
import util.sqlparse.visitor.common.bean.TableInfo;
import util.sqlparse.visitor.common.memo.TableMemo;
import util.sqlparse.visitor.gauss.GaussNameWrapper;

public class RowVisitController {
   private final String tableRegex;
   private final String column;
   private final String value;
   private final SQLResult result;
   private final Map params;
   private String limit;

   public RowVisitController(SQLResult result, Map params) {
      this.result = result;
      this.params = params;
      Map<String, List<String>> modifyTable = (Map)params.get("modifyTable");
      Map<String, String> tableKeyWord = (Map)params.get("tableKeyWord");
      this.limit = (String)params.get("limit");
      this.tableRegex = (String)tableKeyWord.keySet().iterator().next();
      this.column = (String)tableKeyWord.values().iterator().next();
      this.value = (String)((List)modifyTable.values().iterator().next()).get(0);
   }

   public void perform() {
      this.limitStatement();
      if (this.checkRowVisit()) {
         SQLStatement statement = this.result.statement;
         String tempTableRegex = GaussNameWrapper.normalize(this.tableRegex);
         Pattern pattern = Pattern.compile("\"?" + tempTableRegex + "\"?");
         List<TableInfo> tables = this.result.tables;
         Set<SQLObject> parsed = new HashSet();

         for(TableInfo table : tables) {
            String name = GaussNameWrapper.normalize(table.getTable().name);
            if (pattern.matcher(name).matches()) {
               for(TableMemo memo : table.getMemos()) {
                  SQLObject ref = memo.ref;
                  if (!parsed.contains(ref) && ref instanceof SQLTableSource) {
                     this.params.put("isMatched", "true");
                     switch (this.result.statementType) {
                        case select:
                           this.select(ref, table);
                           continue;
                        case insert:
                           this.insert(ref, table);
                           break;
                        case update:
                           this.update(ref, table);
                           break;
                        case replace:
                           this.replace(ref, table);
                           break;
                        case delete:
                           this.delete(ref, table);
                           break;
                        default:
                           this.params.put("isMatched", "false");
                     }

                     parsed.add(ref);
                  }
               }
            }
         }

         this.params.put("newSql", statement.toString());
      }
   }

   private void limitStatement() {
      if (this.limit != null && this.limit.trim().length() != 0) {
         if (this.result.statement instanceof SQLSelectStatement) {
            SQLSelectStatement statement = (SQLSelectStatement)this.result.statement;
            this.setLimit(statement.getSelect(), this.limit);
         }
      }
   }

   private void setLimit(SQLSelect select, String limit) {
      if (select != null && limit != null) {
         SQLLimit limits = select.getLimit();
         if (limits == null) {
            SQLSelectQuery query = select.getQuery();
            if (query instanceof SQLSelectQueryBlock) {
               SQLSelectQueryBlock block = (SQLSelectQueryBlock)query;
               limits = block.getLimit();
            }

            if (limits == null) {
               limits = new SQLLimit();
               select.setLimit(limits);
            }
         }

         SQLExpr rowCount = limits.getRowCount();
         if (rowCount == null) {
            limits.setRowCount(Integer.parseInt(limit));
         } else if (rowCount instanceof SQLIntegerExpr) {
            SQLIntegerExpr c = (SQLIntegerExpr)rowCount;
            int row = Math.min(Integer.parseInt(limit), (Integer)c.getValue());
            limits.setRowCount(row);
         }

      }
   }

   private void select(SQLObject ref, TableInfo table) {
      SQLExprTableSource tableSource = (SQLExprTableSource)ref;
      String alias = tableSource.getAlias();
      if (alias == null) {
         alias = table.getTable().name;
      }

      SQLObject body = this.getStmtBody(StatementType.select, tableSource);
      if (body != null) {
         SQLSelect select = (SQLSelect)body;
         this.replaceSelect(select, alias);
      }

   }

   private void insert(SQLObject ref, TableInfo table) {
      SQLExprTableSource tableSource = (SQLExprTableSource)ref;
      String alias = tableSource.getAlias();
      if (alias == null) {
         alias = table.getTable().name;
      }

      SQLObject body = this.getStmtBody(StatementType.insert, tableSource);
      if (body != null) {
         if (body instanceof SQLSelect) {
            SQLSelect select = (SQLSelect)body;
            this.replaceSelect(select, alias);
         }

      }
   }

   private void update(SQLObject ref, TableInfo table) {
      SQLExprTableSource tableSource = (SQLExprTableSource)ref;
      String alias = tableSource.getAlias();
      if (alias == null) {
         alias = table.getTable().name;
      }

      SQLObject body = this.getStmtBody(StatementType.update, tableSource);
      if (body != null) {
         if (body instanceof SQLSelect) {
            SQLSelect select = (SQLSelect)body;
            this.replaceSelect(select, alias);
         } else if (body instanceof SQLUpdateStatement) {
            SQLUpdateStatement updateStmt = (SQLUpdateStatement)body;
            SQLExpr expr = this.createBinaryExpr(alias);
            updateStmt.addWhere(expr);
         }

      }
   }

   private void replace(SQLObject ref, TableInfo table) {
      SQLExprTableSource tableSource = (SQLExprTableSource)ref;
      String alias = tableSource.getAlias();
      if (alias == null) {
         alias = table.getTable().name;
      }

      SQLObject body = this.getStmtBody(StatementType.replace, tableSource);
      if (body != null) {
         if (body instanceof SQLSelect) {
            SQLSelect select = (SQLSelect)body;
            this.replaceSelect(select, alias);
         }

      }
   }

   private void delete(SQLObject ref, TableInfo table) {
      SQLExprTableSource tableSource = (SQLExprTableSource)ref;
      String alias = tableSource.getAlias();
      if (alias == null) {
         alias = table.getTable().name;
      }

      SQLObject body = this.getStmtBody(StatementType.delete, tableSource);
      if (body != null) {
         if (body instanceof SQLSelect) {
            SQLSelect select = (SQLSelect)body;
            this.replaceSelect(select, alias);
         } else if (body instanceof SQLDeleteStatement) {
            SQLDeleteStatement stmt = (SQLDeleteStatement)body;
            SQLExpr expr = this.createBinaryExpr(alias);
            stmt.addWhere(expr);
         }

      }
   }

   private void replaceSelect(SQLSelect select, String alias) {
      SQLExpr expr = this.createBinaryExpr(alias);
      SQLSelectQuery query = select.getQuery();
      if (query instanceof SQLUnionQuery) {
         for(SQLSelectQuery relation : ((SQLUnionQuery)query).getRelations()) {
            if (relation instanceof GaussSelectQueryBlock) {
               SQLTableSource from = ((GaussSelectQueryBlock)relation).getFrom();
               if (this.tableMatch(from, alias)) {
                  ((GaussSelectQueryBlock)relation).addWhere(expr);
               }
            }
         }
      } else {
         select.addWhere(expr);
      }

   }

   private boolean tableMatch(SQLTableSource tableSource, String alias) {
      if (!(tableSource instanceof SQLJoinTableSource)) {
         if (!(tableSource instanceof GaussExprTableSource)) {
            return false;
         } else {
            GaussExprTableSource gpTablSource = (GaussExprTableSource)tableSource;
            String unionAlias = gpTablSource.getAlias();
            if (unionAlias != null) {
               return unionAlias.equals(alias);
            } else {
               SQLExpr unionExpr = gpTablSource.getExpr();
               if (unionExpr instanceof SQLIdentifierExpr) {
                  String unionTableName = ((SQLIdentifierExpr)unionExpr).getName();
                  return unionTableName.equals(alias);
               } else {
                  return false;
               }
            }
         }
      } else {
         SQLJoinTableSource joinTableSource = (SQLJoinTableSource)tableSource;
         SQLTableSource left = joinTableSource.getLeft();
         SQLTableSource right = joinTableSource.getRight();
         return this.tableMatch(left, alias) || this.tableMatch(right, alias);
      }
   }

   private SQLExpr createBinaryExpr(String alias) {
      SQLPropertyExpr left = new SQLPropertyExpr();
      left.setName(this.column);
      left.setOwner(alias);
      SQLCharExpr right = new SQLCharExpr();
      right.setText(this.value);
      SQLBinaryOpExpr expr = new SQLBinaryOpExpr();
      expr.setOperator(SQLBinaryOperator.LessThanOrGreater);
      expr.setLeft(left);
      expr.setRight(right);
      return expr;
   }

   private SQLObject getStmtBody(StatementType stmtType, SQLTableSource tableSource) {
      for(SQLObject cursor = tableSource.getParent(); cursor != null; cursor = cursor.getParent()) {
         if (cursor instanceof SQLSelect) {
            return cursor;
         }

         if (stmtType == StatementType.insert && cursor instanceof SQLInsertStatement) {
            return cursor;
         }

         if (stmtType == StatementType.update && cursor instanceof SQLUpdateStatement) {
            return cursor;
         }

         if (stmtType == StatementType.replace && cursor instanceof SQLReplaceStatement) {
            return cursor;
         }

         if (stmtType == StatementType.delete && cursor instanceof SQLDeleteStatement) {
            return cursor;
         }
      }

      return null;
   }

   private boolean checkRowVisit() {
      String operateType = (String)this.params.get("operateType");
      switch (this.result.statementType) {
         case select:
            if (!"select".equalsIgnoreCase(operateType)) {
               return false;
            }
            break;
         case insert:
            if (!"insert".equalsIgnoreCase(operateType)) {
               return false;
            }
            break;
         case update:
            if (!"update".equalsIgnoreCase(operateType)) {
               return false;
            }
            break;
         case replace:
            if (!"replace".equalsIgnoreCase(operateType)) {
               return false;
            }
            break;
         case delete:
            if (!"delete".equalsIgnoreCase(operateType)) {
               return false;
            }
            break;
         default:
            return false;
      }

      return true;
   }
}
