springboot 通过websocket将订阅的消息推送给指定用户

引入pom依赖:

	   <dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-data-redis</artifactId>
		</dependency>
		<dependency>
			<groupId>org.springframework.boot</groupId>
			<artifactId>spring-boot-starter-websocket</artifactId>
		</dependency>

websocket配置类:

package com.wty.core.config;

import org.springframework.context.annotation.Bean;
import org.springframework.stereotype.Component;
import org.springframework.web.socket.server.standard.ServerEndpointExporter;

@Component
public class WebSocketConfig {

    @Bean
    public ServerEndpointExporter serverEndpointExporter(){
        return new ServerEndpointExporter();
    }

}

application.properties配置:

# redis 配置
spring.redis.database=0
spring.redis.host=127.0.0.1
#spring.redis.host=192.168.249.112
# redis 服务器连接端口
spring.redis.port=6379
# redis服务器连接密码(默认为空)
spring.redis.password=
# 连接池最大连接数(使用负值表示没有限制)
spring.redis.lettuce.pool.max-active=-1
# 连接池最大组设等待时间(使用负值表示没有限制)
spring.redis.lettuce.pool.max-wait=-1
# 连接池最大空闲连接
spring.redis.lettuce.pool.max-idle=8
# 连接池中最小空闲连接
spring.redis.lettuce.pool.min-idle=0
# 连接超时时间(毫秒)
spring.redis.timeout=3000
# 缓存
spring.cache.redis.time-to-live=-1




#redis频道
redis.pindao=remote_detection_inference
redis.chaifen=remote_transform_generate_tiles

redisConfig配置类:

package com.wty.core.config;

import com.baomidou.mybatisplus.extension.api.R;
import com.huaru.utils.Config;
import io.lettuce.core.protocol.ConnectionFacade;
import org.reactivestreams.Subscriber;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.data.redis.connection.RedisConnectionFactory;
import org.springframework.data.redis.connection.lettuce.LettuceConnectionFactory;
import org.springframework.data.redis.core.RedisTemplate;
import org.springframework.data.redis.core.StringRedisTemplate;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.data.redis.listener.adapter.MessageListenerAdapter;
import org.springframework.data.redis.serializer.GenericJackson2JsonRedisSerializer;
import org.springframework.data.redis.serializer.StringRedisSerializer;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;

import java.io.Serializable;
import java.time.Duration;
import java.util.ArrayList;
import java.util.List;

@Configuration
//@EnableCaching
public class RedisConfig {

	//
    @Value("${spring.cache.redis.time-to-live}")
    private Duration timeToLive = Duration.ZERO;

    @Value("${redis.chaifen}")
    private String chaifen;

    @Value("${redis.pindao}")
    private String pindao;

    @Bean
    public RedisTemplate<String, Serializable> redisTemplate(LettuceConnectionFactory lettuceConnectionFactory){
        RedisTemplate<String,Serializable> redisTemplate = new RedisTemplate<>();
        redisTemplate.setKeySerializer(new StringRedisSerializer());
        redisTemplate.setValueSerializer(new GenericJackson2JsonRedisSerializer());
        redisTemplate.setHashKeySerializer(new StringRedisSerializer());
        redisTemplate.setHashKeySerializer(new GenericJackson2JsonRedisSerializer());
        redisTemplate.setConnectionFactory(lettuceConnectionFactory);
        redisTemplate.afterPropertiesSet();//配置立即生效
        return redisTemplate;
    }



    @Bean("container")
    RedisMessageListenerContainer container(RedisConnectionFactory connectionFactory, MessageListenerAdapter listenerAdapter, MessageListenerAdapter listenerAdapter2){
        RedisMessageListenerContainer container = new RedisMessageListenerContainer();
        LettuceConnectionFactory lettuceConnectionFactory = (LettuceConnectionFactory) connectionFactory;
        
        container.setConnectionFactory(lettuceConnectionFactory);

        List<PatternTopic> list = new ArrayList<>();
        list.add(new PatternTopic(pindao));
        list.add(new PatternTopic(chaifen));
        //设置监听器,可以添加多个监听类addMessageListener(MessageListener listener, Collection<? extends Topic> topics) {
        container.addMessageListener(listenerAdapter2, list);
        return container;
    }


    @Bean
    MessageListenerAdapter listenerAdapter2(RedisMessageListener2 receiver) {
        //设置监听
        return new MessageListenerAdapter(receiver);
    }

    @Bean
    StringRedisTemplate template(RedisConnectionFactory connectionFactory) {
        //StringRedisTemplate继承了RedisTemplate,是专门用于字符串操作
        return new StringRedisTemplate(connectionFactory);
    }

}

