前言

        该JBDC工具类实现了基本的增、删、改,查的操作,这并不是最主要的,主要的是实现这个工具采用的技术和思想,例如,注解的创建,使用,以及代码自动加载被注解的实体类对象,还有泛型的使用,模仿Mybatis的底层原理实现。

实现流程

创建注解

Table

用来标注实体类的对应的表

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.TYPE)
@Retention(RetentionPolicy.RUNTIME)
public @interface Table {
    String value() default "";
}

Column

用来标注属性对应的表的字段

import java.lang.annotation.ElementType;
import java.lang.annotation.Retention;
import java.lang.annotation.RetentionPolicy;
import java.lang.annotation.Target;

@Target(ElementType.FIELD)
@Retention(RetentionPolicy.RUNTIME)
public @interface Column {

    String value() default "";
}

创建实体类

1、创建一个实体类来测试,随便找了一个数据表进行测试,数据表结构如图

2、根据数据表创建实体类,可以看到有些字段注解,有些不注解,继续看下去就知道了

import com.cw.jdbc.mysql.util.api.Column;
import com.cw.jdbc.mysql.util.api.Table;


@Table(value = "reader")
public class User {

    private Integer r_id;  // 不注解

    @Column( value = "r_name")
    private String name; // 注解赋值

    @Column
    private String r_sex; // 注解不赋值

    @Column(value = "r_age")
    private Integer age;


    public int getR_id() {
        return r_id;
    }

    public void setR_id(int r_id) {
        this.r_id = r_id;
    }

    public String getName() {
        return name;
    }

    public void setName(String name) {
        this.name = name;
    }

    public String getR_sex() {
        return r_sex;
    }

    public void setR_sex(String r_sex) {
        this.r_sex = r_sex;
    }

    public int getAge() {
        return age;
    }

    public void setAge(int age) {
        this.age = age;
    }

    @Override
    public String toString() {
        return "com.cw.jdbc.mysql.bean.User{" +
                "age=" + age +
                ", name='" + name + '\'' +
                ", r_id=" + r_id +
                ", r_sex='" + r_sex + '\'' +
                '}';
    }
}

编写工具类

仔细看代码,相关的解释都写在代码上了

import com.cw.jdbc.mysql.util.api.Column;
import com.cw.jdbc.mysql.util.api.Table;

import java.lang.reflect.Field;
import java.sql.*;
import java.util.ArrayList;
import java.util.List;

public class JDBCUtil {

    // 相当一个线程池连接,避免数据库连接频繁的连接和关闭
    private static ThreadLocal<Connection> conHolder = new ThreadLocal<Connection>();

    // 连接数据库的相关信息,这里也可以将这些信息写到配置文件中,然后读取
    private static final String DRIVER = "com.mysql.jdbc.Driver";
    private static final String URL = "jdbc:mysql://localhost:3306/ebook?useUnicode=true&characterEncoding=UTF-8&useSSL=false";
    private static final String USERNAME = "root";
    private static final String PASSWORD = "123456";

    // 注册驱动,static在调用这个类时使用一次,也就使用一次
    static {
        try {
            Class.forName(DRIVER);
        } catch (Exception e) {
            e.printStackTrace();
        }
    }


    // 获取连接的方法
    private static Connection getConnection() {
        try {
            Connection conn = conHolder.get();
            // 如果线程池中没有连接,就创建连接,存放到线程池中,否则直接返回即可
            if (conn == null) {
                conn = DriverManager.getConnection(URL, USERNAME, PASSWORD);
                conHolder.set(conn);
            }
            return conn;
        } catch (SQLException e) {
            throw new RuntimeException(e);
        }
    }

    // 关闭连接
    public static void close() throws SQLException {
        Connection conn = getConnection();
        if (conn != null) {
            conn.close();
            conHolder.remove();
        }
    }

