大家好,我是扣扣。今天来聊聊如何让自动化脚本更加健壮——自动重试机制。

为什么要重试?

网络请求总是充满了不确定性:

  • 网络抖动导致的临时断开
  • 服务器过载返回503
  • DNS解析偶尔失败
  • 并发太高被限流

如果每次失败就直接报错,脚本的稳定性会很差。重试机制就是:遇到临时性错误时,给服务器一点恢复时间,然后重试。

一、基础重试装饰器

先从简单的开始:

"""基础重试装饰器"""

import time
import functools
from typing import Callable, Type, Tuple

def retry(max_attempts: int = 3, 
          delay: float = 1.0, 
          exceptions: Tuple[Type[Exception], ...] = (Exception,),
          backoff: float = 2.0):
    """
    重试装饰器
    
    参数:
        max_attempts: 最大重试次数
        delay: 初始延迟(秒)
        exceptions: 需要重试的异常类型
        backoff: 退避倍数
    """
    def decorator(func: Callable):
        @functools.wraps(func)
        def wrapper(*args, **kwargs):
            current_delay = delay
            
            for attempt in range(max_attempts):
                try:
                    return func(*args, **kwargs)
                
                except exceptions as e:
                    if attempt == max_attempts - 1:
                        raise  # 最后一次还是失败,直接抛出异常
                    
                    print(f"第 {attempt + 1} 次尝试失败: {e}")
                    print(f"{current_delay} 秒后重试...")
                    time.sleep(current_delay)
                    current_delay *= backoff
            
            # 不应该走到这里,但为了代码完整性
            raise RuntimeError("重试逻辑异常")
        
        return wrapper
    return decorator


# 使用示例
@retry(max_attempts=3, delay=1, backoff=2, exceptions=(ConnectionError, TimeoutError))
def fetch_data(url: str):
    import requests
    response = requests.get(url, timeout=10)
    response.raise_for_status()
    return response.json()

# 实际使用时
try:
    data = fetch_data("https://api.example.com/data")
except Exception as e:
    print(f"获取数据失败: {e}")

二、requests专用重试Session

对于网络请求,封装一个专用类更方便:

"""requests重试封装"""

import requests
from requests.adapters import HTTPAdapter
from urllib3.util.retry import Retry
from typing import Optional, Dict, Any
import time

class RetrySession:
    """带重试功能的requests Session"""
    
    def __init__(self,
                 retries: int = 3,
                 backoff_factor: float = 0.5,
                 status_forcelist: tuple = (500, 502, 503, 504),
                 allowed_methods: tuple = ("HEAD", "GET", "PUT", "DELETE", 
                                          "OPTIONS", "TRACE", "PATCH"),
                 timeout: int = 30):
        """
        初始化重试Session
        
        参数:
            retries: 总重试次数
            backoff_factor: 退避因子 (0.5s, 1s, 2s...)
            status_forcelist: 需要重试的HTTP状态码
            allowed_methods: 需要重试的HTTP方法
            timeout: 请求超时时间
        """
        self.session = requests.Session()
        self.timeout = timeout
        
        # 配置重试策略
        retry_strategy = Retry(
            total=retries,
            backoff_factor=backoff_factor,
            status_forcelist=status_forcelist,
            allowed_methods=allowed_methods,
            raise_on_status=False  # 不自动抛出异常,手动处理
        )
        
        # 挂载适配器
        adapter = HTTPAdapter(max_retries=retry_strategy)
        self.session.mount("http://", adapter)
        self.session.mount("https://", adapter)
    
    def get(self, url: str, **kwargs) -> requests.Response:
        """GET请求"""
        return self.session.get(url, timeout=kwargs.pop('timeout', self.timeout), **kwargs)
    
    def post(self, url: str, **kwargs) -> requests.Response:
        """POST请求"""
        return self.session.post(url, timeout=kwargs.pop('timeout', self.timeout), **kwargs)
    
    def put(self, url: str, **kwargs) -> requests.Response:
        """PUT请求"""
        return self.session.put(url, timeout=kwargs.pop('timeout', self.timeout), **kwargs)
    
    def delete(self, url: str, **kwargs) -> requests.Response:
        """DELETE请求"""
        return self.session.delete(url, timeout=kwargs.pop('timeout', self.timeout), **kwargs)
    
    def close(self):
        """关闭Session"""
        self.session.close()