websocket配置redis监听器类:

package com.wty.core.controller;

import com.alibaba.fastjson.JSONObject;
import com.huaru.core.config.RedisMessageListener2;
import com.huaru.core.model.RemoteBreakUp;
import com.huaru.core.service.Impl.RemoteSensingImageServiceImpl;
import com.huaru.core.service.RemoteBreakUpService;
import com.huaru.core.service.RemoteSensingImageService;
import com.huaru.utils.SpringContextUtils;
import com.huaru.utils.SpringUtils;
import com.huaru.utils.StringUtil;
import lombok.extern.slf4j.Slf4j;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.data.redis.listener.PatternTopic;
import org.springframework.data.redis.listener.RedisMessageListenerContainer;
import org.springframework.scheduling.concurrent.ThreadPoolTaskExecutor;
import org.springframework.web.bind.annotation.RestController;

import javax.annotation.Resource;
import javax.websocket.*;
import javax.websocket.server.PathParam;
import javax.websocket.server.ServerEndpoint;
import java.io.IOException;
import java.util.Date;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArraySet;
import java.util.concurrent.atomic.AtomicInteger;

/**
 * userId 用户id
 * topic redis中的主题
 */
@ServerEndpoint("/socket/{tokenid}")
@RestController
@Slf4j
public class WebsocketEndpoint {


    /***
     * 用来记录当前连接数的变量
     */
    private static AtomicInteger onlineCount = new AtomicInteger(0);


    /***
     * concurrent包的线程安全Set,用来存放每个客户端对应的MyWebSocket对象
     */
//    private static CopyOnWriteArraySet<WebsocketEndpoint> webSocketSet = new CopyOnWriteArraySet<WebsocketEndpoint>();



    private static ConcurrentHashMap<String, WebsocketEndpoint> webSocketSet = new ConcurrentHashMap<>();


	//用户id
    private String tokenid;


    /**
     * 得到线程池,执行并发操作
     */
    private ThreadPoolTaskExecutor threadPoolTaskExecutor = SpringUtils.getBean(ThreadPoolTaskExecutor.class);


    /**
     * 与某个客户端的连接会话,需要通过它来与客户端进行数据收发
     */
    private Session session;

    private static final Logger LOGGER = LoggerFactory.getLogger(WebsocketEndpoint.class);

    //用来引入刚才在RedisConfig注入的类
    private RedisMessageListenerContainer container = SpringUtils.getBean("container");


    // 自定义redis监听器
    private RedisMessageListener2 listener2;

    /***
     * socket打开的处理逻辑
     * @param session
     * @param tokenid
     * @throws Exception
     */
    @OnOpen
    public void onOpen(Session session, @PathParam("tokenid") String tokenid) throws Exception {
        LOGGER.info(String.format("用户:%s 打开了Socket链接", tokenid));
        this.session = session;
        this.session.isOpen();
        this.tokenid = tokenid;
        //webSocketSet中存当前用户对象
        webSocketSet.put(tokenid,this);
        //在线人数加一
        addOnlineCount();
        listener2 = new RedisMessageListener2();
        // 放入session
        listener2.setSession(session);
        // 放入用户ID
        listener.setUserId(userId);
		 //初始化监听器
        container.addMessageListener(listener2, new PatternTopic(tokenid));
    }

    /**
     * socket关闭的处理逻辑
     */
    @OnClose
    public void onClose() {
        // 删除当前对象(this)
        webSocketSet.remove(this);
        subOnlineCount();
        getOnlineCount();
        container.removeMessageListener(listener2);
        LOGGER.info(String.format("%s关闭了Socket链接Close a html ", tokenid));
    }


    /**
     * socket收到消息的处理逻辑
     */
    @OnMessage
    public void onMessage(String message, Session session) {
        getOnlineCount();
        LOGGER.info("收到一条数据消息----------" + message + "----------------------------------------");
        //可以自己根据业务处理
        try {
            // socket心跳返回
            Map map = new HashMap();
            map.put("type", "0");
            map.put("data", "soeket连接已建立");
            JSONObject jsonObject = new JSONObject(map);
            this.sendMessage(jsonObject.toJSONString());
        } catch (Exception e) {
            e.printStackTrace();
        }
    }

    /**
     * 加一方法
     */
    public volatile int p = 0;

    public synchronized void addOne() {
        p++;
        System.out.println(Thread.currentThread().getName() + "------->" + "自增==>" + p);
    }

    /**
     * socket链接错误
     *
     * @param session
     * @param error
     */
    @OnError
    public void onError(Session session, Throwable error) {
        LOGGER.error("socket链接错误", error);
    }

