利用zookeeper实现分布式锁的代码有很多,本文只是其中的一种,不再过多的介绍思想,简单说一句,就是高并发下,所有请求都去创建一个临时顺序节点,然后对所有节点进行排序,当前拿到锁的节点执行完成后,删除当前节点,zookeeper通知前一个节点,让前一个节点获得到锁,从而达到顺序执行的目的。

pom依赖:

<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
         xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
         xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
    <modelVersion>4.0.0</modelVersion>

    <groupId>zk.lock</groupId>
    <artifactId>zkLock</artifactId>
    <version>1.0-SNAPSHOT</version>
    <name>zkLock</name>
    <packaging>jar</packaging>

    <properties>
        <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
        <zkclient.version>0.10</zkclient.version>
    </properties>

    <dependencies>
        <dependency>
            <groupId>com.101tec</groupId>
            <artifactId>zkclient</artifactId>
            <version>${zkclient.version}</version>
        </dependency>
    </dependencies>

    <build>
        <finalName>ETSSchedule</finalName>
        <plugins>
            <plugin>
                <groupId>org.apache.maven.plugins</groupId>
                <artifactId>maven-compiler-plugin</artifactId>
                <configuration>
                    <source>1.8</source>
                    <target>1.8</target>
                </configuration>
            </plugin>
        </plugins>
    </build>

</project>

分布式锁的实现类:

package com.zookeeper.lock;

import org.I0Itec.zkclient.IZkDataListener;
import org.I0Itec.zkclient.ZkClient;
import org.apache.log4j.Logger;

import java.util.Collections;
import java.util.List;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.locks.Condition;
import java.util.concurrent.locks.Lock;

/**
 * @author: alex
 * @Date: 2019/4/8
 * @Description: 分布式锁实现
 */
public class ZookeeperDistributedLock implements Lock {

    protected static Logger logger = Logger.getLogger(ZookeeperDistributedLock.class);

    /**
     * zookeeper服务地址
     */
    private String hosts = "127.0.0.1:2181";

    /**
     * zk客户端
     */
    private ZkClient client;

    /**
     * 当前节点
     */
    private ThreadLocal<String> currentPath = new ThreadLocal<>();

    /**
     * 前一个节点
     */
    private ThreadLocal<String> beforePath = new ThreadLocal<>();

    /**
     * 构造方法
     */
    public ZookeeperDistributedLock() {
        this.client = new ZkClient(hosts);  //获得客端
        this.client.setZkSerializer(new MyZkSerializer()); //设置序列化类
        //判断根节点是否存在,不存在则创建
        if (!this.client.exists(Constant.LOCKPATH)) {
            try {
                this.client.createPersistent(Constant.LOCKPATH);
            } catch (Exception e) {
                logger.error("ZkClient create root node failed...");
                logger.error(e);
            }
        }
    }

    /**
     * 加锁方法
     */
    public void lock() {
        //如果没有获得到锁,那么就等待,一直到获得到锁为止
        if (!tryLock()) {
            // 没有获得锁,阻塞自己
            waitForLock();
            // 再次尝试加锁
            lock();
        }
    }

    /**
     * 释放锁
     */
    public void unlock() {
        this.client.delete(this.currentPath.get());
    }

    /**
     * 尝试获取锁
     * @return true拿到锁 false没拿到锁
     */
    public boolean tryLock() {
        //当前节点为空,说明还没有线程来创建节点
        if(this.currentPath.get() == null) {
            this.currentPath.set(this.client.createEphemeralSequential(Constant.LOCKPATH + Constant.SEPARATOR,"data"));
        }
        //获取所有子节点
        List<String> children = this.client.getChildren(Constant.LOCKPATH);
        //排序
        Collections.sort(children);
        //判断当前节点是否是最小的节点
        if(this.currentPath.get().equals(Constant.LOCKPATH + Constant.SEPARATOR + children.get(0))) {
            return true;
        } else {
            //获取当前节点的位置
            int curIndex = children.indexOf(this.currentPath.get().substring(Constant.LOCKPATH.length() + 1));
            //设置前一个节点
            beforePath.set(Constant.LOCKPATH + Constant.SEPARATOR + children.get(curIndex - 1));
        }
        return false;
    }

