大家好,今天来聊聊如何给自动化脚本加上高效的缓存机制。

为什么要缓存?

看几个典型场景:

  1. API数据重复请求:同样的数据反复请求,浪费带宽还容易被限流
  2. 数据库频繁查询:热点数据反复读取,增加数据库压力
  3. 文件解析重复执行:大文件每次都重新解析,CPU飙升
  4. 计算结果重复计算:同样的输入反复计算,纯属浪费

合理的缓存可以提升几倍甚至几十倍的性能。

一、基础内存缓存

从最简单的开始:

"""基础内存缓存"""

import time
import functools
from typing import Any, Callable, Optional, Tuple
from dataclasses import dataclass
import hashlib

@dataclass
class CacheItem:
    """缓存项"""
    value: Any
    created_at: float
    ttl: Optional[float] = None  # 生存时间(秒)
    
    def is_expired(self) -> bool:
        """检查是否过期"""
        if self.ttl is None:
            return False
        return time.time() - self.created_at > self.ttl


class SimpleCache:
    """简单内存缓存"""
    
    def __init__(self, max_size: int = 128):
        self._cache = {}
        self._max_size = max_size
        self._access_order = []  # LRU顺序
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        if key not in self._cache:
            return None
        
        item = self._cache[key]
        
        # 检查过期
        if item.is_expired():
            self.delete(key)
            return None
        
        # 更新访问顺序(LRU)
        self._access_order.remove(key)
        self._access_order.append(key)
        
        return item.value
    
    def set(self, key: str, value: Any, ttl: Optional[float] = None):
        """设置缓存"""
        # 容量满时删除最久未使用的
        if len(self._cache) >= self._max_size and key not in self._cache:
            oldest = self._access_order.pop(0)
            del self._cache[oldest]
        
        self._cache[key] = CacheItem(
            value=value,
            created_at=time.time(),
            ttl=ttl
        )
        
        # 更新访问顺序
        if key in self._access_order:
            self._access_order.remove(key)
        self._access_order.append(key)
    
    def delete(self, key: str):
        """删除缓存"""
        if key in self._cache:
            del self._cache[key]
            self._access_order.remove(key)
    
    def clear(self):
        """清空缓存"""
        self._cache.clear()
        self._access_order.clear()
    
    def cleanup_expired(self):
        """清理过期项"""
        expired_keys = [k for k, v in self._cache.items() if v.is_expired()]
        for key in expired_keys:
            self.delete(key)
        return len(expired_keys)


# 缓存装饰器
def cached(ttl: Optional[float] = None, max_size: int = 128):
    """缓存装饰器"""
    _cache = SimpleCache(max_size)
    
    def decorator(func: Callable):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 生成缓存key
            key = _generate_key(func.__name__, args, kwargs)
            
            # 尝试从缓存获取
            result = _cache.get(key)
            if result is not None:
                return result
            
            # 执行函数
            result = func(*args, **kwargs)
            
            # 存入缓存
            _cache.set(key, result, ttl)
            
            return result
        
        # 添加缓存控制方法
        wrapper.cache = _cache
        wrapper.clear_cache = _cache.clear
        return wrapper
    
    return decorator


def _generate_key(func_name: str, args: tuple, kwargs: dict) -> str:
    """生成缓存key"""
    # 简单实现:序列化参数
    key_parts = [func_name]
    
    for arg in args:
        if hasattr(arg, '__dict__'):
            # 对象:用id
            key_parts.append(str(id(arg)))
        else:
            key_parts.append(str(arg))
    
    for k, v in sorted(kwargs.items()):
        key_parts.append(f"{k}={v}")
    
    key_str = '|'.join(key_parts)
    
    # hash化(太长的话)
    if len(key_str) > 100:
        return hashlib.md5(key_str.encode()).hexdigest()
    
    return key_str