    /**
     *  查询操作
     * @param object 如果该对象中的所有属性值都为null,做全表查询,否则条件查询
     * @return
     * @throws SQLException
     * @throws IllegalAccessException
     * @throws NoSuchFieldException
     * @throws InstantiationException
     */
    public static List query(Object object) throws SQLException, IllegalAccessException, NoSuchFieldException, InstantiationException {

        // 获取数据库连接
        Connection conn = getConnection();

        Class clazz = object.getClass();
        // 获取类的表的注解值
        Table table = (Table)clazz.getAnnotation(Table.class);
        String tableName = "";
        // 如果实体类没有注解则默认使用类名
        if (table == null){
            tableName = clazz.getName();
        }else{
            // 获取表名
            tableName = table.value();
        }

        String sql = "";

        // 获取类的所有属性
        Field[] fields = clazz.getDeclaredFields();
        // 定义条件查询参数语句
        String queryParam = "";
        // 条件查询参数对应的参数值
        ArrayList<Object> paramList = new ArrayList<>();
        String columnName = "";
        for (Field field:fields){
            field.setAccessible(true);
            Object value = field.get(object);

            // 属性值不为空的属性作为条件进行查询
            if (value != null){
                Column column = field.getAnnotation(Column.class);
                // 根据注解的值作为表中的字段名,如果属性没有注解,就默认属性名作为字段名
                // 因此此时就需要属性名与表的字段名必须一致
                if ( column != null){
                    columnName = column.value();
                }else{
                    columnName = field.getName();
                }
                // 拼接查询参数语句
                queryParam = queryParam + " " + columnName + " = ? " + "and";
                paramList.add(value);
            }
        }

        // 判断是条件查询还是全表扫描
        if (queryParam.equals("")){
            sql = "select * from " + tableName;
        }else{
            String temp = "select * from " + tableName + " where" + queryParam;
            // 将最后的" and"删除
            sql = temp.substring(0,temp.length()-3);
        }
        PreparedStatement pstmt = conn.prepareStatement(sql);
        for (int i = 0; i < paramList.size(); i++){
            pstmt.setObject(i+1,paramList.get(i));
        }
        ResultSet resultSet = pstmt.executeQuery();
        return handler(resultSet,pstmt,clazz);
    }

    /**
     * 对查询返回的数据集进行自动封装到Bean,返回的是一个bean集合
     *
     * @param resultSet 结果集参数
     * @param clazz 类的类型
     * @return
     */
    private static <T> List<T> handler(ResultSet resultSet,PreparedStatement pstmt, Class<T> clazz) throws SQLException, IllegalAccessException, InstantiationException, NoSuchFieldException {

        List<T> beanList = new ArrayList<T>();
        while (resultSet.next()) {
            // 创建一个Bean对象
            T bean = clazz.newInstance();

            // 获取类中的所有属性
            Field[] fields = clazz.getDeclaredFields();

            for (Field field : fields) {
                Column column = (Column) field.getAnnotation(Column.class);

                // 获取属性中的注解中的对应表的字段名
                String columnName = "";
                // 如果没有使用注解或者注解中没有规定字段名,默认取属性名
                if (column == null || column.value().equals("")){
                    columnName = field.getName();
                }else{
                    columnName = column.value();
                }
                // 根据字段名获取值
                Object value = resultSet.getObject(columnName);
                // 私有的属性不能直接访问
                field.setAccessible(true);
                // 给bean中的属性赋值
                field.set(bean, value);
            }
            beanList.add(bean);
        }
        pstmt.close();
        return beanList;
    }


    /**
     *  更新操作
     * @param oldObject 锁定需要更新的行(需要更新的数据)
     * @param newObject 更新的数据
     * @throws SQLException
     * @throws IllegalAccessException
     */
    public static void update(Object oldObject, Object newObject) throws SQLException, IllegalAccessException {

        Class clazz = newObject.getClass();

        Field[] fields = clazz.getDeclaredFields();

        // 获取表名
        Table table = (Table)clazz.getAnnotation(Table.class);

        String tableName = "";
        // 如果类没有被注解或者注解没有被赋值,默认取类名作为表名
        if (table == null || table.value().equals("")){
            tableName = clazz.getName();
        }else{
            tableName = table.value();
        }
        String setParams = "";
        String queryParams = "";
        String columnName = "";
        ArrayList<Object> params = new ArrayList<Object>();
        // 更新的字段,属性中
        for (Field field: fields){
            field.setAccessible(true);
            Object newValue = field.get(newObject);

            if (newValue != null){
                Column column = field.getAnnotation(Column.class);
                if (column == null || column.value().equals("")){
                    columnName = field.getName();
                }else{
                    columnName = column.value();
                }
                setParams = setParams + columnName + " = ?" + ",";
                params.add(newValue);
            }
        }
        // 将最后的","删除
        setParams = setParams.substring(0,setParams.length()-1);
        // 用来条件查询需要更新的行的字段
        for (Field field: fields){
            field.setAccessible(true);
            Object oldValue = field.get(oldObject);

            if (oldValue != null){
                Column column = field.getAnnotation(Column.class);
                if (column == null || column.value().equals("")){
                    columnName = field.getName();
                }else{
                    columnName = column.value();
                }
                queryParams = queryParams + " " + columnName + " = ? " + "and";
                params.add(oldValue);
            }
        }

        // 将最后的"and"删除
        queryParams = queryParams.substring(0,queryParams.length()-3);

        String sql = "update " + tableName + " set " + setParams + " where" + queryParams;
        exit(sql, params);
    }


