前言

前面使用了grpc进行客户端和服务端之间的数据传输。客户端每次使用前都需要先Dial,使用完之后直接就Close掉了,下一次请求进来又重新Dial,这样资源消耗十分严重,于是将rfyiamcool写的连接池改了一点点,实现连接的复用

先上对比

在这里插入图片描述在这里插入图片描述

go test -bench=. -run=none
goos: linux
goarch: amd64
pkg: client
BenchmarkRpc 	    1462	    805166 ns/op
PASS
ok  	client	1.267s
go test -bench=. -run=none
goos: linux
goarch: amd64
pkg: client
BenchmarkRpc 	    5713	    220720 ns/op
PASS
ok  	client	2.272s

可以看到执行速度提升了接近4倍

代码

ServicePool中的clients是以target为键,ClientPool类型为值组成的map。 同一个target的service(s)由同1个ClientPool来维护,1个clientpool里又有多个conn可以使用,采用取余的方式来随机选用(可防止超出clientpool)

package common

import (
	"context"
	"errors"
	"google.golang.org/grpc"
	"google.golang.org/grpc/connectivity"
	"google.golang.org/grpc/keepalive"
	"google.golang.org/grpc/metadata"
	"strings"
	"sync"
	"sync/atomic"
	"time"
)

var (
	ErrNotFoundClient = errors.New("not found grpc conn")
	ErrConnShutdown   = errors.New("grpc conn shutdown")

	defaultClientPoolConnsSizeCap    = 5
	defaultDialTimeout      = 5 * time.Second
	defaultKeepAlive        = 30 * time.Second
	defaultKeepAliveTimeout = 10 * time.Second
)

type ClientOption struct {
	ClientPoolConnsSize  	int
	DialTimeOut				time.Duration
	KeepAlive				time.Duration
	KeepAliveTimeout		time.Duration
}


type ClientPool struct {
	target 		string
	option 		*ClientOption
	next 		int64
	cap 		int64

	sync.Mutex

	conns 		[]*grpc.ClientConn
}

func (cc *ClientPool) getConn() (*grpc.ClientConn, error){
	var (
		idx 		int64
		next		int64

		err			error
	)

	next = atomic.AddInt64(&cc.next, 1)
	idx = next % cc.cap
	conn := cc.conns[idx]
	if conn != nil && cc.checkState(conn) == nil {
		return conn, nil
	}

	//gc old conn
	if conn != nil {
		conn.Close()
	}

	cc.Lock()
	defer cc.Unlock()

	//double check, Prevent have been initialized
	if conn != nil && cc.checkState(conn) == nil {
		return conn, nil
	}

	conn, err = cc.connect()
	if err != nil {
		return nil, err
	}

	cc.conns[idx] = conn
	return conn, nil
}

func (cc *ClientPool) checkState(conn *grpc.ClientConn) error {
	state := conn.GetState()
	switch state {
	case connectivity.TransientFailure, connectivity.Shutdown:
		return ErrConnShutdown
	}

	return nil
}

func (cc *ClientPool) connect() (*grpc.ClientConn, error) {
	ctx, cal := context.WithTimeout(context.TODO(), cc.option.DialTimeOut)
	defer cal()
	conn, err := grpc.DialContext(ctx,
		cc.target,
		grpc.WithInsecure(),
		grpc.WithBlock(),
		grpc.WithKeepaliveParams(keepalive.ClientParameters{
					Time:		cc.option.KeepAlive,
					Timeout:	cc.option.KeepAliveTimeout,
		}))
	if err != nil {
		return nil, err
	}

	return conn, nil
}

func (cc *ClientPool) Close() {
	cc.Lock()
	defer cc.Unlock()

	for _, conn := range cc.conns {
		if conn == nil {
			continue
		}

		conn.Close()
	}
}

func NewClientPoolWithOption(target string, option *ClientOption) *ClientPool {
	if (option.ClientPoolConnsSize) <= 0 {
		option.ClientPoolConnsSize = defaultClientPoolConnsSizeCap
		}

	if option.DialTimeOut <= 0 {
		option.DialTimeOut = defaultDialTimeout
	}

	if option.KeepAlive <= 0 {
		option.KeepAlive = defaultKeepAlive
	}

	if option.KeepAliveTimeout <= 0 {
		option.KeepAliveTimeout = defaultKeepAliveTimeout
	}


	return &ClientPool{
		target: target,
		option: option,
		cap:    int64(option.ClientPoolConnsSize),
		conns:   make([]*grpc.ClientConn, option.ClientPoolConnsSize)	,
	}
	}

