SqlDruidParser.java 2.73 KB
Newer Older
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82
package io.hmit.modules.datasource.db;

import com.alibaba.druid.DbType;
import com.alibaba.druid.sql.SQLUtils;
import com.alibaba.druid.sql.ast.SQLStatement;
import com.alibaba.druid.sql.dialect.mysql.visitor.MySqlSchemaStatVisitor;
import com.alibaba.druid.sql.dialect.oracle.visitor.OracleSchemaStatVisitor;
import com.alibaba.druid.sql.visitor.SchemaStatVisitor;
import com.alibaba.druid.stat.TableStat;
import lombok.extern.slf4j.Slf4j;

import java.util.*;

/**
 * <h1>Sql解析器</h1>
 * @author Shen && syf0412@vip.qq.com
 * @since 2022/8/2 15:37
 **/
@Slf4j
public class SqlDruidParser {

    public static Map<String,Object> sqlParser(DbType dbType, String sql) {
        Map<String, Object> resultMap = new HashMap<>();
        List<SQLStatement> stmtList = null;
        try {
            String result = SQLUtils.format(sql, dbType);
            String[] sqlParam = result.split(";");
            List<String> sqlList = Arrays.asList(sqlParam);
            resultMap.put("executeSql", sqlList);
            stmtList = SQLUtils.parseStatements(sql, dbType);
        }catch(Exception e){
            log.error("SQL解析异常:{}", sql);
            resultMap.put("executeType", null);
            return resultMap;
        }
        SQLStatement stmt = stmtList.get(0);
        SchemaStatVisitor visitor = null;
        if (Objects.equals(dbType, DbType.oracle)) {
            visitor = new OracleSchemaStatVisitor();
        } else if (Objects.equals(dbType, DbType.mysql)) {
            visitor = new MySqlSchemaStatVisitor();
        }
        stmt.accept(visitor);
        Object method = getFirstOrNull(visitor.getTables());
        resultMap.put("tables", visitor.getTables());
        resultMap.put("fields", visitor.getColumns());
        resultMap.put("tableName", getTable(visitor.getTables()));
        resultMap.put("sqlContent", stmt);
        if (method != null) {
            resultMap.put("executeType", method.toString().toUpperCase());
        } else {
            resultMap.put("executeType", "SELECT");
        }
        return resultMap;
    }

    private static Object getFirstOrNull(Map<TableStat.Name, TableStat> map) {
        Object obj = null;
        for (Map.Entry<TableStat.Name, TableStat> entry : map.entrySet()) {
            obj = entry.getValue();
            if (obj != null) {
                break;
            }
        }
        return  obj;
    }

    private static Object getTable(Map<TableStat.Name, TableStat> map) {
        String tableName = null;
        for (Map.Entry<TableStat.Name, TableStat> entry : map.entrySet()) {
            TableStat.Name keys = entry.getKey();
            tableName =  keys.getName();
            if (tableName != null) {
                break;
            }
        }
        return tableName;
    }



}