diff --git a/services/hwfit/fit.py b/services/hwfit/fit.py index fb0e006..6dd3b89 100644 --- a/services/hwfit/fit.py +++ b/services/hwfit/fit.py @@ -280,10 +280,14 @@ def _native_quant(model): return "FP8" if "gptq" in text: m = re.search(r"(?:gptq|int|w)(?:[-_]?)(\d{1,2})(?:bit)?", text) - return f"GPTQ-{m.group(1)}bit" if m else "GPTQ" + # Canonical catalog label is "GPTQ-Int4"/"GPTQ-Int8" (see models.py + # QUANT_BPP / QUANT_QUALITY_PENALTY keys); "GPTQ-4bit" misses both + # maps, so BPP and the quality penalty silently fall to defaults. + return f"GPTQ-Int{m.group(1)}" if m else "GPTQ-Int4" if "awq" in text: m = re.search(r"(?:awq|int|w)(?:[-_]?)(\d{1,2})(?:bit)?", text) - return f"AWQ-{m.group(1)}bit" if m else "AWQ" + # Catalog keys are "AWQ-4bit"/"AWQ-8bit"; bare "AWQ" misses the maps. + return f"AWQ-{m.group(1)}bit" if m else "AWQ-4bit" if "mlx" in text: m = re.search(r"mlx[-_]?(\d{1,2})bit", text) return f"mlx-{m.group(1)}bit" if m else native_quant diff --git a/tests/test_hwfit_native_quant_labels.py b/tests/test_hwfit_native_quant_labels.py new file mode 100644 index 0000000..c73f979 --- /dev/null +++ b/tests/test_hwfit_native_quant_labels.py @@ -0,0 +1,42 @@ +"""_native_quant must emit canonical quant labels that key the cost maps. + +services/hwfit/models.py keys QUANT_BPP and QUANT_QUALITY_PENALTY on +"GPTQ-Int4"/"GPTQ-Int8" and "AWQ-4bit"/"AWQ-8bit". _native_quant returned +"GPTQ-4bit" (and bare "AWQ" when no digit), which miss both maps, so a +pre-quantized GPTQ/AWQ model fell back to the default BPP (0.58 instead of +0.50) and a zero quality penalty, over-estimating VRAM and inflating the +score. The label is also shown in the UI and disagreed with the catalog. +""" +from services.hwfit.fit import _native_quant +from services.hwfit.models import QUANT_BPP, QUANT_QUALITY_PENALTY + + +def test_gptq_int4_label_is_canonical(): + label = _native_quant({"name": "Qwen2.5-32B-Instruct-GPTQ-Int4"}) + assert label == "GPTQ-Int4" + assert label in QUANT_BPP and label in QUANT_QUALITY_PENALTY + + +def test_gptq_int8_label_is_canonical(): + label = _native_quant({"name": "x-GPTQ-Int8"}) + assert label == "GPTQ-Int8" + assert label in QUANT_BPP and label in QUANT_QUALITY_PENALTY + + +def test_awq_no_digit_falls_back_to_canonical(): + label = _native_quant({"name": "SomeModel-AWQ"}) + assert label == "AWQ-4bit" + assert label in QUANT_BPP and label in QUANT_QUALITY_PENALTY + + +def test_awq_with_digit_is_canonical(): + label = _native_quant({"name": "x-AWQ-8bit"}) + assert label == "AWQ-8bit" + assert label in QUANT_BPP and label in QUANT_QUALITY_PENALTY + + +def test_gptq_fallback_label_is_in_maps(): + # GPTQ mentioned with no parseable bit-width + label = _native_quant({"name": "model-gptq", "format": ""}) + assert label == "GPTQ-Int4" + assert label in QUANT_BPP and label in QUANT_QUALITY_PENALTY