diff --git a/config.default.toml b/config.default.toml index 16fb39d..5becae6 100644 --- a/config.default.toml +++ b/config.default.toml @@ -1,4 +1,4 @@ -dtypes = ["float32", "float16"] +dtypes = ["float32", "float16", "bfloat16"] device_map = "auto" diff --git a/src/heretic/model.py b/src/heretic/model.py index 5a3169b..ffb582e 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -47,6 +47,7 @@ class Model: self.generate([settings.test_prompt], max_new_tokens=1) except Exception as error: self.model = None + empty_cache() print(f"[red]Failed[/] ({error})") continue