Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(open-mysql-db): pandas support #3868

Merged
merged 9 commits into from
Jul 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 33 additions & 0 deletions extensions/open-mysql-db/python-testcases/main.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
import pandas as pd
from sqlalchemy import create_engine

if __name__ == '__main__':
# Create a Pandas DataFrame (replace this with your actual data)
data = {'id': [1, 2, 3],
'name': ['Alice', 'Bob', 'Charlie'],
'age': [25, 30, 35],
'score': [1.1, 2.2, 3.3],
'ts': [pd.Timestamp.utcnow().timestamp(), pd.Timestamp.utcnow().timestamp(),
pd.Timestamp.utcnow().timestamp()],
'dt': [pd.to_datetime('20240101', format='%Y%m%d'), pd.to_datetime('20240201', format='%Y%m%d'),
pd.to_datetime('20240301', format='%Y%m%d')],
}
df = pd.DataFrame(data)

# Create a MySQL database engine using SQLAlchemy
engine = create_engine('mysql+pymysql://root:[email protected]:3307/demo_db')

# Replace 'username', 'password', 'host', and 'db_name' with your actual database credentials

# Define the name of the table in the database where you want to write the data
table_name = 'demo_table1'

# Write the DataFrame 'df' into the MySQL table
df.to_sql(table_name, engine, if_exists='replace', index=False)

# 'if_exists' parameter options:
# - 'fail': If the table already exists, an error will be raised.
# - 'replace': If the table already exists, it will be replaced.
# - 'append': If the table already exists, data will be appended to it.

