Python自动化脚本断点续传下载实战:大文件处理完整指南
·
大家好,我是扣扣。今天来聊聊一个很实用的功能——断点续传下载。
为什么要关心断点续传?
你有没有遇到过这些情况:
- 下载一个大文件,下载到99%的时候网络断了,得从头再来
- 公司网络不稳定,几十MB的文件死活下载不下来
- 凌晨跑个定时任务下载数据,结果因为网络波动失败了
断点续传就是来解决这些问题的。原理很简单:记住下载到哪了,下次从断点继续。
一、基础实现
先从最简单的方式开始:
"""基础断点续传下载"""
import requests
from pathlib import Path
import os
def download_with_resume(url: str, dest: str, chunk_size: int = 8192) -> bool:
"""
支持断点续传的下载函数
原理:
1. 检查本地已有文件大小
2. 发送Range请求,从已有大小位置开始下载
3. 追加写入文件
"""
dest_path = Path(dest)
# 获取请求头(探测服务器是否支持断点续传)
response_head = requests.head(url, allow_redirects=True)
# 检查是否支持Range(断点续传)
supports_ranges = response_head.headers.get('Accept-Ranges') == 'bytes'
total_size = int(response_head.headers.get('Content-Length', 0))
print(f"文件大小: {total_size / 1024 / 1024:.2f} MB")
print(f"支持断点续传: {supports_ranges}")
# 获取已有文件大小
existing_size = 0
if dest_path.exists():
existing_size = dest_path.stat().st_size
print(f"已有文件大小: {existing_size / 1024 / 1024:.2f} MB")
# 如果文件已下载完成,直接返回
if existing_size >= total_size and total_size > 0:
print("文件已下载完成")
return True
# 设置请求头,从断点开始下载
headers = {}
if supports_ranges and existing_size > 0:
headers['Range'] = f'bytes={existing_size}-'
print(f"从断点 {existing_size / 1024 / 1024:.2f} MB 处继续下载")
# 下载文件
response = requests.get(url, headers=headers, stream=True, timeout=60)
# 服务器可能返回416(范围无效),此时从头开始下载
if response.status_code == 416:
print("服务器不支持该范围,从头开始下载")
response = requests.get(url, stream=True, timeout=60)
mode = 'wb'
else:
mode = 'ab' # 追加写入
downloaded = existing_size
with open(dest, mode) as f:
for chunk in response.iter_content(chunk_size=chunk_size):
if chunk:
f.write(chunk)
downloaded += len(chunk)
# 显示进度
if total_size > 0:
progress = (downloaded / total_size) * 100
print(f"\r进度: {progress:.1f}% ({downloaded / 1024 / 1024:.2f} MB)", end='')
print() # 换行
return True
if __name__ == '__main__':
url = "https://example.com/largefile.zip"
download_with_resume(url, "./downloads/largefile.zip")
二、进阶实现:多线程断点续传
单线程下载大文件还是太慢?我们来加多线程:
"""多线程断点续传下载器"""
import requests
import threading
import queue
import os
import time
from pathlib import Path
from typing import List, Tuple
from dataclasses import dataclass
@dataclass
class DownloadTask:
"""下载任务"""
thread_id: int
start: int
end: int
url: str
dest: str
@dataclass
class DownloadResult:
"""下载结果"""
thread_id: int
success: bool
bytes_downloaded: int
error: str = None
class MultiThreadDownloader:
"""多线程断点续传下载器"""
def __init__(self, url: str, dest: str, num_threads: int = 4):
self.url = url
self.dest = Path(dest)
self.num_threads = num_threads
self.chunk_size = 8192
self.total_size = 0
self.downloaded_size = 0
self.lock = threading.Lock()
self.tasks: queue.Queue = queue.Queue()
self.results: List[DownloadResult] = []
# 创建临时目录
self.temp_dir = self.dest.parent / '.temp'
self.temp_dir.mkdir(exist_ok=True)
def get_file_info(self) -> dict:
"""获取文件信息"""
response = requests.head(self.url, allow_redirects=True, timeout=30)
return {
'total_size': int(response.headers.get('Content-Length', 0)),
'supports_ranges': response.headers.get('Accept-Ranges') == 'bytes',
'content_type': response.headers.get('Content-Type', ''),
'filename': self._get_filename(response)
}
def _get_filename(self, response) -> str:
"""从响应头或URL获取文件名"""
# 尝试从Content-Disposition获取
content_disposition = response.headers.get('Content-Disposition')
if content_disposition:
import re
match = re.search(r'filename="?([^";\n]+)"?', content_disposition)
if match:
return match.group(1)
# 从URL提取
return self.url.split('/')[-1].split('?')[0]
def _get_existing_size(self) -> int:
"""获取已下载文件大小"""
if self.dest.exists():
return self.dest.stat().st_size
return 0
def _create_tasks(self, start: int, end: int) -> List[DownloadTask]:
"""创建下载任务分段"""
chunk_size = (end - start) // self.num_threads
tasks = []
for i in range(self.num_threads):
task_start = start + i * chunk_size
task_end = task_start + chunk_size if i < self.num_threads - 1 else end
tasks.append(DownloadTask(
thread_id=i,
start=task_start,
end=task_end,
url=self.url,
dest=str(self.temp_dir / f'.part_{i}')
))
return tasks
def _download_chunk(self, task: DownloadTask) -> DownloadResult:
"""下载单个分块"""
headers = {
'Range': f'bytes={task.start}-{task.end}'
}
try:
response = requests.get(task.url, headers=headers, stream=True, timeout=60)
# 检查响应状态
if response.status_code not in (200, 206):
return DownloadResult(
thread_id=task.thread_id,
success=False,
bytes_downloaded=0,
error=f"HTTP {response.status_code}"
)
# 下载并写入临时文件
downloaded = 0
with open(task.dest, 'wb') as f:
for chunk in response.iter_content(chunk_size=self.chunk_size):
if chunk:
f.write(chunk)
downloaded += len(chunk)
# 更新总进度
with self.lock:
self.downloaded_size += len(chunk)
return DownloadResult(
thread_id=task.thread_id,
success=True,
bytes_downloaded=downloaded
)
except Exception as e:
return DownloadResult(
thread_id=task.thread_id,
success=False,
bytes_downloaded=0,
error=str(e)
)
def _merge_files(self):
"""合并临时文件"""
with open(self.dest, 'wb') as dest_file:
for i in range(self.num_threads):
temp_file = self.temp_dir / f'.part_{i}'
if temp_file.exists():
with open(temp_file, 'rb') as f:
dest_file.write(f.read())
temp_file.unlink() # 删除临时文件
# 清理临时目录
if not list(self.temp_dir.iterdir()):
self.temp_dir.rmdir()
def download(self) -> bool:
"""执行下载"""
# 获取文件信息
info = self.get_file_info()
self.total_size = info['total_size']
if self.total_size == 0:
print("无法获取文件大小")
return False
print(f"文件大小: {self.total_size / 1024 / 1024:.2f} MB")
print(f"使用 {self.num_threads} 个线程下载")
# 检查已下载大小
existing_size = self._get_existing_size()
if existing_size >= self.total_size:
print("文件已下载完成")
return True
# 如果不支持断点续传且有部分下载,删除重来
if existing_size > 0 and not info['supports_ranges']:
print("服务器不支持断点续传,删除旧文件重新下载")
self.dest.unlink()
existing_size = 0
# 创建并执行任务
tasks = self._create_tasks(existing_size, self.total_size)
threads = []
for task in tasks:
t = threading.Thread(target=lambda t: self.results.append(self._download_chunk(t)), args=(task,))
t.start()
threads.append(t)
# 等待所有线程完成并显示进度
start_time = time.time()
while any(t.is_alive() for t in threads):
time.sleep(0.5)
elapsed = time.time() - start_time
speed = self.downloaded_size / elapsed / 1024 if elapsed > 0 else 0
progress = (self.downloaded_size / self.total_size) * 100 if self.total_size > 0 else 0
print(f"\r进度: {progress:.1f}% | 速度: {speed:.1f} KB/s", end='', flush=True)
for t in threads:
t.join()
print()
# 检查结果
if all(r.success for r in self.results):
self._merge_files()
print(f"下载完成: {self.dest}")
return True
else:
failed = [r for r in self.results if not r.success]
print(f"下载失败: {len(failed)} 个线程出错")
for r in failed:
print(f" 线程 {r.thread_id}: {r.error}")
return False
if __name__ == '__main__':
downloader = MultiThreadDownloader(
url="https://example.com/largefile.zip",
dest="./downloads/largefile.zip",
num_threads=4
)
downloader.download()
三、带重试的稳健下载器
网络不稳定的环境需要更强的容错能力:
"""带重试和断点续传的稳健下载器"""
import requests
import time
import random
from pathlib import Path
from typing import Optional, Callable
from dataclasses import dataclass, field
import json
@dataclass
class DownloadState:
"""下载状态(用于断点续传)"""
url: str
dest: str
total_size: int
downloaded_size: int
attempts: int = 0
max_speed: float = 0
start_time: float = field(default_factory=time.time)
def save(self, state_file: str):
"""保存状态"""
with open(state_file, 'w') as f:
json.dump({
'url': self.url,
'dest': self.dest,
'total_size': self.total_size,
'downloaded_size': self.downloaded_size,
'attempts': self.attempts,
'max_speed': self.max_speed,
'start_time': self.start_time
}, f)
@classmethod
def load(cls, state_file: str) -> Optional['DownloadState']:
"""加载状态"""
if not Path(state_file).exists():
return None
with open(state_file, 'r') as f:
data = json.load(f)
return cls(**data)
class RobustDownloader:
"""稳健下载器:断点续传 + 自动重试 + 指数退避"""
def __init__(self,
url: str,
dest: str,
chunk_size: int = 8192,
max_retries: int = 5,
timeout: int = 60):
self.url = url
self.dest = Path(dest)
self.chunk_size = chunk_size
self.max_retries = max_retries
self.timeout = timeout
self.state_file = str(self.dest.parent / f'.{self.dest.name}.state')
self.state: Optional[DownloadState] = None
# 回调函数
self.progress_callback: Optional[Callable] = None
def set_progress_callback(self, callback: Callable[[int, int, float], None]):
"""设置进度回调 (downloaded, total, speed)"""
self.progress_callback = callback
def _calculate_delay(self, attempt: int) -> float:
"""计算重试延迟(指数退避 + 抖动)"""
base_delay = min(2 ** attempt, 60) # 最多60秒
jitter = random.uniform(0, base_delay * 0.3) # 30%抖动
return base_delay + jitter
def _download(self,
headers: dict = None,
start_pos: int = 0) -> tuple:
"""执行下载,返回 (success, bytes_downloaded, error)"""
try:
response = requests.get(
self.url,
headers=headers,
stream=True,
timeout=self.timeout
)
# 处理不支持Range的情况
if response.status_code == 416:
return False, 0, "范围请求超出文件大小"
response.raise_for_status()
mode = 'ab' if start_pos > 0 else 'wb'
downloaded = 0
with open(self.dest, mode) as f:
for chunk in response.iter_content(chunk_size=self.chunk_size):
if chunk:
f.write(chunk)
downloaded += len(chunk)
return True, downloaded, None
except requests.exceptions.Timeout:
return False, 0, "请求超时"
except requests.exceptions.ConnectionError:
return False, 0, "连接错误"
except requests.exceptions.HTTPError as e:
return False, 0, f"HTTP错误: {e}"
except Exception as e:
return False, 0, f"未知错误: {e}"
def download(self) -> bool:
"""执行稳健下载"""
# 尝试加载状态
self.state = DownloadState.load(self.state_file)
if self.state and self.state.url == self.url:
print(f"恢复下载,已下载: {self.state.downloaded_size / 1024 / 1024:.2f} MB")
start_pos = self.state.downloaded_size
self.state.attempts += 1
else:
# 新下载
response = requests.head(self.url, allow_redirects=True, timeout=30)
total_size = int(response.headers.get('Content-Length', 0))
self.state = DownloadState(
url=self.url,
dest=str(self.dest),
total_size=total_size,
downloaded_size=0
)
start_pos = 0
if self.state.total_size > 0:
print(f"文件大小: {self.state.total_size / 1024 / 1024:.2f} MB")
# 指数退避重试
attempt = 0
while attempt < self.max_retries:
print(f"尝试下载 (第 {attempt + 1} 次)...")
headers = {}
if start_pos > 0:
headers['Range'] = f'bytes={start_pos}-'
success, downloaded, error = self._download(headers, start_pos)
if success:
self.state.downloaded_size += downloaded
self.state.save(self.state_file)
print(f"下载完成: {self.state.downloaded_size / 1024 / 1024:.2f} MB")
self._cleanup()
return True
attempt += 1
self.state.attempts += 1
if attempt < self.max_retries:
delay = self._calculate_delay(attempt)
print(f"失败: {error},{delay:.1f}秒后重试...")
time.sleep(delay)
# 重置文件位置(下次从断点续传)
start_pos = self.state.downloaded_size
print(f"达到最大重试次数 ({self.max_retries})")
self.state.save(self.state_file)
return False
def _cleanup(self):
"""清理状态文件"""
if Path(self.state_file).exists():
Path(self.state_file).unlink()
# 使用示例
if __name__ == '__main__':
def progress(downloaded, total, speed):
if total > 0:
print(f"\r进度: {downloaded/total*100:.1f}% | 速度: {speed/1024:.1f} KB/s", end='')
downloader = RobustDownloader(
url="https://example.com/largefile.zip",
dest="./downloads/largefile.zip",
max_retries=5
)
downloader.set_progress_callback(progress)
success = downloader.download()
print(f"\n下载{'成功' if success else '失败'}")
四、异步下载版本
用asyncio实现高性能异步下载:
"""异步断点续传下载器"""
import asyncio
import aiohttp
import aiofiles
from pathlib import Path
from dataclasses import dataclass
from typing import List, Optional
@dataclass
class AsyncDownloader:
"""异步下载器"""
url: str
dest: str
num_connections: int = 4
chunk_size: int = 1024 * 1024 # 1MB
_session: Optional[aiohttp.ClientSession] = None
_total_size: int = 0
_downloaded: int = 0
_lock: asyncio.Lock = None
async def __aenter__(self):
self._session = aiohttp.ClientSession()
self._lock = asyncio.Lock()
# 获取文件大小
async with self._session.head(self.url) as resp:
self._total_size = int(resp.headers.get('Content-Length', 0))
return self
async def __aexit__(self, *args):
if self._session:
await self._session.close()
async def _download_range(self, start: int, end: int, part_num: int):
"""下载指定范围"""
headers = {'Range': f'bytes={start}-{end}'}
temp_file = Path(self.dest).parent / f'.part_{part_num}'
async with self._session.get(self.url, headers=headers) as resp:
async with aiofiles.open(temp_file, 'wb') as f:
async for chunk in resp.content.iter_chunked(self.chunk_size):
await f.write(chunk)
async with self._lock:
self._downloaded += len(chunk)
async def download(self):
"""执行异步下载"""
tasks = []
chunk_size = self._total_size // self.num_connections
for i in range(self.num_connections):
start = i * chunk_size
end = start + chunk_size if i < self.num_connections - 1 else self._total_size - 1
tasks.append(self._download_range(start, end, i))
await asyncio.gather(*tasks)
# 合并文件
with open(self.dest, 'wb') as out_file:
for i in range(self.num_connections):
temp_file = Path(self.dest).parent / f'.part_{i}'
with open(temp_file, 'rb') as f:
out_file.write(f.read())
temp_file.unlink()
return True
# 使用
async def main():
async with AsyncDownloader(
url="https://example.com/file.zip",
dest="./downloads/file.zip",
num_connections=4
) as downloader:
print(f"文件大小: {downloader._total_size / 1024 / 1024:.2f} MB")
await downloader.download()
print("下载完成")
if __name__ == '__main__':
asyncio.run(main())
总结
- 基础版:用
Range请求头实现简单断点续传 - 多线程版:分段并行下载,提高速度
- 稳健版:加重试、状态保存、指数退避
- 异步版:用
aiohttp实现高性能并发
断点续传是个看起来简单但细节很多的功能。希望今天的分享对你有帮助。我是扣扣,有问题留言~🙃
更多推荐
所有评论(0)