Python自动化脚本缓存策略实现:多级缓存与防穿透实战
·
大家好,今天来聊聊如何给自动化脚本加上高效的缓存机制。
为什么要缓存?
看几个典型场景:
- API数据重复请求:同样的数据反复请求,浪费带宽还容易被限流
- 数据库频繁查询:热点数据反复读取,增加数据库压力
- 文件解析重复执行:大文件每次都重新解析,CPU飙升
- 计算结果重复计算:同样的输入反复计算,纯属浪费
合理的缓存可以提升几倍甚至几十倍的性能。
一、基础内存缓存
从最简单的开始:
"""基础内存缓存"""
import time
import functools
from typing import Any, Callable, Optional, Tuple
from dataclasses import dataclass
import hashlib
@dataclass
class CacheItem:
"""缓存项"""
value: Any
created_at: float
ttl: Optional[float] = None # 生存时间(秒)
def is_expired(self) -> bool:
"""检查是否过期"""
if self.ttl is None:
return False
return time.time() - self.created_at > self.ttl
class SimpleCache:
"""简单内存缓存"""
def __init__(self, max_size: int = 128):
self._cache = {}
self._max_size = max_size
self._access_order = [] # LRU顺序
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
if key not in self._cache:
return None
item = self._cache[key]
# 检查过期
if item.is_expired():
self.delete(key)
return None
# 更新访问顺序(LRU)
self._access_order.remove(key)
self._access_order.append(key)
return item.value
def set(self, key: str, value: Any, ttl: Optional[float] = None):
"""设置缓存"""
# 容量满时删除最久未使用的
if len(self._cache) >= self._max_size and key not in self._cache:
oldest = self._access_order.pop(0)
del self._cache[oldest]
self._cache[key] = CacheItem(
value=value,
created_at=time.time(),
ttl=ttl
)
# 更新访问顺序
if key in self._access_order:
self._access_order.remove(key)
self._access_order.append(key)
def delete(self, key: str):
"""删除缓存"""
if key in self._cache:
del self._cache[key]
self._access_order.remove(key)
def clear(self):
"""清空缓存"""
self._cache.clear()
self._access_order.clear()
def cleanup_expired(self):
"""清理过期项"""
expired_keys = [k for k, v in self._cache.items() if v.is_expired()]
for key in expired_keys:
self.delete(key)
return len(expired_keys)
# 缓存装饰器
def cached(ttl: Optional[float] = None, max_size: int = 128):
"""缓存装饰器"""
_cache = SimpleCache(max_size)
def decorator(func: Callable):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 生成缓存key
key = _generate_key(func.__name__, args, kwargs)
# 尝试从缓存获取
result = _cache.get(key)
if result is not None:
return result
# 执行函数
result = func(*args, **kwargs)
# 存入缓存
_cache.set(key, result, ttl)
return result
# 添加缓存控制方法
wrapper.cache = _cache
wrapper.clear_cache = _cache.clear
return wrapper
return decorator
def _generate_key(func_name: str, args: tuple, kwargs: dict) -> str:
"""生成缓存key"""
# 简单实现:序列化参数
key_parts = [func_name]
for arg in args:
if hasattr(arg, '__dict__'):
# 对象:用id
key_parts.append(str(id(arg)))
else:
key_parts.append(str(arg))
for k, v in sorted(kwargs.items()):
key_parts.append(f"{k}={v}")
key_str = '|'.join(key_parts)
# hash化(太长的话)
if len(key_str) > 100:
return hashlib.md5(key_str.encode()).hexdigest()
return key_str
# 使用示例
if __name__ == '__main__':
@cached(ttl=60, max_size=100)
def fetch_data(api_url: str, params: dict = None):
"""模拟API请求"""
print(f"正在请求: {api_url}")
time.sleep(1) # 模拟网络延迟
return {"status": "success", "data": [1, 2, 3]}
# 第一次调用(会打印"正在请求")
result1 = fetch_data("https://api.example.com/data")
# 第二次调用(使用缓存,不会打印)
result2 = fetch_data("https://api.example.com/data")
print(f"结果相同: {result1 == result2}")
# 清除缓存
fetch_data.clear_cache()
# 再调用(会重新请求)
result3 = fetch_data("https://api.example.com/data")
二、多级缓存实现
更实用的多级缓存策略:
"""多级缓存"""
import time
import threading
import pickle
from pathlib import Path
from typing import Any, Optional, TypeVar, Generic
from abc import ABC, abstractmethod
T = TypeVar('T')
class CacheLevel(ABC):
"""缓存层基类"""
@abstractmethod
def get(self, key: str) -> Optional[Any]:
pass
@abstractmethod
def set(self, key: str, value: Any, ttl: Optional[float] = None):
pass
@abstractmethod
def delete(self, key: str):
pass
@abstractmethod
def clear(self):
pass
class MemoryCache(CacheLevel):
"""内存缓存层"""
def __init__(self, max_size: int = 1000, default_ttl: float = 300):
self._cache = {}
self._max_size = max_size
self._default_ttl = default_ttl
self._lock = threading.RLock()
self._access_times = {}
def get(self, key: str) -> Optional[Any]:
with self._lock:
if key not in self._cache:
return None
item = self._cache[key]
if self._is_expired(item):
self.delete(key)
return None
self._access_times[key] = time.time()
return item['value']
def set(self, key: str, value: Any, ttl: Optional[float] = None):
with self._lock:
if len(self._cache) >= self._max_size:
self._evict_lru()
self._cache[key] = {
'value': value,
'created_at': time.time(),
'ttl': ttl or self._default_ttl
}
self._access_times[key] = time.time()
def delete(self, key: str):
with self._lock:
self._cache.pop(key, None)
self._access_times.pop(key, None)
def clear(self):
with self._lock:
self._cache.clear()
self._access_times.clear()
def _is_expired(self, item: dict) -> bool:
if item['ttl'] is None:
return False
return time.time() - item['created_at'] > item['ttl']
def _evict_lru(self):
"""淘汰最久未使用的"""
if not self._access_times:
return
lru_key = min(self._access_times.items(), key=lambda x: x[1])[0]
self.delete(lru_key)
class DiskCache(CacheLevel):
"""磁盘缓存层"""
def __init__(self, cache_dir: str = './cache', max_size_mb: int = 500):
self._cache_dir = Path(cache_dir)
self._cache_dir.mkdir(parents=True, exist_ok=True)
self._max_size = max_size_mb * 1024 * 1024
self._lock = threading.Lock()
self._memory_index = {} # 内存索引
def get(self, key: str) -> Optional[Any]:
with self._lock:
if key not in self._memory_index:
return None
meta = self._memory_index[key]
# 检查过期
if meta['ttl'] and time.time() - meta['created_at'] > meta['ttl']:
self.delete(key)
return None
# 读取文件
filepath = self._cache_dir / f"{key}.cache"
if not filepath.exists():
return None
try:
with open(filepath, 'rb') as f:
return pickle.load(f)
except Exception:
return None
def set(self, key: str, value: Any, ttl: Optional[float] = None):
with self._lock:
filepath = self._cache_dir / f"{key}.cache"
try:
with open(filepath, 'wb') as f:
pickle.dump(value, f)
self._memory_index[key] = {
'created_at': time.time(),
'ttl': ttl,
'size': filepath.stat().st_size
}
self._check_size()
except Exception as e:
print(f"磁盘缓存写入失败: {e}")
def delete(self, key: str):
with self._lock:
filepath = self._cache_dir / f"{key}.cache"
if filepath.exists():
filepath.unlink()
self._memory_index.pop(key, None)
def clear(self):
with self._lock:
for f in self._cache_dir.glob('*.cache'):
f.unlink()
self._memory_index.clear()
def _check_size(self):
"""检查并清理缓存大小"""
total_size = sum(m['size'] for m in self._memory_index.values())
if total_size > self._max_size:
# 删除最老的缓存
sorted_items = sorted(
self._memory_index.items(),
key=lambda x: x[1]['created_at']
)
for key, meta in sorted_items:
if total_size <= self._max_size * 0.8: # 清理到80%
break
self.delete(key)
total_size -= meta['size']
class MultiLevelCache:
"""多级缓存管理器"""
def __init__(self):
# L1: 进程内存缓存(快,容量小)
self.l1 = MemoryCache(max_size=500, default_ttl=60)
# L2: 磁盘缓存(慢,容量大)
self.l2 = DiskCache(cache_dir='./cache', max_size_mb=200)
def get(self, key: str) -> Optional[Any]:
"""从L1获取,没有则从L2获取并回填L1"""
# L1
value = self.l1.get(key)
if value is not None:
return value
# L2
value = self.l2.get(key)
if value is not None:
# 回填L1
self.l1.set(key, value)
return value
return None
def set(self, key: str, value: Any, ttl: Optional[float] = None):
"""写入L1和L2"""
self.l1.set(key, value, ttl)
self.l2.set(key, value, ttl)
def delete(self, key: str):
"""删除所有层级"""
self.l1.delete(key)
self.l2.delete(key)
def clear(self):
"""清空所有层级"""
self.l1.clear()
self.l2.clear()
# 全局缓存实例
_global_cache = MultiLevelCache()
def multi_cached(ttl: Optional[float] = 300):
"""多级缓存装饰器"""
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 生成key
key = f"{func.__module__}.{func.__name__}"
key += f":{args}:{kwargs}"
key = hashlib.md5(key.encode()).hexdigest()
# 尝试从缓存获取
result = _global_cache.get(key)
if result is not None:
return result
# 执行函数
result = func(*args, **kwargs)
# 存入缓存
_global_cache.set(key, result, ttl)
return result
wrapper.clear_cache = _global_cache.clear
return wrapper
return decorator
三、缓存防穿透
缓存穿透是个严重问题——大量请求不存在的数据,每次都打到后端。
"""缓存防穿透"""
import time
import threading
import hashlib
from typing import Any, Optional, Callable, Set
from dataclasses import dataclass
import random
@dataclass
class Sentinel:
"""哨兵值:表示数据不存在"""
pass
# 哨兵单例
NOT_FOUND = Sentinel()
class AntiPenetrationCache:
"""防穿透缓存"""
def __init__(self,
cache: 'MultiLevelCache' = None,
bloom_filter_size: int = 10000,
null_cache_ttl: float = 60):
self._cache = cache or MultiLevelCache()
# 布隆过滤器:快速判断key是否存在
self._bloom_filter = BloomFilter(bloom_filter_size)
# 空值缓存TTL
self._null_ttl = null_cache_ttl
self._lock = threading.Lock()
def get_or_fetch(self,
key: str,
fetch_func: Callable[[], Any],
ttl: Optional[float] = None) -> Any:
"""
获取或抓取数据,自动处理穿透
Args:
key: 缓存key
fetch_func: 抓取函数
ttl: 缓存时间
"""
# 快速路径:从缓存获取
value = self._cache.get(key)
if value is not None:
# 检查是否是哨兵(空值缓存)
if isinstance(value, Sentinel):
return None
return value
# 检查布隆过滤器(可能误判,但不会漏判)
if not self._bloom_filter.might_contain(key):
# 布隆过滤器说肯定没有,直接返回
return None
# 缓存未命中,需要从数据源获取
value = fetch_func()
# 添加到布隆过滤器
self._bloom_filter.add(key)
if value is None:
# 空值缓存,防止穿透
self._cache.set(key, NOT_FOUND, self._null_ttl)
return None
else:
# 正常缓存
self._cache.set(key, value, ttl)
return value
def invalidate(self, key: str):
"""失效缓存"""
self._cache.delete(key)
class BloomFilter:
"""简单的布隆过滤器"""
def __init__(self, size: int = 10000):
self._size = size
self._bits = [False] * size
self._hash_count = 3 # 使用3个哈希函数
def _hashes(self, item: str) -> list:
"""生成多个哈希值"""
result = []
for i in range(self._hash_count):
h = hashlib.md5(f"{item}:{i}".encode()).hexdigest()
result.append(int(h, 16) % self._size)
return result
def add(self, item: str):
"""添加元素"""
for pos in self._hashes(item):
self._bits[pos] = True
def might_contain(self, item: str) -> bool:
"""检查元素是否存在(可能有误判)"""
return all(self._bits[pos] for pos in self._hashes(item))
# 使用示例
if __name__ == '__main__':
cache = AntiPenetrationCache()
# 模拟数据库查询
def fetch_from_db(user_id: str) -> Optional[dict]:
"""模拟数据库查询"""
print(f"查询数据库: {user_id}")
# 假设user_1存在,user_2不存在
if user_id == "user_1":
return {"id": user_id, "name": "张三", "age": 25}
return None
# 测试穿透
print("=" * 50)
# user_1 存在,会缓存结果
result1 = cache.get_or_fetch("user:user_1",
lambda: fetch_from_db("user_1"))
print(f"user_1 结果: {result1}")
# 再次获取,使用缓存
result1_again = cache.get_or_fetch("user:user_1",
lambda: fetch_from_db("user_1"))
print(f"user_1 再次获取: {result1_again}")
print("=" * 50)
# user_2 不存在,会空值缓存
result2 = cache.get_or_fetch("user:user_2",
lambda: fetch_from_db("user_2"))
print(f"user_2 结果: {result2}")
# 再次获取,同样不查数据库
result2_again = cache.get_or_fetch("user:user_2",
lambda: fetch_from_db("user_2"))
print(f"user_2 再次获取: {result2_again}")
print("=" * 50)
# user_999 布隆过滤器直接过滤,不查数据库
result3 = cache.get_or_fetch("user:user_999",
lambda: fetch_from_db("user_999"))
print(f"user_999 结果: {result3}")
四、缓存一致性
分布式环境下缓存一致性问题:
"""缓存一致性处理"""
import time
import threading
from typing import Any, Optional, Callable, List
from dataclasses import dataclass
from enum import Enum
class CacheStrategy(Enum):
"""缓存策略"""
CACHE_ASIDE = "cache_aside" # 旁路缓存(最常用)
WRITE_THROUGH = "write_through" # 写穿透
WRITE_BACK = "write_back" # 写回
@dataclass
class CacheEntry:
"""缓存条目"""
value: Any
version: int # 版本号,用于CAS
created_at: float
updated_at: float
class ConsistentCache:
"""一致性缓存"""
def __init__(self, strategy: CacheStrategy = CacheStrategy.CACHE_ASIDE):
self._cache = {}
self._lock = threading.RLock()
self._strategy = strategy
self._pending_updates = {} # 写回缓冲
def get(self, key: str) -> Optional[Any]:
"""获取缓存"""
with self._lock:
entry = self._cache.get(key)
if entry:
return entry.value
return None
def set(self, key: str, value: Any):
"""设置缓存"""
with self._lock:
now = time.time()
entry = self._cache.get(key)
if entry:
entry.value = value
entry.version += 1
entry.updated_at = now
else:
self._cache[key] = CacheEntry(
value=value,
version=1,
created_at=now,
updated_at=now
)
def setnx(self, key: str, value: Any) -> bool:
"""
Set if Not eXists(原子操作)
只在key不存在时设置
"""
with self._lock:
if key not in self._cache:
self.set(key, value)
return True
return False
def cas(self, key: str, old_value: Any, new_value: Any) -> bool:
"""
Compare And Swap(CAS)
只在值匹配时更新,用于解决并发更新问题
"""
with self._lock:
entry = self._cache.get(key)
if entry and entry.value == old_value:
entry.value = new_value
entry.version += 1
entry.updated_at = time.time()
return True
return False
def delete(self, key: str):
"""删除缓存"""
with self._lock:
self._cache.pop(key, None)
def get_version(self, key: str) -> Optional[int]:
"""获取版本号"""
with self._lock:
entry = self._cache.get(key)
return entry.version if entry else None
class CacheAsideManager:
"""旁路缓存管理器"""
def __init__(self):
self._cache = ConsistentCache()
def read(self,
key: str,
db_fetch: Callable[[], Any],
ttl: Optional[float] = None) -> Any:
"""
读操作:Cache Aside模式
1. 先读缓存
2. 缓存未命中则读数据库
3. 回填缓存
"""
# 1. 读缓存
value = self._cache.get(key)
if value is not None:
return value
# 2. 读数据库
value = db_fetch()
# 3. 回填缓存
if value is not None:
self._cache.set(key, value)
return value
def write(self,
key: str,
value: Any,
db_write: Callable[[], None],
ttl: Optional[float] = None):
"""
写操作:先写数据库,再删缓存
(注意是删除缓存,不是更新)
"""
# 1. 写数据库
db_write()
# 2. 删除缓存(而不是更新)
# 原因:更新可能导致并发问题,删除更安全
self._cache.delete(key)
def delete(self,
key: str,
db_delete: Callable[[], None]):
"""
删除操作:先删数据库,再删缓存
"""
# 1. 删数据库
db_delete()
# 2. 删缓存
self._cache.delete(key)
# 使用示例
if __name__ == '__main__':
cache_manager = CacheAsideManager()
# 模拟数据库
db_data = {"user_1": {"name": "张三", "age": 25}}
def fetch_user(user_id: str):
return db_data.get(user_id)
# 读取
user = cache_manager.read("user:user_1", lambda: fetch_user("user_1"))
print(f"读取用户: {user}")
# 再次读取(命中缓存)
user = cache_manager.read("user:user_1", lambda: fetch_user("user_1"))
print(f"再次读取: {user}")
# 写入
db_data["user_1"]["name"] = "李四"
cache_manager.write(
"user:user_1",
{"name": "李四", "age": 25},
lambda: None # 模拟写数据库
)
# 读取最新数据
user = cache_manager.read("user:user_1", lambda: fetch_user("user_1"))
print(f"更新后读取: {user}")
五、缓存监控
监控缓存效果:
"""缓存监控"""
import time
import threading
from typing import Dict, List
from dataclasses import dataclass, field
from collections import defaultdict
@dataclass
class CacheStats:
"""缓存统计"""
hits: int = 0
misses: int = 0
sets: int = 0
deletes: int = 0
errors: int = 0
@property
def hit_rate(self) -> float:
total = self.hits + self.misses
return self.hits / total if total > 0 else 0.0
class CacheMonitor:
"""缓存监控"""
def __init__(self):
self._stats = CacheStats()
self._lock = threading.Lock()
self._access_log: List[dict] = []
self._max_log_size = 1000
def record_hit(self, key: str):
with self._lock:
self._stats.hits += 1
self._log_access(key, 'hit')
def record_miss(self, key: str):
with self._lock:
self._stats.misses += 1
self._log_access(key, 'miss')
def record_set(self, key: str):
with self._lock:
self._stats.sets += 1
self._log_access(key, 'set')
def record_delete(self, key: str):
with self._lock:
self._stats.deletes += 1
self._log_access(key, 'delete')
def record_error(self, key: str):
with self._lock:
self._stats.errors += 1
self._log_access(key, 'error')
def _log_access(self, key: str, action: str):
"""记录访问日志"""
self._access_log.append({
'time': time.time(),
'key': key,
'action': action
})
# 限制日志大小
if len(self._access_log) > self._max_log_size:
self._access_log = self._access_log[-self._max_log_size:]
def get_stats(self) -> CacheStats:
with self._lock:
return CacheStats(
hits=self._stats.hits,
misses=self._stats.misses,
sets=self._stats.sets,
deletes=self._stats.deletes,
errors=self._stats.errors
)
def get_recent_accesses(self, limit: int = 10) -> List[dict]:
with self._lock:
return self._access_log[-limit:]
def report(self) -> str:
"""生成报告"""
stats = self.get_stats()
return f"""
========== 缓存报告 ==========
命中: {stats.hits}
未命中: {stats.misses}
设置: {stats.sets}
删除: {stats.deletes}
错误: {stats.errors}
命中率: {stats.hit_rate:.2%}
=============================
"""
# 全局监控实例
_global_monitor = CacheMonitor()
class MonitoredCache:
"""带监控的缓存包装"""
def __init__(self, cache: 'MultiLevelCache'):
self._cache = cache
def get(self, key: str):
try:
value = self._cache.get(key)
if value is not None:
_global_monitor.record_hit(key)
else:
_global_monitor.record_miss(key)
return value
except Exception:
_global_monitor.record_error(key)
raise
def set(self, key: str, value: Any, ttl: float = None):
self._cache.set(key, value, ttl)
_global_monitor.record_set(key)
def delete(self, key: str):
self._cache.delete(key)
_global_monitor.record_delete(key)
# 使用示例
if __name__ == '__main__':
monitored = MonitoredCache(_global_cache)
# 模拟访问
for i in range(5):
monitored.get("key_1") # 第一次miss,后续hit
monitored.set("key_2", "value2")
# 输出报告
print(_global_monitor.report())
总结
- 基础缓存:内存缓存 + LRU淘汰
- 多级缓存:L1内存 + L2磁盘
- 防穿透:空值缓存 + 布隆过滤器
- 一致性:CAS操作 + Cache Aside模式
- 监控:命中率统计 + 访问日志
缓存是个大话题,这里只是入门级介绍。
更多推荐
所有评论(0)