Answer a question

Is there a simple/clean way to iterate an array of axis returned by subplots like

nrow = ncol = 2
a = []
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)
for i, row in enumerate(axs):
    for j, ax in enumerate(row):
        a.append(ax)

for i, ax in enumerate(a):
    ax.set_ylabel(str(i))

which even works for nrow or ncol == 1.

I tried list comprehension like:

[element for tupl in tupleOfTuples for element in tupl]

but that fails if nrows or ncols == 1

Answers

The ax return value is a numpy array, which can be reshaped, I believe, without any copying of the data. If you use the following, you'll get a linear array that you can iterate over cleanly.

nrow = 1; ncol = 2;
fig, axs = plt.subplots(nrows=nrow, ncols=ncol)

for ax in axs.reshape(-1): 
  ax.set_ylabel(str(i))

This doesn't hold when ncols and nrows are both 1, since the return value is not an array; you could turn the return value into an array with one element for consistency, though it feels a bit like a cludge:

nrow = 1; ncol = 1;
fig, axs = plt.subplots(nrows=nrow, ncols=nrow)
axs = np.array(axs)

for ax in axs.reshape(-1):
  ax.set_ylabel(str(i))

reshape docs. The argument -1 causes reshape to infer dimensions of the output.

Logo

Python社区为您提供最前沿的新闻资讯和知识内容

更多推荐