背景

目前业务需要长连接进行实时数据交互,同时已有的业务系统统一经过gateway网关调度,websocket服务是有状态服务,所以希望集成到同一个注册中心让gateway进行长连接的负载均衡转发和管理,以此写个demo进行测试

思路

提供http请求api和长连接进行消息发送

  •  首先连接需要登录后获取密钥以供进行鉴权和用户信息查询,客户端发起长连接请求在gateway中进行密钥鉴权和分发,记录用户连接netty服务信息
  • netty服务只做具体的消息组件,具体业务逻辑下沉回业务系统,采用mq进行通信,避免大量逻辑处理在消息组件服务导致消息发送阻塞
  • mq路由规则,业务系统订阅各自的业务队列,netty服务根据用户操作进行具体mq队列投递,而业务系统需要推送到客户端,先根据gateway管理用户连接信息获取路由队列key投递到mq,netty服务监听各自的队列

代码目录结构

 具体实现

gateway底层也是采用netty来实现,所以很好的支持websocket长连接的路由,普通可使用断言配置来进行websocket服务的负载转发,但不符合我们的改造点

gateway自定义负载均衡

Gateway有两种客户端负载均衡器,LoadBalancerClientFilter和ReactiveLoadBalancerClientFilter,我们需要采用Reactive模式,设置spring.cloud.loadbalancer.ribbon.enabled=false,切换到ReactiveLoadBalancerClientFilter

通过阅读ReactiveLoadBalancerClientFilter源码发现有个choose方法是进行服务的选择,ReactorLoadBalancer的实现类下的choose方法进行具体服务选择逻辑,所以我们需要实现ReactorLoadBalancer来自定义我们的负载均衡

而负载均衡的策略我们选用的是一致性哈希

自定义负载类CustomReactorNettyWebSocketClient

@Slf4j
public class WebsocketLoadBalancer implements ReactorServiceInstanceLoadBalancer {

    private final String serviceId;

    private ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider;

    private ClusterManager clusterManager;

    public WebsocketLoadBalancer(ObjectProvider<ServiceInstanceListSupplier> serviceInstanceListSupplierProvider,
                                      String serviceId,
                                      ClusterManager clusterManager) {
        this.serviceId = serviceId;
        this.serviceInstanceListSupplierProvider = serviceInstanceListSupplierProvider;
        this.clusterManager = clusterManager;
    }


    @Override
    public Mono<Response<ServiceInstance>> choose(Request request) {
        if (this.serviceInstanceListSupplierProvider == null) {
            log.warn("No servers available for service: " + this.serviceId);
        }
        ServiceInstanceListSupplier supplier = (ServiceInstanceListSupplier)this.serviceInstanceListSupplierProvider.getIfAvailable(NoopServiceInstanceListSupplier::new);
        return (supplier.get()).next().map(serviceInstances -> getInstanceResponse(serviceInstances,request));
    }

    private Response<ServiceInstance> getInstanceResponse(List<ServiceInstance> instances,Request request) {
        if (instances.isEmpty()) {
            log.warn("No servers available for service: " + this.serviceId);
            return new EmptyResponse();
        } else {
            List<String> token = ((WebsocketLoadBalancerRequest) request).getRequest().getHeaders().get(AuthConstant.SUB_PROTOCOL);
            ServiceInstance instance = null;
            if(CollectionUtils.isEmpty(token)){
                int pos = Math.abs((new Random()).nextInt(1000));
                instance = (ServiceInstance)instances.get(pos % instances.size());
            }else {
                ServerNode server = clusterManager.getServer(token.get(0));
                ServiceInstance serviceInstance =null;
                if(server !=null){
                    log.info("ws 请求进行负载均衡");
                    serviceInstance = instances.stream().filter(v -> server.getInstanceId().equals(v.getMetadata().get("nacos.instanceId"))).findFirst().orElse(null);
                    log.info("ws 负载均衡节点为:{}", JSON.toJSONString(serviceInstance));
                }
                if(serviceInstance == null){
                    int pos = Math.abs((new Random()).nextInt(1000));
                    instance = (ServiceInstance)instances.get(pos % instances.size());
                }else {
                    instance = serviceInstance;
                }
            }
            return new DefaultResponse(instance);
        }
    }
}

