diff --git a/services/search/query.py b/services/search/query.py index dbe9dd7..22f0c11 100644 --- a/services/search/query.py +++ b/services/search/query.py @@ -29,7 +29,7 @@ def _extract_entities(query: str) -> Dict[str, List[str]]: cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip() for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned): entities["names"].append(token) - for year in re.findall(r"\b(19|20)\d{2}\b", cleaned): + for year in re.findall(r"\b(?:19|20)\d{2}\b", cleaned): entities["dates"].append(year) month_day_year = re.findall( r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b", diff --git a/src/search/query.py b/src/search/query.py index dbe9dd7..22f0c11 100644 --- a/src/search/query.py +++ b/src/search/query.py @@ -29,7 +29,7 @@ def _extract_entities(query: str) -> Dict[str, List[str]]: cleaned = re.sub(rf"^{qtype}\b", "", cleaned, flags=re.I).strip() for token in re.findall(r"\b[A-Z][a-zA-Z]+\b", cleaned): entities["names"].append(token) - for year in re.findall(r"\b(19|20)\d{2}\b", cleaned): + for year in re.findall(r"\b(?:19|20)\d{2}\b", cleaned): entities["dates"].append(year) month_day_year = re.findall( r"\b(?:Jan|January|Feb|February|Mar|March|Apr|April|May|Jun|June|Jul|July|Aug|August|Sep|Sept|September|Oct|October|Nov|November|Dec|December)\s+\d{1,2},?\s*\d{4}\b", diff --git a/tests/test_search_query.py b/tests/test_search_query.py new file mode 100644 index 0000000..7de6e4d --- /dev/null +++ b/tests/test_search_query.py @@ -0,0 +1,21 @@ +"""Tests for research query entity extraction (src/search/query.py).""" + +from src.search.query import _extract_entities + + +def test_extracts_full_four_digit_year(): + # Regression: the year pattern used a capturing group `(19|20)`, so + # re.findall returned just the century ("20") instead of the full year. + entities = _extract_entities("What happened to OpenAI in 2024") + assert "2024" in entities["dates"] + assert "20" not in entities["dates"] + + +def test_extracts_multiple_years(): + entities = _extract_entities("Compare revenue in 1999 and 2008") + assert entities["dates"] == ["1999", "2008"] + + +def test_no_false_year_from_other_numbers(): + entities = _extract_entities("Top 50 albums of all time") + assert entities["dates"] == []