在 Python 中使用 Spark,主要通过 PySpark 库来实现。它是 Apache Spark 的 Python API,语法风格与 Pandas 有相似之处,但核心思想是分布式计算惰性求值

以下是 PySpark 核心语法的系统性介绍(基于目前主流的 Spark 3.x / 4.x 版本):

1. 入口:SparkSession

所有 PySpark 程序的起点,替代了旧版的 SparkContext

from pyspark.sql import SparkSession

spark = (SparkSession.builder
    .appName("MyApp")
    .master("local[*]")          # 本地测试用,生产环境由集群管理器指定
    .config("spark.some.config", "value")
    .getOrCreate())

2. DataFrame 创建

DataFrame 是 PySpark 最核心的数据结构,等价于分布式表。

方式 代码示例
从文件读取 spark.read.parquet("/path") / .csv() / .json() / .table("db.tbl")
从列表创建 spark.createDataFrame([(1,"a"),(2,"b")], ["id","name"])
从 Pandas 转换 spark.createDataFrame(pandas_df)
SQL 查询创建 spark.sql("SELECT * FROM table WHERE dt='2026-06-30'")

3. 核心变换语法 (Transformations)

⚠️ 关键概念:惰性求值。以下操作不会立即执行,只是记录逻辑计划,直到遇到 Action 操作才真正触发计算。

3.1 列操作 (pyspark.sql.functions)

这是 PySpark 最常用的模块,通常简写为 F

from pyspark.sql import functions as F

df.select(
    F.col("name"),                          # 引用列
    F.lit(2026).alias("year"),              # 常量列
    F.concat_ws("-", F.col("y"), F.col("m")), # 字符串拼接
    F.when(F.col("age") > 18, "adult").otherwise("minor"), # 条件判断
    F.date_add(F.current_date(), -7),       # 日期函数
    F.explode(F.col("tags"))                # 数组展开
)
3.2 常用 DataFrame 方法
# 过滤
df.filter(F.col("amount") > 100)
df.where("region = 'cn' AND dt = '2026-06-30'")  # 也支持SQL表达式字符串

# 聚合
df.groupBy("region").agg(
    F.sum("amount").alias("total"),
    F.countDistinct("user_id").alias("uv")
)

# 关联
df1.join(df2, on="user_id", how="left")
df1.join(df2, on=(df1.id == df2.uid) & (df1.dt == df2.dt), how="inner")

# 窗口函数
from pyspark.sql.window import Window
w = Window.partitionBy("dept").orderBy(F.desc("salary"))
df.withColumn("rank", F.row_number().over(w))

# 其他
df.dropDuplicates(["user_id"])      # 去重
df.na.fill(0, subset=["amount"])    # 空值填充
df.repartition(10, "dt")            # 重分区(影响Shuffle)
df.cache() / df.persist()           # 缓存到内存/磁盘

4. 动作语法 (Actions)

触发实际计算并返回结果到 Driver 端:

df.show(20, truncate=False)   # 打印前N行
df.count()                    # 计数
df.collect()                  # ⚠️ 拉取全部数据到Driver,大表慎用!
df.toPandas()                 # ⚠️ 转为Pandas DataFrame,同样慎用
df.write.parquet("/output")   # 写出文件
df.createOrReplaceTempView("tmp")  # 注册临时视图供SQL使用

5. Spark SQL 语法

如果不习惯 DataFrame API,可以直接写 SQL:

df.createOrReplaceTempView("sales")

result = spark.sql("""
    SELECT region, SUM(amount) as total
    FROM sales
    WHERE dt = '2026-06-30'
    GROUP BY region
    HAVING total > 10000
    ORDER BY total DESC
""")

6. UDF(用户自定义函数)

当内置函数无法满足需求时使用,但性能较差(Python 与 JVM 间有序列化开销),应优先使用内置函数或 Pandas UDF。

# 普通 UDF(逐行处理,慢)
@F.udf("string")
def mask_phone(phone):
    return phone[:3] + "****" + phone[7:] if phone else None

# Pandas UDF / Vectorized UDF(向量化,快10-100倍)✅ 推荐
@F.pandas_udf("double")
def normalize(s: pd.Series) -> pd.Series:
    return (s - s.mean()) / s.std()

7. PySpark vs Pandas 语法速查对照

操作 Pandas PySpark
选列 df[["a","b"]] df.select("a","b")
过滤 df[df.age>18] df.filter(F.col("age")>18)
新增列 df["new"] = ... df.withColumn("new", ...)
重命名 df.rename(columns={}) df.withColumnRenamed("old","new")
分组聚合 df.groupby().agg() df.groupBy().agg()
排序 df.sort_values() df.orderBy() / df.sort()
采样 df.sample(frac=0.1) df.sample(fraction=0.1)

💡 最佳实践与避坑指南

  1. 避免 collect() / toPandas():除非确认数据量很小(< 几GB),否则会导致 Driver OOM。调试时用 show()limit(100).toPandas()
  2. 优先用内置函数pyspark.sql.functions 里的函数在 JVM 端执行,比 Python UDF 快几个数量级。
  3. 注意 ShufflejoingroupByrepartitiondistinct 都会触发 Shuffle,是性能瓶颈。尽量用广播 Join(F.broadcast(small_df))减少 Shuffle。
  4. 合理设置分区数:分区太少导致并行度不足,太多导致任务调度开销大。一般每个分区 128MB~256MB 为宜。
  5. 类型安全:PySpark 是强类型的,字符串 "123" 和整数 123 不能直接比较/关联,需用 F.col().cast() 显式转换。
  6. Spark Connect (Spark 4.x 新特性):如果你使用的是较新版本,推荐使用 Spark Connect 客户端模式,将 Driver 与 Cluster 解耦,支持更好的 IDE 补全和远程开发体验。