# 使用示例
if __name__ == '__main__':
    session = RetrySession(retries=5, backoff_factor=0.5)
    
    try:
        response = session.get("https://api.example.com/data")
        print(f"状态码: {response.status_code}")
        print(f"内容: {response.text[:100]}")
    except requests.exceptions.RequestException as e:
        print(f"请求失败: {e}")
    finally:
        session.close()

三、指数退避完整实现

更智能的重试策略,根据不同情况调整:

"""智能重试策略"""

import time
import random
import logging
from typing import Callable, Any, Optional, List
from dataclasses import dataclass
from enum import Enum
import requests

logger = logging.getLogger(__name__)


class RetryStrategy(Enum):
    """重试策略类型"""
    FIXED = "fixed"           # 固定间隔
    LINEAR = "linear"          # 线性递增
    EXPONENTIAL = "exponential"  # 指数退避
    FIBONACCI = "fibonacci"    # 斐波那契


@dataclass
class RetryConfig:
    """重试配置"""
    max_attempts: int = 3
    base_delay: float = 1.0
    max_delay: float = 60.0
    strategy: RetryStrategy = RetryStrategy.EXPONENTIAL
    jitter: bool = True  # 添加随机抖动
    jitter_factor: float = 0.3  # 抖动幅度
    
    # 可重试的异常
    retryable_exceptions: tuple = (ConnectionError, TimeoutError, 
                                     requests.exceptions.ConnectionError,
                                     requests.exceptions.Timeout,
                                     requests.exceptions.ChunkedEncodingError)
    
    # 可重试的状态码
    retryable_status_codes: tuple = (408, 429, 500, 502, 503, 504)


class SmartRetry:
    """智能重试器"""
    
    def __init__(self, config: RetryConfig = None):
        self.config = config or RetryConfig()
    
    def calculate_delay(self, attempt: int) -> float:
        """计算延迟时间"""
        if self.config.strategy == RetryStrategy.FIXED:
            delay = self.config.base_delay
        
        elif self.config.strategy == RetryStrategy.LINEAR:
            delay = self.config.base_delay * (attempt + 1)
        
        elif self.config.strategy == RetryStrategy.EXPONENTIAL:
            delay = self.config.base_delay * (2 ** attempt)
        
        elif self.config.strategy == RetryStrategy.FIBONACCI:
            # 斐波那契数列
            a, b = 1, 1
            for _ in range(attempt):
                a, b = b, a + b
            delay = self.config.base_delay * a
        
        else:
            delay = self.config.base_delay
        
        # 限制最大延迟
        delay = min(delay, self.config.max_delay)
        
        # 添加随机抖动
        if self.config.jitter:
            jitter_range = delay * self.config.jitter_factor
            delay = delay + random.uniform(-jitter_range, jitter_range)
        
        return max(0, delay)
    
    def should_retry(self, exception: Exception, status_code: int = None) -> bool:
        """判断是否应该重试"""
        # 检查异常类型
        if isinstance(exception, self.config.retryable_exceptions):
            return True
        
        # 检查状态码
        if status_code and status_code in self.config.retryable_status_codes:
            return True
        
        return False
    
    def execute(self, func: Callable, *args, **kwargs) -> Any:
        """执行带重试的函数"""
        last_exception = None
        
        for attempt in range(self.config.max_attempts):
            try:
                return func(*args, **kwargs)
            
            except Exception as e:
                last_exception = e
                status_code = getattr(e, 'response', None)
                if status_code:
                    status_code = status_code.status_code
                
                should_retry = self.should_retry(e, status_code)
                
                if not should_retry or attempt == self.config.max_attempts - 1:
                    logger.error(f"执行失败且不重试: {e}")
                    raise
                
                delay = self.calculate_delay(attempt)
                logger.warning(f"第 {attempt + 1} 次尝试失败: {e}")
                logger.info(f"等待 {delay:.2f} 秒后重试...")
                
                time.sleep(delay)
        
        raise last_exception