    /**
     *  插入操作
     * @param object 插入的对象数据
     * @throws SQLException
     * @throws IllegalAccessException
     */
    public static void insert(Object object) throws SQLException, IllegalAccessException {
        Class clazz = object.getClass();
        // 获取表名
        String tableName = "";
        Table table = (Table) clazz.getAnnotation(Table.class);
        if(table == null || table.value().equals("")){
            tableName = clazz.getName();
        }else{
            tableName = table.value();
        }

        Field [] fields = clazz.getDeclaredFields();
        String sql = "";
        String columnName = "";
        String insertParams = "";
        String valueParams = "";
        ArrayList<Object> params = new ArrayList<>();
        for (Field field:fields){
            field.setAccessible(true);
            Object value = field.get(object);
            if (value != null){
                Column column = field.getAnnotation(Column.class);
                if (column == null || column.value().equals("")){
                    columnName = field.getName();
                }else{
                    columnName = column.value();
                }

                insertParams = insertParams + columnName + ",";
                valueParams = valueParams + "?,";
                params.add(value);
            }
        }
        // 删除最后的","
        insertParams = insertParams.substring(0,insertParams.length()-1);
        valueParams = valueParams.substring(0,valueParams.length()-1);
        sql = "insert into " + tableName + "(" + insertParams + ") values(" + valueParams + ")";
        exit(sql, params);
    }

    /**
     * 删除操作
     * @param object 删除的对象
     * @throws SQLException
     * @throws IllegalAccessException
     */
    public static void delete(Object object) throws SQLException, IllegalAccessException {

        Class clazz = object.getClass();
        // 获取表名
        String tableName = "";
        Table table = (Table) clazz.getAnnotation(Table.class);
        if(table == null || table.value().equals("")){
            tableName = clazz.getName();
        }else{
            tableName = table.value();
        }

        Field [] fields = clazz.getDeclaredFields();
        String sql = "";
        String columnName = "";
        String queryParams = "";
        ArrayList<Object> params = new ArrayList<>();
        for (Field field:fields){
            field.setAccessible(true);
            Object value = field.get(object);
            if (value != null){
                Column column = field.getAnnotation(Column.class);
                if (column == null || column.value().equals("")){
                    columnName = field.getName();
                }else{
                    columnName = column.value();
                }

                queryParams = queryParams + " " + columnName + " =? " + "and";
                params.add(value);
            }
        }
        // 删除最后的"and"
        queryParams = queryParams.substring(0,queryParams.length()-3);
        sql = "delete from " + tableName + " where" + queryParams;
        exit(sql, params);
    }

    /**
     * 提取写操作(增,删,改)的共同操作
     *
     * @param sql
     * @param params
     * @throws SQLException
     */
    private static void exit(String sql,ArrayList<Object> params) throws SQLException {

        Connection conn = getConnection();

        PreparedStatement pstmt = conn.prepareStatement(sql);

        for (int i = 0; i < params.size(); i++){
            pstmt.setObject(i+1,params.get(i));
        }
        pstmt.executeUpdate();
        pstmt.close();
    }
}

测试

1、全表查询

public static void main(String[] args) throws SQLException, NoSuchFieldException, InstantiationException, IllegalAccessException {

        // 全表查询
        User user1 = new User();

        List query = JDBCUtil.query(user1);

        query.forEach( r->{
            System.out.println(r.toString());
        } );

        // 关闭连接
        JDBCUtil.close();
    }

 2、条件查询

public static void main(String[] args) throws SQLException, NoSuchFieldException, InstantiationException, IllegalAccessException {

        // 条件查询
        User user2 = new User();
        user2.setName("李白");
        List query = JDBCUtil.query(user2);
        query.forEach( r->{
            System.out.println(r.toString());
        } );

        // 关闭连接
        JDBCUtil.close();
    }

 

Logo

助力广东及东莞地区开发者,代码托管、在线学习与竞赛、技术交流与分享、资源共享、职业发展,成为松山湖开发者首选的工作与学习平台

更多推荐