type TargetServiceNames struct {
	m map[string][]string
}

func NewTargetServiceNames() *TargetServiceNames {
	return &TargetServiceNames{
		m: make(map[string][]string),
	}
}

func (h *TargetServiceNames) Set(target string, serviceNames ...string) {
	if len(serviceNames) <= 0 {
		return
	}

	soureServNames := h.m[target]
	for _, sn := range serviceNames {
		soureServNames = append(soureServNames, sn)
	}

	h.m[target] = soureServNames
}

func (h *TargetServiceNames) list() map[string][]string {
	return h.m
}

func (h *TargetServiceNames) len() int {
	return len(h.m)
}

//通过属性clients以服务名为key去map里取ClientPool连接池里的clientconn
type ServiceClientPool struct {
	clients map[string]*ClientPool
	option *ClientOption
	clientCap int
}

func NewServiceClientPool(option *ClientOption) *ServiceClientPool {
	return &ServiceClientPool{
		option:    option,
		clientCap: option.ClientPoolConnsSize,
	}
	}

func (sc *ServiceClientPool) Init(m *TargetServiceNames) {

	var clients = make(map[string]*ClientPool, m.len())

	for target, servNameArr := range m.list() {
		cc := NewClientPoolWithOption(target, sc.option)
		for _, srv := range servNameArr {
			clients[srv] = cc
		}
	}

	sc.clients =  clients
}

func (sc *ServiceClientPool) GetClientWithFullMethod(fullMethod string) (*grpc.ClientConn, error){
	sn := sc.SpiltFullMethod(fullMethod)
	return sc.GetClient(sn)
}

func (sc *ServiceClientPool) GetClient(sname string) (*grpc.ClientConn, error) {
	cc, ok := sc.clients[sname]
	if !ok {
		return nil, ErrNotFoundClient
	}

	return cc.getConn()
}

func (sc *ServiceClientPool) Close(sname string) {
	cc, ok := sc.clients[sname]
	if !ok {
		return
	}

	cc.Close()
}

func (sc *ServiceClientPool) CloseAll() {
	for _, client := range sc.clients {
		client.Close()
	}
}

func (sc *ServiceClientPool) SpiltFullMethod(fullMethod string) string {
	var arr []string

	arr = strings.Split(fullMethod, "/")
	if len(arr) != 3 {
		return ""
	}

	return arr[1]
}

func (sc *ServiceClientPool) Invoke(ctx context.Context, fullMethod string, headers map[string]string, args interface{}, reply interface{}, opts ...grpc.CallOption) error {
	var md metadata.MD

	sname := sc.SpiltFullMethod(fullMethod)
	conn, err := sc.GetClient(sname)
	if err != nil {
		return err
	}

	md, flag := metadata.FromOutgoingContext(ctx)
	if flag == true {
		md = md.Copy()
	} else {
		md = metadata.MD{}
	}

	for k, v := range headers {
		md.Set(k, v)
	}

	ctx = metadata.NewOutgoingContext(ctx, md)
	return conn.Invoke(ctx, fullMethod, args, reply, opts...)
}

const (
	ADDRESS = "127.0.0.1:7778"

	SERVICENAME = "hello.FirstService"
)

var scp *ServiceClientPool

func init() {
	co := ClientOption{
		ClientPoolConnsSize: defaultClientPoolConnsSizeCap,
		DialTimeOut:         defaultDialTimeout,
		KeepAlive:           defaultKeepAlive,
		KeepAliveTimeout:    defaultKeepAliveTimeout,
	}
	scp = NewServiceClientPool(&co)
	tsn := NewTargetServiceNames()
	tsn.Set(ADDRESS, SERVICENAME)
	scp.Init(tsn)
}

func GetScp() *ServiceClientPool{
	return scp
}

使用示例

	clientPool := common.GetScp()
	reply := &pb.HelloReply{}
	err := clientPool.Invoke(context.Background(), "/hello.FirstService/SayHello", nil, &pb.HelloRequest{Name: "lubenwei", Age: "21"}, reply)
	fmt.Println("耗时:", time.Since(startTime))
	if err != nil {
		fmt.Println("超时:", err)
		return
	}
	fmt.Println(reply.Time)
Logo

更多推荐