Spring手动扫描包路径及容器外获取Bean实例
Spring内扫描指定包下的类(@ComponentScan原理)容器外获取Spring容器内的Bean实例HttpServletRequest封装类,支持RequestBody的多次读取
·
最近做的项目有一个需求,希望开放指定包下的Controller给其他应用调用,但需要验证其许可。
解决方案:定义一个Filter,在init初始化方法内扫描指定包下的所有Controller,生成开放URL集合;在doFilter方法内对请求参数校验(加盐MD5生成)
方案用到了两个工具类,第一个是HttpServletRequest的包装类,主要是为了解决RequestBody两次读取的问题。正常情况下HttpServletRequest的流只能被读取一次
/**
* 自定义的Http请求封装类,解决RequestBody只能读取一次问题
*
* @create 2017-12-11 17:18
*/
public class HttpRequestTwiceReadingWrapper extends HttpServletRequestWrapper {
private byte[] requestBody = null;
public HttpRequestTwiceReadingWrapper(HttpServletRequest request) {
super(request);
//缓存请求body
try {
requestBody = StreamUtils.copyToByteArray(request.getInputStream());
} catch (IOException e) {
e.printStackTrace();
}
}
/**
* 重写 getInputStream()
*/
@Override
public ServletInputStream getInputStream() throws IOException {
if (requestBody == null) {
requestBody = new byte[0];
}
InputStream bis = new ByteArrayInputStream(requestBody);
return new ServletInputStream() {
@Override
public boolean isFinished() {
return false;
}
@Override
public boolean isReady() {
return true;
}
@Override
public void setReadListener(ReadListener readListener) {
}
@Override
public int read() throws IOException {
return bis.read();
}
};
}
/**
* 重写 getReader()
*/
@Override
public BufferedReader getReader() throws IOException {
return new BufferedReader(new InputStreamReader(getInputStream()));
}
}
第二个工具类是Spring外获取容器内的Bean的类。
/**
* Spring容器工具类,用于容器外对象获取容器内的Bean
*
* @create 2017-12-12 14:19
*/
@Component
public class SpringBeanInstanceAccessor implements BeanFactoryAware {
//@Autowired不支持static属性注入,只能用实现指定接口的形式
private static BeanFactory factory;
@Override
public void setBeanFactory(BeanFactory beanFactory) throws BeansException {
factory = beanFactory;
}
/**
* 获取指定名称的Bean
*
* @param beanName
* @param clazz
* @param <T>
* @return
*/
public static <T> Object getBean(String beanName, Class<T> clazz) {
return factory.getBean(beanName, clazz);
}
/**
* 获取指定类型的Bean
*
* @param clazz
* @param <T>
* @return
*/
public static <T> Object getBean(Class<T> clazz) {
return factory.getBean(clazz);
}
}
最后一个就是Filter了,扫描指定包下的Controller生成URL集合,获取RequestBody内的参数,通过Spring的Mapper查询参数指定的key,组合key+”,”+Referer+”,”+RequestBody,然后通过加盐MD5生成token,然后和HttpHeader内的token验证。
/**
* 请求许可验证过滤器
* [token]和[timestamp]由HttpHeader传入
*
* @create 2017-12-08 11:26
*/
public class CrosRequestPermitCheckingFilter implements Filter {
private static final Logger LOGGER = LoggerFactory.getLogger(CrosRequestPermitCheckingFilter.class);
private static final String RESOURCE_PATTERN = "**/*.class";
private final List<TypeFilter> includeFilters = new LinkedList<TypeFilter>();
private final List<TypeFilter> excludeFilters = new LinkedList<TypeFilter>();
private static final String APPLICATION_NAMESPACE = "";
private static final Integer PERMIT_VALIDITY_IN_MINUTE = 5;
private static final List<String> EXPOSED_CONTROLLER_PACKAGES = Arrays.asList("com.bob.mvc.controller");
private Set<String> exposedRequestUriSet = new LinkedHashSet<String>();
/**
* 初始化,扫描开放的Controller包,生成开放的URL集合
*
* @param filterConfig
* @throws ServletException
*/
@Override
public void init(FilterConfig filterConfig) throws ServletException {
includeFilters.add(new AnnotationTypeFilter(RequestMapping.class, false));
List<String> controllers = new ArrayList<String>();
try {
for (String pkg : EXPOSED_CONTROLLER_PACKAGES) {
controllers.addAll(this.findCandidateControllers(pkg));
}
if (controllers.isEmpty()) {
if (LOGGER.isWarnEnabled()) {
LOGGER.warn("扫描指定包{}时未发现符合的开放Controller类", EXPOSED_CONTROLLER_PACKAGES.toString());
}
return;
}
generateExposedURL(this.transformToClass(controllers), APPLICATION_NAMESPACE);
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("扫描指定Controller包,发现开放URL:{}", exposedRequestUriSet.toString());
}
} catch (Exception e) {
LOGGER.error("扫描开放Controller出现异常", e);
return;
}
}
@Override
public void doFilter(ServletRequest servletRequest, ServletResponse servletResponse, FilterChain filterChain) throws IOException, ServletException {
HttpServletRequest request = new HttpRequestTwiceReadingWrapper((HttpServletRequest)servletRequest);
String path = request.getRequestURI();
if (path.endsWith("/")) {
path = path.substring(0, path.length() - 1);
}
if (!exposedRequestUriSet.contains(path)) {
filterChain.doFilter(servletRequest, servletResponse);
return;
}
try {
processCrosRequestPermitCkecking(request);
} catch (IllegalArgumentException | IllegalStateException e) {
writeResult(servletResponse, e.getMessage());
}
filterChain.doFilter(request, servletResponse);
}
/**
* 处理跨域请求许可验证
*
* @param request
* @throws Exception
*/
private void processCrosRequestPermitCkecking(HttpServletRequest request) throws IOException {
String timestamp = request.getHeader("timestamp");
Assert.hasText(timestamp, "跨域请求未指定[timestamp]");
Assert.state(isNumber(timestamp), "[timestamp]不是一个有效的时间戳");
Assert.state(Long.valueOf(timestamp) >= System.currentTimeMillis() - 1000 * 60 * PERMIT_VALIDITY_IN_MINUTE, "跨域请求许可已过期");
String referer = request.getHeader("Referer");
String requestBody = getRequestBodyInString(request);
String appcode = getTargetProperty(requestBody, "appcode");
Assert.hasText(appcode, "跨域请求[appcode]不存在");
String campusId = getTargetProperty(requestBody, "campusId");
Assert.isTrue(isNumber(campusId), "跨域请求[campusId]不正确");
String key = selectKey(appcode, campusId);
Assert.notNull(key, "跨域请求[appcode]和[campusId]相应的key不存在");
Assert.state(verify(key + "," + referer + "," + requestBody, request.getHeader("token")), "跨域请求[token]和参数不匹配");
}
/**
* 写入错误结果
*
* @param servletResponse
* @param result
* @throws IOException
*/
private void writeResult(ServletResponse servletResponse, String result) throws IOException {
servletResponse.getOutputStream().write(result.getBytes("UTF-8"));
}
/**
* 如果RequestBody内没有数据,则返回""
*
* @param request
* @return
* @throws IOException
*/
private String getRequestBodyInString(HttpServletRequest request) throws IOException {
InputStream is = request.getInputStream();
byte[] bytes = new byte[2048];
int length = is.read(bytes);
return length < 0 ? "" : new String(Arrays.copyOf(bytes, length), "UTF-8");
}
/**
* 从json字符串中获取指定属性的值
*
* @param json
* @param property
* @return
*/
private String getTargetProperty(String json, String property) {
if (StringUtils.isEmpty(json)) {
return null;
}
String[] fragments = json.split(",");
String value = null;
for (String fragment : fragments) {
if (fragment.contains(property)) {
value = fragment.substring(fragment.indexOf(":") + 1).trim();
if (value.contains("}")) {
value = value.substring(0, value.indexOf("}")).trim();
}
if (value.contains("\"")) {
value = value.substring(1, value.length() - 1).trim();
}
break;
}
}
return value;
}
/**
* 校验加盐后是否和原文一致
*
* @param password
* @param md5
* @return
*/
private boolean verify(String password, String md5) {
if (StringUtils.isEmpty(password) || StringUtils.isEmpty(md5)) {
return false;
}
char[] cs1 = new char[32];
char[] cs2 = new char[16];
for (int i = 0; i < 48; i += 3) {
cs1[i / 3 * 2] = md5.charAt(i);
cs1[i / 3 * 2 + 1] = md5.charAt(i + 2);
cs2[i / 3] = md5.charAt(i + 1);
}
return new String(cs1).equals(md5Hex(password + new String(cs2)));
}
/**
* 获取十六进制字符串形式的MD5摘要
*/
private String md5Hex(String src) {
try {
MessageDigest md5 = MessageDigest.getInstance("MD5");
byte[] bs = md5.digest(src.getBytes());
return new String(new Hex().encode(bs));
} catch (Exception e) {
return null;
}
}
/**
* 字符串是否是数字
*
* @param value
* @return
*/
private boolean isNumber(String value) {
char[] chars = ((String)value).toCharArray();
for (char c : chars) {
if (!Character.isDigit(c)) {
return false;
}
}
return true;
}
/**
* 获取符合要求的Controller名称
* @ComponentScan就是使用这些代码扫描包,然后通过TypeFilter过滤想要的
* @ComponentScan扫描时添加了一个AnnotationTypeFilter(Component.class, false)的类型过滤
*
* @param basePackage
* @return
* @throws IOException
*/
private List<String> findCandidateControllers(String basePackage) throws IOException {
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("开始扫描包[" + basePackage + "]下的所有类");
}
List<String> controllers = new ArrayList<String>();
String packageSearchPath = CLASSPATH_ALL_URL_PREFIX + replaceDotByDelimiter(basePackage) + '/' + RESOURCE_PATTERN;
ResourceLoader resourceLoader = new DefaultResourceLoader();
MetadataReaderFactory readerFactory = new SimpleMetadataReaderFactory(resourceLoader);
Resource[] resources = ResourcePatternUtils.getResourcePatternResolver(resourceLoader).getResources(packageSearchPath);
for (Resource resource : resources) {
MetadataReader reader = readerFactory.getMetadataReader(resource);
if (isCandidateController(reader, readerFactory)) {
controllers.add(reader.getClassMetadata().getClassName());
if (LOGGER.isDebugEnabled()) {
LOGGER.debug("扫描到符合要求开放Controller类:[" + controllers.get(controllers.size() - 1) + "]");
}
}
}
return controllers;
}
/**
* 通过TypeFilter得到标识了@RequestMapping的类
*
* @param reader
* @param readerFactory
* @return
* @throws IOException
*/
protected boolean isCandidateController(MetadataReader reader, MetadataReaderFactory readerFactory) throws IOException {
for (TypeFilter tf : this.excludeFilters) {
if (tf.match(reader, readerFactory)) {
return false;
}
}
for (TypeFilter tf : this.includeFilters) {
if (tf.match(reader, readerFactory)) {
return true;
}
}
return false;
}
/**
* 将类名转换为类对象
*
* @param classNames
* @return
* @throws ClassNotFoundException
*/
private List<Class<?>> transformToClass(List<String> classNames) throws ClassNotFoundException {
List<Class<?>> classes = new ArrayList<Class<?>>(classNames.size());
for (String className : classNames) {
classes.add(ClassUtils.forName(className, this.getClass().getClassLoader()));
}
return classes;
}
/**
* 用"/"替换包路径中"."
*
* @param path
* @return
*/
private String replaceDotByDelimiter(String path) {
return StringUtils.replace(path, ".", "/");
}
/**
* 内省Controllers,生成开放的URL集合
*
* @param controllers
* @param prefix
*/
private void generateExposedURL(List<Class<?>> controllers, String prefix) {
for (Class<?> controller : controllers) {
String[] classMappings = controller.getAnnotation(RequestMapping.class).value();
ReflectionUtils.doWithMethods(controller,
(method) -> {
String[] methodMappings = method.getAnnotation(RequestMapping.class).value();
exposedRequestUriSet.add(prefix + transformMappings(classMappings) + transformMappings(methodMappings));
},
(method) -> method.isAnnotationPresent(RequestMapping.class)
);
}
}
/**
* 通过工具类获取Spring容器内的Mapper,查询数据库的刀key
*
* @param appcode
* @param campusId
* @return
*/
public String selectKey(String appcode, String campusId) {
CampusMapper campusMapper = (CampusMapper )SpringBeanInstanceAccessor.getBean(CampusMapper .class);
return bankUserMapper.selectKey(Integer.valueOf(campusId));
}
/**
* 如果方法或者类上的{@linkplain RequestMapping#value()}未指定,则使用""代替
* value()仅支持单个值
*
* @param mappings
* @return
*/
private String transformMappings(String[] mappings) {
return ObjectUtils.isEmpty(mappings) ? "" : mappings[0];
}
@Override
public void destroy() {
}
}
更多推荐
已为社区贡献6条内容
所有评论(0)