因为只用netty服务需要特定的负载策略,所以我们要进行配置

WebsocketLoadBalancerConfig配置类

@LoadBalancerClient(value = "nettyBus",configuration = WebsocketLoadBalancerConfig.class)
public class WebsocketLoadBalancerConfig {
    @Bean
    @LoadBalanced
    public RestTemplate restTemplate(){
        return new RestTemplate();
    }

    @Bean
    public ReactorLoadBalancer<ServiceInstance> websocketLoadBalancer(Environment environment,
                                                                      ClusterManager clusterManager,
                                                                      LoadBalancerClientFactory loadBalancerClientFactory){
        String name = environment.getProperty(LoadBalancerClientFactory.PROPERTY_NAME);
        return new WebsocketLoadBalancer(
                loadBalancerClientFactory.getLazyProvider(name, ServiceInstanceListSupplier.class),
                name,
                clusterManager
        );
    }

    @Bean
    public WebsocketLoadBalancerClientFilter websocketLoadBalancerClientFilter(LoadBalancerClientFactory clientFactory, LoadBalancerProperties properties){
        return new WebsocketLoadBalancerClientFilter(clientFactory,properties);
    }
}

上面ClusterManager 是我们自定义实现一致性hash环的实现类,因为ReactiveLoadBalancerClientFilter中的choose方法Request变量无法获取到请求头来解析数据,所以重写了ReactiveLoadBalancerClientFilter类-》WebsocketLoadBalancerClientFilter

WebsocketLoadBalancerClientFilter类,和源码无太大差距

@Slf4j
public class WebsocketLoadBalancerClientFilter extends ReactiveLoadBalancerClientFilter {

    private LoadBalancerClientFactory clientFactory;
    private LoadBalancerProperties properties;

    public WebsocketLoadBalancerClientFilter(LoadBalancerClientFactory clientFactory, LoadBalancerProperties properties) {
        super(clientFactory,properties);
        this.clientFactory = clientFactory;
        this.properties = properties;
    }

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        URI url = (URI)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
        String schemePrefix = (String)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_SCHEME_PREFIX_ATTR);
        if (url != null && ("lb".equals(url.getScheme()) || "lb".equals(schemePrefix))) {
            ServerWebExchangeUtils.addOriginalRequestUrl(exchange, url);
            if (log.isTraceEnabled()) {
                log.trace(ReactiveLoadBalancerClientFilter.class.getSimpleName() + " url before: " + url);
            }
            return this.choose(exchange).doOnNext((response) -> {
                if (!response.hasServer()) {
                    throw NotFoundException.create(this.properties.isUse404(), "Unable to find instance for " + url.getHost());
                } else {
                    URI uri = exchange.getRequest().getURI();
                    String overrideScheme = null;
                    if (schemePrefix != null) {
                        overrideScheme = url.getScheme();
                    }

                    DelegatingServiceInstance serviceInstance = new DelegatingServiceInstance((ServiceInstance)response.getServer(), overrideScheme);
                    URI requestUrl = LoadBalancerUriTools.reconstructURI(serviceInstance, uri);
                    log.info("WebsocketLoadBalancerClientFilter url chosen: " + requestUrl);
                    exchange.getAttributes().put(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR, requestUrl);
                }
            }).then(chain.filter(exchange));
        } else {
            return chain.filter(exchange);
        }
    }

    private Mono<Response<ServiceInstance>> choose(ServerWebExchange exchange) {
        URI uri = (URI)exchange.getAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
        ServerHttpRequest request = exchange.getRequest();
        ReactorLoadBalancer<ServiceInstance> loadBalancer = (ReactorLoadBalancer)this.clientFactory.getInstance(uri.getHost(), ReactorLoadBalancer.class, new Class[]{ServiceInstance.class});
        if (loadBalancer == null) {
            throw new NotFoundException("Htpp No loadbalancer available for " + uri.getHost());
        } else {
            return loadBalancer.choose(this.createRequest(request));
        }

    }

    private Request createRequest(ServerHttpRequest request) {
        return new WebsocketLoadBalancerRequest(request);
    }
}

