Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError: generator raised StopIteration #17

Open
n0obcoder opened this issue Sep 23, 2020 · 2 comments
Open

RuntimeError: generator raised StopIteration #17

n0obcoder opened this issue Sep 23, 2020 · 2 comments

Comments

@n0obcoder
Copy link

i am getting the following error while training the model using the command

python train.py -d path/to/dataset/
  File "D:\daftar\recomendation systems\sequence-based-recommendations-master\neural_networks\rnn_base.py", line 405, in _gen_mini_batch
    sequence, user_id = next(sequence_generator)
StopIteration

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "train.py", line 60, in <module>
    main()
  File "train.py", line 57, in main
    validation_metrics=args.metrics.split(','))
  File "D:\daftar\recomendation systems\sequence-based-recommendations-master\neural_networks\rnn_base.py", line 324, in train
    metrics = self._compute_validation_metrics(metrics)
  File "D:\daftar\recomendation systems\sequence-based-recommendations-master\neural_networks\rnn_base.py", line 365, in _compute_validation_metrics
    for batch_input, goal in self._gen_mini_batch(self.dataset.validation_set(epochs=1), test=True):
RuntimeError: generator raised StopIteration

What might be causing this error?

@n0obcoder
Copy link
Author

i noticed that this error is occuring when all the 500 test examples are yielded and there is calling the next() method on this causes this error. But i still don't know how can i resolve this error.

@n0obcoder
Copy link
Author

i fixed it by the help of https://stackoverflow.com/questions/51700960/runtimeerror-generator-raised-stopiteration-every-time-i-try-to-run-app

i changed

while True:
	j = 0
	sequences = []
	batch_size = self.batch_size
	if test:
		batch_size = 1
	while j < batch_size:
		
		sequence, user_id = next(sequence_generator)

		# print('sequence: ', sequence)
		# print('user_id: ', user_id, 'test: ', test)
		# pdb.set_trace()

		# finds the lengths of the different subsequences
		if not test:
			seq_lengths = sorted(random.sample(range(2, len(sequence)), min([batch_size - j, len(sequence) - 2, max_reuse_sequence])))
		else:
			seq_lengths = [int(len(sequence) / 2)] 

		skipped_seq = 0
		for l in seq_lengths:
			target = self.target_selection(sequence[l:], test=test)
			if len(target) == 0:
				skipped_seq += 1
				continue
			start = max(0, l - self.max_length) # sequences cannot be longer than self.max_lenght
			sequences.append([user_id, sequence[start:l], target])

		j += len(seq_lengths) - skipped_seq

	if test:
		# sequence[seq_lengths[0]:] is the sequence (ratings included) GT here
		# [i[0] for i in sequence[seq_lengths[0]:]]  is the GT here (ratings excluded)
		yield self._prepare_input(sequences), [i[0] for i in sequence[seq_lengths[0]:]] 
	else:
		yield self._prepare_input(sequences)

to

while True:
	try:
		j = 0
		sequences = []
		batch_size = self.batch_size
		if test:
			batch_size = 1
		while j < batch_size:
			
			sequence, user_id = next(sequence_generator)

			# print('sequence: ', sequence)
			# print('user_id: ', user_id, 'test: ', test)
			# pdb.set_trace()

			# finds the lengths of the different subsequences
			if not test:
				seq_lengths = sorted(random.sample(range(2, len(sequence)), min([batch_size - j, len(sequence) - 2, max_reuse_sequence])))
			else:
				seq_lengths = [int(len(sequence) / 2)] 

			skipped_seq = 0
			for l in seq_lengths:
				target = self.target_selection(sequence[l:], test=test)
				if len(target) == 0:
					skipped_seq += 1
					continue
				start = max(0, l - self.max_length) # sequences cannot be longer than self.max_lenght
				sequences.append([user_id, sequence[start:l], target])

			j += len(seq_lengths) - skipped_seq

		if test:
			# sequence[seq_lengths[0]:] is the sequence (ratings included) GT here
			# [i[0] for i in sequence[seq_lengths[0]:]]  is the GT here (ratings excluded)
			yield self._prepare_input(sequences), [i[0] for i in sequence[seq_lengths[0]:]] 
		else:
			yield self._prepare_input(sequences)
	except StopIteration:
		return

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant