-
Notifications
You must be signed in to change notification settings - Fork 19.6k
keras.utils.PyDataset / tf.keras.utils.Sequence ignoring __len__ different behavior Keras2/Keras3 (tensorflow 2.16) #19994
New issue
Have a question about this project? No Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “No Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? No Sign in to your account
Comments
Hi @Pagey- KGen class inherits from tf.keras.utils.Sequence. In tf.keras.utils.Sequence(PyDataset class) implement So in KGen class, |
Thanks @mehtamansi29 - it looks like you changed the code in the gist between the tensorflow 2.15 and 2.16 versions?
is supposed to represent an infinite data generator and thus is not limited to the length of self.alist. It could have just been written there: return np.random.random() in any case this represents a difference in behavior between the two versions, i.e. one that is terminated after len()/ i saw that in the new version method |
Hello, I am seeing the same problem. I was looking at the source for class PyDataset I think I may have found an issue. I first noticed that the number of calls to getitem, more calls are made than len returns. I tried setting up a def num_batches property since it is part of the PyDataset class on lines 157 and 158 of the source code for py_dataset_adapter.py. I might be wrong but when calling that property and supplying a calculation for the number of batches, I get an error: File "/home/user/miniconda3/envs//lib/python3.12/site-packages/keras/src/trainers/data_adapters/py_dataset_adapter.py", line 288, in get_tf_dataset The code for that section is calling num_batches = self.py_dataset.num_batches This seems off to me but I may be wrong that the line num_batches = self.py_dataset.num_batches is returning the num_batches method which cannot be evaluated by the min statement and should be num_batches = self.py_dataset.num_batches() and above at the at line 165 the def num_batches should return if hasattr(self, "len"): where it has: return len(self) which should just give an infinite recursion error. The combination of these bugs might be leading to len being None and causing the generator to be called an infinite number of times. I've also noticed that the number of calls to the getitem of the PyDataset class, that some indexes are called more than once as if some calls are made, something fails, and then getitem is called again. When I print the number of times getitem has been called, the number of batches delivered to training is always larger than what the progress bar reports when training a model in tensorflow. I generate my data on the fly, perhaps I need to design my dataset better, My particular issue is that I process the data coming off the disk before sending to training to save system memory, Doing this with multiple workers on the i-o side causes issues where the data is not necessarily associated with a particular index. I get that is an issue on my part but it would be nice if the description of PyDataset made it clear that not all batches of data delivered are used in training. |
Hi @thirstythurston, thanks for reporting this. You can try to terminate it using
Attaching gist for reference. Thanks! |
This issue is stale because it has been open for 14 days with no activity. It will be closed if no further activity occurs. Thank you. |
This issue was closed because it has been inactive for 28 days. Please reopen if you'd like to work on this further. |
Hi there - paraphrasing an issue from 2018 :
change is return idx#self.alist[idx] in
__getitem__
. this is relevant in cases of generated datasets- i.e. it looks as though__len__
value is ignored and it used not to be?the above code on tensorflow 2.15 (Python 3.10.13, Ubuntu 20.04) produces this output:
and on tensorflow 2.16 (Python 3.10.13, Ubuntu 20.04) produces this output:
The text was updated successfully, but these errors were encountered: