Skip to content

Commit

Permalink
Set N for ngrams as argument (#1774)
Browse files Browse the repository at this point in the history
Co-authored-by: Andy Z <[email protected]>
  • Loading branch information
andyzorigin and Andy Z authored Aug 9, 2023
1 parent 3857be3 commit f5ef0fc
Show file tree
Hide file tree
Showing 3 changed files with 10 additions and 7 deletions.
7 changes: 7 additions & 0 deletions scripts/data_overlap/common/arguments.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,4 +23,11 @@ def get_data_overlap_args() -> Any:
parser.add_argument(
"--normalization", type=str, default="default", help="What normalization and tokenization strategy to apply"
)
parser.add_argument(
"--N",
type=int,
nargs="*",
default=[5, 9, 13],
help="N for ngrams that we want to run, defaults to 5, 9, 13",
)
return parser.parse_args()
5 changes: 1 addition & 4 deletions scripts/data_overlap/compute_data_overlap_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,9 +25,6 @@
from scenarios.scenario import ScenarioSpec


# The n values of the ngrams to be computed
N_VALUES: List[int] = [5, 9, 13] # TODO: Pick the N values

PART_INPUT: str = "input"
PART_REF: str = "references"

Expand Down Expand Up @@ -211,7 +208,7 @@ def compute_document_data_overlap(
with htrack_block("Initializing the stats, ngram_index, and ngram_counter"):
ngram_index: NgramIndex
ngram_index = create_ngram_index(
light_scenarios=light_scenarios, n_values=N_VALUES, tokenizer=tokenizer, stats_key_counts=stats_key_counts
light_scenarios=light_scenarios, n_values=args.N, tokenizer=tokenizer, stats_key_counts=stats_key_counts
)

# DataOverlapStatsKey -> Set[str] for ids
Expand Down
5 changes: 2 additions & 3 deletions scripts/data_overlap/run_data_overlap_beam.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import apache_beam as beam

from typing import Callable, List
from typing import Callable

from data_overlap_beam import ComputeAndWriteDataOverlapStats
from common.arguments import get_data_overlap_args
Expand All @@ -25,7 +25,6 @@ def extract_text_from_raw_document(document: str) -> str:
def main():
args = get_data_overlap_args()

n_values: List[int] = [5, 9, 13] # TODO: Pick the N values
extract_text_from_document: Callable[[str], str] = get_extract_text_function(args.input_format)

# The model developer should pass in the appropriate PipelineOptions here.
Expand All @@ -39,7 +38,7 @@ def main():
| "ComputeAndWriteDataOverlapStats"
>> ComputeAndWriteDataOverlapStats(
scenario_data_path=args.scenario_data,
n_values=n_values,
n_values=args.N,
normalization=args.normalization,
tags={"tags:": args.tags},
output_stats=args.output_stats,
Expand Down

0 comments on commit f5ef0fc

Please sign in to comment.