背景:随着业务的发展,我们同一套业务系统需支持提供给多个客户(不同的企业用户)使用,所以需确保在多用户环境下,各用户间数据的隔离。但目前系统在早期设计的时候没有考虑到多租户的情况,业务数据没有做到充分隔离(有些表做了字段区分,有些没有)。
目前数据访问层用的是MyBatis框架,sql语句散布在xml里,dao注解里,量非常大。另外,租户字段(companyId)定义也不是所有的业务实体类都有。
基于现状,一个个修改sql,这样工作量太大了,所以排除掉一个个修改sql的方案。只能考虑怎样可以统一修改sql。而租户字段(companyId)的传递也需要有统一处理的地方。
我们先给没有租户字段(companyId)的表加上字段。然后考虑怎样给字段统一添加值的改造。因为业务系统目前是使用Mybatis做持久化,Mybatis有拦截器的功能,是否可以通过自定义Mybatis拦截器拦截下所有的 sql 语句,然后对其进行动态修改,自动添加company_id 字段及其字段值,实现数据隔离呢?答案是肯定的。
先看下Mybatis的核心对象:
Mybatis核心对象 |
解释 |
SqlSession |
作为MyBatis工作的主要顶层API,表示和数据库交互的会话,完成必要数据库增删改查功能。 |
Executor |
MyBatis执行器,是MyBatis 调度的核心,负责SQL语句的生成和查询缓存的维护。 |
StatementHandler |
封装了JDBC Statement操作,负责对JDBC statement 的操作,如设置参数、将Statement结果集转换成List集合。 |
ParameterHandler |
负责对用户传递的参数转换成JDBC Statement 所需要的参数。 |
ResultSetHandler |
负责将JDBC返回的ResultSet结果集对象转换成List类型的集合。 |
TypeHandler |
负责JAVA数据类型和jdbc数据类型之间的映射和转换。 |
MAppedStatement |
MappedStatement维护了一条mapper.xml文件里面 select 、update、delete、insert节点的封装。 |
SqlSource |
负责根据用户传递的parameterObject,动态地生成SQL语句,将信息封装到BoundSql对象中。 |
BoundSql |
表示动态生成的SQL语句以及相应的参数信息。 |
Configuration |
MyBatis所有的配置信息都维持在Configuration对象。 |
Mybatis拦截器可以拦截Executor、ParameterHandler、StatementHandler、ResultSetHandler四个对象里面的方法。Executor是Mybatis的核心接口。Mybatis中所有的Mapper语句的执行都是通过Executor进行的。其中增删改语句是通过Executor接口的update方法,查询语句是通过query方法。所以我们可以拦截Executor,拦载所有的select 、insert、update、delete语句进行改造,添加company_id字段及字段值。
创建一个自定义的拦截器:
/**
* Mybatis - 通用拦截器。用于拦截sql并自动补充公共字段。包括query、insert、update、delete语句
*/
@Slf4j
@Intercepts(
{
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class}),
@Signature(type = Executor.class, method = "query", args = {MappedStatement.class, Object.class, RowBounds.class, ResultHandler.class, CacheKey.class, BoundSql.class}),
@Signature(type = Executor.class, method = "update", args = {MappedStatement.class, Object.class})
}
)
public class AutoFillParamInterceptor implements Interceptor {
private static final String LAST_INSERT_ID_SQL = "LAST_INSERT_ID()";
private static final String COMPANY_ID = "company_id";
/**
* 拦截主要的逻辑
* @param invocation
* @return
* @throws Throwable
*/
@Override
public Object intercept(Invocation invocation) throws Throwable {
final Object[] args = invocation.getArgs();
final MappedStatement ms = (MappedStatement) args[0];
final Object paramObj = args[1];
// 1.通过注解判断是否需要处理此SQL
String namespace = ms.getId();
String className = namespace.substring(0, namespace.lastIndexOf("."));
//selectByExample
String methodName = StringUtils.substringAfterLast(namespace, ".");
Class<?> classType = Class.forName(className);
if (classType.isAnnotationPresent(IgnoreAutoFill.class)) {
//注解在类上
String userType = classType.getAnnotation(IgnoreAutoFill.class).userType();
if (StringUtils.isNotBlank(userType)) {
//ignore特定的用户类型,其他均拦截
if (userType.equals(getCurrentUserType())) {
return invocation.proceed();
}
} else {
return invocation.proceed();
}
} else {
//注解在方法上
for (Method method : classType.getMethods()) {
if (!methodName.equals(method.getName())) {
continue;
} else {
if (method.isAnnotationPresent(IgnoreAutoFill.class)) {
String userType = method.getAnnotation(IgnoreAutoFill.class).userType();
if (StringUtils.isNotBlank(userType)) {
//ignore特定的用户类型,其他均拦截
if (userType.equals(getCurrentUserType())) {
return invocation.proceed();
}
} else {
return invocation.proceed();
}
}
break;
}
}
}
// 2.获取SQL语句
BoundSql boundSql = ms.getBoundSql(paramObj);
// 原始sql
String originalSql = boundSql.getSql();
log.debug("originalSql:{}", originalSql);
// 3.根据语句类型改造SQL语句
switch (ms.getSqlCommandType()) {
case INSERT: {
originalSql = convertInsertSQL(originalSql);
args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);
break;
}
case UPDATE:
case DELETE: {
originalSql = SQLUtils.addCondition(originalSql, COMPANY_ID + "='" + getCompanyId() +"'", null);
args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);
break;
}
case SELECT: {
if (!StringUtils.containsIgnoreCase(originalSql, LAST_INSERT_ID_SQL)) {
//where 条件拼接 companyId
MySQLStatementParser parser = new MySqlStatementParser(originalSql);
SQLStatement statement = parser.parseStatement();
SQLSelectStatement selectStatement = (SQLSelectStatement) statement;
SQLSelect sqlSelect = selectStatement.getSelect();
SQLSelectQuery query = sqlSelect.getQuery();
addSelectCondition(query, COMPANY_ID + "='" + getCompanyId() + "'");
originalSql = SQLUtils.toSQLString(selectStatement, JdbcConstants.MYSQL);
// 将新生成的MappedStatement对象替换到参数列表中
args[0] = newMappedStatement(ms, boundSql, originalSql, paramObj);
}
break;
}
}
log.debug("modifiedSql:{}", originalSql);
// 4.应用修改后的SQL语句
return invocation.proceed();
}
private void addSelectCondition(SQLSelectQuery query, String condition){
if (query instanceof SQLUnionQuery) {
SQLUnionQuery sqlUnionQuery = (SQLUnionQuery) query;
addSelectCondition(sqlUnionQuery.getLeft(), condition);
addSelectCondition(sqlUnionQuery.getRight(), condition);
} else if (query instanceof SQLSelectQueryBlock) {
SQLSelectQueryBlock selectQueryBlock = (SQLSelectQueryBlock) query;
SQLTableSource tableSource = selectQueryBlock.getFrom();
String conditionTmp = condition;
String alias = getLeftAlias(tableSource);
if (StringUtils.isNotBlank(alias)) {
//拼接别名
conditionTmp = alias + "." + condition;
}
SQLExpr conditionExpr = SQLUtils.toMySqlExpr(conditionTmp);
selectQueryBlock.addCondition(conditionExpr);
}
}
private String getLeftAlias(SQLTableSource tableSource) {
if (tableSource != null) {
if (tableSource instanceof SQLExprTableSource) {
if (StringUtils.isNotBlank(tableSource.getAlias())) {
return tableSource.getAlias();
}
} else if (tableSource instanceof SQLJoinTableSource) {
SQLJoinTableSource join = (SQLJoinTableSource) tableSource;
return getLeftAlias(join.getLeft());
}
}
return null;
}
/**
* 用于封装目标对象的,通过该方法我们可以返回目标对象本身,也可以返回一个它的代理
* @param target
* @return
*/
@Override
public Object plugin(Object target) {
//只拦截Executor对象,减少目标被代理的次数
if (target instanceof Executor) {
return Plugin.wrap(target, this);
}
return target;
}
/**
* 注册当前拦截器的时候可以设置一些属性
*/
@Override
public void setProperties(Properties properties) {
}
private String convertInsertSQL(String originalSql) {
MySqlStatementParser parser = new MySqlStatementParser(originalSql);
SQLStatement statement = parser.parseStatement();
MySqlSchemaStatVisitor visitor = new MySqlSchemaStatVisitor();
statement.accept(visitor);
MySqlInsertStatement myStatement = (MySqlInsertStatement) statement;
String tableName = myStatement.getTableName().getSimpleName();
List<SQLExpr> columns = myStatement.getColumns();
List<SQLInsertStatement.ValuesClause> vcl = myStatement.getValuesList();
if (columns == null || columns.size() <= 0 || myStatement.getQuery() != null) {
return originalSql;
}
if (!visitor.containsColumn(tableName, COMPANY_ID)) {
SQLExpr columnExpr = SQLUtils.toMySqlExpr(COMPANY_ID);
columns.add(columnExpr);
SQLExpr valuesExpr = SQLUtils.toMySqlExpr("'" + getCompanyId() + "'");
vcl.stream().forEach(v -> v.addValue(valuesExpr));
}
return SQLUtils.toSQLString(myStatement, JdbcConstants.MYSQL);
}
private MappedStatement newMappedStatement(MappedStatement ms, BoundSql boundSql,
String sql, Object parameter){
BoundSql newBoundSql = new BoundSql(ms.getConfiguration(),sql, new ArrayList(boundSql.getParameterMappings()), parameter);
for (ParameterMapping mapping : boundSql.getParameterMappings()) {
String prop = mapping.getProperty();
if (boundSql.hasAdditionalParameter(prop)) {
newBoundSql.setAdditionalParameter(prop, boundSql.getAdditionalParameter(prop));
}
}
return copyFromOriMappedStatement(ms, new WarpBoundSqlSqlSource(newBoundSql));
}
private MappedStatement copyFromOriMappedStatement(MappedStatement ms, SqlSource newSqlSource) {
MappedStatement.Builder builder = new MappedStatement.Builder(ms.getConfiguration(),ms.getId(),newSqlSource,ms.getSqlCommandType());
builder.cache(ms.getCache()).databaseId(ms.getDatabaseId())
.fetchSize(ms.getFetchSize())
.flushCacheRequired(ms.isFlushCacheRequired())
.keyColumn(StringUtils.join(ms.getKeyColumns(), ','))
.keyGenerator(ms.getKeyGenerator())
.keyProperty(StringUtils.join(ms.getKeyProperties(), ','))
.lang(ms.getLang()).parameterMap(ms.getParameterMap())
.resource(ms.getResource()).resultMaps(ms.getResultMaps())
.resultOrdered(ms.isResultOrdered())
.resultSets(StringUtils.join(ms.getResultSets(), ','))
.resultSetType(ms.getResultSetType()).statementType(ms.getStatementType())
.timeout(ms.getTimeout()).useCache(ms.isUseCache());
return builder.build();
}
static class WarpBoundSqlSqlSource implements SqlSource {
private final BoundSql boundSql;
public WarpBoundSqlSqlSource(BoundSql boundSql) {
this.boundSql = boundSql;
}
@Override
public BoundSql getBoundSql(Object parameterObject) {
return boundSql;
}
}
public String getCompanyId() {
//先从authenticationFacade取
String companyId = CompanyContext.getCompanyId();
if(StringUtils.isBlank(companyId)){
log.error("Can not get the companyId! {}", companyId);
throw new RuntimeException("Can not get the companyId! " + companyId);
}
return companyId;
}
public String getCurrentUserType() {
//authenticationFacade取
AuthenticationFacade authenticationFacade = ApplicationContextProvider.getBean(AuthenticationFacade.class);
Integer currentUserType = authenticationFacade.getCurrentUserType();
if (currentUserType == null) {
log.error("Can not get the currentUserType! {}", currentUserType);
throw new RuntimeException("Can not get the currentUserType! " + currentUserType);
}
UserTypeEnum userTypeEnum = UserTypeEnum.getByCode(currentUserType);
return userTypeEnum.getUserType();
}
}
虽然大部分sql都需要做条件过滤,但也有些特殊情况某些sql可能不需要过滤companyId条件,所以增加一个注解,如果不需要拦截的sql可以在Mapper类或方法上添加此注解,这样可以兼容不需要拦截的方法。
添加 IgnoreAutoFill 注解:
/**
* 用于标注在不需要被拦截器处理的SQL上(Mapper类)
*/
@Target({ElementType.METHOD, ElementType.TYPE})
@Retention(RetentionPolicy.RUNTIME)
public @interface IgnoreAutoFill {
String userType() default "";
}
Mapper示例:
public interface PostRecordDOMapper {
long countByExample(PostRecordDOExample example);
int deleteByExample(PostRecordDOExample example);
int deleteByPrimaryKey(Long id);
int insert(PostRecordDO record);
int insertSelective(PostRecordDO record);
List<PostRecordDO> selectByExample(PostRecordDOExample example);
@IgnoreAutoFill
List<PostRecordDO> selectByExampleAllCompany(PostRecordDOExample example);
PostRecordDO selectByPrimaryKey(Long id);
int updateByExampleSelective(@Param("record") PostRecordDO record, @Param("example") PostRecordDOExample example);
int updateByExample(@Param("record") PostRecordDO record, @Param("example") PostRecordDOExample example);
int updateByPrimaryKeySelective(PostRecordDO record);
int updateByPrimaryKey(PostRecordDO record);
void batchInsert(@Param("items") List<PostRecordDO> items);
}
在拦截器中,我们使用阿里的druid做sql解析,修改sql。
加入 druid 依赖:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>druid</artifactId>
<version>1.1.6</version>
</dependency>
拦截修改sql时,对于select、update、delete语句,我们直接添加company_id条件,对于insert语名,先判断原sql的参数列表里有没有company_id字段,如果有的话不作处理(说明原来就做了字段隔离),没有才自动给它添加company_id字段及值。
至此,我们解决了统一修改sql的问题,那还有一个重要问题,填充的字段值从哪里取得呢?因为调用持久层Mapper类方法的入参并不一定带有租户字段(companyId)信息过来,有些方法甚至只会传一个id的参数,像 deleteByPrimaryKey(Long id);selectByPrimaryKey(Long id);即使是传对象参数,对象实体类也不一定有租户字段(companyId)。所以如何传递租户字段(companyId)是一个改造难点。
考虑一翻,我们是否可以用 ThreadLocal 来存取呢?答案是肯定的。
要传递多租户字段(companyId)值,得先取得companyId值。因为每一个系统用户都有所属的companyId,所以只要在用户登录系统的时候,从token中拿到用户所属的companyId,然后set进ThreadLocal。后续线程的处理都可以从ThreadLocal中取得companyId。这样Mybatis拦截器也就随时都可以取得companyId的值进行sql参数或者条件的拼接改造了。
多租户上下文信息:
@Slf4j
public class CompanyContext implements AutoCloseable {
private static final TransmittableThreadLocal<String> COMPANY_ID_CTX = new TransmittableThreadLocal<>();
public CompanyContext(String companyId) {
COMPANY_ID_CTX.remove();
COMPANY_ID_CTX.set(companyId);
}
public static String getCompanyId(){
return COMPANY_ID_CTX.get();
}
@Override
public void close() throws Exception {
COMPANY_ID_CTX.remove();
}
public static void remove(){
COMPANY_ID_CTX.remove();
}
}
但是,系统的业务处理不可能只用一个线程从头处理到结束,很多时候为了加快业务的处理,都是需要用到线程池的。
那么,问题又来了,不同线程间如何将这个companyId的ThreadLocal值传递下去呢?
这也是有解决方案的。
Transmittable ThreadLocal
Alibaba 有一个 Transmittable ThreadLocal 库,提供了一个TransmittableThreadLocal,它是 ThreadLocal 的一个扩展,提供了将变量的值从一个线程传递到另一个线程的能力。当一个任务被提交到线程池时,TransmittableThreadLocal 变量的值被捕获并传递给执行任务的工作线程。这确保了正确的值在工作线程中可用,即使它最初在不同的线程中设置。
使用Transmittable ThreadLocal 库,需引入依赖:
<dependency>
<groupId>com.alibaba</groupId>
<artifactId>transmittable-thread-local</artifactId>
<version>2.11.5</version>
</dependency>
使用的时候,调用一下TtlExecutors工具提供的getTtlExecutor静态方法,传入一个Executor,即可获取一个支持 TTL (TransmittableThreadLocal)传递的 Executor 实例,此线程池就确保了上下文信息的正确传递,可放心使用了,如下所示:
@Bean(name = "exportDataExecutorPool")
public Executor exportDataExecutorPool() {
ThreadPoolTaskExecutor threadPoolTaskExecutor = new ThreadPoolTaskExecutor();
threadPoolTaskExecutor.setCorePoolSize(CPU_NUM);
threadPoolTaskExecutor.setMaxPoolSize(CPU_NUM * 2);
threadPoolTaskExecutor.setKeepAliveSeconds(60);
threadPoolTaskExecutor.setQueueCapacity(100);
threadPoolTaskExecutor.setThreadNamePrefix("ExportData Thread-");
threadPoolTaskExecutor.setRejectedExecutionHandler(new ThreadPoolExecutor.CallerRunsPolicy());
threadPoolTaskExecutor.initialize();
return TtlExecutors.getTtlExecutor(threadPoolTaskExecutor);
}
这样就可以确保线程池的线程随时可以都取到正确的companyId了。
至此,是不是就完成了改造了呢?
还没有。
为什么呢?
如果是同一个JVM确实是没问题了,如果不同的JVM呢?
一般较为复杂的系统都会按业务划分成不同的模块,同一个模块也可能部署多个不同的实例,这些不同的模块或不同的实例间的通信一般是通过远程调用或者消息队列进行数据传递。那么问题就来了,如何在不同的模块或实例间传递这个companyId呢?
目前我们系统的远程调用用的是RestTemplate,消息队列用的Kafka。那就要考虑怎么把companyId统一传递出去了。
远程调用 RestTemplate 的改造
@Slf4j
public class BearerTokenHeaderInterceptor implements ClientHttpRequestInterceptor {
public BearerTokenHeaderInterceptor() {
}
@Override
public ClientHttpResponse intercept(HttpRequest request, byte[] body,
ClientHttpRequestExecution execution) throws IOException {
//通过拦截器统一把companyId放到header
String companyId = CompanyContext.getCompanyId();
log.info("companyId={}", companyId);
if (!StringUtils.isEmpty(companyId)) {
request.getHeaders().set("companyId", companyId);
}
return execution.execute(request, body);
}
}
注意创建 RestTemplate 时需要把这个拦截器加进去:
@Bean
@LoadBalanced
public RestTemplate restTemplate(RestTemplateBuilder restTemplateBuilder) {
final RestTemplate restTemplate = restTemplateBuilder
.setConnectTimeout(Duration.ofMillis(getConnectTimeout()))
.setReadTimeout(Duration.ofMillis(getReadTimeout()))
.requestFactory(()->httpRequestFactory())
.build();
List<ClientHttpRequestInterceptor> interceptors = restTemplate.getInterceptors();
if (interceptors == null) {
interceptors = Collections.emptyList();
}
interceptors = new ArrayList<>(interceptors);
interceptors.removeIf(BearerTokenHeaderInterceptor.class::isInstance);
interceptors.add(new BearerTokenHeaderInterceptor());
restTemplate.setInterceptors(interceptors);
return restTemplate;
}
@Slf4j
public class TokenParseAndLoginFilter extends OncePerRequestFilter {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
String accessToken = null;
String companyId = null;
try {
//从header取得并设置companyId本地变量
companyId = request.getHeader("companyId");
new CompanyContext(companyId);
filterChain.doFilter(request, response);
} catch (Exception e) {
log.error("request error:",e);
response.setContentType(MediaType.APPLICATION_JSON_UTF8_VALUE);
response.setStatus(500);
response.getWriter().write(e.getMessage());
response.getWriter().close();
}
}
}
消息队列 kafka 的改造
/**
* 发送消息
*/
public void sendMsg(String topic, Object value, Map<String, String> headers) {
RecordHeaders kafkaHeaders = new RecordHeaders();
headers.forEach((k,v)->{
RecordHeader recordHeader = new RecordHeader(k,v.getBytes());
kafkaHeaders.add(recordHeader);
});
RecordHeader recordHeader = new RecordHeader("companyId", CompanyContext.getCompanyId().getBytes());
kafkaHeaders.add(recordHeader);
//kafka默认分区
ProducerRecord<String, String> producerRecord = new ProducerRecord<String, String>(topic, null, null, JsonUtil.toJson(value), kafkaHeaders);
kafkaTemplate.send(producerRecord);
}
/**
* 获取实例-手动处理ack
*/
@Bean
public KafkaListenerContainerFactory<ConcurrentMessageListenerContainer<String, String>> kafkaManualAckListenerContainerFactory() {
ConcurrentKafkaListenerContainerFactory<String, String> factory = new ConcurrentKafkaListenerContainerFactory<>();
factory.setConsumerFactory(consumerFactory());
factory.setConcurrency(concurrency);
factory.getContainerProperties().setPollTimeout(3000);
//RetryingAcknowledgingMessageListenerAdapter
factory.getContainerProperties().setAckMode(ContainerProperties.AckMode.MANUAL);
factory.setRetryTemplate(retryTemplate);
factory.setRecoveryCallback(recoveryCallback());
factory.setRecordFilterStrategy(consumerRecord -> {
String companyId = getHead(consumerRecord, "company_id");
// 设置companyId本地变量
new CompanyContext(companyId);
logger.info("Getting the company from kafka message header : {}", companyId);
if(needRequestId) {
String requestId = getHead(consumerRecord, KafkaHeadEnum.REQUEST_ID.getKey());
new RequestIdContext(requestId);
}
return false;
});
return factory;
}
至此,我们就完成了多租户数据隔离的改造。