Answer a question

Does it call forward() in nn.Module? I thought when we call the model, forward method is being used. Why do we need to specify train()?

Answers

model.train() tells your model that you are training the model. This helps inform layers such as Dropout and BatchNorm, which are designed to behave differently during training and evaluation. For instance, in training mode, BatchNorm updates a moving average on each new batch; whereas, for evaluation mode, these updates are frozen.

More details: model.train() sets the mode to train (see source code). You can call either model.eval() or model.train(mode=False) to tell that you are testing. It is somewhat intuitive to expect train function to train model but it does not do that. It just sets the mode.

Logo

Python社区为您提供最前沿的新闻资讯和知识内容

更多推荐