Skip to content

Commit

Permalink
feat(open-mysql-db): pandas support (#3868)
Browse files Browse the repository at this point in the history
* feat(open-mysql-db): refactor

1. remove unnecessary instance var port
2. fix cause null bug
3. remove unnecessary throws
4. fix ctx.close() sequence bug
5. config sessionTimeout and requestTimeout
6. add docs of SqlEngine

* feat(open-mysql-db): refactor

* feat(open-mysql-db): revert passsword

* feat(open-mysql-db): mock commit and schema table count

* feat(open-mysql-db): replace data type text with string

* feat(open-mysql-db): remove null

---------

Co-authored-by: yangwucheng <[email protected]>
  • Loading branch information
yangwucheng and yangwucheng authored Jul 1, 2024
1 parent 1c1e213 commit b7e592c
Show file tree
Hide file tree
Showing 3 changed files with 95 additions and 2 deletions.
33 changes: 33 additions & 0 deletions extensions/open-mysql-db/python-testcases/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pandas as pd
from sqlalchemy import create_engine

if __name__ == '__main__':
# Create a Pandas DataFrame (replace this with your actual data)
data = {'id': [1, 2, 3],
'name': ['Alice', 'Bob', 'Charlie'],
'age': [25, 30, 35],
'score': [1.1, 2.2, 3.3],
'ts': [pd.Timestamp.utcnow().timestamp(), pd.Timestamp.utcnow().timestamp(),
pd.Timestamp.utcnow().timestamp()],
'dt': [pd.to_datetime('20240101', format='%Y%m%d'), pd.to_datetime('20240201', format='%Y%m%d'),
pd.to_datetime('20240301', format='%Y%m%d')],
}
df = pd.DataFrame(data)

# Create a MySQL database engine using SQLAlchemy
engine = create_engine('mysql+pymysql://root:[email protected]:3307/demo_db')

# Replace 'username', 'password', 'host', and 'db_name' with your actual database credentials

# Define the name of the table in the database where you want to write the data
table_name = 'demo_table1'

# Write the DataFrame 'df' into the MySQL table
df.to_sql(table_name, engine, if_exists='replace', index=False)

# 'if_exists' parameter options:
# - 'fail': If the table already exists, an error will be raised.
# - 'replace': If the table already exists, it will be replaced.
# - 'append': If the table already exists, data will be appended to it.

print("Data written to MySQL table successfully!")
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ private void handleQuery(
&& !queryStringWithoutComment.startsWith("set @@execute_mode=")) {
// ignore SET command
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
} else if (queryStringWithoutComment.equalsIgnoreCase("COMMIT")) {
// ignore COMMIT command
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
} else if (useDbMatcher.matches()) {
sqlEngine.useDatabase(getConnectionId(ctx), useDbMatcher.group(1));
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com._4paradigm.openmldb.jdbc.SQLResultSet;
import com._4paradigm.openmldb.mysql.mock.MockResult;
import com._4paradigm.openmldb.mysql.util.TypeUtil;
import com._4paradigm.openmldb.proto.NS;
import com._4paradigm.openmldb.sdk.Column;
import com._4paradigm.openmldb.sdk.Schema;
import com._4paradigm.openmldb.sdk.SdkOption;
Expand Down Expand Up @@ -56,6 +57,12 @@ public class OpenmldbMysqlServer {
Pattern.compile(
"(?i)SELECT COUNT\\(\\*\\) FROM information_schema\\.TABLES WHERE TABLE_SCHEMA = '(.+)'");

// SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'xzs' AND table_name =
// 't_exam_paper'
private final Pattern selectCountSchemaTablesPattern =
Pattern.compile(
"(?i)SELECT COUNT\\(\\*\\) FROM information_schema\\.TABLES WHERE TABLE_SCHEMA = '(.+)' AND table_name = '(.+)'");

// SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = 'xzs'
private final Pattern selectCountColumnsPattern =
Pattern.compile(
Expand Down Expand Up @@ -182,6 +189,10 @@ public void query(
return;
}

if (mockSelectSchemaTableCount(connectionId, resultSetWriter, sql)) {
return;
}

// This mock must execute before mockPatternQuery
// SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = 'demo_db'
// UNION SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA =
Expand Down Expand Up @@ -240,6 +251,7 @@ public void query(
return;
}

String originalSql = sql;
if (sql.startsWith("SHOW FULL TABLES")) {
// SHOW FULL TABLES WHERE Table_type != 'VIEW'
Matcher showTablesFromDbMatcher = showTablesFromDbPattern.matcher(sql);
Expand All @@ -250,6 +262,14 @@ public void query(
} else {
sql = "SHOW TABLES";
}
} else if (sql.matches("(?i)(?s)^\\s*CREATE TABLE.*$")) {
// convert data type TEXT to STRING
sql = sql.replaceAll("(?i) TEXT", " STRING");
// sql = sql.replaceAll("(?i) DATETIME", " DATE");
if (!sql.toLowerCase().contains(" not null")
&& sql.toLowerCase().contains(" null")) {
sql = sql.replaceAll("(?i) null", "");
}
} else {
Matcher crateDatabaseMatcher = createDatabasePattern.matcher(sql);
Matcher selectLimitMatcher = selectLimitPattern.matcher(sql);
Expand All @@ -264,7 +284,7 @@ public void query(

if (sql.toLowerCase().startsWith("select") || sql.toLowerCase().startsWith("show")) {
SQLResultSet resultSet = (SQLResultSet) stmt.getResultSet();
outputResultSet(resultSetWriter, resultSet, sql);
outputResultSet(resultSetWriter, resultSet, originalSql);
}

System.out.println("Success to execute OpenMLDB SQL: " + sql);
Expand Down Expand Up @@ -622,6 +642,36 @@ private boolean mockPatternQuery(ResultSetWriter resultSetWriter, String sql) {
return false;
}

private boolean mockSelectSchemaTableCount(
int connectionId, ResultSetWriter resultSetWriter, String sql) throws SQLException {
// SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'xzs' AND
// table_name = 't_exam_paper'
Matcher selectCountSchemaTablesMatcher = selectCountSchemaTablesPattern.matcher(sql);
if (selectCountSchemaTablesMatcher.matches()) {
// COUNT(*)
List<QueryResultColumn> columns = new ArrayList<>();
columns.add(new QueryResultColumn("COUNT(*)", "VARCHAR(255)"));
resultSetWriter.writeColumns(columns);

List<String> row;
String dbName = selectCountSchemaTablesMatcher.group(1);
String tableName = selectCountSchemaTablesMatcher.group(2);
row = new ArrayList<>();
NS.TableInfo tableInfo =
sqlClusterExecutorMap.get(connectionId).getTableInfo(dbName, tableName);
if (tableInfo == null || tableInfo.getName().equals("")) {
row.add("0");
} else {
row.add("1");
}
resultSetWriter.writeRow(row);

resultSetWriter.finish();
return true;
}
return false;
}

private boolean mockSelectCountUnion(
int connectionId, ResultSetWriter resultSetWriter, String sql) throws SQLException {
// SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = 'demo_db'
Expand Down Expand Up @@ -713,14 +763,18 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result
// Add schema
for (int i = 0; i < columnCount; i++) {
String columnName = schema.getColumnName(i);
if (sql.equalsIgnoreCase("show table status") && columnName.equalsIgnoreCase("table_id")) {
if ((sql.startsWith("SHOW FULL TABLES") || sql.equalsIgnoreCase("show table status"))
&& columnName.equalsIgnoreCase("table_id")) {
tableIdColumnIndex = i;
continue;
}
int columnType = schema.getColumnType(i);
columns.add(
new QueryResultColumn(columnName, TypeUtil.openmldbTypeToMysqlTypeString(columnType)));
}
if (sql.startsWith("SHOW FULL TABLES")) {
columns.add(new QueryResultColumn("Table_type", "VARCHAR(255)"));
}

resultSetWriter.writeColumns(columns);

Expand All @@ -739,6 +793,9 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result
String columnValue = TypeUtil.getResultSetStringColumn(resultSet, i + 1, type);
row.add(columnValue);
}
if (sql.startsWith("SHOW FULL TABLES")) {
row.add("BASE TABLE");
}

resultSetWriter.writeRow(row);
}
Expand Down

0 comments on commit b7e592c

Please sign in to comment.