diff --git a/src/heretic/main.py b/src/heretic/main.py index b99b1ac..20895b9 100644 --- a/src/heretic/main.py +++ b/src/heretic/main.py @@ -69,6 +69,7 @@ from .reproduce import collect_reproducibles from .system import empty_cache, get_accelerator_info from .utils import ( format_duration, + format_exception, get_readme_intro, get_trial_parameters, is_hf_path, @@ -364,7 +365,11 @@ def run(): # We cannot recover from this. raise - print(f"[red]Failed[/] ({error})") + formatted = format_exception(error) + if "\n" in formatted: + print(f"[red]Failed[/]:\n{formatted}") + else: + print(f"[red]Failed[/] ({formatted})") break response_lengths = [ @@ -1120,7 +1125,11 @@ def run(): print(table) except Exception as error: - print(f"[red]Error: {error}[/]") + formatted = format_exception(error) + if "\n" in formatted: + print(f"[red]Error:[/]\n{formatted}") + else: + print(f"[red]Error: {formatted}[/]") def main(): diff --git a/src/heretic/model.py b/src/heretic/model.py index 92eb98c..cb4c103 100644 --- a/src/heretic/model.py +++ b/src/heretic/model.py @@ -33,7 +33,7 @@ from transformers.generation import ( from .config import QuantizationMethod, RowNormalization, Settings from .system import empty_cache -from .utils import Prompt, batchify, print +from .utils import Prompt, batchify, format_exception, print def get_model_class( @@ -150,7 +150,11 @@ class Model: except Exception as error: self.model = None # ty:ignore[invalid-assignment] empty_cache() - print(f"* [red]Failed[/] ({error})") + formatted = format_exception(error) + if "\n" in formatted: + print(f"* [red]Failed[/]:\n{formatted}") + else: + print(f"* [red]Failed[/] ({formatted})") continue if settings.quantization == QuantizationMethod.BNB_4BIT: diff --git a/src/heretic/utils.py b/src/heretic/utils.py index 778d52e..3d2d788 100644 --- a/src/heretic/utils.py +++ b/src/heretic/utils.py @@ -7,6 +7,7 @@ import os import platform import random import tempfile +import traceback from dataclasses import dataclass from datetime import datetime, timezone from importlib.metadata import version @@ -746,3 +747,16 @@ def upload_reproduce_folder( repo_id=repo_id, token=token, ) + + +def format_exception(error: Exception) -> str: + # Walk causal chain to find a non-empty message. + current = error + while current is not None: + message = str(current).strip() + if message: + return message + current = current.__cause__ or current.__context__ + + # If there is no message in the entire causal chain, fall back to the complete traceback. + return traceback.format_exc().strip()