diff --git a/src/gretel_synthetics/batch.py b/src/gretel_synthetics/batch.py index b8601d75..f846af6d 100644 --- a/src/gretel_synthetics/batch.py +++ b/src/gretel_synthetics/batch.py @@ -291,7 +291,7 @@ def set_batch_validator(self, batch_idx: int, validator: Callable): except KeyError: raise ValueError("invalid batch number!") - def generate_batch_lines(self, batch_idx: int, max_invalid=1000): + def generate_batch_lines(self, batch_idx: int, max_invalid=MAX_INVALID): """Generate lines for a single batch. Lines generated are added to the underlying ``Batch`` object for each batch. The lines can be accessed after generation and re-assembled into a DataFrame. @@ -312,7 +312,7 @@ def generate_batch_lines(self, batch_idx: int, max_invalid=1000): t2 = tqdm(total=max_invalid, desc="Invalid record count ") line: gen_text for line in generate_text( - batch.config, line_validator=validator, max_invalid=MAX_INVALID + batch.config, line_validator=validator, max_invalid=max_invalid ): if line.valid is None or line.valid is True: batch.add_valid_data(line) diff --git a/tests/test_batch.py b/tests/test_batch.py index a43fd123..fafae1b9 100644 --- a/tests/test_batch.py +++ b/tests/test_batch.py @@ -139,6 +139,13 @@ def good(): def bad(): return gen_text(text="1,2,3", valid=False, delimiter=",") + with patch("gretel_synthetics.batch.generate_text") as mock_gen: + mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] + assert batches.generate_batch_lines(5, max_invalid=1) + check_call = mock_gen.mock_calls[0] + _, _, kwargs = check_call + assert kwargs["max_invalid"] == 1 + with patch("gretel_synthetics.batch.generate_text") as mock_gen: mock_gen.return_value = [good(), good(), good(), bad(), bad(), good(), good()] assert batches.generate_batch_lines(5) @@ -148,8 +155,11 @@ def bad(): assert not batches.generate_batch_lines(5) with patch.object(batches, "generate_batch_lines") as mock_gen: - batches.generate_all_batch_lines() + batches.generate_all_batch_lines(max_invalid=15) assert mock_gen.call_count == len(batches.batches.keys()) + check_call = mock_gen.mock_calls[0] + _, _, kwargs = check_call + assert kwargs["max_invalid"] == 15 # get synthetic df