diff --git a/src/heretic/main.py b/src/heretic/main.py index fcc7e3d..1e5cad0 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -966,6 +966,7 @@ def run(): hflm = HFLM( pretrained=model.model, # ty:ignore[invalid-argument-type] tokenizer=model.tokenizer, # ty:ignore[invalid-argument-type] + batch_size="auto", ) table = Table() @@ -989,7 +990,6 @@ def run(): results = lm_eval.simple_evaluate( model=hflm, tasks=[benchmark.task], - batch_size="auto", ) return results["results"][benchmark.task]