ClusterManager 实现一致性hash环方法:这里注册中心采用nacos

  • 订阅注册中心服务上下线动作
  • 进行gateway本地服务的hash环处理

服务监听RegistrationCenterChangeEventListener

@Component
@Slf4j
public class RegistrationCenterChangeEventListener implements EventListener {



    @Value("${spring.cloud.nacos.discovery.server-addr}")
    private String discoveryServerListList;

    @Value("${server.listener.name}")
    private String serverListenerName;


    @Autowired
    private ClusterManager clusterManager;

    private NamingService namingService;


    /**
     * key serverId  value 服务实例
     */

    public RegistrationCenterChangeEventListener() {

    }

    @Override
    @PostConstruct
    public void addListener() {

        try {
            this.namingService = NamingFactory.createNamingService(discoveryServerListList);
            namingService.subscribe(serverListenerName, event -> {
                log.info(event.toString());
                if (event instanceof NamingEvent) {
                    NamingEvent namingEvent = (NamingEvent) event;
                    log.info("- - - - - - - - - 监听到服务实例【" + serverListenerName + "】变化事件为{} - - - - - - - - - ",namingEvent.getEventType());
                    clusterManager.refreshNettyServer(namingEvent.getInstances());
                }
            });
        } catch (NacosException e) {
           log.error("刷新nacos监听{}服务异常:{}",serverListenerName,e);
        }
    }


}

ClusterManager类:

/**
* 自定义用户连接netty存储信息
 * 才有本地哈希环+redis缓存用户连接节点信息
* */
@Slf4j
public class ClusterManager {

    @Resource
    private RedisService redisService;

    @Value("${netty.virtual.node}")
    private Integer VIRTUAL_NODES;


    private final StampedLock stampedLock = new StampedLock();

    private static SortedMap<Integer, ServerNode> virtualNodes = new TreeMap<Integer, ServerNode>();

    private static List<Instance> currentInstance = new ArrayList<>();

    public void refreshNettyServer(List<Instance> instances){
        log.info("刷新一致性hash环,服务数量为:{}",instances.size());
        if(CollectionUtils.isEmpty(instances)){
            //可用服务为空,自定义发送告警
        }
        if(!CollectionUtils.isEmpty(currentInstance) && instances.size()==currentInstance.size()){
            Instance instance = currentInstance.stream().filter(v -> !instances.stream().map(Instance::getInstanceId).collect(Collectors.toList()).contains(v.getInstanceId())).findFirst().orElse(null);
            if(instance==null){
                log.info("当前无实例进行变化");
                return;
            }
        }
        currentInstance = instances;
        long lock = stampedLock.writeLock();
        try {
            virtualNodes.clear();
            instances.stream().parallel().forEach(v->{
                for (Integer i = 0; i < VIRTUAL_NODES; i++) {
                    String virtualNodeName = v.getIp()+":"+v.getPort()+  "&&VN" + String.valueOf(i);
                    int hash = getHash(virtualNodeName);
                    ServerNode serverNode = ServerNode.builder().instanceId(v.getInstanceId())
                            .host(v.getIp()).port(v.getPort()).clusterName(v.getServiceName()).build();
                    virtualNodes.put(hash, serverNode);
                }
            });
        }catch (Exception e){
            log.error("刷新一致性hash环异常:{]",e);
        }finally {
            stampedLock.unlockWrite(lock);
            log.info("刷新一致性hash环结果为:{}",JSON.toJSONString(instances));
        }
    }