# 使用示例
if __name__ == '__main__':
    @cached(ttl=60, max_size=100)
    def fetch_data(api_url: str, params: dict = None):
        """模拟API请求"""
        print(f"正在请求: {api_url}")
        time.sleep(1)  # 模拟网络延迟
        return {"status": "success", "data": [1, 2, 3]}
    
    # 第一次调用(会打印"正在请求")
    result1 = fetch_data("https://api.example.com/data")
    
    # 第二次调用(使用缓存,不会打印)
    result2 = fetch_data("https://api.example.com/data")
    
    print(f"结果相同: {result1 == result2}")
    
    # 清除缓存
    fetch_data.clear_cache()
    
    # 再调用(会重新请求)
    result3 = fetch_data("https://api.example.com/data")

二、多级缓存实现

更实用的多级缓存策略:

"""多级缓存"""

import time
import threading
import pickle
from pathlib import Path
from typing import Any, Optional, TypeVar, Generic
from abc import ABC, abstractmethod

T = TypeVar('T')


class CacheLevel(ABC):
    """缓存层基类"""
    
    @abstractmethod
    def get(self, key: str) -> Optional[Any]:
        pass
    
    @abstractmethod
    def set(self, key: str, value: Any, ttl: Optional[float] = None):
        pass
    
    @abstractmethod
    def delete(self, key: str):
        pass
    
    @abstractmethod
    def clear(self):
        pass


class MemoryCache(CacheLevel):
    """内存缓存层"""
    
    def __init__(self, max_size: int = 1000, default_ttl: float = 300):
        self._cache = {}
        self._max_size = max_size
        self._default_ttl = default_ttl
        self._lock = threading.RLock()
        self._access_times = {}
    
    def get(self, key: str) -> Optional[Any]:
        with self._lock:
            if key not in self._cache:
                return None
            
            item = self._cache[key]
            if self._is_expired(item):
                self.delete(key)
                return None
            
            self._access_times[key] = time.time()
            return item['value']
    
    def set(self, key: str, value: Any, ttl: Optional[float] = None):
        with self._lock:
            if len(self._cache) >= self._max_size:
                self._evict_lru()
            
            self._cache[key] = {
                'value': value,
                'created_at': time.time(),
                'ttl': ttl or self._default_ttl
            }
            self._access_times[key] = time.time()
    
    def delete(self, key: str):
        with self._lock:
            self._cache.pop(key, None)
            self._access_times.pop(key, None)
    
    def clear(self):
        with self._lock:
            self._cache.clear()
            self._access_times.clear()
    
    def _is_expired(self, item: dict) -> bool:
        if item['ttl'] is None:
            return False
        return time.time() - item['created_at'] > item['ttl']
    
    def _evict_lru(self):
        """淘汰最久未使用的"""
        if not self._access_times:
            return
        
        lru_key = min(self._access_times.items(), key=lambda x: x[1])[0]
        self.delete(lru_key)


class DiskCache(CacheLevel):
    """磁盘缓存层"""
    
    def __init__(self, cache_dir: str = './cache', max_size_mb: int = 500):
        self._cache_dir = Path(cache_dir)
        self._cache_dir.mkdir(parents=True, exist_ok=True)
        self._max_size = max_size_mb * 1024 * 1024
        self._lock = threading.Lock()
        self._memory_index = {}  # 内存索引
    
    def get(self, key: str) -> Optional[Any]:
        with self._lock:
            if key not in self._memory_index:
                return None
            
            meta = self._memory_index[key]
            
            # 检查过期
            if meta['ttl'] and time.time() - meta['created_at'] > meta['ttl']:
                self.delete(key)
                return None
            
            # 读取文件
            filepath = self._cache_dir / f"{key}.cache"
            if not filepath.exists():
                return None
            
            try:
                with open(filepath, 'rb') as f:
                    return pickle.load(f)
            except Exception:
                return None
    
    def set(self, key: str, value: Any, ttl: Optional[float] = None):
        with self._lock:
            filepath = self._cache_dir / f"{key}.cache"
            
            try:
                with open(filepath, 'wb') as f:
                    pickle.dump(value, f)
                
                self._memory_index[key] = {
                    'created_at': time.time(),
                    'ttl': ttl,
                    'size': filepath.stat().st_size
                }
                
                self._check_size()
                
            except Exception as e:
                print(f"磁盘缓存写入失败: {e}")
    
    def delete(self, key: str):
        with self._lock:
            filepath = self._cache_dir / f"{key}.cache"
            if filepath.exists():
                filepath.unlink()
            self._memory_index.pop(key, None)
    
    def clear(self):
        with self._lock:
            for f in self._cache_dir.glob('*.cache'):
                f.unlink()
            self._memory_index.clear()
    
    def _check_size(self):
        """检查并清理缓存大小"""
        total_size = sum(m['size'] for m in self._memory_index.values())
        
        if total_size > self._max_size:
            # 删除最老的缓存
            sorted_items = sorted(
                self._memory_index.items(),
                key=lambda x: x[1]['created_at']
            )
            
            for key, meta in sorted_items:
                if total_size <= self._max_size * 0.8:  # 清理到80%
                    break
                self.delete(key)
                total_size -= meta['size']