# 便捷函数
def retry_request(method: str, url: str, **kwargs) -> requests.Response:
    """带重试的请求"""
    config = RetryConfig(
        max_attempts=5,
        base_delay=1.0,
        strategy=RetryStrategy.EXPONENTIAL
    )
    
    retryer = SmartRetry(config)
    
    def request_func():
        session = requests.Session()
        return getattr(session, method.lower())(url, **kwargs)
    
    return retryer.execute(request_func)


# 使用示例
if __name__ == '__main__':
    # 使用智能重试器
    config = RetryConfig(
        max_attempts=5,
        base_delay=1.0,
        max_delay=30.0,
        strategy=RetryStrategy.EXPONENTIAL,
        jitter=True
    )
    
    retryer = SmartRetry(config)
    
    try:
        result = retryer.execute(
            lambda: requests.get("https://api.example.com/data", timeout=10)
        )
        print(f"成功: {result.status_code}")
    except Exception as e:
        print(f"最终失败: {e}")
    
    # 或者使用便捷函数
    print("\n使用便捷函数:")
    try:
        response = retry_request('get', 'https://httpbin.org/delay/2')
        print(f"状态码: {response.status_code}")
    except Exception as e:
        print(f"请求失败: {e}")

四、限流与退避

针对API限流(429状态码)有特殊处理:

"""限流退避处理"""

import time
import requests
from typing import Optional
from dataclasses import dataclass

@dataclass
class RateLimitConfig:
    """限流配置"""
    max_retries: int = 5
    base_delay: float = 1.0
    max_delay: float = 300.0  # 最多等5分钟
    respect_retry_after: bool = True  # 遵循Retry-After头


class RateLimitedSession:
    """带限流处理的Session"""
    
    def __init__(self, config: RateLimitConfig = None):
        self.config = config or RateLimitConfig()
        self.session = requests.Session()
        self.request_count = 0
        self.last_request_time = 0
        self.rate_limit_reset = 0  # Unix时间戳
    
    def _calculate_wait_time(self, response: requests.Response) -> float:
        """计算需要等待的时间"""
        # 方式1: 使用Retry-After头
        if self.config.respect_retry_after:
            retry_after = response.headers.get('Retry-After')
            if retry_after:
                try:
                    return float(retry_after)
                except ValueError:
                    pass
        
        # 方式2: 使用X-RateLimit-Reset头
        rate_reset = response.headers.get('X-RateLimit-Reset')
        if rate_reset:
            reset_time = float(rate_reset)
            current_time = time.time()
            if reset_time > current_time:
                return reset_time - current_time
        
        # 方式3: 指数退避
        return self.config.base_delay
    
    def _wait_for_rate_limit(self):
        """等待直到可以发送请求"""
        now = time.time()
        
        if self.rate_limit_reset > now:
            wait_time = self.rate_limit_reset - now
            print(f"限流中,等待 {wait_time:.1f} 秒...")
            time.sleep(wait_time)
    
    def request(self, method: str, url: str, **kwargs) -> requests.Response:
        """发送请求"""
        for attempt in range(self.config.max_retries):
            # 等待限流
            self._wait_for_rate_limit()
            
            try:
                response = self.session.request(method, url, **kwargs)
                
                # 检查限流
                if response.status_code == 429:
                    wait_time = self._calculate_wait_time(response)
                    self.rate_limit_reset = time.time() + wait_time
                    
                    if attempt < self.config.max_retries - 1:
                        print(f"触发限流 (429),等待 {wait_time:.1f} 秒...")
                        continue
                    else:
                        raise Exception(f"超过最大重试次数")
                
                # 检查其他可重试状态码
                if response.status_code in (500, 502, 503, 504):
                    if attempt < self.config.max_retries - 1:
                        delay = min(self.config.base_delay * (2 ** attempt), 
                                   self.config.max_delay)
                        print(f"服务器错误 ({response.status_code}),{delay:.1f}秒后重试...")
                        time.sleep(delay)
                        continue
                
                return response
                
            except requests.exceptions.RequestException as e:
                if attempt < self.config.max_retries - 1:
                    delay = self.config.base_delay * (2 ** attempt)
                    print(f"请求异常: {e}{delay:.1f}秒后重试...")
                    time.sleep(delay)
                else:
                    raise
        
        raise Exception("请求失败")


