I would like to know if using a DataLoader connected to a MongoDB is a sensible thing to do and how this could be implemented.
Background
I have about 20 million documents sitting in a (local) MongoDB. Way more documents than fit in memory. I would like to train a deep neural net on the data. So far, I have been exporting the data to the file system first, with subfolders named as the classes of the documents. But I find this approach nonsensical. Why export first (and later delete) if the data is already well-maintained sitting in a DB.
Question 1:
Am I right? Would it make sense to directly connect to the MongoDB? Or are there reasons not to do it (e.g. DBs generally being too slow etc.)? If DBs are too slow (why?), can one prefetch the data somehow?
Question 2:
How would one implement a PyTorch DataLoader
? I have found only very few code snippets online ([1] and [2]) which makes me doubt my approach.
Code snippet
The general way how I access MongoDB is as follows below. Nothing special about this, I think.
import pymongo
from pymongo import MongoClient
myclient = pymongo.MongoClient("mongodb://localhost:27017/")
mydb = myclient["xyz"]
mycol = mydb["xyz_documents"]
query = {
# some filters
}
results = mycol.find(query)
# results is now a cursor that can run through all docs
# Assume, for the sake of this example, that each doc contains a class name and some image that I want to train a classifier on
Introduction
This one is a little open-ended, but let's try, also please correct me if I'm wrong somewhere.
So far, I have been exporting the data to the file system first, with subfolders named as the classes of the documents.
IMO this isn't sensible because:
- you are essentially duplicating data
- any time you would like to train a-new given only code and database this operation would have to be repeated
- you can access multiple datapoints at once and cache them in RAM for later reuse without reading from hard drive multiple times (which is quite heavy)
Am I right? Would it make sense to directly connect to the MongoDB?
Given above, probably yes (especially when it comes to clear and portable implementation)
Or are there reasons not to do it (e.g. DBs generally being to slow etc.)?
AFAIK DB shouldn't be slower in this case as it will cache access to it, but I'm no db expert unfortunately. Many tricks for faster access are implemented out-of-the-box for databases.
can one prefetch the data somehow?
Yes, if you just want to get data, you could load a larger part of data (say 1024
records) at one go and return batches of data from that (say batch_size=128
)
Implementation
How would one implement a PyTorch DataLoader? I have found only very few code snippets online ([1] and [2]) which makes me doubt with my approach.
I'm not sure why would you want to do that. What you should go for is torch.utils.data.Dataset
as shown in the examples you've listed.
I would start with simple non-optimized approach similar to the one here, so:
- open connection to db in
__init__
and keep it as long as it's used (I would create a context manager from torch.utils.data.Dataset
so the connection is closed after epochs are finished)
- I would not transform the results to
list
(especially since you cannot fit it in RAM for obvious reasons) as it misses the point of generators
- I would perform batching inside this Dataset (there is an argument
batch_size
here).
- I am not sure about
__getitem__
function but it seems it can return multiple datapoints at once, hence I'd use that and it should allow us to use num_workers>0
(given that mycol.find(query)
returns data in the same order every time)
Given that, something along those lines is what I'd do:
class DatabaseDataset(torch.utils.data.Dataset):
def __init__(self, query, batch_size, path: str, database: str):
self.batch_size = batch_size
client = pymongo.MongoClient(path)
self.db = client[database]
self.query = query
# Or non-approximate method, if the approximate method
# returns smaller number of items you should be fine
self.length = self.db.estimated_document_count()
self.cursor = None
def __enter__(self):
# Ensure that this find returns the same order of query every time
# If not, you might get duplicated data
# It is rather unlikely (depending on batch size), shouldn't be a problem
# for 20 million samples anyway
self.cursor = self.db.find(self.query)
return self
def shuffle(self):
# Find a way to shuffle data so it is returned in different order
# If that happens out of the box you might be fine without it actually
pass
def __exit__(self, *_, **__):
# Or anything else how to close the connection
self.cursor.close()
def __len__(self):
return len(self.examples)
def __getitem__(self, index):
# Read takes long, hence if you can load a batch of documents it should speed things up
examples = self.cursor[index * batch_size : (index + 1) * batch_size]
# Do something with this data
...
# Return the whole batch
return data, labels
Now batching is taken care of by DatabaseDataset
, hence torch.utils.data.DataLoader
can have batch_size=1
. You might need to squeeze additional dimension.
As MongoDB
uses locks (which is no surprise, but see here) num_workers>0
shouldn't be a problem.
Possible usage (schematically):
with DatabaseDataset(...) as e:
dataloader = torch.utils.data.DataLoader(e, batch_size=1)
for epoch in epochs:
for batch in dataloader:
# And all the stuff
...
dataset.shuffle() # after each epoch
Remember about shuffling implementation in such case! (also shuffling can be done inside context manager and you might want to close connection manually or something along those lines).
所有评论(0)