diff --git a/src/main/java/com/btk5h/skriptdb/skript/EffExecuteStatement.java b/src/main/java/com/btk5h/skriptdb/skript/EffExecuteStatement.java index 6a74670..9a935b2 100644 --- a/src/main/java/com/btk5h/skriptdb/skript/EffExecuteStatement.java +++ b/src/main/java/com/btk5h/skriptdb/skript/EffExecuteStatement.java @@ -25,6 +25,7 @@ import java.util.*; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; +import java.util.regex.Pattern; /** * Executes a statement on a database and optionally stores the result in a variable. Expressions @@ -45,17 +46,20 @@ import java.util.concurrent.Executors; */ public class EffExecuteStatement extends Effect { private static final ExecutorService threadPool = Executors.newFixedThreadPool(SkriptDB.getInstance().getConfig().getInt("thread-pool-size", 10)); + private static final Pattern ARGUMENT_PLACEHOLDER = Pattern.compile("(? query; private Expression dataSource; + private Expression queryArguments; private VariableString var; private boolean isLocal; private boolean isList; @@ -157,142 +161,144 @@ public class EffExecuteStatement extends Effect { } private Pair> parseQuery(Event e) { - if (!(query instanceof VariableString)) { - return new Pair<>(query.getSingle(e), null); + if (queryArguments != null) { + Object[] args = queryArguments.getArray(e); + String queryString = query.getSingle(e); + int queryArgCount = (int) ARGUMENT_PLACEHOLDER.matcher(queryString).results().count(); + if (queryArgCount != args.length) { + Skript.warning("Your query has %d question marks, but you provided %d arguments."); + args = Arrays.copyOf(args, queryArgCount); + } + return new Pair<>(query.getSingle(e), List.of(args)); + } else if (query instanceof VariableString && !((VariableString) query).isSimple()) { + return parseVariableQuery(e, (VariableString) query); } - VariableString q = (VariableString) query; - if (q.isSimple()) { - return new Pair<>(q.toString(e), null); - } - + return new Pair<>(query.getSingle(e), null); + } + + private Pair> parseVariableQuery(Event e, VariableString varQuery) { StringBuilder sb = new StringBuilder(); - List parameters = new ArrayList<>(); - Object[] objects = SkriptUtil.getTemplateString(q); - + List parameters = new LinkedList<>(); + Object[] objects = SkriptUtil.getTemplateString(varQuery); + for (int i = 0; i < objects.length; i++) { - Object o = objects[i]; - if (o instanceof String) { - sb.append(o); + if (objects[i] instanceof String) { + sb.append(objects[i]); } else { - Expression expr; - if (o instanceof Expression) - expr = (Expression) o; - else - expr = SkriptUtil.getExpressionFromInfo(o); - - String before = getString(objects, i - 1); - String after = getString(objects, i + 1); - boolean standaloneString = false; - - if (before != null && after != null) { - if (before.endsWith("'") && after.endsWith("'")) { - standaloneString = true; - } - } - + Expression expr = objects[i] instanceof Expression ? (Expression) objects[i] : SkriptUtil.getExpressionFromInfo(objects[i]); + boolean standaloneString = isStandaloneString(objects, i); Object expressionValue = expr.getSingle(e); - if (expr instanceof ExprUnsafe) { - sb.append(expressionValue); - - if (standaloneString && expressionValue instanceof String) { - String rawExpression = ((ExprUnsafe) expr).getRawExpression(); - Skript.warning( - String.format("Unsafe may have been used unnecessarily. Try replacing 'unsafe %1$s' with %1$s", - rawExpression)); - } - } else { - parameters.add(expressionValue); - sb.append('?'); - - if (standaloneString) { - Skript.warning("Do not surround expressions with quotes!"); - } + Pair toAppend = parseExpressionQuery(expr, expressionValue, standaloneString); + sb.append(toAppend.getFirst()); + if (toAppend.getSecond() != null) { + parameters.add(toAppend.getSecond()); } } } return new Pair<>(sb.toString(), parameters); } + + private Pair parseExpressionQuery(Expression expr, Object expressionValue, boolean standaloneString) { + if (expr instanceof ExprUnsafe) { + if (standaloneString && expressionValue instanceof String) { + Skript.warning( + String.format("Unsafe may have been used unnecessarily. Try replacing 'unsafe %1$s' with %1$s", + ((ExprUnsafe) expr).getRawExpression())); + } + return new Pair<>((String) expressionValue, null); + } else { + if (standaloneString) { + Skript.warning("Do not surround expressions with quotes!"); + } + return new Pair<>("?", expressionValue); + } + } private Object executeStatement(DataSource ds, String baseVariable, Pair> query) { if (ds == null) { return "Data source is not set"; } - Map variableList = new HashMap<>(); - try (Connection conn = ds.getConnection(); - PreparedStatement stmt = createStatement(conn, query)) { + try (Connection conn = ds.getConnection()) { + try (PreparedStatement stmt = createStatement(conn, query)) { + boolean hasResultSet = stmt.execute(); - boolean hasResultSet = stmt.execute(); - - if (baseVariable != null) { - if (isList) { - baseVariable = baseVariable.substring(0, baseVariable.length() - 1); - } - - if (hasResultSet) { - CachedRowSet crs = SkriptDB.getRowSetFactory().createCachedRowSet(); - crs.populate(stmt.getResultSet()); - - if (isList) { - ResultSetMetaData meta = crs.getMetaData(); - int columnCount = meta.getColumnCount(); - - for (int i = 1; i <= columnCount; i++) { - String label = meta.getColumnLabel(i); - variableList.put(baseVariable + label, label); - } - - int rowNumber = 1; - try { - while (crs.next()) { - for (int i = 1; i <= columnCount; i++) { - variableList.put(baseVariable + meta.getColumnLabel(i).toLowerCase(Locale.ENGLISH) - + Variable.SEPARATOR + rowNumber, crs.getObject(i)); - } - rowNumber++; - } - } catch (SQLException ex) { - return ex.getMessage(); - } - } else { - crs.last(); - variableList.put(baseVariable, crs.getRow()); - } - } else if (!isList) { - //if no results are returned and the specified variable isn't a list variable, put the affected rows count in the variable - variableList.put(baseVariable, stmt.getUpdateCount()); + if (baseVariable != null) { + return processBaseVariable(baseVariable, stmt, hasResultSet); } + return Map.of(); } } catch (SQLException ex) { return ex.getMessage(); } + } + + private Object processBaseVariable(String baseVariable, PreparedStatement stmt, boolean hasResultSet) throws SQLException { + Map variableList = new HashMap<>(); + if (isList) { + baseVariable = baseVariable.substring(0, baseVariable.length() - 1); + } + + if (hasResultSet) { + CachedRowSet crs = SkriptDB.getRowSetFactory().createCachedRowSet(); + crs.populate(stmt.getResultSet()); + + if (isList) { + return fetchQueryResultSet(crs, baseVariable); + } else { + crs.last(); + variableList.put(baseVariable, crs.getRow()); + } + } else if (!isList) { + //if no results are returned and the specified variable isn't a list variable, put the affected rows count in the variable + return Map.of(baseVariable, stmt.getUpdateCount()); + } + return Map.of(); + } + + private Map fetchQueryResultSet(CachedRowSet crs, String baseVariable) throws SQLException { + Map variableList = new HashMap<>(); + ResultSetMetaData meta = crs.getMetaData(); + int columnCount = meta.getColumnCount(); + + for (int i = 1; i <= columnCount; i++) { + String label = meta.getColumnLabel(i); + variableList.put(baseVariable + label, label); + } + + int rowNumber = 1; + while (crs.next()) { + for (int i = 1; i <= columnCount; i++) { + variableList.put(baseVariable + meta.getColumnLabel(i).toLowerCase(Locale.ENGLISH) + + Variable.SEPARATOR + rowNumber, crs.getObject(i)); + } + rowNumber++; + } return variableList; } private PreparedStatement createStatement(Connection conn, Pair> query) throws SQLException { PreparedStatement stmt = conn.prepareStatement(query.getFirst()); - List parameters = query.getSecond(); - - if (parameters != null) { - for (int i = 0; i < parameters.size(); i++) { - stmt.setObject(i + 1, parameters.get(i)); + if (query.getSecond() != null) { + Iterator iter = query.getSecond().iterator(); + for (int i = 1; iter.hasNext(); i++) { + stmt.setObject(i, iter.next()); } } return stmt; } + + private boolean isStandaloneString(Object[] objects, int index) { + String before = getString(objects, index - 1); + String after = getString(objects, index + 1); + return before != null && before.endsWith("'") && after != null && after.endsWith("'"); + } private String getString(Object[] objects, int index) { - if (index < 0 || index >= objects.length) { - return null; + if (index >= 0 && index < objects.length && objects[index] instanceof String) { + return (String) objects[index]; } - - Object object = objects[index]; - - if (object instanceof String) { - return (String) object; - } - return null; } @@ -334,15 +340,21 @@ public class EffExecuteStatement extends Effect { return false; } dataSource = (Expression) exprs[1]; - Expression expr = exprs[2]; + if (exprs[2] != null) { + if (query instanceof VariableString && !((VariableString) query).isSimple()) { + Skript.warning("Your query string contains expresions, but you've also provided query arguments. Consider using `unsafe` keyword before your query."); + } + queryArguments = (Expression) exprs[2]; + } + Expression resultHolder = exprs[3]; quickly = matchedPattern == 1; - if (expr instanceof Variable) { - Variable varExpr = (Variable) expr; + if (resultHolder instanceof Variable) { + Variable varExpr = (Variable) resultHolder; var = varExpr.getName(); isLocal = varExpr.isLocal(); isList = varExpr.isList(); - } else if (expr != null) { - Skript.error(expr + " is not a variable"); + } else if (resultHolder != null) { + Skript.error(resultHolder + " is not a variable"); return false; } return true;