From ca22a8fef7a31c0235b0b2951260a7819b89993b Mon Sep 17 00:00:00 2001 From: vran Date: Mon, 18 Apr 2022 11:43:33 +0800 Subject: [PATCH] fix some security bug (#103) * fix: use hard-code secret * feat: add driver class validate * feat: optimize drvier resource code * fix:ut failed --- .../resources/application-local.properties | 3 +- api/src/main/resources/application.properties | 3 +- .../databasir/core/domain/DomainErrors.java | 2 +- .../database/service/DatabaseTypeService.java | 9 +- .../CustomDatabaseConnectionFactory.java | 10 +- .../driver/DriverResources.java | 158 +++++++++++------- .../core/infrastructure/jwt/JwtTokens.java | 8 +- .../service/DatabaseTypeServiceTest.java | 14 ++ .../test/resources/application-ut.properties | 3 +- 9 files changed, 138 insertions(+), 72 deletions(-) diff --git a/api/src/main/resources/application-local.properties b/api/src/main/resources/application-local.properties index 9933c49..7d146b2 100644 --- a/api/src/main/resources/application-local.properties +++ b/api/src/main/resources/application-local.properties @@ -9,4 +9,5 @@ spring.flyway.locations=classpath:db/migration databasir.db.url=localhost:3306 databasir.db.username=root databasir.db.password=123456 -databasir.db.driver-directory=drivers \ No newline at end of file +databasir.db.driver-directory=drivers +databasir.jwt.secret=DatabasirJwtSecret \ No newline at end of file diff --git a/api/src/main/resources/application.properties b/api/src/main/resources/application.properties index 89200d7..baf91d4 100644 --- a/api/src/main/resources/application.properties +++ b/api/src/main/resources/application.properties @@ -11,4 +11,5 @@ spring.flyway.enabled=true spring.flyway.baseline-on-migrate=true spring.flyway.locations=classpath:db/migration # driver directory -databasir.db.driver-directory=drivers \ No newline at end of file +databasir.db.driver-directory=drivers +databasir.jwt.secret=${random.uuid} \ No newline at end of file diff --git a/core/src/main/java/com/databasir/core/domain/DomainErrors.java b/core/src/main/java/com/databasir/core/domain/DomainErrors.java index 176c86e..a4fa7c9 100644 --- a/core/src/main/java/com/databasir/core/domain/DomainErrors.java +++ b/core/src/main/java/com/databasir/core/domain/DomainErrors.java @@ -44,7 +44,7 @@ public enum DomainErrors implements DatabasirErrors { DUPLICATE_COLUMN("A_10028", "重复的列"), INVALID_MOCK_DATA_SCRIPT("A_10029", "不合法的表达式"), CANNOT_DELETE_SELF("A_10030", "无法对自己执行删除账号操作"), - DRIVER_CLASS_NAME_OBTAIN_ERROR("A_10031", "获取驱动类名失败"), + DRIVER_CLASS_NOT_FOUND("A_10031", "获取驱动类名失败"), ; private final String errCode; diff --git a/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java b/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java index 773d99a..ceff088 100644 --- a/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java +++ b/core/src/main/java/com/databasir/core/domain/database/service/DatabaseTypeService.java @@ -36,6 +36,7 @@ public class DatabaseTypeService { private final DatabaseTypePojoConverter databaseTypePojoConverter; public Integer create(DatabaseTypeCreateRequest request) { + driverResources.validateJar(request.getJdbcDriverFileUrl(), request.getJdbcDriverClassName()); DatabaseTypePojo pojo = databaseTypePojoConverter.of(request); try { return databaseTypeDao.insertAndReturnId(pojo); @@ -50,7 +51,7 @@ public class DatabaseTypeService { if (DatabaseTypes.has(data.getDatabaseType())) { throw DomainErrors.MUST_NOT_MODIFY_SYSTEM_DEFAULT_DATABASE_TYPE.exception(); } - + driverResources.validateJar(request.getJdbcDriverFileUrl(), request.getJdbcDriverClassName()); DatabaseTypePojo pojo = databaseTypePojoConverter.of(request); try { databaseTypeDao.updateById(pojo); @@ -61,7 +62,7 @@ public class DatabaseTypeService { // 名称修改,下载地址修改需要删除原有的 driver if (!Objects.equals(request.getDatabaseType(), data.getDatabaseType()) || !Objects.equals(request.getJdbcDriverFileUrl(), data.getJdbcDriverFileUrl())) { - driverResources.delete(data.getDatabaseType()); + driverResources.deleteByDatabaseType(data.getDatabaseType()); } }); @@ -73,7 +74,7 @@ public class DatabaseTypeService { throw DomainErrors.MUST_NOT_MODIFY_SYSTEM_DEFAULT_DATABASE_TYPE.exception(); } databaseTypeDao.deleteById(id); - driverResources.delete(data.getDatabaseType()); + driverResources.deleteByDatabaseType(data.getDatabaseType()); }); } @@ -109,7 +110,7 @@ public class DatabaseTypeService { } public String resolveDriverClassName(DriverClassNameResolveRequest request) { - return driverResources.resolveSqlDriverNameFromJar(request.getJdbcDriverFileUrl()); + return driverResources.resolveDriverClassName(request.getJdbcDriverFileUrl()); } } diff --git a/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java b/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java index 3e5d39c..f37df45 100644 --- a/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java +++ b/core/src/main/java/com/databasir/core/infrastructure/connection/CustomDatabaseConnectionFactory.java @@ -36,8 +36,10 @@ public class CustomDatabaseConnectionFactory implements DatabaseConnectionFactor @Override public Connection getConnection(Context context) throws SQLException { - DatabaseTypePojo type = databaseTypeDao.selectByDatabaseType(context.getDatabaseType()); - File driverFile = driverResources.loadOrDownload(context.getDatabaseType(), type.getJdbcDriverFileUrl()); + String databaseType = context.getDatabaseType(); + DatabaseTypePojo type = databaseTypeDao.selectByDatabaseType(databaseType); + File driverFile = driverResources.loadOrDownloadByDatabaseType(databaseType, type.getJdbcDriverFileUrl()); + URLClassLoader loader = null; try { loader = new URLClassLoader( @@ -55,11 +57,11 @@ public class CustomDatabaseConnectionFactory implements DatabaseConnectionFactor Class clazz = null; Driver driver = null; try { - clazz = Class.forName(type.getJdbcDriverClassName(), true, loader); + clazz = Class.forName(type.getJdbcDriverClassName(), false, loader); driver = (Driver) clazz.getConstructor().newInstance(); } catch (ClassNotFoundException e) { log.error("init driver error", e); - throw DomainErrors.CONNECT_DATABASE_FAILED.exception("驱动初始化异常, 请检查 Driver name:" + e.getMessage()); + throw DomainErrors.CONNECT_DATABASE_FAILED.exception("驱动初始化异常, 请检查驱动类名:" + e.getMessage()); } catch (InvocationTargetException | InstantiationException | IllegalAccessException diff --git a/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java b/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java index cfa4780..a498287 100644 --- a/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java +++ b/core/src/main/java/com/databasir/core/infrastructure/driver/DriverResources.java @@ -4,16 +4,22 @@ import com.databasir.core.domain.DomainErrors; import lombok.RequiredArgsConstructor; import lombok.extern.slf4j.Slf4j; import org.apache.commons.io.IOUtils; +import org.apache.commons.lang3.ClassUtils; import org.springframework.beans.factory.annotation.Value; import org.springframework.http.HttpMethod; import org.springframework.stereotype.Component; import org.springframework.util.StreamUtils; +import org.springframework.web.client.RestClientException; import org.springframework.web.client.RestTemplate; import java.io.*; +import java.net.MalformedURLException; +import java.net.URL; +import java.net.URLClassLoader; import java.nio.file.Files; import java.nio.file.Path; import java.nio.file.Paths; +import java.util.Optional; import java.util.UUID; import java.util.jar.JarFile; @@ -27,7 +33,7 @@ public class DriverResources { private final RestTemplate restTemplate; - public void delete(String databaseType) { + public void deleteByDatabaseType(String databaseType) { Path path = Paths.get(driverFilePath(driverBaseDirectory, databaseType)); try { Files.deleteIfExists(path); @@ -36,10 +42,24 @@ public class DriverResources { } } - public String resolveSqlDriverNameFromJar(String driverFileUrl) { + public Optional loadByDatabaseType(String databaseType) { + Path path = Paths.get(driverFilePath(driverBaseDirectory, databaseType)); + if (Files.exists(path)) { + return Optional.of(path.toFile()); + } else { + return Optional.empty(); + } + } + + public File loadOrDownloadByDatabaseType(String databaseType, String driverFileUrl) { + return loadByDatabaseType(databaseType) + .orElseGet(() -> download(driverFileUrl, driverFilePath(driverBaseDirectory, databaseType))); + } + + public String resolveDriverClassName(String driverFileUrl) { String tempFilePath = "temp/" + UUID.randomUUID() + ".jar"; - File driverFile = doDownload(driverFileUrl, tempFilePath); - String className = doResolveSqlDriverNameFromJar(driverFile); + File driverFile = download(driverFileUrl, tempFilePath); + String className = resolveDriverClassName(driverFile); try { Files.deleteIfExists(driverFile.toPath()); } catch (IOException e) { @@ -48,62 +68,13 @@ public class DriverResources { return className; } - public File loadOrDownload(String databaseType, String driverFileUrl) { - String filePath = driverFilePath(driverBaseDirectory, databaseType); - Path path = Path.of(filePath); - if (Files.exists(path)) { - // ignore - log.debug("{} already exists, ignore download from {}", filePath, driverFileUrl); - return path.toFile(); - } - return this.doDownload(driverFileUrl, filePath); - } - - private File doDownload(String driverFileUrl, String filePath) { - Path path = Path.of(filePath); - - // create parent directory - if (Files.notExists(path)) { - path.getParent().toFile().mkdirs(); - try { - Files.createFile(path); - } catch (IOException e) { - log.error("create file error " + filePath, e); - throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage()); - } - } - - // download - try { - return restTemplate.execute(driverFileUrl, HttpMethod.GET, null, response -> { - if (response.getStatusCode().is2xxSuccessful()) { - File file = path.toFile(); - FileOutputStream out = new FileOutputStream(file); - StreamUtils.copy(response.getBody(), out); - IOUtils.closeQuietly(out, ex -> log.error("close file error", ex)); - log.info("{} download success ", filePath); - return file; - } else { - log.error("{} download error from {}: {} ", filePath, driverFileUrl, response); - throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception("驱动下载失败:" - + response.getStatusCode() - + ", " - + response.getStatusText()); - } - }); - } catch (IllegalArgumentException e) { - log.error(filePath + " download driver error", e); - throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage()); - } - } - - private String doResolveSqlDriverNameFromJar(File driverFile) { + public String resolveDriverClassName(File driverFile) { JarFile jarFile = null; try { jarFile = new JarFile(driverFile); } catch (IOException e) { log.error("resolve driver class name error", e); - throw DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR.exception(e.getMessage()); + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception(e.getMessage()); } final JarFile driverJar = jarFile; @@ -119,16 +90,89 @@ public class DriverResources { return reader.readLine(); } catch (IOException e) { log.error("resolve driver class name error", e); - throw DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR.exception(e.getMessage()); + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception(e.getMessage()); } finally { IOUtils.closeQuietly(reader, ex -> log.error("close reader error", ex)); } }) - .orElseThrow(DomainErrors.DRIVER_CLASS_NAME_OBTAIN_ERROR::exception); + .orElseThrow(DomainErrors.DRIVER_CLASS_NOT_FOUND::exception); IOUtils.closeQuietly(jarFile, ex -> log.error("close jar file error", ex)); return driverClassName; } + private File download(String driverFileUrl, String targetFile) { + Path path = Path.of(targetFile); + + // create parent directory + if (Files.notExists(path)) { + path.getParent().toFile().mkdirs(); + try { + Files.createFile(path); + } catch (IOException e) { + log.error("create file error " + targetFile, e); + throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage()); + } + } + + // download + try { + return restTemplate.execute(driverFileUrl, HttpMethod.GET, null, response -> { + if (response.getStatusCode().is2xxSuccessful()) { + File file = path.toFile(); + FileOutputStream out = new FileOutputStream(file); + StreamUtils.copy(response.getBody(), out); + IOUtils.closeQuietly(out, ex -> log.error("close file error", ex)); + log.info("{} download success ", targetFile); + return file; + } else { + log.error("{} download error from {}: {} ", targetFile, driverFileUrl, response); + throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception("驱动下载失败:" + + response.getStatusCode() + + ", " + + response.getStatusText()); + } + }); + } catch (RestClientException e) { + log.error(targetFile + " download driver error", e); + throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage()); + } + } + + public void validateJar(String driverFileUrl, String className) { + String tempFilePath = "temp/" + UUID.randomUUID() + ".jar"; + File driverFile = download(driverFileUrl, tempFilePath); + URLClassLoader loader = null; + try { + loader = new URLClassLoader( + new URL[]{driverFile.toURI().toURL()}, + this.getClass().getClassLoader() + ); + } catch (MalformedURLException e) { + log.error("load driver jar error ", e); + throw DomainErrors.DOWNLOAD_DRIVER_ERROR.exception(e.getMessage()); + } + + try { + Class clazz = Class.forName(className, false, loader); + boolean isValid = ClassUtils.getAllInterfaces(clazz) + .stream() + .anyMatch(cls -> cls.getName().equals("java.sql.Driver")); + if (!isValid) { + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception("不合法的驱动类,请重新指定"); + } + } catch (ClassNotFoundException e) { + log.error("init driver error", e); + throw DomainErrors.DRIVER_CLASS_NOT_FOUND.exception("驱动初始化异常, 请检查驱动类名:" + e.getMessage()); + } finally { + IOUtils.closeQuietly(loader); + try { + Files.deleteIfExists(driverFile.toPath()); + } catch (IOException e) { + log.error("delete driver error " + tempFilePath, e); + } + } + } + private String driverFilePath(String baseDir, String databaseType) { String fileName = databaseType + ".jar"; String filePath; diff --git a/core/src/main/java/com/databasir/core/infrastructure/jwt/JwtTokens.java b/core/src/main/java/com/databasir/core/infrastructure/jwt/JwtTokens.java index 81fbcb7..e3e046f 100644 --- a/core/src/main/java/com/databasir/core/infrastructure/jwt/JwtTokens.java +++ b/core/src/main/java/com/databasir/core/infrastructure/jwt/JwtTokens.java @@ -5,6 +5,7 @@ import com.auth0.jwt.algorithms.Algorithm; import com.auth0.jwt.exceptions.JWTVerificationException; import com.auth0.jwt.interfaces.JWTVerifier; import lombok.extern.slf4j.Slf4j; +import org.springframework.beans.factory.annotation.Value; import org.springframework.stereotype.Component; import java.time.Instant; @@ -23,10 +24,11 @@ public class JwtTokens { private static final String ISSUER = "Databasir"; - private static final String SECRET = "Databasir2022"; + @Value("${databasir.jwt.secret}") + private String tokenSecret; public String accessToken(String username) { - Algorithm algorithm = Algorithm.HMAC256(SECRET); + Algorithm algorithm = Algorithm.HMAC256(tokenSecret); return JWT.create() .withExpiresAt(new Date(new Date().getTime() + ACCESS_EXPIRE_TIME)) @@ -36,7 +38,7 @@ public class JwtTokens { } public boolean verify(String token) { - JWTVerifier verifier = JWT.require(Algorithm.HMAC256(SECRET)) + JWTVerifier verifier = JWT.require(Algorithm.HMAC256(tokenSecret)) .withIssuer(ISSUER) .build(); try { diff --git a/core/src/test/java/com/databasir/core/domain/database/service/DatabaseTypeServiceTest.java b/core/src/test/java/com/databasir/core/domain/database/service/DatabaseTypeServiceTest.java index ed2d134..8d03586 100644 --- a/core/src/test/java/com/databasir/core/domain/database/service/DatabaseTypeServiceTest.java +++ b/core/src/test/java/com/databasir/core/domain/database/service/DatabaseTypeServiceTest.java @@ -5,14 +5,20 @@ import com.databasir.core.BaseTest; import com.databasir.core.domain.DomainErrors; import com.databasir.core.domain.database.data.DatabaseTypeCreateRequest; import com.databasir.core.domain.database.data.DatabaseTypeUpdateRequest; +import com.databasir.core.infrastructure.driver.DriverResources; import com.databasir.dao.impl.DatabaseTypeDao; import com.databasir.dao.tables.pojos.DatabaseTypePojo; import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; +import org.mockito.Mockito; import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.boot.test.mock.mockito.MockBean; import org.springframework.test.context.jdbc.Sql; import org.springframework.transaction.annotation.Transactional; +import static org.mockito.ArgumentMatchers.anyString; + @Transactional class DatabaseTypeServiceTest extends BaseTest { @@ -22,6 +28,14 @@ class DatabaseTypeServiceTest extends BaseTest { @Autowired private DatabaseTypeDao databaseTypeDao; + @MockBean + private DriverResources driverResources; + + @BeforeEach + public void setUp() { + Mockito.doNothing().when(driverResources).validateJar(anyString(), anyString()); + } + @Test void create() { DatabaseTypeCreateRequest request = new DatabaseTypeCreateRequest(); diff --git a/core/src/test/resources/application-ut.properties b/core/src/test/resources/application-ut.properties index e339f38..4547eda 100644 --- a/core/src/test/resources/application-ut.properties +++ b/core/src/test/resources/application-ut.properties @@ -15,4 +15,5 @@ spring.flyway.locations=classpath:db/migration databasir.db.url=localhost:3306 databasir.db.username=root databasir.db.password=123456 -databasir.db.driver-directory=drivers \ No newline at end of file +databasir.db.driver-directory=drivers +databasir.jwt.secret=DatabasirJwtSecret \ No newline at end of file