diff --git a/docs/assets/examples/traincheck-collect/gpt2-pretrain-config/config.yml b/docs/assets/examples/traincheck-collect/gpt2-pretrain-config/config.yml index 628b21b3..a014282d 100644 --- a/docs/assets/examples/traincheck-collect/gpt2-pretrain-config/config.yml +++ b/docs/assets/examples/traincheck-collect/gpt2-pretrain-config/config.yml @@ -7,4 +7,4 @@ shscript: run.sh copy_all_files: true models_to_track: - model -model_tracker_style: proxy \ No newline at end of file +model_tracker_style: proxy # [Optional] "proxy" (default), "subclass", or "sampler". diff --git a/docs/assets/examples/traincheck-collect/mnist-config/config.yml b/docs/assets/examples/traincheck-collect/mnist-config/config.yml index 34119a94..4aa57483 100644 --- a/docs/assets/examples/traincheck-collect/mnist-config/config.yml +++ b/docs/assets/examples/traincheck-collect/mnist-config/config.yml @@ -4,7 +4,7 @@ pyscript: mnist.py # The Python entry point of your training program. shscript: run.sh # [Optional] Shell script to launch the program with custom arguments or environment setup. models_to_track: # [Optional] List of variable names for models you want to track. If omitted, model tracking is disabled. - model -model_tracker_style: proxy # [Optional] Method for model tracking. Choose between "proxy" (default) or "sampler". +model_tracker_style: proxy # [Optional] Method for model tracking. Choose between "proxy" (default), "subclass", or "sampler". copy_all_files: false # [Optional] Set to true if your code uses relative paths (e.g., loading local datasets or configs). # This ensures TrainCheck copies the entire working directory before execution. - # Note: TrainCheck automatically handles PYTHONPATH. Default is false. \ No newline at end of file + # Note: TrainCheck automatically handles PYTHONPATH. Default is false. diff --git a/docs/instr.md b/docs/instr.md index d2781316..6f9be6e7 100644 --- a/docs/instr.md +++ b/docs/instr.md @@ -48,7 +48,7 @@ modules_to_instr: # Libraries to instrument. Defaults to ['torch'] if - torch models_to_track: # [Optional] Variable names of models to track. Leave empty to disable model tracking. - model -model_tracker_style: proxy # [Optional] Tracking method: "proxy" (default) or "sampler". +model_tracker_style: proxy # [Optional] Tracking method: "proxy" (default), "subclass", or "sampler". copy_all_files: false # [Optional] Set true if your code relies on relative paths (e.g., local datasets/configs). ``` @@ -140,4 +140,4 @@ Instructions for defining and injecting meta variables into traces will be provi ## Instrumentation Mechanisms 📌 **[To Be Documented]** -Details about TrainCheck’s instrumentation strategies, supported APIs, and limitations will be covered here later. \ No newline at end of file +Details about TrainCheck’s instrumentation strategies, supported APIs, and limitations will be covered here later. diff --git a/traincheck/checker.py b/traincheck/checker.py index 62045468..dd8816f5 100644 --- a/traincheck/checker.py +++ b/traincheck/checker.py @@ -153,7 +153,7 @@ def main(): trace_parent_folders = [] if args.traces is not None: logger.info("Reading traces from %s", "\n".join(args.traces)) - trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces[0]))] + trace_parent_folders = [os.path.basename(os.path.commonpath(args.traces))] traces.append(read_trace_file(args.traces)) if args.trace_folders is not None: for trace_folder in args.trace_folders: diff --git a/traincheck/collect_trace.py b/traincheck/collect_trace.py index 48bcfe87..854735f4 100644 --- a/traincheck/collect_trace.py +++ b/traincheck/collect_trace.py @@ -350,7 +350,7 @@ def main(): parser.add_argument( "--model-tracker-style", type=str, - choices=["sampler", "proxy"], + choices=["sampler", "proxy", "subclass"], default="proxy", ) parser.add_argument( @@ -371,6 +371,11 @@ def main(): action="store_true", help="Disable automatic variable instrumentation, necessary when the default behavior of the instrumentor is not desired (e.g. cause segmentation fault)", ) + parser.add_argument( + "--use-torch-compile", + action="store_true", + help="Indicate wthether use torch.compile to speed the model, necessary to realize compatibility", + ) args = parser.parse_args() @@ -444,7 +449,7 @@ def main(): scan_proxy_in_args = not args.disable_scan_proxy_in_args # if no proxy tracking specified in the arguments, disable the scan_proxy_in_args - if not args.models_to_track or args.model_tracker_style != "proxy": + if not args.models_to_track or args.model_tracker_style == "sampler": scan_proxy_in_args = False if args.invariants: @@ -481,6 +486,7 @@ def main(): output_dir=output_dir, instr_descriptors=args.instr_descriptors, no_auto_var_instr=args.no_auto_var_instr, + use_torch_compile=args.use_torch_compile, ) else: source_code = instrumentor.instrument_file( @@ -496,6 +502,7 @@ def main(): output_dir=output_dir, instr_descriptors=args.instr_descriptors, no_auto_var_instr=args.no_auto_var_instr, + use_torch_compile=args.use_torch_compile, ) if args.copy_all_files: diff --git a/traincheck/config/config.py b/traincheck/config/config.py index 51c457af..cc8e5e2c 100644 --- a/traincheck/config/config.py +++ b/traincheck/config/config.py @@ -89,6 +89,7 @@ ] INSTR_OPTS = None # TODO: set defaults for this variable +MODEL_TRACKER_STYLE: str | None = None # var dumper related error-backoff configs TYPE_ERR_THRESHOLD = 3 @@ -105,8 +106,9 @@ def __init__( assert model_tracker_style in [ "sampler", "proxy", + "subclass", None, - ], "model_tracker_style should be one of ['sampler', 'proxy', None]" + ], "model_tracker_style should be one of ['sampler', 'proxy', 'subclass', None]" self.funcs_instr_opts: dict[str, dict[str, bool | dict]] = func_instr_opts self.model_tracker_style = model_tracker_style @@ -238,6 +240,7 @@ def should_disable_proxy_dumping() -> bool: INSTR_DESCRIPTORS = False +USE_TORCH_COMPILE = False ALL_STAGE_NAMES = { "init", @@ -249,3 +252,12 @@ def should_disable_proxy_dumping() -> bool: "preprocessing", "postprocessing", } + +COMPILE_INTERNAL_MODULE = ( + "torch.fx", + # "torch._dynamo", + "torch._inductor", + "torch._subclasses", + "torch._higher_order_ops", + "torch.utils._sympy", +) diff --git a/traincheck/instrumentor/dumper.py b/traincheck/instrumentor/dumper.py index f6bf03fc..04935e8a 100644 --- a/traincheck/instrumentor/dumper.py +++ b/traincheck/instrumentor/dumper.py @@ -18,12 +18,14 @@ # if torch.cuda.is_available(): from traincheck.proxy_wrapper.hash import tensor_hash +from traincheck.proxy_wrapper.proxy_basics import is_fake_tensor from traincheck.proxy_wrapper.proxy_config import ( attribute_black_list, primitive_types, + proxy_attribute, tensor_dump_format, ) -from traincheck.utils import get_timestamp_ns, typename +from traincheck.utils import get_timestamp_ns, typename, typename_compile DEBUG = os.environ.get("ML_DAIKON_DEBUG", False) THREAD_DATA = threading.local() @@ -44,12 +46,48 @@ logger = logging.getLogger(__name__) +def _json_default(o): + try: + if type(o).__name__ in ("SymInt", "SymFloat", "SymBool"): + return str(o) + + if isinstance(o, torch.device): + return str(o) + if isinstance(o, torch.dtype): + return str(o) + if isinstance(o, torch.Size): + out = [] + for d in o: + try: + out.append(int(d)) + except Exception: + out.append(str(d)) + return out + except Exception: + pass + + if isinstance(o, set): + return list(o) + if isinstance(o, tuple): + return list(o) + + try: + import numpy as np + + if isinstance(o, (np.generic,)): + return o.item() + except Exception: + pass + + return repr(o) + + def serialize(obj_dict: dict[str, object | str]) -> str: try: - return orjson.dumps(obj_dict).decode("utf-8") + return orjson.dumps(obj_dict, default=_json_default).decode("utf-8") except Exception: # if orjson fails (e.g. cannot handle ints larger than 64-bit), fallback to json - return json.dumps(obj_dict) + return json.dumps(obj_dict, default=_json_default) def monitor_main_thread(main_thread, stop_event): @@ -335,6 +373,9 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict ): continue + if attr_name in proxy_attribute: + continue + if attr_name in attribute_black_list: continue @@ -346,12 +387,17 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict attr = safe_getattr(var, attr_name) if attr is NOT_FOUND: - logger.warning( - f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." - ) - if var_type not in skip_attrs_due_to_errs: - skip_attrs_due_to_errs[var_type] = set() - skip_attrs_due_to_errs[var_type].add(attr_name) + if not ( + attr_name == "data" + and isinstance(var, torch.Tensor) + and not include_tensor_data + ): + logger.warning( + f"Failed to get attribute {attr_name} of object type {type(var)}, skipping it for all following dumps for this attribute." + ) + if var_type not in skip_attrs_due_to_errs: + skip_attrs_due_to_errs[var_type] = set() + skip_attrs_due_to_errs[var_type].add(attr_name) continue attr_name = str(attr_name) @@ -395,7 +441,25 @@ def convert_var_to_dict(var, include_tensor_data=True, dump_config=None) -> dict return result +def convert_fake_tensor_to_dict(var): + try: + shape = tuple(var.shape) + except Exception: + shape = None + try: + dtype = str(var.dtype) + except Exception: + dtype = None + return { + "fake": True, + "shape": shape, + "dtype": dtype, + } + + def obj_to_serializable(obj, dump_config=None) -> dict[str, object]: + if is_fake_tensor(obj): + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} if ( type(obj) in skip_type_due_to_recursion and skip_type_due_to_recursion[type(obj)] > RECURSION_ERR_THRESHOLD @@ -429,6 +493,9 @@ def var_to_serializable(obj, dump_config=None) -> dict[str, object]: If you want to dump the `data` attribute of a tensor, use `convert_var_to_dict` and set `include_tensor_data=True`. """ + if is_fake_tensor(obj): + return {typename_compile(obj): convert_fake_tensor_to_dict(obj)} + if issubclass(type(obj), dict) and type(obj) != dict: # noqa E721 return obj_to_serializable(obj, dump_config=dump_config) diff --git a/traincheck/instrumentor/source_file.py b/traincheck/instrumentor/source_file.py index 4de57416..d0279fff 100644 --- a/traincheck/instrumentor/source_file.py +++ b/traincheck/instrumentor/source_file.py @@ -271,7 +271,7 @@ def get_child_parent_map(root) -> dict[ast.AST, ast.AST]: def instrument_all_model_assignments( - source_code: str, model_name: str, mode: str + source_code: str, model_name: str, mode: str | None ) -> str: """ Finds all assignment statements to `model` and inserts a Proxy statement or a VarSampler statement @@ -292,8 +292,15 @@ def instrument_all_model_assignments( instr_statement = ast.parse( f"{model_name}_sampler = VarSampler({model_name}, var_name='{model_name}')" ) + elif mode == "subclass": + instr_statement = ast.parse( + f"proxy_parameter({model_name}, logdir=proxy_config.proxy_log_dir, parent_name='{model_name}')" + ) + else: - raise ValueError(f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler']") + raise ValueError( + f"Invalid mode: {mode}. Must be one of ['proxy', 'sampler', 'subclass']" + ) # find all assignment statements to `model` assignments = [] @@ -348,6 +355,7 @@ def instrument_model_tracker_proxy( models_to_track: list[str], adjusted_proxy_config: list[dict[str, int | bool | str]], no_auto_var_instr: bool, + model_tracker_style: str | None, ): auto_observer_config: dict[str, int | bool | str] = adjusted_proxy_config[0] proxy_basic_config: dict[str, int | bool | str] = adjusted_proxy_config[1] @@ -373,8 +381,13 @@ def instrument_model_tracker_proxy( tensor_dump_format.update({tensor_dump_format}) """ - proxy_start_code += """ + if model_tracker_style == "proxy": + proxy_start_code += """ from traincheck.proxy_wrapper.proxy import Proxy +""" + else: + proxy_start_code += """ +from traincheck.proxy_wrapper.subclass import proxy_parameter """ if auto_observer_config["enable_auto_observer"]: @@ -435,7 +448,7 @@ def instrument_model_tracker_proxy( if not no_auto_var_instr: for model in models_to_track: instrumented_source = instrument_all_model_assignments( - instrumented_source, model, "proxy" + instrumented_source, model, model_tracker_style ) code_head, code_tail = get_code_head_and_tail(instrumented_source) @@ -797,6 +810,7 @@ def instrument_file( output_dir: str, instr_descriptors: bool, no_auto_var_instr: bool, + use_torch_compile: bool, ) -> str: """ Instruments the given file and returns the instrumented source code. @@ -833,20 +847,28 @@ def instrument_file( general_config_update = f""" import traincheck.config.config as general_config general_config.INSTR_DESCRIPTORS = {instr_descriptors} +general_config.MODEL_TRACKER_STYLE = {model_tracker_style!r} +""" + if use_torch_compile: + torch_compile_config_update = """ +general_config.USE_TORCH_COMPILE = True """ + general_config_update = general_config_update + torch_compile_config_update # TODO: move the INSTR_DESCRIPTORS to the instr_opts file if models_to_track: assert model_tracker_style in [ "proxy", "sampler", - ], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler']" - if model_tracker_style == "proxy": + "subclass", + ], f"Invalid model tracker style: {model_tracker_style}, must be one of ['proxy', 'sampler', 'subclass']" + if model_tracker_style == "proxy" or model_tracker_style == "subclass": instrumented_source = instrument_model_tracker_proxy( instrumented_source, models_to_track, adjusted_proxy_config, no_auto_var_instr, + model_tracker_style, ) else: instrumented_source = instrument_model_tracker_sampler( diff --git a/traincheck/instrumentor/tracer.py b/traincheck/instrumentor/tracer.py index cf28785a..feb63f4f 100644 --- a/traincheck/instrumentor/tracer.py +++ b/traincheck/instrumentor/tracer.py @@ -29,7 +29,11 @@ funcs_to_be_replaced, is_funcs_to_be_unproxied, ) -from traincheck.proxy_wrapper.proxy_basics import is_proxied, unproxy_func +from traincheck.proxy_wrapper.proxy_basics import ( + is_proxied, + is_proxyparameter, + unproxy_func, +) from traincheck.proxy_wrapper.proxy_config import enable_C_level_observer from traincheck.proxy_wrapper.proxy_registry import get_global_registry from traincheck.utils import get_timestamp_ns, get_unique_id, typename @@ -136,20 +140,16 @@ def to_dict_return_value(result) -> dict | list[dict]: return result_dict -def global_wrapper( +def global_wrapper_subclass( original_function: Callable, original_function_name: str, is_bound_method: bool, - is_builtin: bool, scan_proxy_in_args: bool, dump_stack_trace: bool, dump_args: bool, dump_args_config, dump_ret: bool, dump_ret_config, - handle_proxy: bool, - trigger_proxy_state_dump: bool, - proxy_state_dump_config: dict, *args, **kwargs, ): @@ -157,8 +157,6 @@ def global_wrapper( Pre-call Phase 1. Log the pre-call information - 2. Unproxy the arguments if the function is a C level function -- Proxy objects passed to built-in functions will cause segfault - 3. Add additional 'observer' (monitoring whether the input arguments have changed after the function call) to the function if specified Call Phase 1. Calls the original function @@ -168,6 +166,216 @@ def global_wrapper( 1. Log the post-call information """ + global DISABLE_WRAPPER + global PROCESS_ID + + if DISABLE_WRAPPER: + return original_function(*args, **kwargs) + + if COLLECT_OVERHEAD_METRICS: + ENTER_PERF_TIME = time.perf_counter() + + func_call_id = get_unique_id() + process_id, thread_id = get_process_thread_id() + increment_step_if_needed( + original_function, original_function_name, is_bound_method, args + ) + + pre_meta_vars = get_meta_vars() + + if IS_INSTRUMENTING: + return original_function( + *args, **kwargs + ) # don't instrument while instrumenting + + pre_record = { + "func_call_id": func_call_id, + "thread_id": thread_id, + "process_id": process_id, + "meta_vars": pre_meta_vars, + "type": TraceLineType.FUNC_CALL_PRE, + "function": original_function_name, + "is_bound_method": is_bound_method, + "obj_id": None if not is_bound_method else id(args[0]), + } + + if dump_stack_trace: + pre_record["stack_trace"] = traceback.format_stack() + + if scan_proxy_in_args: + proxy_in_args = [] + + def find_proxy_in_args(args): + for i, arg in enumerate(args): + if is_proxied(arg) or is_proxyparameter(arg): + proxy_in_args.append(arg) + elif type(arg) in [list, tuple]: + find_proxy_in_args(arg) + elif isinstance(arg, types.GeneratorType) and not isinstance( + arg, tuple + ): + arg_list = list(arg) + args[i] = iter(arg_list) + find_proxy_in_args(arg_list) + + args = list(args) # type: ignore[assignment] + find_proxy_in_args(args) + args = tuple(args) + + if proxy_in_args: + if "proxy_obj_names" not in pre_record: + pre_record["proxy_obj_names"] = [] + for proxy in proxy_in_args: + if is_proxyparameter(proxy): + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], "Parameter"] + ) + else: + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], type(proxy._obj).__name__] + ) + if dump_args: + dict_args_kwargs = to_dict_args_kwargs(args, kwargs, dump_args_config) + pre_record["args"] = dict_args_kwargs["args"] + pre_record["kwargs"] = dict_args_kwargs["kwargs"] + dump_trace_API(pre_record) + + try: + if COLLECT_OVERHEAD_METRICS: + ORIG_ENTER_PERF_TIME = time.perf_counter() + result = original_function(*args, **kwargs) + if COLLECT_OVERHEAD_METRICS: + ORIG_EXIT_PERF_TIME = time.perf_counter() + except Exception as e: + if COLLECT_OVERHEAD_METRICS: + ORIG_EXIT_PERF_TIME = time.perf_counter() + + dump_trace_API( + { + "func_call_id": func_call_id, + "thread_id": thread_id, + "process_id": process_id, + "meta_vars": pre_meta_vars, + "type": TraceLineType.FUNC_CALL_POST_EXCEPTION, + "function": original_function_name, + "exception": typename(e, is_runtime=True), + "exception_msg": str(e), + "is_bound_method": is_bound_method, + "obj_id": None if not is_bound_method else id(args[0]), + }, + ) + + if COLLECT_OVERHEAD_METRICS: + EXIT_PERF_TIME = time.perf_counter() + print( + f"WRAPPER TIME: {original_function_name},{ORIG_EXIT_PERF_TIME - ORIG_ENTER_PERF_TIME},{EXIT_PERF_TIME - ENTER_PERF_TIME}" + ) + raise e + + post_record = { + "func_call_id": func_call_id, + "thread_id": thread_id, + "process_id": process_id, + "meta_vars": pre_meta_vars, + "type": TraceLineType.FUNC_CALL_POST, + "function": original_function_name, + "is_bound_method": is_bound_method, + "obj_id": None if not is_bound_method else id(args[0]), + } + + result_to_dump = result + + # if the current function name is transformers.generate, then we will dump the response tokens only, let's see. + # a concrete name: "transformers.models.whisper.modeling_whisper.WhisperForConditionalGeneration.generate" + # we want a pattern that abstracts the specific model name + pattern = "transformers.models.*.generate" + # find matches in the pattern + import re + + if ( + GENERATE_START_TOKEN_ID is not None + and re.match(pattern, original_function_name) + and isinstance(result, torch.Tensor) + ): + print(f"Found match for {original_function_name}") + # the first dimension is the batch size, and each corresponds to a separate response, let's try to match the batch size with the start token ids first + response_starting_indices = [] + for i in range(result.size(0)): + # try to find the match of the start token ids in the response + response = result[i] + # Find all indices where the start_token_id matches + matches = (response == GENERATE_START_TOKEN_ID).nonzero(as_tuple=True)[0] + indexes = matches.tolist() + if len(indexes) == 0: + # No occurrences found + print( + f"start_token_id ({GENERATE_START_TOKEN_ID}) not found in response {i}" + ) + start_index = -1 # Handle case where token is not found + elif len(indexes) > 1: + # Multiple occurrences found, raise an error + raise ValueError( + f"Multiple occurrences of start_token_id ({GENERATE_START_TOKEN_ID}) found in response {i}: {matches.tolist()}" + ) + else: + # Single occurrence found, get the index + start_index = indexes[0] + if not GENERATE_START_TOKEN_ID_INCLUDE_START_TOKEN: + start_index += 1 + + response_starting_indices.append(start_index) + + # compute the length of each response + response_lengths = [] + for i in range(result.size(0)): + response = result[i] + start_index = response_starting_indices[i] + if start_index == -1: + response_lengths.append(0) + else: + response_lengths.append(response.size(0) - start_index) + + result_to_dump = result.detach() + setattr( + result_to_dump, + "_ML_DAIKON_RESPONSE_STARTING_INDICES", + response_starting_indices, + ) + setattr(result_to_dump, "_ML_DAIKON_RESPONSE_LENGTHS", response_lengths) + + print(response_starting_indices) + print(response_lengths) + if dump_ret: + post_record["return_values"] = to_dict_return_value(result_to_dump) + dump_trace_API(post_record) + + if COLLECT_OVERHEAD_METRICS: + EXIT_PERF_TIME = time.perf_counter() + print( + f"WRAPPER TIME: {original_function_name},{ORIG_EXIT_PERF_TIME - ORIG_ENTER_PERF_TIME},{EXIT_PERF_TIME - ENTER_PERF_TIME}" + ) + return result + + +def global_wrapper_proxy( + original_function: Callable, + original_function_name: str, + is_bound_method: bool, + is_builtin: bool, + scan_proxy_in_args: bool, + dump_stack_trace: bool, + dump_args: bool, + dump_args_config, + dump_ret: bool, + dump_ret_config, + handle_proxy: bool, + trigger_proxy_state_dump: bool, + proxy_state_dump_config: dict, + *args, + **kwargs, +): + """Instrumentation for APIs with proxy-specific handling.""" + # if "step" in original_function_name and not "scheduler" in original_function_name: # print("step function called" + original_function_name) # print(trigger_proxy_state_dump) @@ -215,7 +423,7 @@ def global_wrapper( def find_proxy_in_args(args): for i, arg in enumerate(args): - if is_proxied(arg): + if is_proxied(arg) or is_proxyparameter(arg): proxy_in_args.append(arg) elif type(arg) in [list, tuple]: find_proxy_in_args(arg) @@ -234,9 +442,14 @@ def find_proxy_in_args(args): if "proxy_obj_names" not in pre_record: pre_record["proxy_obj_names"] = [] for proxy in proxy_in_args: - pre_record["proxy_obj_names"].append( - [proxy.__dict__["var_name"], type(proxy._obj).__name__] - ) + if is_proxyparameter(proxy): + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], "Parameter"] + ) + else: + pre_record["proxy_obj_names"].append( + [proxy.__dict__["var_name"], type(proxy._obj).__name__] + ) if dump_args: dict_args_kwargs = to_dict_args_kwargs(args, kwargs, dump_args_config) pre_record["args"] = dict_args_kwargs["args"] @@ -392,9 +605,11 @@ def find_proxy_in_args(args): return result -def core_wrapper(original_function, is_builtin, handle_proxy, *args, **kwargs): - """same as global_wrapper but without the logging, will have lower overhead than global_wrapper - We use this wrapper on the functions that are not helpful for invariant inference, but still needs to be instrumented to handle proxy classes +def core_wrapper_proxy(original_function, is_builtin, handle_proxy, *args, **kwargs): + """Same as global_wrapper_proxy but without logging. + + We use this wrapper on functions that are not helpful for invariant inference, + but still need proxy-safe handling. """ global DISABLE_WRAPPER if DISABLE_WRAPPER: @@ -427,20 +642,34 @@ def wrapper( @functools.wraps(original_function) def wrapped(*args, **kwargs): - return global_wrapper( # the wrapper cannot be invoked with named parameters as *args has to be after the named parameters + if handle_proxy: + return global_wrapper_proxy( + original_function, + original_function_name, + is_bound_method, + is_builtin, + scan_proxy_in_args, + dump_stack_trace, + dump_args, + dump_args_config, + dump_ret, + dump_ret_config, + handle_proxy, + trigger_proxy_state_dump, + proxy_state_dump_config, + *args, + **kwargs, + ) + return global_wrapper_subclass( original_function, original_function_name, is_bound_method, - is_builtin, scan_proxy_in_args, dump_stack_trace, dump_args, dump_args_config, dump_ret, dump_ret_config, - handle_proxy, - trigger_proxy_state_dump, - proxy_state_dump_config, *args, **kwargs, ) @@ -451,7 +680,7 @@ def wrapped(*args, **kwargs): @functools.wraps(original_function) def wrapped(*args, **kwargs): - return core_wrapper( + return core_wrapper_proxy( original_function, is_builtin, handle_proxy, *args, **kwargs ) @@ -841,9 +1070,13 @@ def get_wrapped_function(self, func_obj: Callable) -> Callable: """Get the wrapped function for the provided function object, based on the instrumentation options provided in instr_opts.json. """ - used_proxy = True # TODO: dump instr_opts when doing full instr as well so we can determine whether to handle proxy based on the specific instrumentation args + tracker_style = ( + self.instr_opts.model_tracker_style + if self.instr_opts is not None + else config.MODEL_TRACKER_STYLE + ) + used_proxy = tracker_style == "proxy" if self.instr_opts is not None: - used_proxy = self.instr_opts.model_tracker_style == "proxy" func_name = typename(func_obj) if func_name not in self.instr_opts.funcs_instr_opts: return wrapper( diff --git a/traincheck/invariant/precondition.py b/traincheck/invariant/precondition.py index ad507dcb..b2040d83 100644 --- a/traincheck/invariant/precondition.py +++ b/traincheck/invariant/precondition.py @@ -537,7 +537,6 @@ def find_precondition_from_single_group( if len(example) == 0: raise ValueError("Empty example found in positive examples") - # HACK: in ConsistencyRelation in order to avoid the field used in the invariant, we need to skip the field in the precondition. It is up to the caller to provide the keys to skip. We should try to refactor this to have a more generic solution. earliest_time = example[0]["time"] process_id = example[0]["process_id"] thread_id = example[0]["thread_id"] diff --git a/traincheck/proxy_wrapper/proxy_basics.py b/traincheck/proxy_wrapper/proxy_basics.py index 11f8162b..dd3014bb 100644 --- a/traincheck/proxy_wrapper/proxy_basics.py +++ b/traincheck/proxy_wrapper/proxy_basics.py @@ -4,9 +4,52 @@ import astor +import traincheck.config.config as config + + +def is_compile_internal_module(obj): + mod = getattr(type(obj), "__module__", "") or "" + if any(mod.startswith(p) for p in config.COMPILE_INTERNAL_MODULE): + return True + name = type(obj).__name__ + if mod.startswith("torch._dynamo") and name != "OptimizedModule": + return True + return False + + +def is_fake_tensor(x) -> bool: + if not config.USE_TORCH_COMPILE: + return False + try: + from torch._subclasses.fake_tensor import FakeTensor + from torch.fx import Proxy as FxProxy + + if isinstance(x, FakeTensor): + return True + if isinstance(x, FxProxy): + return True + except Exception: + pass + + try: + if is_compile_internal_module(x): + return True + except Exception: + return True + + try: + if x.device.type == "meta": + return True + except Exception: + return True + + return False + def is_proxied(obj): try: + if is_fake_tensor(obj): + return False if obj is not None and "is_traincheck_proxied_obj" in obj.__dict__: return True except Exception: @@ -14,6 +57,17 @@ def is_proxied(obj): return False +def is_proxyparameter(obj): + try: + if is_fake_tensor(obj): + return False + if obj is not None and "is_traincheck_proxyparameter" in obj.__dict__: + return True + except Exception: + return False + return False + + def unproxy_arg(arg, inspect_torch_module=False): if is_proxied(arg): diff --git a/traincheck/proxy_wrapper/proxy_config.py b/traincheck/proxy_wrapper/proxy_config.py index 57c2d4d1..66ce6d7c 100644 --- a/traincheck/proxy_wrapper/proxy_config.py +++ b/traincheck/proxy_wrapper/proxy_config.py @@ -49,3 +49,14 @@ "real", ] attribute_black_list = tensor_attribute_black_list +# TODO +proxy_attribute = [ + "process_id", + "thread_id", + "logdir", + "log_level", + "loglevel", + "is_traincheck_proxyparameter", + "var_name", + "last_update_timestamp", +] diff --git a/traincheck/proxy_wrapper/proxy_observer.py b/traincheck/proxy_wrapper/proxy_observer.py index 06afcc2b..5316fed6 100644 --- a/traincheck/proxy_wrapper/proxy_observer.py +++ b/traincheck/proxy_wrapper/proxy_observer.py @@ -2,18 +2,22 @@ import typing from traincheck.config.config import should_disable_proxy_dumping +from traincheck.proxy_wrapper.subclass import ProxyParameter from traincheck.utils import typename if typing.TYPE_CHECKING: from traincheck.proxy_wrapper.proxy import Proxy -from .proxy_basics import is_proxied, unproxy_func + from traincheck.proxy_wrapper.subclass import ProxyParameter + +from .proxy_basics import is_proxied, is_proxyparameter, unproxy_func def observe_proxy_var( - var: "Proxy", + var: typing.Union["Proxy", "ProxyParameter"], phase, observe_api_name: str, ): + # update the proxy object's timestamp var.update_timestamp() @@ -37,9 +41,9 @@ def wrapper(*args, **kwargs): # if the arg is list or tuple, check if it contains proxied object if type(arg) in [list, tuple]: for element in arg: - if is_proxied(element): + if is_proxied(element) or is_proxyparameter(element): proxied_vars.append(element) - if is_proxied(arg): + if is_proxied(arg) or is_proxyparameter(arg): proxied_vars.append(arg) # pre observe diff --git a/traincheck/proxy_wrapper/subclass.py b/traincheck/proxy_wrapper/subclass.py new file mode 100644 index 00000000..19a0ebf0 --- /dev/null +++ b/traincheck/proxy_wrapper/subclass.py @@ -0,0 +1,239 @@ +import logging +import os +import threading + +import torch +from torch import nn + +from traincheck.config.config import should_disable_proxy_dumping +from traincheck.instrumentor.dumper import dump_trace_VAR +from traincheck.instrumentor.tracer import TraceLineType +from traincheck.proxy_wrapper.dumper import dump_attributes, get_meta_vars +from traincheck.utils import get_timestamp_ns + +from .proxy_basics import is_fake_tensor + +# from .proxy_registry import get_global_registry +# from .utils import print_debug + + +def in_dynamo() -> bool: + try: + import torch._dynamo as dynamo + + return bool(dynamo.is_compiling()) + except Exception: + return False + + +class ProxyParameter(torch.nn.Parameter): + loglevel = logging.INFO + + def __new__( + cls, + data, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + var_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, + ): + if isinstance(data, ProxyParameter): + return data + + if in_dynamo() or is_fake_tensor(data): + # we do not proxy the parameter if we are in dynamo or the tensor is a fake tensor + if isinstance(data, nn.Parameter): + return data + return nn.Parameter(data, requires_grad=data.requires_grad) + + requires_grad = getattr(data, "requires_grad", False) + tensor_grad = getattr(data, "grad", None) + + # When wrapping an existing Parameter we need to preserve any Python level + # attributes (e.g. hooks, user defined flags, ``grad``) so that the proxy + # behaves identically to the original parameter. ``Parameter.__new__`` + # returns a fresh instance, so we snapshot the metadata from ``data`` and + # replay it on the new ProxyParameter via the base Tensor ``__setattr__`` + # to avoid triggering the logging logic implemented in this class. + snapshot: dict = {} + + if isinstance(data, nn.Parameter): + snapshot = dict(getattr(data, "__dict__", {})) + base_tensor = data.detach() + elif isinstance(data, torch.Tensor): + base_tensor = data.detach() + else: + base_tensor = torch.as_tensor(data) + + proxied = super().__new__(cls, base_tensor, requires_grad=requires_grad) + + if snapshot: + tensor_setattr = torch.Tensor.__setattr__ + for name, value in snapshot.items(): + if name == "grad": + continue + try: + tensor_setattr(proxied, name, value) + except AttributeError: + # Some slots (e.g. torch internals) are read-only; skip them. + continue + + if tensor_grad is not None: + torch.Tensor.__setattr__(proxied, "grad", tensor_grad) + + return proxied + + def __init__( + self, + data, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + var_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, + ): + super().__init__() + # Access proxy attribute: since we are wrapping the getattr method, we need to access the attribute directly + self.__dict__["process_id"] = os.getpid() + self.__dict__["thread_id"] = threading.current_thread().ident + self.__dict__["logdir"] = logdir + self.__dict__["log_level"] = log_level + # TODO + # self.__dict__["meta_vars"] = {} + # self.__dict__["is_traincheck_proxied_obj"] = True + self.__dict__["is_traincheck_proxyparameter"] = True + # TODO + # self.__dict__["recurse"] = recurse + self.__dict__["var_name"] = var_name + # TODO + # self.__dict__["old_value"] = None + # self.__dict__["old_meta_vars"] = None + + current_time = get_timestamp_ns() + + self.__dict__["last_update_timestamp"] = current_time + + # print(f"init: {self.var_name}") + if should_dump_trace and not should_disable_proxy_dumping(): + if from_call: + phase = "call" + + if from_iter: + phase = "iter" + # if the object is generated from getattr, then do not dump it + else: + phase = "update" + self.dump_trace(phase=phase, dump_loc="initing") + + def __setattr__(self, name, value): + # print(f"paremeter: {self.var_name}, name = {name}, value = {value}") + super().__setattr__(name, value) + self.update_timestamp() + if should_disable_proxy_dumping(): + return + self.dump_trace( + phase="update", + dump_loc=f"__setattr__ (attribute '{name}')", + ) + + def __deepcopy__(self, memo): + data = self.detach().clone(memory_format=torch.preserve_format) + data.requires_grad_(self.requires_grad) + if in_dynamo() or is_fake_tensor(self): + return self + return type(self)( + data, + var_name=self.var_name, + ) + + def update_timestamp(self): + # Update the timestamp of the object, should be called when the object is updated, e.g. __setattr__ and observer + current_time = get_timestamp_ns() + self.__dict__["last_update_timestamp"] = current_time + # TODO: + # Proxy.var_dict[self.__dict__["var_name"]].last_update_timestamp = current_time + + def register_object(self): + # get_global_registry().add_var(self, self.__dict__["var_name"]) + # TODO: implement the registry, we will need to make sure the registerred timestamp is updated and is consistent with the timestamp in the object + pass + + def dump_trace(self, phase, dump_loc): + # print(f"parameter: {self.var_name}, phase = {phase}, dump_loc = {dump_loc}") + # TODO + var_name = self.__dict__["var_name"] + # assert var_name is not None # '' is allowed as a var_name (root object) + # filter_by_tensor_version = proxy_config.dump_info_config[ + # "filter_by_tensor_version" + # ] + # if filter_by_tensor_version and phase == "update": + # if hasattr(obj, "_version"): + # if obj._version == Proxy.var_dict[self.__dict__["var_name"]].version: + # return + + last_update_timestamp = self.__dict__["last_update_timestamp"] + + # TODO + # if not isinstance(obj, torch.nn.Module): + dump_trace_VAR( + { + "process_id": self.process_id, + "thread_id": self.thread_id, + "time": last_update_timestamp, + "meta_vars": get_meta_vars(self), + "var_name": var_name, + "var_type": "torch.nn.Parameter", + "mode": phase, + "dump_loc": dump_loc, + "attributes": dump_attributes(self, self), + "type": TraceLineType.STATE_CHANGE, + } + ) + + +def proxy_parameter( + module: nn.Module, + logdir="proxy_log.log", + log_level=logging.INFO, + # TODO + # recurse=False, + parent_name="", + should_dump_trace=True, + from_call=False, + from_iter=False, + # TODO + # from_copy=False, +): + if in_dynamo(): + return + for name, t in list(module.named_parameters(recurse=False)): + module._parameters[name] = ProxyParameter( + t, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) + for name, child in module.named_children(): + proxy_parameter( + child, + logdir, + log_level, + parent_name + "." + name, + should_dump_trace, + from_call, + from_iter, + ) diff --git a/traincheck/utils.py b/traincheck/utils.py index 9e332094..944fd989 100644 --- a/traincheck/utils.py +++ b/traincheck/utils.py @@ -35,6 +35,14 @@ def safe_getattr(obj, attr, default=None): raise +def typename_compile(o): + try: + mod = getattr(type(o), "__module__", "") or "" + return f"{mod}.{type(o).__name__}" + except Exception: + return "compile_stage" + + def typename(o, is_runtime=False): if isinstance(o, torch.nn.Parameter): return "torch.nn.Parameter"