    /**
     * 发送消息
     *
     * @param message
     */
    public void sendMessage(String message) {
        if (session.isOpen()) {
            getOnlineCount();
            for (int i = 0; i < 1; i++) {
                threadPoolTaskExecutor.execute(new Runnable() {
                    @Override
                    public void run() {
                        synchronized (session) {
                            try {
                                session.getBasicRemote().sendText(message);
                                //addOne();
                                log.info(Thread.currentThread().getName() + "执行了" + "发送内容为=> " + message);
                            } catch (Exception e) {
                                e.printStackTrace();
                            }
                        }
                    }
                });
            }
        }
    }


    /**
     * 发送给指定用户
     *
     * @param message
     * @param userid
     */
    public void sendMessageToUser(String message, String userid) {

        if (userid != null) {
            webSocketSet.get(userid).sendMessage(message);
            log.info("消息發送成功");
        }

    }


    //AtomicInteger是线程安全的 不需要synchronized修饰
    public static AtomicInteger getOnlineCount() {
        System.out.println(new Date() + "在线人数为" + onlineCount);
        return onlineCount;
    }

    //AtomicInteger是线程安全的 内置自增与自减的方法getAndIncrement()
    public static void addOnlineCount() {
        WebsocketEndpoint.onlineCount.getAndIncrement();
    }

    //AtomicInteger是线程安全的 内置自增与自减的方法getAndDecrement()
    public static void subOnlineCount() {
        WebsocketEndpoint.onlineCount.getAndDecrement();
    }
}

监听器类:

package com.wty.core.config;

import com.alibaba.fastjson.JSONObject;
import com.huaru.core.controller.WebsocketEndpoint;
import com.huaru.core.model.RemoteBreakUp;
import com.huaru.core.service.Impl.RemoteBreakUpServiceImpl;
import com.huaru.core.service.RemoteBreakUpService;
import com.huaru.utils.Config;
import com.huaru.utils.SpringContextUtils;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Value;
import org.springframework.data.redis.connection.Message;
import org.springframework.data.redis.connection.MessageListener;
import org.springframework.stereotype.Component;

import javax.swing.*;
import javax.websocket.Session;
import java.io.IOException;
import java.util.concurrent.atomic.AtomicInteger;

@Component
public class RedisMessageListener2 implements MessageListener {

    //用户的session
    private Session session;

    //用户的ID
    private String userId;

 

    public String getUserId() {
        return userId;
    }

    public void setUserId(String userId) {
        this.userId = userId;
    }

    public Session getSession() {
        return session;
    }

    public void setSession(Session session) {
        this.session = session;
    }


    @Override
    public void onMessage(Message message, byte[] bytes) {

        WebsocketEndpoint websocketEndpoint = (WebsocketEndpoint)SpringContextUtils.getBean("websocketEndpoint");
        // 获主题名称
        String channel = new String(message.getChannel());
        String msg = new String(message.getBody());     //消息体
        if (message!= null) {
            synchronized (this) {
                System.out.println("redis中订阅消息");
                if(channel.equals("remote_transform_generate_tiles")){
                	//根据自身的业务需要编写
                	//下面是通过用户id,调用socket给对应的用户发送信息
                     websocketEndpoint.sendMessageToUser(message.toString(),userId);
                }else if(channel.equals("remote
                	//根据自身的业务需要编写
                	//下面是通过用户id,调用socket给对应的用户发送信息
                     websocketEndpoint.sendMessageToUser(message.toString(),"userId");
                }
            }
        } 
    }

}

spring获取上下文工具类:


package com.wty.utils;

import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.stereotype.Component;

/**
 * Spring Context 工具类
 * 
 * @author Mark sunlightcs@gmail.com
 */
@Component
public class SpringContextUtils implements ApplicationContextAware {
	public static ApplicationContext applicationContext;

	@Override
	public void setApplicationContext(ApplicationContext applicationContext)
			throws BeansException {
		SpringContextUtils.applicationContext = applicationContext;
	}

	public static Object getBean(String name) {
		return applicationContext.getBean(name);
	}

	public static <T> T getBean(Class<T> requiredType) {
		return applicationContext.getBean(requiredType);
	}

	public static <T> T getBean(String name, Class<T> requiredType) {
		return applicationContext.getBean(name, requiredType);
	}

	public static boolean containsBean(String name) {
		return applicationContext.containsBean(name);
	}

	public static boolean isSingleton(String name) {
		return applicationContext.isSingleton(name);
	}

	public static Class<? extends Object> getType(String name) {
		return applicationContext.getType(name);
	}

}

参照博客:

Logo

为开发者提供学习成长、分享交流、生态实践、资源工具等服务,帮助开发者快速成长。

更多推荐