    /**
     * 等待锁
     */
    private void waitForLock() {
        //声明一个计数器
        CountDownLatch cdl = new CountDownLatch(1);
        IZkDataListener listener = new IZkDataListener() {
            @Override
            public void handleDataChange(String arg0, Object arg1) throws Exception {
            }
            @Override
            public void handleDataDeleted(String arg0) throws Exception {
                logger.info("Ephemeral node has been deleted....");
                //计数器减一
                cdl.countDown();
            }
        };
        //完成watcher注册
        this.client.subscribeDataChanges(this.beforePath.get(), listener);
        //阻塞自己
        if (this.client.exists(this.beforePath.get())) {
            try {
                cdl.await();
            } catch (InterruptedException e) {
                logger.error("CountDownLatch thread has been interrupted...");
                logger.error(e);
            }
        }
        //取消注册
        this.client.unsubscribeDataChanges(this.beforePath.get(), listener);
    }

    public boolean tryLock(long time, TimeUnit unit) throws InterruptedException {
        return false;
    }

    public void lockInterruptibly() throws InterruptedException {
    }

    public Condition newCondition() {
        return null;
    }

    public String getHosts() {
        return hosts;
    }

    public void setHosts(String hosts) {
        this.hosts = hosts;
    }
}

序列化类:

package com.zookeeper.lock;

import org.I0Itec.zkclient.exception.ZkMarshallingError;
import org.I0Itec.zkclient.serialize.ZkSerializer;
import org.apache.log4j.Logger;

import java.io.UnsupportedEncodingException;

/**
 * @author: alex
 * @Date: 2019/4/8
 * @Description: 序列化
 */
public class MyZkSerializer implements ZkSerializer {

    protected static Logger logger = Logger.getLogger(MyZkSerializer.class);

    /**
     * 反序列化
     * @param bytes 字节数组
     * @return 实体
     * @throws ZkMarshallingError
     */
    public Object deserialize(byte[] bytes) throws ZkMarshallingError {
        try {
            return new String(bytes, "UTF-8");
        } catch (UnsupportedEncodingException e) {
            logger.error("MyZkSerializer deserialize happened unsupportedEncodingException...");
            logger.error(e);
            throw new ZkMarshallingError(e);
        }
    }

    /**
     * 序列化
     * @param obj 实体
     * @return 字节数组
     * @throws ZkMarshallingError
     */
    public byte[] serialize(Object obj) throws ZkMarshallingError {
        try {
            return String.valueOf(obj).getBytes("UTF-8");
        } catch (UnsupportedEncodingException e) {
            logger.error("MyZkSerializer serialize happened unsupportedEncodingException...");
            logger.error(e);
            throw new ZkMarshallingError(e);
        }
    }
}

常量类:

package com.zookeeper.lock;

/**
 * @author: alex
 * @Date: 2019/4/9
 * @Description: 常量类
 */
public class Constant {

    public final static String SEPARATOR = "/";

    public final static String LOCKPATH = "/zk-lock";
}

测试类:

package com.zookeeper.lock.test;

import com.zookeeper.lock.ZookeeperDistributedLock;

/**
 * @author: alex
 * @Date: 2019/4/9
 * @Description: 业务代码
 */
public class DemoService {

    private static int count = 0; //生成计数器

    /**
     * 业务代码
     * @param name
     */
    public void sayHello(String name) {
        ZookeeperDistributedLock lock = new ZookeeperDistributedLock();
        try {
            lock.lock();
            try {
                Thread.sleep(1000);
            } catch (InterruptedException e) {
                e.printStackTrace();
            }
            count++; //加一
            System.out.println(Thread.currentThread().getName() + " say hello to " + name + "_" + count);
        }finally {
            lock.unlock();
        }
    }
}

模拟高并发请求,使用的CyclicBarrier,阻塞住线程,让线程同时执行。

package com.zookeeper.lock.test;

import java.util.concurrent.BrokenBarrierException;
import java.util.concurrent.CyclicBarrier;

/**
 * @author: alex
 * @Date: 2019/4/9
 * @Description:
 */
public class DemoThread {

    static CyclicBarrier cyclicBarrier = new CyclicBarrier(10);
    static class DemoRun implements  Runnable {

        private int i;

        public DemoRun(int i) {
            this.i = i;
        }

        @Override
        public void run() {
            try {
                DemoService demoService = new DemoService();
                cyclicBarrier.await();
                demoService.sayHello("name_" + i);
            } catch (InterruptedException e) {
                e.printStackTrace();
            } catch (BrokenBarrierException e) {
                e.printStackTrace();
            }
        }
    }

    public static void main(String[] args) {
        for(int i = 0;i<10;i++) {
            new Thread(new DemoRun(i)).start();
        }
    }
}

 

Logo

权威|前沿|技术|干货|国内首个API全生命周期开发者社区

更多推荐