SQLAlchemy,array_agg,并匹配一个输入列表
问题:SQLAlchemy,array_agg,并匹配一个输入列表
我正在尝试更充分地使用 SQLAlchemy,而不是在遇到困境的第一个迹象时退回到纯 SQL。在这种情况下,我在 Postgres 数据库 (9.5) 中有一个表,它通过将单个项目atom_id与组标识符group_id相关联来将一组整数存储为一个组。
给定一个atom_ids的列表,我希望能够弄清楚那组atom_ids属于哪个group_id,如果有的话。只用group_id和atom_id列解决这个问题很简单。
现在我试图概括这样一个“组”不仅由atom_ids的列表组成,而且还由其他上下文组成。在下面的示例中,列表通过包含sequence列进行排序,但从概念上讲,可以使用其他列,例如weight列,它为每个atom_id提供一个 [0,1] 浮点值,表示该原子的“份额”该组的。
以下是演示我的问题的大部分单元测试。
首先,一些设置:
def test_multi_column_grouping(self):
class MultiColumnGroups(base.Base):
__tablename__ = 'multi_groups'
group_id = Column(Integer)
atom_id = Column(Integer)
sequence = Column(Integer) # arbitrary 'other' column. In this case, an integer, but it could be a float (e.g. weighting factor)
base.Base.metadata.create_all(self.engine)
# Insert 6 rows representing 2 different 'groups' of values
vals = [
# Group 1
{'group_id': 1, 'atom_id': 1, 'sequence': 1},
{'group_id': 1, 'atom_id': 2, 'sequence': 2},
{'group_id': 1, 'atom_id': 3, 'sequence': 3},
# Group 2
{'group_id': 2, 'atom_id': 1, 'sequence': 3},
{'group_id': 2, 'atom_id': 2, 'sequence': 2},
{'group_id': 2, 'atom_id': 3, 'sequence': 1},
]
self.session.bulk_save_objects(
[MultiColumnGroups(**x) for x in vals])
self.session.flush()
self.assertEqual(6, len(self.session.query(MultiColumnGroups).all()))
现在,我想查询上表以查找一组特定的输入属于哪个组。我正在使用(命名的)元组列表来表示查询参数。
from collections import namedtuple
Entity = namedtuple('Entity', ['atom_id', 'sequence'])
values_to_match = [
# (atom_id, sequence)
Entity(1, 3),
Entity(2, 2),
Entity(3, 1),
]
# The above list _should_ match with `group_id == 2`
原始 SQL 解决方案。我_宁愿_不要依赖于此,因为本练习的一部分是学习更多 SQLAlchemy。
r = self.session.execute('''
select group_id
from multi_groups
group by group_id
having array_agg((atom_id, sequence)) = :query_tuples
''', {'query_tuples': values_to_match}).fetchone()
print(r) # > (2,)
self.assertEqual(2, r[0])
这是上面的原始 SQL 解决方案,直接转换为损坏的 SQLAlchemy 查询。运行此程序会产生 psycopg2 错误:(psycopg2.ProgrammingError) operator does not exist: record[] = integer[]。我相信我需要将array_agg转换为int[]?只要分组列都是整数(如果需要,这是一个可接受的限制),这将起作用,但理想情况下,这将适用于混合类型的输入元组/表列。
from sqlalchemy import tuple_
from sqlalchemy.dialects.postgresql import array_agg
existing_group = self.session.query(MultiColumnGroups).\
with_entities(MultiColumnGroups.group_id).\
group_by(MultiColumnGroups.group_id).\
having(array_agg(tuple_(MultiColumnGroups.atom_id, MultiColumnGroups.sequence)) == values_to_match).\
one_or_none()
self.assertIsNotNone(existing_group)
print('|{}|'.format(existing_group))
上面的session.query()关闭了吗?我是否在这里蒙蔽了自己,并且错过了可以以其他方式解决此问题的超级明显的东西?
解答
我认为您的解决方案会产生不确定的结果,因为组中的行未指定顺序,因此数组聚合和给定数组之间的比较可能会基于此产生真或假:
[local]:5432 u@sopython*=> select group_id
[local] u@sopython- > from multi_groups
[local] u@sopython- > group by group_id
[local] u@sopython- > having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
group_id
----------
2
(1 row)
[local]:5432 u@sopython*=> update multi_groups set atom_id = atom_id where atom_id = 2;
UPDATE 2
[local]:5432 u@sopython*=> select group_id
from multi_groups
group by group_id
having array_agg((atom_id, sequence)) = ARRAY[(1,3),(2,2),(3,1)];
group_id
----------
(0 rows)
您可以对两者应用排序,或者尝试完全不同的方法:您可以使用[关系除法](https://en.wikipedia.org/wiki/Relational_algebra#Division_(%C3%B7)代替数组比较。
为了划分你必须从你的Entity记录列表中形成一个临时关系。同样,有很多方法可以解决这个问题。这是使用嵌套数组的一个:
In [112]: vtm = select([
...: func.unnest(postgresql.array([
...: getattr(e, f) for e in values_to_match
...: ])).label(f)
...: for f in Entity._fields
...: ]).alias()
另一个使用联合:
In [114]: vtm = union_all(*[
...: select([literal(e.atom_id).label('atom_id'),
...: literal(e.sequence).label('sequence')])
...: for e in values_to_match
...: ]).alias()
临时表也可以。
有了手头的新关系,您想找到“找到那些不存在不存在于组中的实体的那些multi_groups”的答案。这是一个可怕的句子,但有道理:
In [117]: mg = aliased(MultiColumnGroups)
In [119]: session.query(MultiColumnGroups.group_id).\
...: filter(~exists().
...: select_from(vtm).
...: where(~exists().
...: where(MultiColumnGroups.group_id == mg.group_id).
...: where(tuple_(vtm.c.atom_id, vtm.c.sequence) ==
...: tuple_(mg.atom_id, mg.sequence)).
...: correlate_except(mg))).\
...: distinct().\
...: all()
...:
Out[119]: [(2)]
另一方面,您也可以只选择具有给定实体的组的交集:
In [19]: gs = intersect(*[
...: session.query(MultiColumnGroups.group_id).
...: filter(MultiColumnGroups.atom_id == vtm.atom_id,
...: MultiColumnGroups.sequence == vtm.sequence)
...: for vtm in values_to_match
...: ])
In [20]: session.execute(gs).fetchall()
Out[20]: [(2,)]
错误
ProgrammingError: (psycopg2.ProgrammingError) operator does not exist: record[] = integer[]
LINE 3: ...gg((multi_groups.atom_id, multi_groups.sequence)) = ARRAY[AR...
^
HINT: No operator matches the given name and argument type(s). You might need to add explicit type casts.
[SQL: 'SELECT multi_groups.group_id AS multi_groups_group_id \nFROM multi_groups GROUP BY multi_groups.group_id \nHAVING array_agg((multi_groups.atom_id, multi_groups.sequence)) = %(array_agg_1)s'] [parameters: {'array_agg_1': [[1, 3], [2, 2], [3, 1]]}] (Background on this error at: http://sqlalche.me/e/f405)
是您的values_to_match如何首先转换为列表列表(原因未知)然后通过 DB-API 驱动程序](http://initd.org/psycopg/docs/usage.html#lists-adaptation)将[转换为数组的结果。它产生一个整数数组数组,而不是记录数组(int,int)。使用原始 DB-API 连接和游标,传递元组列表可以按预期工作。
在 SQLAlchemy 中,如果你用sqlalchemy.dialects.postgresql.array()包装列表values_to_match,它会按照你的意思工作,但请记住结果是不确定的。
更多推荐
所有评论(0)