    private  int getHash(String str) {
        final int p = 16777619;
        int hash = (int)2166136261L;
        for (int i = 0; i < str.length(); i++) {
            hash = (hash ^ str.charAt(i)) * p;
        }
        hash += hash << 13;
        hash ^= hash >> 7;
        hash += hash << 3;
        hash ^= hash >> 17;
        hash += hash << 5;

        // 如果算出来的值为负数则取其绝对值
        if (hash < 0) {
            hash = Math.abs(hash);
        }
        return hash;
    }

    /**
     * 得到应当路由到的结点
     */
    public  ServerNode getServer(String node) {
        // 得到带路由的结点的Hash值
        int hash = getHash(node);
        // 得到大于该Hash值的所有Map
        long optimisticRead = stampedLock.tryOptimisticRead();
        SortedMap<Integer, ServerNode> subMap = virtualNodes;
        if(!stampedLock.validate(optimisticRead)){
            optimisticRead = stampedLock.readLock();
            try {
                subMap = virtualNodes;
            }finally {
                stampedLock.unlockRead(optimisticRead);
            }
        }
        if(!virtualNodes.isEmpty()){
            subMap = virtualNodes.tailMap(hash);
            // 第一个Key就是顺时针过去离node最近的那个结点
            Integer i = subMap.firstKey();
            // 返回对应的虚拟节点名称,这里字符串稍微截取一下
            ServerNode virtualNode = subMap.get(i);
            return virtualNode;
        }
        return null;
    }


    /*
    * 客户端上线处理
    * */
    public Channel addChannel(Channel channel,UserChannelInfo userChannelInfo) {
        //添加服务器下的用户
        redisService.hset(String.format(Constant.NETTY_CONNECT_USER_KEY,userChannelInfo.getServerHost()+userChannelInfo.getServerPort()),
                userChannelInfo.getUserId(),JSON.toJSONString(userChannelInfo));
        //添加缓存用户连接所在服务器
        redisService.hset(Constant.USER_CONNECT_KEY, userChannelInfo.getUserId(),JSON.toJSONString(userChannelInfo));
        return channel;

    }

    /**
     * 客户端下线处理
     * @param userChannelInfo
     */
    public void channelCloseHandle(UserChannelInfo userChannelInfo) {
        log.info("- - - - - - - - - " + userChannelInfo.getUserId() + " offline from server " + userChannelInfo.getServerHost() + "  - - - - - - - - - ");
        redisService.hdel(String.format(Constant.NETTY_CONNECT_USER_KEY,userChannelInfo.getServerHost()+userChannelInfo.getServerPort()),
                userChannelInfo.getUserId());
        redisService.hdel(Constant.USER_CONNECT_KEY,userChannelInfo.getUserId());
    }



}

最后进行gateway连接鉴权,在gateway的Filter链路中添加我们自定义的Filter,获取到连接请求头解析结果下放到netty服务中进行最终连接结果

@Component
@Slf4j
public class WebsocketFilter implements GlobalFilter, Ordered {

    @Override
    public Mono<Void> filter(ServerWebExchange exchange, GatewayFilterChain chain) {
        URI requestUrl = (URI)exchange.getRequiredAttribute(ServerWebExchangeUtils.GATEWAY_REQUEST_URL_ATTR);
        String scheme = requestUrl.getScheme();
        log.info("gateway-》"+WebsocketFilter.class+"连接转发协议为:{}",scheme);
        if(Constant.WS_PROTOCOL.equals(scheme) || Constant.WSS_PROTOCOL.equals(scheme)){
            ServerHttpRequest request = exchange.getRequest();
            HttpHeaders headers = request.getHeaders();
            List<String> protocols = headers.get(AuthConstant.SUB_PROTOCOL);
            if (protocols != null) {
                protocols = (List)headers.get("Sec-WebSocket-Protocol").stream().flatMap((header) -> {
                    return Arrays.stream(StringUtils.commaDelimitedListToStringArray(header));
                }).map(String::trim).collect(Collectors.toList());
            }
            String token = protocols.get(0);
            log.info("连接协议"+AuthConstant.SUB_PROTOCOL+"为:{}", token);
            String userId = JWTUtil.getUserIdofSring(token);
            ServerHttpRequest build = exchange.getRequest().mutate().headers(wssheaders -> {
                wssheaders.add(AuthConstant.AUTH_HEADER, userId);
            }).build();
            ServerWebExchange wssExchange = exchange.mutate().request(build).build();
            return chain.filter(wssExchange);
        }
        return chain.filter(exchange);
    }