# 使用示例
if __name__ == '__main__':
    api_session = RateLimitedSession()
    
    # 自动处理限流
    try:
        response = api_session.request('GET', 'https://api.example.com/data')
        print(f"成功: {response.status_code}")
    except Exception as e:
        print(f"请求失败: {e}")

五、完整爬虫重试框架

结合以上所有技术,构建健壮的爬虫:

"""健壮爬虫框架"""

import time
import random
import logging
from typing import Optional, List, Dict, Any, Callable
from dataclasses import dataclass, field
from pathlib import Path
import requests
from urllib.parse import urljoin
import json

logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)


@dataclass
class SpiderConfig:
    """爬虫配置"""
    # 重试配置
    max_retries: int = 5
    base_delay: float = 1.0
    max_delay: float = 60.0
    
    # 请求配置
    timeout: int = 30
    headers: Dict[str, str] = field(default_factory=lambda: {
        'User-Agent': 'Mozilla/5.0 (Windows NT 10.0; Win64; x64) AppleWebKit/537.36'
    })
    
    # 代理配置
    proxies: List[str] = field(default_factory=list)
    
    # 限流
    rate_limit: float = 1.0  # 每秒最多请求数


class RobustSpider:
    """健壮爬虫"""
    
    def __init__(self, config: SpiderConfig = None):
        self.config = config or SpiderConfig()
        self.session = requests.Session()
        self.session.headers.update(self.config.headers)
        
        self.stats = {
            'success': 0,
            'failed': 0,
            'retries': 0,
            'rate_limited': 0
        }
        
        self._last_request_time = 0
    
    def _rate_limit(self):
        """限速控制"""
        elapsed = time.time() - self._last_request_time
        if elapsed < self.config.rate_limit:
            time.sleep(self.config.rate_limit - elapsed)
        self._last_request_time = time.time()
    
    def _get_proxy(self) -> Optional[Dict[str, str]]:
        """获取代理"""
        if not self.config.proxies:
            return None
        proxy = random.choice(self.config.proxies)
        return {
            'http': proxy,
            'https': proxy
        }
    
    def _calculate_delay(self, attempt: int) -> float:
        """计算重试延迟"""
        delay = self.config.base_delay * (2 ** attempt)
        # 添加随机抖动
        jitter = delay * 0.2 * random.random()
        return min(delay + jitter, self.config.max_delay)
    
    def fetch(self, url: str, retry_count: int = None) -> Optional[requests.Response]:
        """
        获取URL内容
        
        智能重试:
        1. 网络错误重试
        2. 5xx服务器错误重试
        3. 429限流等待
        4. 超时处理
        """
        if retry_count is None:
            retry_count = self.config.max_retries
        
        last_error = None
        
        for attempt in range(retry_count):
            try:
                self._rate_limit()
                
                response = self.session.get(
                    url,
                    timeout=self.config.timeout,
                    proxies=self._get_proxy()
                )
                
                # 处理状态码
                if response.status_code == 200:
                    self.stats['success'] += 1
                    return response
                
                elif response.status_code == 429:
                    # 限流
                    retry_after = response.headers.get('Retry-After', '60')
                    wait_time = float(retry_after)
                    logger.warning(f"触发限流,等待 {wait_time} 秒...")
                    self.stats['rate_limited'] += 1
                    time.sleep(wait_time)
                    continue
                
                elif 500 <= response.status_code < 600:
                    # 服务器错误,重试
                    if attempt < retry_count - 1:
                        delay = self._calculate_delay(attempt)
                        logger.warning(f"服务器错误 {response.status_code}{delay:.1f}秒后重试...")
                        self.stats['retries'] += 1
                        time.sleep(delay)
                        continue
                
                else:
                    # 其他错误,不重试
                    logger.error(f"HTTP {response.status_code}: {url}")
                    self.stats['failed'] += 1
                    return None
                
            except requests.exceptions.Timeout:
                if attempt < retry_count - 1:
                    delay = self._calculate_delay(attempt)
                    logger.warning(f"请求超时,{delay:.1f}秒后重试...")
                    self.stats['retries'] += 1
                    time.sleep(delay)
                else:
                    logger.error(f"请求超时(已重试{retry_count}次): {url}")
                    self.stats['failed'] += 1
                    last_error = "Timeout"
            
            except requests.exceptions.ConnectionError as e:
                if attempt < retry_count - 1:
                    delay = self._calculate_delay(attempt)
                    logger.warning(f"连接错误,{delay:.1f}秒后重试...")
                    self.stats['retries'] += 1
                    time.sleep(delay)
                else:
                    logger.error(f"连接失败: {e}")
                    self.stats['failed'] += 1
                    last_error = str(e)
            
            except Exception as e:
                logger.error(f"未知错误: {e}")
                self.stats['failed'] += 1
                last_error = str(e)
                break
        
        logger.error(f"获取失败({last_error}): {url}")
        return None
    
    def crawl(self, urls: List[str], callback: Callable = None) -> List[Dict]:
        """爬取多个URL"""
        results = []
        
        for url in urls:
            logger.info(f"正在爬取: {url}")
            
            response = self.fetch(url)
            
            if response and callback:
                try:
                    result = callback(response)
                    results.append(result)
                except Exception as e:
                    logger.error(f"处理响应失败: {e}")
            elif response:
                results.append({
                    'url': url,
                    'status_code': response.status_code,
                    'content': response.text[:500]
                })
            
            # 随机延时,避免被封
            time.sleep(random.uniform(1, 3))
        
        return results
    
    def get_stats(self) -> Dict[str, int]:
        """获取统计信息"""
        return dict(self.stats)
    
    def close(self):
        """关闭会话"""
        self.session.close()
        logger.info(f"爬取完成: {self.get_stats()}")


