Answer a question

I have to stack some my own layers on different kinds of pytorch models with different devices.

E.g. A is a cuda model and B is a cpu model (but I don't know it before I get the device type). Then the new models are C and D respectively, where

class NewModule(torch.nn.Module):
    def __init__(self, base):
        super(NewModule, self).__init__()
        self.base = base
        self.extra = my_layer() # e.g. torch.nn.Linear()

    def forward(self,x):
        y = self.base(x)
        z = self.extra(y)
        return z

...

C = NewModule(A) # cuda
D = NewModule(B) # cpu

However I must move base and extra to the same device, i.e. base and extra of C are cuda models and D's are cpu models. So I tried this __inin__:

def __init__(self, base):
    super(NewModule, self).__init__()
    self.base = base
    self.extra = my_layer().to(base.device)

Unfortunately, there's no attribute device in torch.nn.Module(raise AttributeError).

What should I do to get the device type of base? Or any other method to make base and extra to be on the same device automaticly even the structure of base is unspecific?

Answers

This question has been asked many times (1, 2). Quoting the reply from a PyTorch developer:

That’s not possible. Modules can hold parameters of different types on different devices, and so it’s not always possible to unambiguously determine the device.

The recommended workflow (as described on PyTorch blog) is to create the device object separately and use that everywhere. Copy-pasting the example from the blog here:

# at beginning of the script
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")

...

# then whenever you get a new Tensor or Module
# this won't copy if they are already on the desired device
input = data.to(device)
model = MyModule(...).to(device)

Do note that there is nothing stopping you from adding a .device property to the models.

As mentioned by Kani (in the comments), if the all the parameters in the model are on the same device, one could use next(model.parameters()).device.

Logo

Python社区为您提供最前沿的新闻资讯和知识内容

更多推荐