print("Data written to MySQL table successfully!")
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,9 @@ private void handleQuery(
&& !queryStringWithoutComment.startsWith("set @@execute_mode=")) {
// ignore SET command
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
} else if (queryStringWithoutComment.equalsIgnoreCase("COMMIT")) {
// ignore COMMIT command
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
} else if (useDbMatcher.matches()) {
sqlEngine.useDatabase(getConnectionId(ctx), useDbMatcher.group(1));
ctx.writeAndFlush(OkResponse.builder().sequenceId(query.getSequenceId() + 1).build());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import com._4paradigm.openmldb.jdbc.SQLResultSet;
import com._4paradigm.openmldb.mysql.mock.MockResult;
import com._4paradigm.openmldb.mysql.util.TypeUtil;
import com._4paradigm.openmldb.proto.NS;
import com._4paradigm.openmldb.sdk.Column;
import com._4paradigm.openmldb.sdk.Schema;
import com._4paradigm.openmldb.sdk.SdkOption;
Expand Down Expand Up @@ -56,6 +57,12 @@ public class OpenmldbMysqlServer {
Pattern.compile(
"(?i)SELECT COUNT\\(\\*\\) FROM information_schema\\.TABLES WHERE TABLE_SCHEMA = '(.+)'");

// SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'xzs' AND table_name =
// 't_exam_paper'
private final Pattern selectCountSchemaTablesPattern =
Pattern.compile(
"(?i)SELECT COUNT\\(\\*\\) FROM information_schema\\.TABLES WHERE TABLE_SCHEMA = '(.+)' AND table_name = '(.+)'");

// SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = 'xzs'
private final Pattern selectCountColumnsPattern =
Pattern.compile(
Expand Down Expand Up @@ -182,6 +189,10 @@ public void query(
return;
}

if (mockSelectSchemaTableCount(connectionId, resultSetWriter, sql)) {
return;
}

// This mock must execute before mockPatternQuery
// SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = 'demo_db'
// UNION SELECT COUNT(*) FROM information_schema.COLUMNS WHERE TABLE_SCHEMA =
Expand Down Expand Up @@ -240,6 +251,7 @@ public void query(
return;
}

String originalSql = sql;
if (sql.startsWith("SHOW FULL TABLES")) {
// SHOW FULL TABLES WHERE Table_type != 'VIEW'
Matcher showTablesFromDbMatcher = showTablesFromDbPattern.matcher(sql);
Expand All @@ -250,6 +262,14 @@ public void query(
} else {
sql = "SHOW TABLES";
}
} else if (sql.matches("(?i)(?s)^\\s*CREATE TABLE.*$")) {
// convert data type TEXT to STRING
sql = sql.replaceAll("(?i) TEXT", " STRING");
// sql = sql.replaceAll("(?i) DATETIME", " DATE");
if (!sql.toLowerCase().contains(" not null")
&& sql.toLowerCase().contains(" null")) {
sql = sql.replaceAll("(?i) null", "");
}
} else {
Matcher crateDatabaseMatcher = createDatabasePattern.matcher(sql);
Matcher selectLimitMatcher = selectLimitPattern.matcher(sql);
Expand All @@ -264,7 +284,7 @@ public void query(

if (sql.toLowerCase().startsWith("select") || sql.toLowerCase().startsWith("show")) {
SQLResultSet resultSet = (SQLResultSet) stmt.getResultSet();
outputResultSet(resultSetWriter, resultSet, sql);
outputResultSet(resultSetWriter, resultSet, originalSql);
}

System.out.println("Success to execute OpenMLDB SQL: " + sql);
Expand Down Expand Up @@ -622,6 +642,36 @@ private boolean mockPatternQuery(ResultSetWriter resultSetWriter, String sql) {
return false;
}

private boolean mockSelectSchemaTableCount(
int connectionId, ResultSetWriter resultSetWriter, String sql) throws SQLException {
// SELECT COUNT(*) FROM information_schema.tables WHERE table_schema = 'xzs' AND
// table_name = 't_exam_paper'
Matcher selectCountSchemaTablesMatcher = selectCountSchemaTablesPattern.matcher(sql);
if (selectCountSchemaTablesMatcher.matches()) {
// COUNT(*)
List<QueryResultColumn> columns = new ArrayList<>();
columns.add(new QueryResultColumn("COUNT(*)", "VARCHAR(255)"));
resultSetWriter.writeColumns(columns);

List<String> row;
String dbName = selectCountSchemaTablesMatcher.group(1);
String tableName = selectCountSchemaTablesMatcher.group(2);
row = new ArrayList<>();
NS.TableInfo tableInfo =
sqlClusterExecutorMap.get(connectionId).getTableInfo(dbName, tableName);
if (tableInfo == null || tableInfo.getName().equals("")) {
row.add("0");
} else {
row.add("1");
}
resultSetWriter.writeRow(row);

resultSetWriter.finish();
return true;
}
return false;
}

private boolean mockSelectCountUnion(
int connectionId, ResultSetWriter resultSetWriter, String sql) throws SQLException {
// SELECT COUNT(*) FROM information_schema.TABLES WHERE TABLE_SCHEMA = 'demo_db'
Expand Down Expand Up @@ -713,14 +763,18 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result
// Add schema
for (int i = 0; i < columnCount; i++) {
String columnName = schema.getColumnName(i);
if (sql.equalsIgnoreCase("show table status") && columnName.equalsIgnoreCase("table_id")) {
if ((sql.startsWith("SHOW FULL TABLES") || sql.equalsIgnoreCase("show table status"))
&& columnName.equalsIgnoreCase("table_id")) {
tableIdColumnIndex = i;
continue;
}
int columnType = schema.getColumnType(i);
columns.add(
new QueryResultColumn(columnName, TypeUtil.openmldbTypeToMysqlTypeString(columnType)));
}
if (sql.startsWith("SHOW FULL TABLES")) {
columns.add(new QueryResultColumn("Table_type", "VARCHAR(255)"));
}

resultSetWriter.writeColumns(columns);

Expand All @@ -739,6 +793,9 @@ public void outputResultSet(ResultSetWriter resultSetWriter, SQLResultSet result
String columnValue = TypeUtil.getResultSetStringColumn(resultSet, i + 1, type);
row.add(columnValue);
}
if (sql.startsWith("SHOW FULL TABLES")) {
row.add("BASE TABLE");
}

resultSetWriter.writeRow(row);
}
Expand Down
Loading