# 使用示例
if __name__ == '__main__':
    spider = RobustSpider()
    
    # 定义处理函数
    def parse_page(response: requests.Response) -> dict:
        return {
            'url': response.url,
            'title': response.text[:100],
            'status': response.status_code
        }
    
    # 爬取
    urls = [
        'https://httpbin.org/get',
        'https://httpbin.org/status/503',
        'https://httpbin.org/delay/1'
    ]
    
    results = spider.crawl(urls, callback=parse_page)
    
    print("\n结果:")
    for result in results:
        print(f"  {result}")
    
    print(f"\n统计: {spider.get_stats()}")
    spider.close()

六、重试最佳实践

  1. 区分可重试和不可重试的错误

    • 网络超时、重置 → 可重试
    • 认证失败、权限不足 → 不可重试
  2. 使用指数退避

    • 避免请求风暴
    • 给服务器恢复时间
  3. 添加随机抖动

    • 防止多个客户端同时重试
  4. 记录重试日志

    • 便于问题排查
  5. 设置最大延迟

    • 避免无限等待
  6. 考虑熔断器模式

    • 失败太多时暂停请求

好了,今天的分享就这里。重试机制虽小,但做好了对脚本的稳定性帮助很大。我是扣扣,有问题欢迎留言~🙃

更多推荐