Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -7,4 +7,4 @@ shscript: run.sh
copy_all_files: true
models_to_track:
- model
model_tracker_style: proxy
model_tracker_style: proxy # [Optional] "proxy" (default), "subclass", or "sampler".
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ pyscript: mnist.py # The Python entry point of your training program.
shscript: run.sh # [Optional] Shell script to launch the program with custom arguments or environment setup.
models_to_track: # [Optional] List of variable names for models you want to track. If omitted, model tracking is disabled.
- model
model_tracker_style: proxy # [Optional] Method for model tracking. Choose between "proxy" (default) or "sampler".
model_tracker_style: proxy # [Optional] Method for model tracking. Choose between "proxy" (default), "subclass", or "sampler".
copy_all_files: false # [Optional] Set to true if your code uses relative paths (e.g., loading local datasets or configs).
# This ensures TrainCheck copies the entire working directory before execution.
# Note: TrainCheck automatically handles PYTHONPATH. Default is false.
# Note: TrainCheck automatically handles PYTHONPATH. Default is false.
4 changes: 2 additions & 2 deletions docs/instr.md
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ modules_to_instr: # Libraries to instrument. Defaults to ['torch'] if
- torch
models_to_track: # [Optional] Variable names of models to track. Leave empty to disable model tracking.
- model
model_tracker_style: proxy # [Optional] Tracking method: "proxy" (default) or "sampler".
model_tracker_style: proxy # [Optional] Tracking method: "proxy" (default), "subclass", or "sampler".
copy_all_files: false # [Optional] Set true if your code relies on relative paths (e.g., local datasets/configs).
```
Expand Down Expand Up @@ -140,4 +140,4 @@ Instructions for defining and injecting meta variables into traces will be provi

## Instrumentation Mechanisms
📌 **[To Be Documented]**
Details about TrainCheck’s instrumentation strategies, supported APIs, and limitations will be covered here later.
Details about TrainCheck’s instrumentation strategies, supported APIs, and limitations will be covered here later.
2 changes: 1 addition & 1 deletion traincheck/checker.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,7 @@ def main():
trace_parent_folders = []
if args.traces is not None:
logger.info("Reading traces from %s", "\n".join(args.traces))
trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces[0]))]
trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces))]
traces.append(read_trace_file(args.traces))
if args.trace_folders is not None:
for trace_folder in args.trace_folders:
Expand Down
11 changes: 9 additions & 2 deletions traincheck/collect_trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -350,7 +350,7 @@ def main():
parser.add_argument(
"--model-tracker-style",
type=str,
choices=["sampler", "proxy"],
choices=["sampler", "proxy", "subclass"],
default="proxy",
)
parser.add_argument(
Expand All @@ -371,6 +371,11 @@ def main():
action="store_true",
help="Disable automatic variable instrumentation, necessary when the default behavior of the instrumentor is not desired (e.g. cause segmentation fault)",
)
parser.add_argument(
"--use-torch-compile",
action="store_true",
help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility",
)

args = parser.parse_args()

Expand Down Expand Up @@ -444,7 +449,7 @@ def main():
scan_proxy_in_args = not args.disable_scan_proxy_in_args

# if no proxy tracking specified in the arguments, disable the scan_proxy_in_args
if not args.models_to_track or args.model_tracker_style != "proxy":
if not args.models_to_track or args.model_tracker_style == "sampler":
scan_proxy_in_args = False

if args.invariants:
Expand Down Expand Up @@ -481,6 +486,7 @@ def main():
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
no_auto_var_instr=args.no_auto_var_instr,
use_torch_compile=args.use_torch_compile,
)
else:
source_code = instrumentor.instrument_file(
Expand All @@ -496,6 +502,7 @@ def main():
output_dir=output_dir,
instr_descriptors=args.instr_descriptors,
no_auto_var_instr=args.no_auto_var_instr,
use_torch_compile=args.use_torch_compile,
)

if args.copy_all_files:
Expand Down
14 changes: 13 additions & 1 deletion traincheck/config/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -89,6 +89,7 @@
]

INSTR_OPTS = None # TODO: set defaults for this variable
MODEL_TRACKER_STYLE: str | None = None

# var dumper related error-backoff configs
TYPE_ERR_THRESHOLD = 3
Expand All @@ -105,8 +106,9 @@ def __init__(
assert model_tracker_style in [
"sampler",
"proxy",
"subclass",
None,
], "model_tracker_style should be one of ['sampler', 'proxy', None]"
], "model_tracker_style should be one of ['sampler', 'proxy', 'subclass', None]"

self.funcs_instr_opts: dict[str, dict[str, bool | dict]] = func_instr_opts
self.model_tracker_style = model_tracker_style
Expand Down Expand Up @@ -238,6 +240,7 @@ def should_disable_proxy_dumping() -> bool:


INSTR_DESCRIPTORS = False
USE_TORCH_COMPILE = False

ALL_STAGE_NAMES = {
"init",
Expand All @@ -249,3 +252,12 @@ def should_disable_proxy_dumping() -> bool:
"preprocessing",
"postprocessing",
}

COMPILE_INTERNAL_MODULE = (
"torch.fx",
# "torch._dynamo",
"torch._inductor",
"torch._subclasses",
"torch._higher_order_ops",
"torch.utils._sympy",
)
85 changes: 76 additions & 9 deletions traincheck/instrumentor/dumper.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,12 +18,14 @@

# if torch.cuda.is_available():
from traincheck.proxy_wrapper.hash import tensor_hash
from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor
from traincheck.proxy_wrapper.proxy_config import (
attribute_black_list,
primitive_types,
proxy_attribute,
tensor_dump_format,
)
from traincheck.utils import get_timestamp_ns, typename
from traincheck.utils import get_timestamp_ns, typename, typename_compile

DEBUG = os.environ.get("ML_DAIKON_DEBUG", False)
THREAD_DATA = threading.local()
Expand All @@ -44,12 +46,48 @@
logger = logging.getLogger(__name__)


def _json_default(o):
try:
if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"):
return str(o)

if isinstance(o, torch.device):
return str(o)
if isinstance(o, torch.dtype):
return str(o)
if isinstance(o, torch.Size):
out = []
for d in o:
try:
out.append(int(d))
except Exception:
out.append(str(d))
return out
except Exception:
pass

if isinstance(o, set):
return list(o)
if isinstance(o, tuple):
return list(o)

try:
import numpy as np

if isinstance(o, (np.generic,)):
return o.item()
except Exception:
pass

return repr(o)


def serialize(obj_dict: dict[str, object | str]) -> str:
try:
return orjson.dumps(obj_dict).decode("utf-8")
return orjson.dumps(obj_dict, default=_json_default).decode("utf-8")
except Exception:
# if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json
return json.dumps(obj_dict)
return json.dumps(obj_dict, default=_json_default)


def monitor_main_thread(main_thread, stop_event):
Expand Down Expand Up @@ -335,6 +373,9 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
):
continue

if attr_name in proxy_attribute:
continue

if attr_name in attribute_black_list:
continue

Expand All @@ -346,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict

attr = safe_getattr(var, attr_name)
if attr is NOT_FOUND:
logger.warning(
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
)
if var_type not in skip_attrs_due_to_errs:
skip_attrs_due_to_errs[var_type] = set()
skip_attrs_due_to_errs[var_type].add(attr_name)
if not (
attr_name == "data"
and isinstance(var, torch.Tensor)
and not include_tensor_data
):
logger.warning(
f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute."
)
if var_type not in skip_attrs_due_to_errs:
skip_attrs_due_to_errs[var_type] = set()
skip_attrs_due_to_errs[var_type].add(attr_name)
continue

attr_name = str(attr_name)
Expand Down Expand Up @@ -395,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict
return result


def convert_fake_tensor_to_dict(var):
try:
shape = tuple(var.shape)
except Exception:
shape = None
try:
dtype = str(var.dtype)
except Exception:
dtype = None
return {
"fake": True,
"shape": shape,
"dtype": dtype,
}


def obj_to_serializable(obj, dump_config=None) -> dict[str, object]:
if is_fake_tensor(obj):
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}
if (
type(obj) in skip_type_due_to_recursion
and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD
Expand Down Expand Up @@ -429,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]:
If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`.
"""

if is_fake_tensor(obj):
return {typename_compile(obj): convert_fake_tensor_to_dict(obj)}

if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721
return obj_to_serializable(obj, dump_config=dump_config)

Expand Down
34 changes: 28 additions & 6 deletions traincheck/instrumentor/source_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]:


def instrument_all_model_assignments(
source_code: str, model_name: str, mode: str
source_code: str, model_name: str, mode: str | None
) -> str:
"""
Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement
Expand All @@ -292,8 +292,15 @@ def instrument_all_model_assignments(
instr_statement = ast.parse(
f"{model_name}_sampler = VarSampler({model_name}, var_name='{model_name}')"
)
elif mode == "subclass":
instr_statement = ast.parse(
f"proxy_parameter({model_name}, logdir=proxy_config.proxy_log_dir, parent_name='{model_name}')"
)

else:
raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']")
raise ValueError(
f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'subclass']"
)

