Just for record!

Jedis实现Redis分布式锁

import org.apache.commons.lang3.StringUtils;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import redis.clients.jedis.Jedis;
import redis.clients.jedis.JedisPool;
import redis.clients.jedis.JedisPoolConfig;
import redis.clients.jedis.params.SetParams;

import java.io.*;
import java.nio.charset.StandardCharsets;
import java.util.*;


/**
 * @author lshbo
 */
public class JedisUtil {

    private final static Logger logger = LoggerFactory.getLogger(JedisUtil.class);
    private static final ThreadLocal<String> LOCK_CONTEXT_JEDIS = new ThreadLocal<>();

    private static JedisPool jedisPool;

    static {
        int port = PropertiesManualContext.REDIS_PORT;
        String proxyHost = PropertiesManualContext.REDIS_HOST;
        String pwd = PropertiesManualContext.REDIS_PWD;
        Integer dbIndex = PropertiesManualContext.REDIS_DB_INDEX;
        String timeout = PropertiesManualContext.REDIS_CONNECT_TIMEOUT;
        Integer timeoutInt;
        if (timeout.endsWith("ms")) {
            timeoutInt = Integer.parseInt(timeout.substring(0, timeout.lastIndexOf("ms")));
        } else {
            timeoutInt = Integer.parseInt(timeout);
        }


        JedisPoolConfig poolConfig = new JedisPoolConfig();
        poolConfig.setMaxTotal(1000);
        poolConfig.setMaxWaitMillis(3000);
        poolConfig.setMaxIdle(100);
        poolConfig.setMinIdle(0);
        poolConfig.setBlockWhenExhausted(true);

        if (StringUtils.isBlank(pwd)) {
            jedisPool = new JedisPool(poolConfig, proxyHost, port, timeoutInt);
        } else {
            jedisPool = new JedisPool(poolConfig, proxyHost, port, timeoutInt, pwd, dbIndex);
        }
    }

    private JedisUtil() {
    }

    /**
     * 使用默认超时时间
     */
    public boolean tryLock(String key) {
        return tryLock(key, RedisContext.EXPIRE_TIMEOUT, RedisContext.WAIT_TIMEOUT);
    }

    /**
     * redis分布式锁,获取超时即返回
     *
     * @param key               共享的key
     * @param lockExpireTimeOut 锁过期时间,设置redis自动过期时间,防止死锁,单位ms
     * @param lockWaitTimeOut   获取锁超时时间,超过该时间认为获取锁失败,单位ms
     * @return 成功上锁:true
     */
    public boolean tryLock(String key, long lockExpireTimeOut, long lockWaitTimeOut) {
        // 保持唯一,用于unlock时进行比对,保证解锁人为持锁人
        String value = Thread.currentThread().getName() + "_" + UUID.randomUUID().toString();
        try (Jedis jedis = jedisPool.getResource()) {
            long deadTimeLine = System.currentTimeMillis() + lockWaitTimeOut;
            SetParams setParams = new SetParams();
            // NX|XX, NX -- Only set the key if it does not already exist. XX -- Only set the key
            // if it already exist.
            setParams.nx();
            // EX|PX, expire time units: EX = seconds; PX = milliseconds
            setParams.px(lockExpireTimeOut);
            for (; ; ) {
                String result = jedis.set(key, value, setParams);
                if ("OK".equals(result)) {
                    LOCK_CONTEXT_JEDIS.set(value);
                    return true;
                }

                lockWaitTimeOut = deadTimeLine - System.currentTimeMillis();

                // 超过超时时间仍然没有成功获取锁
                if (lockWaitTimeOut <= 0L) {
                    return false;
                }

                try {
                    Thread.sleep(RedisContext.SLEEP_TIME);
                } catch (InterruptedException e) {
                    logger.error("Thread sleep error when getting lock", e);
                }
            }
        } catch (Exception ex) {
            logger.error("lock error", ex);
        }

        return false;
    }