class MultiLevelCache:
    """多级缓存管理器"""
    
    def __init__(self):
        # L1: 进程内存缓存(快,容量小)
        self.l1 = MemoryCache(max_size=500, default_ttl=60)
        # L2: 磁盘缓存(慢,容量大)
        self.l2 = DiskCache(cache_dir='./cache', max_size_mb=200)
    
    def get(self, key: str) -> Optional[Any]:
        """从L1获取,没有则从L2获取并回填L1"""
        # L1
        value = self.l1.get(key)
        if value is not None:
            return value
        
        # L2
        value = self.l2.get(key)
        if value is not None:
            # 回填L1
            self.l1.set(key, value)
            return value
        
        return None
    
    def set(self, key: str, value: Any, ttl: Optional[float] = None):
        """写入L1和L2"""
        self.l1.set(key, value, ttl)
        self.l2.set(key, value, ttl)
    
    def delete(self, key: str):
        """删除所有层级"""
        self.l1.delete(key)
        self.l2.delete(key)
    
    def clear(self):
        """清空所有层级"""
        self.l1.clear()
        self.l2.clear()


# 全局缓存实例
_global_cache = MultiLevelCache()


def multi_cached(ttl: Optional[float] = 300):
    """多级缓存装饰器"""
    def decorator(func):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            # 生成key
            key = f"{func.__module__}.{func.__name__}"
            key += f":{args}:{kwargs}"
            key = hashlib.md5(key.encode()).hexdigest()
            
            # 尝试从缓存获取
            result = _global_cache.get(key)
            if result is not None:
                return result
            
            # 执行函数
            result = func(*args, **kwargs)
            
            # 存入缓存
            _global_cache.set(key, result, ttl)
            
            return result
        
        wrapper.clear_cache = _global_cache.clear
        return wrapper
    
    return decorator

三、缓存防穿透

缓存穿透是个严重问题——大量请求不存在的数据,每次都打到后端。

"""缓存防穿透"""

import time
import threading
import hashlib
from typing import Any, Optional, Callable, Set
from dataclasses import dataclass
import random

@dataclass
class Sentinel:
    """哨兵值:表示数据不存在"""
    pass

# 哨兵单例
NOT_FOUND = Sentinel()