# find all assignment statements to `model`
assignments = []
Expand Down Expand Up @@ -348,6 +355,7 @@ def instrument_model_tracker_proxy(
models_to_track: list[str],
adjusted_proxy_config: list[dict[str, int | bool | str]],
no_auto_var_instr: bool,
model_tracker_style: str | None,
):
auto_observer_config: dict[str, int | bool | str] = adjusted_proxy_config[0]
proxy_basic_config: dict[str, int | bool | str] = adjusted_proxy_config[1]
Expand All @@ -373,8 +381,13 @@ def instrument_model_tracker_proxy(
tensor_dump_format.update({tensor_dump_format})
"""

proxy_start_code += """
if model_tracker_style == "proxy":
proxy_start_code += """
from traincheck.proxy_wrapper.proxy import Proxy
"""
else:
proxy_start_code += """
from traincheck.proxy_wrapper.subclass import proxy_parameter
"""

if auto_observer_config["enable_auto_observer"]:
Expand Down Expand Up @@ -435,7 +448,7 @@ def instrument_model_tracker_proxy(
if not no_auto_var_instr:
for model in models_to_track:
instrumented_source = instrument_all_model_assignments(
instrumented_source, model, "proxy"
instrumented_source, model, model_tracker_style
)

code_head, code_tail = get_code_head_and_tail(instrumented_source)
Expand Down Expand Up @@ -797,6 +810,7 @@ def instrument_file(
output_dir: str,
instr_descriptors: bool,
no_auto_var_instr: bool,
use_torch_compile: bool,
) -> str:
"""
Instruments the given file and returns the instrumented source code.
Expand Down Expand Up @@ -833,20 +847,28 @@ def instrument_file(
general_config_update = f"""
import traincheck.config.config as general_config
general_config.INSTR_DESCRIPTORS = {instr_descriptors}
general_config.MODEL_TRACKER_STYLE = {model_tracker_style!r}
"""
if use_torch_compile:
torch_compile_config_update = """
general_config.USE_TORCH_COMPILE = True
"""
general_config_update = general_config_update + torch_compile_config_update
# TODO: move the INSTR_DESCRIPTORS to the instr_opts file

if models_to_track:
assert model_tracker_style in [
"proxy",
"sampler",
], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler']"
if model_tracker_style == "proxy":
"subclass",
], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler', 'subclass']"
if model_tracker_style == "proxy" or model_tracker_style == "subclass":
instrumented_source = instrument_model_tracker_proxy(
instrumented_source,
models_to_track,
adjusted_proxy_config,
no_auto_var_instr,
model_tracker_style,
)
else:
instrumented_source = instrument_model_tracker_sampler(
Expand Down
Loading