From 8e32dad3b7078223ceae7a6c72cde761c60a4722 Mon Sep 17 00:00:00 2001 From: cyjseagull Date: Wed, 6 Nov 2024 13:53:14 +0800 Subject: [PATCH] optimize pir service publish perf (#122) * optimize pir service publish perf * optimize pir * use jdbcTemplate to execute sql --- .../wedpr/common/utils/CSVFileParser.java | 127 +++-------- .../publish/model/PirServiceSetting.java | 7 +- .../mybatis/MybatisConfigurationFactory.java | 7 + .../core/impl/PirDatasetConstructorImpl.java | 209 +++++++++++------- .../task/plugin/pir/dao/NativeSQLMapper.java | 32 --- .../pir/dao/NativeSQLMapperWrapper.java | 38 +++- .../pir/service/impl/PirServiceImpl.java | 13 +- wedpr-pir/conf/wedpr.properties | 2 +- 8 files changed, 207 insertions(+), 228 deletions(-) delete mode 100644 wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapper.java diff --git a/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java b/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java index 7e730fdc..d1cadae2 100644 --- a/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java +++ b/wedpr-common/utils/src/main/java/com/webank/wedpr/common/utils/CSVFileParser.java @@ -15,11 +15,13 @@ package com.webank.wedpr.common.utils; +import com.fasterxml.jackson.annotation.JsonIgnoreProperties; import com.opencsv.CSVReaderHeaderAware; import com.webank.wedpr.common.config.WeDPRCommonConfig; import java.io.*; import java.nio.file.Paths; import java.util.*; +import lombok.Data; import org.apache.commons.lang3.ArrayUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -45,19 +47,8 @@ private static Object loadCSVFile(String filePath, int chunkSize, ParseHandler h } } - public static Set getFields(String filePath) throws Exception { - return (Set) - (loadCSVFile( - filePath, - WeDPRCommonConfig.getReadChunkSize(), - new ParseHandler() { - @Override - public Object call(CSVReaderHeaderAware reader) throws Exception { - return reader.readMap().keySet(); - } - })); - } - + @Data + @JsonIgnoreProperties(ignoreUnknown = true) public static class ExtractConfig { private String originalFilePath; private List extractFields; @@ -66,8 +57,6 @@ public static class ExtractConfig { private Integer writeChunkSize = WeDPRCommonConfig.getWriteChunkSize(); private Integer readChunkSize = WeDPRCommonConfig.getReadChunkSize(); - public ExtractConfig() {} - public ExtractConfig( String originalFilePath, List extractFields, String extractFilePath) { this.originalFilePath = originalFilePath; @@ -75,54 +64,6 @@ public ExtractConfig( this.extractFilePath = extractFilePath; } - public String getOriginalFilePath() { - return originalFilePath; - } - - public void setOriginalFilePath(String originalFilePath) { - this.originalFilePath = originalFilePath; - } - - public List getExtractFields() { - return extractFields; - } - - public void setExtractFields(List extractFields) { - this.extractFields = extractFields; - } - - public String getExtractFilePath() { - return extractFilePath; - } - - public void setExtractFilePath(String extractFilePath) { - this.extractFilePath = extractFilePath; - } - - public String getFieldSplitter() { - return fieldSplitter; - } - - public void setFieldSplitter(String fieldSplitter) { - this.fieldSplitter = fieldSplitter; - } - - public Integer getWriteChunkSize() { - return writeChunkSize; - } - - public void setWriteChunkSize(Integer writeChunkSize) { - this.writeChunkSize = writeChunkSize; - } - - public Integer getReadChunkSize() { - return readChunkSize; - } - - public void setReadChunkSize(Integer readChunkSize) { - this.readChunkSize = readChunkSize; - } - @Override public String toString() { return "ExtractConfig{" @@ -157,7 +98,7 @@ public Object call(CSVReaderHeaderAware reader) throws Exception { Map fieldsMapping = Common.trimAndMapping(headerInfo.keySet()); for (String field : extractConfig.getExtractFields()) { - if (!fieldsMapping.keySet().contains(field.trim())) { + if (!fieldsMapping.containsKey(field.trim())) { String errorMsg = "extractFields failed for the field " + field @@ -204,36 +145,38 @@ public Object call(CSVReaderHeaderAware reader) throws Exception { }); } - public static List> processCsv2SqlMap(String[] tableFields, String csvFilePath) + public interface RowContentHandler { + void handle(List rowContent) throws Exception; + } + + public static void processCsvContent( + String[] tableFields, String csvFilePath, RowContentHandler rowContentHandler) throws Exception { - return (List>) - loadCSVFile( - csvFilePath, - WeDPRCommonConfig.getReadChunkSize(), - reader -> { - List> resultValue = new ArrayList<>(); - Map row; - while ((row = reader.readMap()) != null) { - List rowValue = new ArrayList<>(); - for (String field : tableFields) { - Map rowFieldsMapping = - Common.trimAndMapping(row.keySet()); - if (!rowFieldsMapping.keySet().contains(field.trim())) { - String errorMsg = - "extractFields failed for the field " - + field - + " not existed in the file " - + ArrayUtils.toString( - rowFieldsMapping.keySet()); - logger.warn(errorMsg); - throw new WeDPRException(-1, errorMsg); - } - rowValue.add(row.get(rowFieldsMapping.get(field))); - } - resultValue.add(rowValue); + loadCSVFile( + csvFilePath, + WeDPRCommonConfig.getReadChunkSize(), + reader -> { + Map row; + while ((row = reader.readMap()) != null) { + List rowValue = new ArrayList<>(); + for (String field : tableFields) { + Map rowFieldsMapping = + Common.trimAndMapping(row.keySet()); + if (!rowFieldsMapping.containsKey(field.trim())) { + String errorMsg = + "extractFields failed for the field " + + field + + " not existed in the file " + + ArrayUtils.toString(rowFieldsMapping.keySet()); + logger.warn(errorMsg); + throw new WeDPRException(-1, errorMsg); } - return resultValue; - }); + rowValue.add(row.get(rowFieldsMapping.get(field))); + } + rowContentHandler.handle(rowValue); + } + return Boolean.TRUE; + }); } public static boolean writeMapData( diff --git a/wedpr-components/db-mapper/service-publish/src/main/java/com/webank/wedpr/components/db/mapper/service/publish/model/PirServiceSetting.java b/wedpr-components/db-mapper/service-publish/src/main/java/com/webank/wedpr/components/db/mapper/service/publish/model/PirServiceSetting.java index 25bdcb8a..58d564d5 100644 --- a/wedpr-components/db-mapper/service-publish/src/main/java/com/webank/wedpr/components/db/mapper/service/publish/model/PirServiceSetting.java +++ b/wedpr-components/db-mapper/service-publish/src/main/java/com/webank/wedpr/components/db/mapper/service/publish/model/PirServiceSetting.java @@ -16,7 +16,6 @@ package com.webank.wedpr.components.db.mapper.service.publish.model; import com.webank.wedpr.common.utils.Common; -import com.webank.wedpr.common.utils.Constant; import com.webank.wedpr.common.utils.ObjectMapperFactory; import com.webank.wedpr.common.utils.WeDPRException; import java.util.Collections; @@ -42,14 +41,10 @@ public List obtainQueriedFields(PirSearchType searchType, List q if (searchType == PirSearchType.SearchValue) { // remove duplicated fields Set queriedFieldSet = new HashSet<>(queriedFields); - if (queriedFieldSet.contains(idField)) { - queriedFieldSet.remove(idField); - queriedFieldSet.add(Constant.PIR_ID_FIELD_NAME); - } return (List) CollectionUtils.intersection(queriedFieldSet, accessibleValueQueryFields); } - return Collections.singletonList(Constant.PIR_ID_FIELD_NAME); + return Collections.singletonList(idField); } public void setSearchType(String searchType) { diff --git a/wedpr-components/mybatis/src/main/java/com/webank/wedpr/components/mybatis/MybatisConfigurationFactory.java b/wedpr-components/mybatis/src/main/java/com/webank/wedpr/components/mybatis/MybatisConfigurationFactory.java index ad32f139..d89095a5 100644 --- a/wedpr-components/mybatis/src/main/java/com/webank/wedpr/components/mybatis/MybatisConfigurationFactory.java +++ b/wedpr-components/mybatis/src/main/java/com/webank/wedpr/components/mybatis/MybatisConfigurationFactory.java @@ -39,6 +39,7 @@ import org.springframework.core.io.DefaultResourceLoader; import org.springframework.core.io.Resource; import org.springframework.core.io.support.PathMatchingResourcePatternResolver; +import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.jdbc.datasource.DataSourceTransactionManager; import org.springframework.transaction.PlatformTransactionManager; import org.springframework.transaction.annotation.EnableTransactionManagement; @@ -117,6 +118,12 @@ public PageInterceptor pageInterceptor() { return pageInterceptor; } + @Bean + @Primary + public JdbcTemplate primaryJdbcTemplate() { + return new JdbcTemplate(dataSource); + } + @Bean public MybatisPlusInterceptor mybatisPlusInterceptor() { MybatisPlusInterceptor interceptor = new MybatisPlusInterceptor(); diff --git a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java index f2e48418..df3d3dd0 100644 --- a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java +++ b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/core/impl/PirDatasetConstructorImpl.java @@ -29,92 +29,94 @@ import com.webank.wedpr.components.storage.builder.StoragePathBuilder; import com.webank.wedpr.components.task.plugin.pir.config.PirServiceConfig; import com.webank.wedpr.components.task.plugin.pir.core.PirDatasetConstructor; -import com.webank.wedpr.components.task.plugin.pir.dao.NativeSQLMapper; +import java.sql.SQLException; import java.util.ArrayList; import java.util.Arrays; import java.util.List; +import org.apache.commons.lang3.StringUtils; +import org.apache.commons.lang3.tuple.ImmutablePair; +import org.apache.commons.lang3.tuple.Pair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.jdbc.core.JdbcTemplate; public class PirDatasetConstructorImpl implements PirDatasetConstructor { - private static Logger logger = LoggerFactory.getLogger(PirDatasetConstructor.class); + private static final Logger logger = LoggerFactory.getLogger(PirDatasetConstructor.class); private final DatasetMapper datasetMapper; - private final StoragePathBuilder storagePathBuilder; private final FileStorageInterface fileStorageInterface; - private final NativeSQLMapper nativeSQLMapper; + private final JdbcTemplate jdbcTemplate; + private final String dbName; public PirDatasetConstructorImpl( DatasetMapper datasetMapper, FileStorageInterface fileStorageInterface, - StoragePathBuilder storagePathBuilder, - NativeSQLMapper nativeSQLMapper) { + JdbcTemplate jdbcTemplate) + throws SQLException { this.datasetMapper = datasetMapper; this.fileStorageInterface = fileStorageInterface; - this.storagePathBuilder = storagePathBuilder; - this.nativeSQLMapper = nativeSQLMapper; + this.jdbcTemplate = jdbcTemplate; + this.dbName = this.jdbcTemplate.getDataSource().getConnection().getCatalog(); + logger.info("Current database name: {}", this.dbName); } @Override public void construct(PirServiceSetting serviceSetting) throws Exception { - List allTables = this.nativeSQLMapper.showAllTables(); - String datasetID = serviceSetting.getDatasetId(); - String tableId = - com.webank.wedpr.components.task.plugin.pir.utils.Constant.datasetId2tableId( - datasetID); - if (allTables.contains(tableId)) { - logger.info("The dataset {} has already been constructed into {}", datasetID, tableId); - return; - } - Dataset dataset = this.datasetMapper.getDatasetByDatasetId(datasetID, false); - DataSourceType dataSourceType = DataSourceType.fromStr(dataset.getDataSourceType()); - if (dataSourceType != DataSourceType.CSV && dataSourceType != DataSourceType.EXCEL) { - throw new WeDPRException("PIR only support CSV and excel DataSources now!"); - } - logger.info("constructFromCSV, dataset: {}", dataset.getDatasetId()); - constructFromCSV(dataset, serviceSetting.getIdField()); - logger.info("constructFromCSV success, dataset: {}", dataset.getDatasetId()); - } + Dataset dataset = null; + String tableId = null; + try { + String datasetID = serviceSetting.getDatasetId(); + tableId = + com.webank.wedpr.components.task.plugin.pir.utils.Constant.datasetId2tableId( + datasetID); + if (tableExists(tableId)) { + logger.info( + "The dataset {} has already been constructed into {}", datasetID, tableId); + return; + } + dataset = this.datasetMapper.getDatasetByDatasetId(datasetID, false); + // create table + DataSourceType dataSourceType = DataSourceType.fromStr(dataset.getDataSourceType()); + long startT = System.currentTimeMillis(); + logger.info( + "Load pir service, dataset: {}, type: {}", + dataset.getDatasetId(), + dataSourceType.name()); - private void constructFromCSV(Dataset dataset, String idField) throws Exception { - StoragePath storagePath = - StoragePathBuilder.getInstance( - dataset.getDatasetStorageType(), dataset.getDatasetStoragePath()); - String localFilePath = - Common.joinPath(PirServiceConfig.getPirCacheDir(), dataset.getDatasetId()); - this.fileStorageInterface.download(storagePath, localFilePath); - logger.info( - "Download dataset {} success, localFilePath: {}", - dataset.getDatasetId(), - localFilePath); - String[] datasetFields = - Arrays.stream(dataset.getDatasetFields().trim().split(",")) - .map(String::trim) - .toArray(String[]::new); - List datasetFieldsList = Arrays.asList(datasetFields); - if (datasetFieldsList.contains(Constant.PIR_ID_FIELD_NAME)) { - throw new WeDPRException("Conflict with sys field " + Constant.PIR_ID_FIELD_NAME); - } - if (datasetFieldsList.contains(Constant.PIR_ID_HASH_FIELD_NAME)) { - throw new WeDPRException("Conflict with sys field " + Constant.PIR_ID_HASH_FIELD_NAME); - } - Long startTime = System.currentTimeMillis(); - List> sqlValues = - CSVFileParser.processCsv2SqlMap(datasetFields, localFilePath); - if (sqlValues.size() == 0) { + constructFromCSV(tableId, dataset, serviceSetting.getIdField()); logger.info( - "constructFromCSV with empty dataset, datasetID: {}, datasetPath: {}", + "Load pir success, dataset: {}, type: {}, timecost: {}ms", dataset.getDatasetId(), - localFilePath); - return; + dataSourceType.name(), + System.currentTimeMillis() - startT); + } catch (Exception e) { + logger.warn( + "Publish pir service failed, dataset: {}, e: ", + (dataset == null ? "empty" : dataset.getDatasetId()), + e); + if (StringUtils.isNotBlank(tableId)) { + logger.info("Revert the created table: {}", tableId); + this.jdbcTemplate.execute("drop table if exists " + tableId); + } + throw e; } - logger.info( - "processCsv2SqlMap success, timecost: {}ms", - System.currentTimeMillis() - startTime); - String tableId = - com.webank.wedpr.components.task.plugin.pir.utils.Constant.datasetId2tableId( - dataset.getDatasetId()); + } + + private boolean tableExists(String tableName) { + String query = + String.format( + "select count(*) " + + "from information_schema.tables " + + "where table_name = ? and table_schema = '%s'", + this.dbName); + Integer result = jdbcTemplate.queryForObject(query, Integer.class, tableName); + return result != null && result > 0; + } + private Pair, Integer> createPirTableForDataset( + String tableId, String idField, String[] datasetFields) { + + logger.info("Create table {}", tableId); // all the field + id_hash field String[] fieldsWithType = new String[datasetFields.length + 1]; List tableFields = new ArrayList<>(); @@ -122,8 +124,8 @@ private void constructFromCSV(Dataset dataset, String idField) throws Exception for (int i = 0; i < datasetFields.length; i++) { // the idField if (idField.equalsIgnoreCase(datasetFields[i])) { - fieldsWithType[i] = Constant.PIR_ID_FIELD_NAME + " VARCHAR(255)"; - tableFields.add(Constant.PIR_ID_FIELD_NAME); + fieldsWithType[i] = idField + " VARCHAR(255)"; + tableFields.add(idField); idFieldIndex = i; } else { fieldsWithType[i] = datasetFields[i] + " TEXT"; @@ -140,22 +142,73 @@ private void constructFromCSV(Dataset dataset, String idField) throws Exception tableId, String.join(",", fieldsWithType), Constant.PIR_ID_HASH_FIELD_NAME, - Constant.PIR_ID_FIELD_NAME); - logger.info("constructFromCSV, execute sql: {}", sql); - this.nativeSQLMapper.executeNativeUpdateSql(sql); + idField); + logger.info("createPirTableForDataset, execute sql: {}", sql); + this.jdbcTemplate.execute(sql); + return new ImmutablePair<>(tableFields, idFieldIndex); + } - StringBuilder sb = new StringBuilder(); - for (List values : sqlValues) { - // add hash for the idField - values.add(CryptoToolkitFactory.hash(values.get(idFieldIndex))); - sb.append("(").append(Common.joinAndAddDoubleQuotes(values)).append("), "); + private void constructFromCSV(String tableId, Dataset dataset, String idField) + throws Exception { + StoragePath storagePath = + StoragePathBuilder.getInstance( + dataset.getDatasetStorageType(), dataset.getDatasetStoragePath()); + String localFilePath = + Common.joinPath(PirServiceConfig.getPirCacheDir(), dataset.getDatasetId()); + this.fileStorageInterface.download(storagePath, localFilePath); + logger.info( + "Download dataset {} success, localFilePath: {}", + dataset.getDatasetId(), + localFilePath); + String[] datasetFields = + Arrays.stream(dataset.getDatasetFields().trim().split(",")) + .map(String::trim) + .toArray(String[]::new); + List datasetFieldsList = Arrays.asList(datasetFields); + if (datasetFieldsList.contains(Constant.PIR_ID_HASH_FIELD_NAME)) { + throw new WeDPRException("Conflict with sys field " + Constant.PIR_ID_HASH_FIELD_NAME); } - String insertValues = sb.toString(); - insertValues = insertValues.substring(0, insertValues.length() - 2); - sql = - String.format( - "INSERT INTO %s (%s) VALUES %s ", - tableId, String.join(",", tableFields), insertValues); - this.nativeSQLMapper.executeNativeUpdateSql(sql); + Pair, Integer> tableInfo = + createPirTableForDataset(tableId, idField, datasetFields); + Integer idFieldIndex = tableInfo.getRight(); + + long startTime = System.currentTimeMillis(); + final Long[] publishedRecorders = {0L}; + final Long reportRecorders = 10000L; + CSVFileParser.processCsvContent( + datasetFields, + localFilePath, + new CSVFileParser.RowContentHandler() { + @Override + public void handle(List rowContent) throws Exception { + StringBuilder sb = new StringBuilder(); + // add hash for the idField + rowContent.add(CryptoToolkitFactory.hash(rowContent.get(idFieldIndex))); + sb.append("(") + .append(Common.joinAndAddDoubleQuotes(rowContent)) + .append(")"); + // insert the row-content into sql + String sql = + String.format( + "INSERT INTO %s (%s) VALUES %s ", + tableId, String.join(",", tableInfo.getLeft()), sb); + publishedRecorders[0] += 1; + if (publishedRecorders[0] % reportRecorders == 0) { + logger.info( + "table: {}, dataset: {} publishing, publishedRecorders: {}, timecost: {}ms", + tableId, + dataset.getDatasetId(), + publishedRecorders[0], + (System.currentTimeMillis() - startTime)); + } + jdbcTemplate.execute(sql); + } + }); + logger.info( + "Publish pir success, table: {}, dataset: {}, publishedRecorders: {}, timecost: {}ms", + tableId, + dataset.getDatasetId(), + publishedRecorders[0], + (System.currentTimeMillis() - startTime)); } } diff --git a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapper.java b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapper.java deleted file mode 100644 index c20d3b88..00000000 --- a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapper.java +++ /dev/null @@ -1,32 +0,0 @@ -/* - * Copyright 2017-2025 [webank-wedpr] - * - * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except - * in compliance with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software distributed under the License - * is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express - * or implied. See the License for the specific language governing permissions and limitations under - * the License. - * - */ - -package com.webank.wedpr.components.task.plugin.pir.dao; - -import java.util.List; -import java.util.Map; -import org.apache.ibatis.annotations.*; - -@Mapper -public interface NativeSQLMapper { - @Update(value = "${nativeSql}") - public void executeNativeUpdateSql(@Param("nativeSql") String nativeSql); - - @Select("SHOW TABLES") - public List showAllTables(); - - @Select(value = "${nativeSql}") - public List> executeNativeQuerySql(@Param("nativeSql") String nativeSql); -} diff --git a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapperWrapper.java b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapperWrapper.java index 6af43d4f..8a04110b 100644 --- a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapperWrapper.java +++ b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/dao/NativeSQLMapperWrapper.java @@ -24,19 +24,37 @@ import com.webank.wedpr.components.pir.sdk.model.PirParamEnum; import com.webank.wedpr.components.pir.sdk.model.PirQueryParam; import com.webank.wedpr.components.task.plugin.pir.model.PirDataItem; +import java.sql.ResultSet; +import java.sql.ResultSetMetaData; +import java.sql.SQLException; +import java.util.HashMap; import java.util.LinkedList; import java.util.List; import java.util.Map; import org.apache.commons.lang3.StringUtils; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import org.springframework.jdbc.core.JdbcTemplate; +import org.springframework.jdbc.core.RowMapper; public class NativeSQLMapperWrapper { - private static Logger logger = LoggerFactory.getLogger(NativeSQLMapperWrapper.class); - private final NativeSQLMapper nativeSQLMapper; + private static final Logger logger = LoggerFactory.getLogger(NativeSQLMapperWrapper.class); + private final JdbcTemplate jdbcTemplate; - public NativeSQLMapperWrapper(NativeSQLMapper nativeSQLMapper) { - this.nativeSQLMapper = nativeSQLMapper; + public NativeSQLMapperWrapper(JdbcTemplate jdbcTemplate) { + this.jdbcTemplate = jdbcTemplate; + } + + public class GeneralRowMapper implements RowMapper> { + @Override + public Map mapRow(ResultSet rs, int rowNum) throws SQLException { + Map result = new HashMap<>(); + ResultSetMetaData resultSetMetaData = rs.getMetaData(); + for (int i = 1; i <= resultSetMetaData.getColumnCount(); i++) { + result.put(resultSetMetaData.getColumnName(i), rs.getString(i)); + } + return result; + } } public List query( @@ -87,7 +105,7 @@ public List executeQuery( tableName, condition); logger.debug("executeQuery: {}", sql); - return toPirDataList(this.nativeSQLMapper.executeNativeQuerySql(sql)); + return toPirDataList(serviceSetting, this.jdbcTemplate.query(sql, new GeneralRowMapper())); } public List executeFuzzyMatchQuery( @@ -108,22 +126,22 @@ public List executeFuzzyMatchQuery( tableName, condition); logger.debug("executeQuery: {}", sql); - return toPirDataList(this.nativeSQLMapper.executeNativeQuerySql(sql)); + return toPirDataList(serviceSetting, this.jdbcTemplate.query(sql, new GeneralRowMapper())); } - protected static List toPirDataList(List> values) - throws Exception { + protected static List toPirDataList( + PirServiceSetting serviceSetting, List> values) throws Exception { if (values == null || values.isEmpty()) { return null; } List result = new LinkedList<>(); int i = 0; - for (Map row : values) { + for (Map row : values) { PirDataItem pirTable = new PirDataItem(); pirTable.setId(i); // the key, Note: here must use the idField value since the client use the idField value // to calculateZ0 - pirTable.setPirKey((String) row.get(Constant.PIR_ID_FIELD_NAME)); + pirTable.setPirKey(row.get(serviceSetting.getIdField())); // the values pirTable.setPirValue(ObjectMapperFactory.getObjectMapper().writeValueAsString(row)); logger.trace("toPirDataList result: {}", pirTable.toString()); diff --git a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/service/impl/PirServiceImpl.java b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/service/impl/PirServiceImpl.java index 7db8ef4a..ca5ad356 100644 --- a/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/service/impl/PirServiceImpl.java +++ b/wedpr-components/task-plugin/pir/src/main/java/com/webank/wedpr/components/task/plugin/pir/service/impl/PirServiceImpl.java @@ -36,14 +36,12 @@ import com.webank.wedpr.components.pir.sdk.model.PirQueryParam; import com.webank.wedpr.components.pir.sdk.model.PirQueryRequest; import com.webank.wedpr.components.storage.api.FileStorageInterface; -import com.webank.wedpr.components.storage.builder.StoragePathBuilder; import com.webank.wedpr.components.storage.config.HdfsStorageConfig; import com.webank.wedpr.components.storage.config.LocalStorageConfig; import com.webank.wedpr.components.task.plugin.pir.core.Obfuscator; import com.webank.wedpr.components.task.plugin.pir.core.PirDatasetConstructor; import com.webank.wedpr.components.task.plugin.pir.core.impl.ObfuscatorImpl; import com.webank.wedpr.components.task.plugin.pir.core.impl.PirDatasetConstructorImpl; -import com.webank.wedpr.components.task.plugin.pir.dao.NativeSQLMapper; import com.webank.wedpr.components.task.plugin.pir.dao.NativeSQLMapperWrapper; import com.webank.wedpr.components.task.plugin.pir.handler.PirServiceHook; import com.webank.wedpr.components.task.plugin.pir.model.ObfuscationParam; @@ -58,13 +56,14 @@ import org.slf4j.LoggerFactory; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.beans.factory.annotation.Qualifier; +import org.springframework.jdbc.core.JdbcTemplate; import org.springframework.stereotype.Service; @Service public class PirServiceImpl implements PirService { private static final Logger logger = LoggerFactory.getLogger(PirServiceImpl.class); - @Autowired private NativeSQLMapper nativeSQLMapper; + @Autowired private JdbcTemplate jdbcTemplate; @Autowired private DatasetMapper datasetMapper; @Autowired private HdfsStorageConfig hdfsConfig; @Autowired private LocalStorageConfig localStorageConfig; @@ -99,13 +98,9 @@ public class PirServiceImpl implements PirService { @PostConstruct public void init() throws Exception { this.obfuscator = new ObfuscatorImpl(); - this.nativeSQLMapperWrapper = new NativeSQLMapperWrapper(nativeSQLMapper); + this.nativeSQLMapperWrapper = new NativeSQLMapperWrapper(jdbcTemplate); this.pirDatasetConstructor = - new PirDatasetConstructorImpl( - datasetMapper, - fileStorage, - new StoragePathBuilder(hdfsConfig, localStorageConfig), - nativeSQLMapper); + new PirDatasetConstructorImpl(datasetMapper, fileStorage, jdbcTemplate); this.pirServiceHook = new PirServiceHook(serviceHook, serviceInvokeMapper); this.pirTopicSubscriber = new PirTopicSubscriberImpl( diff --git a/wedpr-pir/conf/wedpr.properties b/wedpr-pir/conf/wedpr.properties index bb3399f3..70a94846 100644 --- a/wedpr-pir/conf/wedpr.properties +++ b/wedpr-pir/conf/wedpr.properties @@ -24,7 +24,7 @@ wedpr.crypto.symmetric.iv=123456 wedpr.mybatis.mapperLocations=classpath*:mapper/*Mapper.xml # Note: the basePackage can't set to com.webank.wedpr simply for the mybatis will scan the Service -wedpr.mybatis.BasePackage=com.webank.wedpr.components.db.mapper.dataset.mapper,com.webank.wedpr.components.db.mapper.service.publish.dao,com.webank.wedpr.components.task.plugin.pir.dao,com.webank.wedpr.components.api.credential.dao +wedpr.mybatis.BasePackage=com.webank.wedpr.components.db.mapper.dataset.mapper,com.webank.wedpr.components.db.mapper.service.publish.dao,com.webank.wedpr.components.api.credential.dao # the pir config wedpr.pir.cache.dir=.cache