class AntiPenetrationCache:
    """防穿透缓存"""
    
    def __init__(self, 
                 cache: 'MultiLevelCache' = None,
                 bloom_filter_size: int = 10000,
                 null_cache_ttl: float = 60):
        self._cache = cache or MultiLevelCache()
        
        # 布隆过滤器:快速判断key是否存在
        self._bloom_filter = BloomFilter(bloom_filter_size)
        
        # 空值缓存TTL
        self._null_ttl = null_cache_ttl
        
        self._lock = threading.Lock()
    
    def get_or_fetch(self, 
                     key: str, 
                     fetch_func: Callable[[], Any],
                     ttl: Optional[float] = None) -> Any:
        """
        获取或抓取数据,自动处理穿透
        
        Args:
            key: 缓存key
            fetch_func: 抓取函数
            ttl: 缓存时间
        """
        # 快速路径:从缓存获取
        value = self._cache.get(key)
        
        if value is not None:
            # 检查是否是哨兵(空值缓存)
            if isinstance(value, Sentinel):
                return None
            return value
        
        # 检查布隆过滤器(可能误判,但不会漏判)
        if not self._bloom_filter.might_contain(key):
            # 布隆过滤器说肯定没有,直接返回
            return None
        
        # 缓存未命中,需要从数据源获取
        value = fetch_func()
        
        # 添加到布隆过滤器
        self._bloom_filter.add(key)
        
        if value is None:
            # 空值缓存,防止穿透
            self._cache.set(key, NOT_FOUND, self._null_ttl)
            return None
        else:
            # 正常缓存
            self._cache.set(key, value, ttl)
            return value
    
    def invalidate(self, key: str):
        """失效缓存"""
        self._cache.delete(key)


class BloomFilter:
    """简单的布隆过滤器"""
    
    def __init__(self, size: int = 10000):
        self._size = size
        self._bits = [False] * size
        self._hash_count = 3  # 使用3个哈希函数
    
    def _hashes(self, item: str) -> list:
        """生成多个哈希值"""
        result = []
        for i in range(self._hash_count):
            h = hashlib.md5(f"{item}:{i}".encode()).hexdigest()
            result.append(int(h, 16) % self._size)
        return result
    
    def add(self, item: str):
        """添加元素"""
        for pos in self._hashes(item):
            self._bits[pos] = True
    
    def might_contain(self, item: str) -> bool:
        """检查元素是否存在(可能有误判)"""
        return all(self._bits[pos] for pos in self._hashes(item))


# 使用示例
if __name__ == '__main__':
    cache = AntiPenetrationCache()
    
    # 模拟数据库查询
    def fetch_from_db(user_id: str) -> Optional[dict]:
        """模拟数据库查询"""
        print(f"查询数据库: {user_id}")
        # 假设user_1存在,user_2不存在
        if user_id == "user_1":
            return {"id": user_id, "name": "张三", "age": 25}
        return None
    
    # 测试穿透
    print("=" * 50)
    
    # user_1 存在,会缓存结果
    result1 = cache.get_or_fetch("user:user_1", 
                                  lambda: fetch_from_db("user_1"))
    print(f"user_1 结果: {result1}")
    
    # 再次获取,使用缓存
    result1_again = cache.get_or_fetch("user:user_1",
                                       lambda: fetch_from_db("user_1"))
    print(f"user_1 再次获取: {result1_again}")
    
    print("=" * 50)
    
    # user_2 不存在,会空值缓存
    result2 = cache.get_or_fetch("user:user_2",
                                lambda: fetch_from_db("user_2"))
    print(f"user_2 结果: {result2}")
    
    # 再次获取,同样不查数据库
    result2_again = cache.get_or_fetch("user:user_2",
                                       lambda: fetch_from_db("user_2"))
    print(f"user_2 再次获取: {result2_again}")
    
    print("=" * 50)
    
    # user_999 布隆过滤器直接过滤,不查数据库
    result3 = cache.get_or_fetch("user:user_999",
                                 lambda: fetch_from_db("user_999"))
    print(f"user_999 结果: {result3}")

四、缓存一致性

分布式环境下缓存一致性问题:

"""缓存一致性处理"""

import time
import threading
from typing import Any, Optional, Callable, List
from dataclasses import dataclass
from enum import Enum

class CacheStrategy(Enum):
    """缓存策略"""
    CACHE_ASIDE = "cache_aside"      # 旁路缓存(最常用)
    WRITE_THROUGH = "write_through"  # 写穿透
    WRITE_BACK = "write_back"       # 写回