    /**
     * 删除锁
     *
     * @param key   共享的锁key
     */
    public void unlock(String key) {
        // 保证唯一
        String value = LOCK_CONTEXT_JEDIS.get();
        if (value == null) {
            //无人持锁
            return;
        }
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.eval(RedisContext.REDIS_UNLOCK_LUA, Collections.singletonList(key), Collections.singletonList(value));
            LOCK_CONTEXT_JEDIS.remove();
        } catch (Exception e) {
            logger.error("unlock error", e);
        }
    }

    /**
     * 对象序列化为字符串
     */
    public static String objectSerialiable(Object obj) {
        String serStr = null;
        try {
            ByteArrayOutputStream byteArrayOutputStream = new ByteArrayOutputStream();
            ObjectOutputStream objectOutputStream = new ObjectOutputStream(byteArrayOutputStream);
            objectOutputStream.writeObject(obj);
            serStr = byteArrayOutputStream.toString("ISO-8859-1");
            serStr = java.net.URLEncoder.encode(serStr, "UTF-8");

            objectOutputStream.close();
            byteArrayOutputStream.close();
        } catch (IOException e) {
            logger.error(e.getMessage(), e);
        }

        return serStr;
    }

    /**
     * 字符串反序列化为对象
     */
    public static Object objectDeserialization(String serStr) {
        Object newObj = null;
        try {
            String redStr = java.net.URLDecoder.decode(serStr, "UTF-8");
            ByteArrayInputStream byteArrayInputStream = new ByteArrayInputStream(redStr.getBytes(StandardCharsets.ISO_8859_1));
            ObjectInputStream objectInputStream = new ObjectInputStream(byteArrayInputStream);
            newObj = objectInputStream.readObject();
            objectInputStream.close();
            byteArrayInputStream.close();
        } catch (ClassNotFoundException | IOException e) {
            logger.error(e.getMessage(), e);
        }
        return newObj;
    }

    public static String get(String k) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.get(k);
        }
    }

    /**
     * String add
     */
    public static String set(String k, String v) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.set(k, v);
        }
    }

    public static String setSer(String k, Object obj) {
        return set(k, objectSerialiable(obj));
    }

    public static Object getDes(String k) {
        String v = get(k);
        return objectDeserialization(v);
    }

    /**
     * set add
     */
    public static void sadd(String key, String... value) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.sadd(key, value);
        }
    }

    /**
     * set get all
     */
    public static Set<String> smembers(String key) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.smembers(key);
        }
    }

    /**
     * set 删除指定key
     */
    public static void srem(String key, String... value) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.srem(key, value);
        }
    }

    /**
     * set中指定值是否存在
     */
    public static boolean sismember(String key, String value) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.sismember(key, value);
        }
    }

    /**
     * sorted set add
     *
     * @param score 排序依据的分数
     */
    public static void zadd(String key, String value, double score) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.zadd(key, score, value);
        }
    }

    /**
     * 返回指定位置的集合元素,0为第一个元素,-1为最后一个元素
     *
     * @param key
     * @param start
     * @param end
     * @return
     */
    public static Set<String> zrange(String key, int start, int end) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.zrange(key, start, end);
        }
    }

    /**
     * 获取给定区间的元素,原始按照权重由高到低排序
     *
     * @param key
     * @param start
     * @param end
     * @return
     */
    public static Set<String> zrevrange(String key, int start, int end) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.zrevrange(key, start, end);
        }
    }

    /**
     * 添加对应关系,如果对应关系已存在,则覆盖
     *
     * @param key
     * @param map 对应关系
     * @return 状态,成功返回OK
     */
    public static String hmset(String key, Map<String, String> map) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.hmset(key, map);
        }
    }

    public static String hmset(String key, String hashKey, String value) {
        Map<String, String> map = new HashMap<>(2);
        map.put(hashKey, key);
        return hmset(key, map);
    }

    /**
     * 向List头部追加记录
     *
     * @param key
     * @param value
     * @return 记录总数
     */
    public static long rpush(String key, String value) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.rpush(key, value);
        }
    }

    /**
     * 向List头部追加记录
     *
     * @param key
     * @param value
     * @return 记录总数
     */
    private static long rpush(byte[] key, byte[] value) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.rpush(key, value);
        }
    }

    /**
     * 删除
     *
     * @param key
     * @return
     */
    public static long del(String key) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.del(key);
        }
    }

    /**
     * 从集合中删除成员
     *
     * @param key
     * @param value
     * @return 返回1成功
     */
    public static long zrem(String key, String... value) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.zrem(key, value);
        }
    }

    /**
     * 指定数据库设置key, value并设置过期时间
     *
     * @param dbIndex
     * @param key
     * @param value
     * @param expireTime
     */
    public static void saveValueByKey(int dbIndex, byte[] key, byte[] value, int expireTime) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.select(dbIndex);
            jedis.set(key, value);
            if (expireTime > 0) {
                jedis.expire(key, expireTime);
            }
        }
    }

    /**
     * 指定数据库获取String类型数据
     *
     * @param dbIndex
     * @param key
     * @return
     */
    public static byte[] getValueByKey(int dbIndex, byte[] key) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.select(dbIndex);
            return jedis.get(key);
        }
    }

    /**
     * 指定数据库编号删除key
     *
     * @param dbIndex 数据库编号
     * @param key
     */
    public static void deleteByKey(int dbIndex, byte[] key) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.select(dbIndex);
            jedis.del(key);
        }
    }

    /**
     * 获取总数量
     *
     * @param key
     * @return
     */
    public static long zcard(String key) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.zcard(key);
        }
    }

    /**
     * 是否存在KEY
     *
     * @param key
     * @return
     */
    public static boolean exists(String key) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.exists(key);
        }
    }

    /**
     * 重命名KEY
     *
     * @param oldKey
     * @param newKey
     * @return
     */
    public static String rename(String oldKey, String newKey) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.rename(oldKey, newKey);
        }
    }

    /**
     * 设置失效时间
     *
     * @param key
     * @param seconds
     */
    public void expire(String key, int seconds) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.expire(key, seconds);
        }
    }

    /**
     * 删除失效时间
     *
     * @param key
     */
    public static void persist(String key) {
        try (Jedis jedis = jedisPool.getResource()) {
            jedis.persist(key);
        }
    }

    /**
     * 添加一个键值对,如果键存在不在添加,如果不存在,添加完成以后设置键的有效期
     *
     * @param key
     * @param value
     * @param timeOut
     */
    public static void setnxWithTimeOut(String key, String value, int timeOut) {
        try (Jedis jedis = jedisPool.getResource()) {
            if (0 != jedis.setnx(key, value)) {
                jedis.expire(key, timeOut);
            }
        }
    }

    /**
     * 返回指定key序列值
     *
     * @param key
     * @return
     */
    public static long incr(String key) {
        try (Jedis jedis = jedisPool.getResource()) {
            return jedis.incr(key);
        }
    }

    /**
     * 获取当前时间
     * codis可能出错
     *
     * @return 秒
     */
    public static long currentTimeSecond() {
        try (Jedis jedis = jedisPool.getResource()) {
            Object obj = jedis.eval("return redis.call('TIME')", 0);
            if (obj != null) {
                List<String> list = (List) obj;
                return Long.valueOf(list.get(0));
            }
        }
        return 0L;
    }

}

Logo

鸿蒙生态一站式服务平台。

更多推荐