    @Override
    public int getOrder() {
        return Ordered.LOWEST_PRECEDENCE - 2;
    }

}

通过上述可以很好的控制每一个用户同一次长连接都在同一台服务(服务无宕机情况),后续还需要考虑服务下线连接迁移问题

netty服务注册到nacos

使用springboot集成netty,同时去掉tomact服务启动,只启动netty服务一个端口,修改启动类

/**
    * 将内置启动tomact去掉
    * */
    public static void main(String[] args) {
        new SpringApplicationBuilder(NettyBusApplication.class).web(WebApplicationType.NONE).run(args);
    }
NettyServer配置,同是注册到nacos注册中心
@Configuration
@Slf4j
public class NettyServer implements ApplicationListener<ApplicationStartedEvent> {

    @Value("${netty.port}")
    private int port;

    @Value("${netty.name}")
    private String name;

    @Value("${spring.cloud.nacos.discovery.server-addr}")
    private String nacosServer;

    private EventLoopGroup bossGroup = null;
    private EventLoopGroup workerGroup = null;


    @Override
    public void onApplicationEvent(@NonNull ApplicationStartedEvent applicationStartedEvent) {
        start();
    }

    public void start() {
        bossGroup = new NioEventLoopGroup(2);
        workerGroup  = new NioEventLoopGroup(6);
        ServerBootstrap bootstrap  = new ServerBootstrap();
        bootstrap.option(ChannelOption.SO_BACKLOG, 1024);
        bootstrap.group(bossGroup,workerGroup )
                .channel(NioServerSocketChannel.class)
                .localAddress(this.port)
                //保持连接数
                .option(ChannelOption.SO_BACKLOG, 600)
                //有数据立即发送
                .option(ChannelOption.TCP_NODELAY, true)
                //保持连接
                .childOption(ChannelOption.SO_KEEPALIVE, true)
                //处理新连接,按照tcp和websocket进行区分
                .childHandler(new NettyServerInitializer());
        ChannelFuture channelFuture  = bootstrap.bind().syncUninterruptibly().addListener(future -> {
            NamingService namingService = NamingFactory.createNamingService(nacosServer);
            //将服务注册到注册中心
            InetAddress address = InetAddress.getLocalHost();
            namingService.registerInstance(name, address.getHostAddress(), Integer.valueOf(port));
            log.info(name + "注册nacos成功");
            log.info(NettyServer.class + "已启动,正在监听:"+this.port);

        });
        channelFuture.channel().closeFuture().addListener(future -> {
            destroy();
        });

    }

    public void destroy() {
        log.info(NettyServer.class +"服务stop");
        workerGroup.shutdownGracefully();
        bossGroup.shutdownGracefully();
    }

}

目前还考虑可能其他协议的连接,所以在netty服务中进行了不同协议的处理

@Component
@ChannelHandler.Sharable
public class NettyServerInitializer extends ChannelInitializer<SocketChannel> {



    @Override
    protected void initChannel(SocketChannel socketChannel){
        socketChannel.pipeline().addLast("socketChoose", SpringContextUtil.getBeanByClass(SocketChooseHandler.class));
    }
}

@Component
@Slf4j
@ChannelHandler.Sharable
public class SocketChooseHandler extends ChannelInboundHandlerAdapter {