@dataclass
class CacheEntry:
    """缓存条目"""
    value: Any
    version: int  # 版本号,用于CAS
    created_at: float
    updated_at: float


class ConsistentCache:
    """一致性缓存"""
    
    def __init__(self, strategy: CacheStrategy = CacheStrategy.CACHE_ASIDE):
        self._cache = {}
        self._lock = threading.RLock()
        self._strategy = strategy
        self._pending_updates = {}  # 写回缓冲
    
    def get(self, key: str) -> Optional[Any]:
        """获取缓存"""
        with self._lock:
            entry = self._cache.get(key)
            if entry:
                return entry.value
            return None
    
    def set(self, key: str, value: Any):
        """设置缓存"""
        with self._lock:
            now = time.time()
            entry = self._cache.get(key)
            
            if entry:
                entry.value = value
                entry.version += 1
                entry.updated_at = now
            else:
                self._cache[key] = CacheEntry(
                    value=value,
                    version=1,
                    created_at=now,
                    updated_at=now
                )
    
    def setnx(self, key: str, value: Any) -> bool:
        """
        Set if Not eXists(原子操作)
        只在key不存在时设置
        """
        with self._lock:
            if key not in self._cache:
                self.set(key, value)
                return True
            return False
    
    def cas(self, key: str, old_value: Any, new_value: Any) -> bool:
        """
        Compare And Swap(CAS)
        只在值匹配时更新,用于解决并发更新问题
        """
        with self._lock:
            entry = self._cache.get(key)
            if entry and entry.value == old_value:
                entry.value = new_value
                entry.version += 1
                entry.updated_at = time.time()
                return True
            return False
    
    def delete(self, key: str):
        """删除缓存"""
        with self._lock:
            self._cache.pop(key, None)
    
    def get_version(self, key: str) -> Optional[int]:
        """获取版本号"""
        with self._lock:
            entry = self._cache.get(key)
            return entry.version if entry else None


class CacheAsideManager:
    """旁路缓存管理器"""
    
    def __init__(self):
        self._cache = ConsistentCache()
    
    def read(self, 
            key: str, 
            db_fetch: Callable[[], Any],
            ttl: Optional[float] = None) -> Any:
        """
        读操作:Cache Aside模式
        1. 先读缓存
        2. 缓存未命中则读数据库
        3. 回填缓存
        """
        # 1. 读缓存
        value = self._cache.get(key)
        if value is not None:
            return value
        
        # 2. 读数据库
        value = db_fetch()
        
        # 3. 回填缓存
        if value is not None:
            self._cache.set(key, value)
        
        return value
    
    def write(self,
              key: str,
              value: Any,
              db_write: Callable[[], None],
              ttl: Optional[float] = None):
        """
        写操作:先写数据库,再删缓存
        (注意是删除缓存,不是更新)
        """
        # 1. 写数据库
        db_write()
        
        # 2. 删除缓存(而不是更新)
        # 原因:更新可能导致并发问题,删除更安全
        self._cache.delete(key)
    
    def delete(self,
               key: str,
               db_delete: Callable[[], None]):
        """
        删除操作:先删数据库,再删缓存
        """
        # 1. 删数据库
        db_delete()
        
        # 2. 删缓存
        self._cache.delete(key)


# 使用示例
if __name__ == '__main__':
    cache_manager = CacheAsideManager()
    
    # 模拟数据库
    db_data = {"user_1": {"name": "张三", "age": 25}}
    
    def fetch_user(user_id: str):
        return db_data.get(user_id)
    
    # 读取
    user = cache_manager.read("user:user_1", lambda: fetch_user("user_1"))
    print(f"读取用户: {user}")
    
    # 再次读取(命中缓存)
    user = cache_manager.read("user:user_1", lambda: fetch_user("user_1"))
    print(f"再次读取: {user}")
    
    # 写入
    db_data["user_1"]["name"] = "李四"
    cache_manager.write(
        "user:user_1",
        {"name": "李四", "age": 25},
        lambda: None  # 模拟写数据库
    )
    
    # 读取最新数据
    user = cache_manager.read("user:user_1", lambda: fetch_user("user_1"))
    print(f"更新后读取: {user}")

