diff --git a/core/src/main/java/com/databasir/core/domain/DomainErrors.java b/core/src/main/java/com/databasir/core/domain/DomainErrors.java index 176c86e..a4fa7c9 100644 --- a/core/src/main/java/com/databasir/core/domain/DomainErrors.java +++ b/core/src/main/java/com/databasir/core/domain/DomainErrors.java @@ -44,7 +44,7 @@ public enum DomainErrors implements DatabasirErrors { DUPLICATE_COLUMN("A_10028", "重复的列"), INVALID_MOCK_DATA_SCRIPT("A_10029", "不合法的表达式"), CANNOT_DELETE_SELF("A_10030", "无法对自己执行删除账号操作"), - DRIVER_CLASS_NAME_OBTAIN_ERROR("A_10031", "获取驱动类名失败"), + DRIVER_CLASS_NOT_FOUND("A_10031", "获取驱动类名失败"), ; private final String errCode; diff --git a/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java b/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java index 773d99a..8db700a 100644 --- a/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java +++ b/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java @@ -36,6 +36,7 @@ public class DatabaseTypeService { private final DatabaseTypePojoConverter databaseTypePojoConverter; public Integer create(DatabaseTypeCreateRequest request) { + driverResources.validateJar(request.getJdbcDriverFileUrl(), request.getJdbcDriverClassName()); DatabaseTypePojo pojo = databaseTypePojoConverter.of(request); try { return databaseTypeDao.insertAndReturnId(pojo); @@ -50,7 +51,7 @@ public class DatabaseTypeService { if (DatabaseTypes.has(data.getDatabaseType())) { throw DomainErrors.MUST_NOT_MODIFY_SYSTEM_DEFAULT_DATABASE_TYPE.exception(); } - + driverResources.validateJar(request.getJdbcDriverFileUrl(), request.getJdbcDriverClassName()); DatabaseTypePojo pojo = databaseTypePojoConverter.of(request); try { databaseTypeDao.updateById(pojo); diff --git a/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java b/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java index 3e5d39c..bf0f860 100644 --- a/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java +++ b/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java @@ -55,11 +55,11 @@ public class CustomDatabaseConnectionFactory implements DatabaseConnectionFactor Class clazz = null; Driver driver = null; try { - clazz = Class.forName(type.getJdbcDriverClassName(), true, loader); + clazz = Class.forName(type.getJdbcDriverClassName(), false, loader); driver = (Driver) clazz.getConstructor().newInstance(); } catch (ClassNotFoundException e) { log.error("init driver error", e); - throw DomainErrors.CONNECT_DATABASE_FAILED.exception("驱动初始化异常, 请检查 Driver name:" + e.getMessage()); + throw DomainErrors.CONNECT_DATABASE_FAILED.exception("驱动初始化异常, 请检查驱动类名:" + e.getMessage()); } catch (InvocationTargetException | InstantiationException | IllegalAccessException diff --git a/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java b/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java index cfa4780..4bc9683 100644 --- a/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java +++ b/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java @@ -4,6 +4,7 @@ import com.databasir.core.domain.DomainErrors; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.ClassUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpMethod; import org.springframework.stereotype.Component; @@ -11,6 +12,9 @@ import org.springframework.util.StreamUtils; import org.springframework.web.client.RestTemplate; import java.io.*; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; @@ -36,6 +40,41 @@ public class DriverResources { } } + public void validateJar(String driverFileUrl, String className) { + String tempFilePath = "temp/" + UUID.randomUUID() + ".jar"; + File driverFile = doDownload(driverFileUrl, tempFilePath); + URLClassLoader loader = null; + try { + loader = new URLClassLoader( + new URL[]{driverFile.toURI().toURL()}, + this.getClass().getClassLoader() + ); + } catch (MalformedURLException e) { + log.error("load driver jar error ", e); + throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage()); + } + + try { + Class clazz = Class.forName(className, false, loader); + boolean isValid = ClassUtils.getAllInterfaces(clazz) + .stream() + .anyMatch(cls -> cls.getName().equals("java.sql.Driver")); + if (!isValid) { + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception("不合法的驱动类,请重新指定"); + } + } catch (ClassNotFoundException e) { + log.error("init driver error", e); + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception("驱动初始化异常, 请检查驱动类名:" + e.getMessage()); + } finally { + IOUtils.closeQuietly(loader); + try { + Files.deleteIfExists(driverFile.toPath()); + } catch (IOException e) { + log.error("delete driver error " + tempFilePath, e); + } + } + } + public String resolveSqlDriverNameFromJar(String driverFileUrl) { String tempFilePath = "temp/" + UUID.randomUUID() + ".jar"; File driverFile = doDownload(driverFileUrl, tempFilePath); @@ -103,7 +142,7 @@ public class DriverResources { jarFile = new JarFile(driverFile); } catch (IOException e) { log.error("resolve driver class name error", e); - throw DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR.exception(e.getMessage()); + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception(e.getMessage()); } final JarFile driverJar = jarFile; @@ -119,12 +158,12 @@ public class DriverResources { return reader.readLine(); } catch (IOException e) { log.error("resolve driver class name error", e); - throw DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR.exception(e.getMessage()); + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception(e.getMessage()); } finally { IOUtils.closeQuietly(reader, ex -> log.error("close reader error", ex)); } }) - .orElseThrow(DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR::exception); + .orElseThrow(DomainErrors.DRIVER_CLASS_NOT_FOUND::exception); IOUtils.closeQuietly(jarFile, ex -> log.error("close jar file error", ex)); return driverClassName; }