    /**
     * WebSocket握手的协议前缀
     */
    private static final String WEBSOCKET_PREFIX = "GET /";

    private final static String match = "sec-websocket-protocol:([\\s\\S]*?)sec-websocket-version";



    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        if(msg instanceof ByteBuf){
            ByteBuf byteBuf = Unpooled.wrappedBuffer((ByteBuf) msg);
            String protocol = getBufStart(byteBuf);
            if (protocol.startsWith(WEBSOCKET_PREFIX)) {
                //  websocket连接时,执行以下处理
                // HttpServerCodec:将请求和应答消息解码为HTTP消息
                ctx.pipeline().addLast("http-codec", new HttpServerCodec());
                // HttpObjectAggregator:将HTTP消息的多个部分合成一条完整的HTTP消息
                ctx.pipeline().addLast("aggregator", new HttpObjectAggregator(65535));
                // ChunkedWriteHandler:分块
                ctx.pipeline().addLast("http-chunked", new ChunkedWriteHandler());
                ctx.pipeline().addLast("WebSocketAggregator", new WebSocketFrameAggregator(65535));
                // 约定心跳,则主动断开channel释放资源
                ctx.pipeline().addLast(new NettyWebSocketAuthHandler());
            } else {
                ctx.pipeline().addLast(new StringDecoder());
                //后续改造自定义编解码
                ctx.pipeline().addLast(new StringEncoder());
                ctx.pipeline().addLast(new NettyServerHandler());
            }
            ctx.pipeline().remove(this.getClass());
        }else {
            log.info("无效连接");
        }
        super.channelRead(ctx,msg);
    }

    private String getBufStart(ByteBuf in) {
        int length = in.readableBytes();
        // 标记读位置
        in.markReaderIndex();
        byte[] content = new byte[length];
        in.readBytes(content);
        return new String(content);
    }




}

同时连接前我们需要进行权限的验证NettyWebSocketAuthHandler ,通过才进行最终的连接成功

@Slf4j
@Component
@ChannelHandler.Sharable
public class NettyWebSocketAuthHandler extends ChannelInboundHandlerAdapter {

    private final String headerKey = "Sec-WebSocket-Protocol";




    @Override
    public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception {
        //解析headers上的token放到channel中
        String token = "";
        if(msg instanceof FullHttpRequest){
            FullHttpRequest request = (FullHttpRequest) msg;
            String userId = request.headers().get(AuthConstant.AUTH_HEADER);
            token = request.headers().get(AuthConstant.SUB_PROTOCOL);
            String uri = request.getUri();
            log.info("Auth获取的uri为:{},userId为:{},token为:{}",uri,userId,token);
            if(StringUtils.isNotBlank(userId)){
                log.info("Auth成功,userId:{}",userId);
                AttributeKey<String> auth = AttributeKey.valueOf("userId");
                ctx.channel().attr(auth).set(userId);
            }
            request.setUri(URLUtil.getPath(uri));
        }
        //用于处理websocket, /ws为访问websocket时的uri,同时设置子协议
        handlerAdded(ctx,token);
        super.channelRead(ctx,msg);
    }


    public void handlerAdded(ChannelHandlerContext ctx,String token) {
        ChannelPipeline cp = ctx.pipeline();
        if (cp.get(WebSocketServerProtocolHandler.class) == null) {
            ctx.pipeline().addLast("ProtocolHandler", new WebSocketServerProtocolHandler("/nettyBus", token,true));
            ctx.pipeline().addLast(new NettyWebSocketHandler());
            ctx.pipeline().remove(this.getClass());
        }
    }

    @Override
    public void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception {
        log.info("NettyWebSocketAuthHandler异常:{}",cause);
    }

}

gateway每次连接都缓存了用户的长连接信息和通道,通道管理有一定的性能影响,导致gateway维护了大量的连接,后续优化

改造点大概差不多,剩下具体业务逻辑就不一一说了

Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