五、缓存监控

监控缓存效果:

"""缓存监控"""

import time
import threading
from typing import Dict, List
from dataclasses import dataclass, field
from collections import defaultdict

@dataclass
class CacheStats:
    """缓存统计"""
    hits: int = 0
    misses: int = 0
    sets: int = 0
    deletes: int = 0
    errors: int = 0
    
    @property
    def hit_rate(self) -> float:
        total = self.hits + self.misses
        return self.hits / total if total > 0 else 0.0

class CacheMonitor:
    """缓存监控"""
    
    def __init__(self):
        self._stats = CacheStats()
        self._lock = threading.Lock()
        self._access_log: List[dict] = []
        self._max_log_size = 1000
    
    def record_hit(self, key: str):
        with self._lock:
            self._stats.hits += 1
            self._log_access(key, 'hit')
    
    def record_miss(self, key: str):
        with self._lock:
            self._stats.misses += 1
            self._log_access(key, 'miss')
    
    def record_set(self, key: str):
        with self._lock:
            self._stats.sets += 1
            self._log_access(key, 'set')
    
    def record_delete(self, key: str):
        with self._lock:
            self._stats.deletes += 1
            self._log_access(key, 'delete')
    
    def record_error(self, key: str):
        with self._lock:
            self._stats.errors += 1
            self._log_access(key, 'error')
    
    def _log_access(self, key: str, action: str):
        """记录访问日志"""
        self._access_log.append({
            'time': time.time(),
            'key': key,
            'action': action
        })
        
        # 限制日志大小
        if len(self._access_log) > self._max_log_size:
            self._access_log = self._access_log[-self._max_log_size:]
    
    def get_stats(self) -> CacheStats:
        with self._lock:
            return CacheStats(
                hits=self._stats.hits,
                misses=self._stats.misses,
                sets=self._stats.sets,
                deletes=self._stats.deletes,
                errors=self._stats.errors
            )
    
    def get_recent_accesses(self, limit: int = 10) -> List[dict]:
        with self._lock:
            return self._access_log[-limit:]
    
    def report(self) -> str:
        """生成报告"""
        stats = self.get_stats()
        return f"""
========== 缓存报告 ==========
命中: {stats.hits}
未命中: {stats.misses}
设置: {stats.sets}
删除: {stats.deletes}
错误: {stats.errors}
命中率: {stats.hit_rate:.2%}
=============================
"""


# 全局监控实例
_global_monitor = CacheMonitor()


class MonitoredCache:
    """带监控的缓存包装"""
    
    def __init__(self, cache: 'MultiLevelCache'):
        self._cache = cache
    
    def get(self, key: str):
        try:
            value = self._cache.get(key)
            if value is not None:
                _global_monitor.record_hit(key)
            else:
                _global_monitor.record_miss(key)
            return value
        except Exception:
            _global_monitor.record_error(key)
            raise
    
    def set(self, key: str, value: Any, ttl: float = None):
        self._cache.set(key, value, ttl)
        _global_monitor.record_set(key)
    
    def delete(self, key: str):
        self._cache.delete(key)
        _global_monitor.record_delete(key)


# 使用示例
if __name__ == '__main__':
    monitored = MonitoredCache(_global_cache)
    
    # 模拟访问
    for i in range(5):
        monitored.get("key_1")  # 第一次miss,后续hit
    
    monitored.set("key_2", "value2")
    
    # 输出报告
    print(_global_monitor.report())

总结

  1. 基础缓存:内存缓存 + LRU淘汰
  2. 多级缓存:L1内存 + L2磁盘
  3. 防穿透:空值缓存 + 布隆过滤器
  4. 一致性:CAS操作 + Cache Aside模式
  5. 监控:命中率统计 + 访问日志

缓存是个大话题,这里只是入门级介绍。

更多推荐