|
- # -*- coding: UTF-8 -*-
- """
- -----------------------------------
- @Author : Encore
- @Date : 2024/3/29
- -----------------------------------
- """
- import contextlib
- import copy
- import functools
- import glob
- import importlib.metadata
- import inspect
- import math
- import os
- import random
- import re
- import shutil
- import sys
- import tempfile
- import time
- import warnings
- from collections.abc import Mapping
- from pathlib import Path
- from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple, Union
-
-
- # Integrations must be imported before ML frameworks:
- # isort: off
- from transformers.integrations import (
- get_reporting_integration_callbacks,
- hp_params,
- )
-
- # isort: on
-
- import huggingface_hub.utils as hf_hub_utils
- import numpy as np
- import torch
- import torch.distributed as dist
- from huggingface_hub import ModelCard, create_repo, upload_folder
- from packaging import version
- from torch import nn
- from torch.utils.data import DataLoader, Dataset, RandomSampler, SequentialSampler
-
- from transformers import Trainer
- from transformers import __version__
- from transformers.configuration_utils import PretrainedConfig
- from transformers.data.data_collator import DataCollator, DataCollatorWithPadding, default_data_collator
- from transformers.debug_utils import DebugOption, DebugUnderflowOverflow
- from transformers.hyperparameter_search import ALL_HYPERPARAMETER_SEARCH_BACKENDS, default_hp_search_backend
- from transformers.integrations.deepspeed import deepspeed_init, deepspeed_load_checkpoint, is_deepspeed_available
- from transformers.integrations.tpu import tpu_spmd_dataloader
- from transformers.modelcard import TrainingSummary
- from transformers.modeling_utils import PreTrainedModel, load_sharded_checkpoint, unwrap_model
- from transformers.models.auto.modeling_auto import (
- MODEL_FOR_CAUSAL_LM_MAPPING_NAMES,
- MODEL_MAPPING_NAMES,
- )
- from transformers.optimization import Adafactor, get_scheduler
- from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS, is_torch_greater_or_equal_than_1_13
- from transformers.tokenization_utils_base import PreTrainedTokenizerBase
- from transformers.trainer_callback import (
- CallbackHandler,
- DefaultFlowCallback,
- PrinterCallback,
- ProgressCallback,
- TrainerCallback,
- TrainerControl,
- TrainerState,
- )
- from transformers.trainer_pt_utils import (
- DistributedTensorGatherer,
- IterableDatasetShard,
- LabelSmoother,
- LayerWiseDummyOptimizer,
- LengthGroupedSampler,
- SequentialDistributedSampler,
- distributed_broadcast_scalars,
- distributed_concat,
- find_batch_size,
- get_dataloader_sampler,
- get_model_param_count,
- get_module_class_from_name,
- get_parameter_names,
- nested_concat,
- nested_detach,
- nested_numpify,
- nested_xla_mesh_reduce,
- reissue_pt_warnings,
- remove_dummy_checkpoint,
- )
- from transformers.trainer_utils import (
- PREFIX_CHECKPOINT_DIR,
- BestRun,
- EvalLoopOutput,
- EvalPrediction,
- HPSearchBackend,
- HubStrategy,
- IntervalStrategy,
- PredictionOutput,
- RemoveColumnsCollator,
- TrainerMemoryTracker,
- TrainOutput,
- check_target_module_exists,
- default_compute_objective,
- denumpify_detensorize,
- enable_full_determinism,
- find_executable_batch_size,
- get_last_checkpoint,
- has_length,
- neftune_post_forward_hook,
- number_of_arguments,
- seed_worker,
- set_seed,
- speed_metrics,
- )
- from transformers.training_args import OptimizerNames, ParallelMode, TrainingArguments
- from transformers.utils import (
- ADAPTER_CONFIG_NAME,
- ADAPTER_SAFE_WEIGHTS_NAME,
- ADAPTER_WEIGHTS_NAME,
- CONFIG_NAME,
- SAFE_WEIGHTS_INDEX_NAME,
- SAFE_WEIGHTS_NAME,
- WEIGHTS_INDEX_NAME,
- WEIGHTS_NAME,
- PushInProgress,
- PushToHubMixin,
- can_return_loss,
- find_labels,
- is_accelerate_available,
- is_apex_available,
- is_bitsandbytes_available,
- is_datasets_available,
- is_galore_torch_available,
- is_in_notebook,
- is_ipex_available,
- is_peft_available,
- is_safetensors_available,
- is_sagemaker_dp_enabled,
- is_sagemaker_mp_enabled,
- is_torch_compile_available,
- is_torch_neuroncore_available,
- is_torch_npu_available,
- is_torch_xla_available,
- logging,
- strtobool,
- )
- from transformers.utils.quantization_config import QuantizationMethod
-
-
- DEFAULT_CALLBACKS = [DefaultFlowCallback]
- DEFAULT_PROGRESS_CALLBACK = ProgressCallback
-
- if is_in_notebook():
- from transformers.utils.notebook import NotebookProgressCallback
-
- DEFAULT_PROGRESS_CALLBACK = NotebookProgressCallback
-
- if is_apex_available():
- from apex import amp
-
- if is_datasets_available():
- import datasets
-
- if is_torch_xla_available():
- import torch_xla.core.xla_model as xm
- import torch_xla.debug.metrics as met
- import torch_xla.distributed.spmd as xs
- import torch_xla.runtime as xr
-
-
- if is_sagemaker_mp_enabled():
- import smdistributed.modelparallel.torch as smp
- from smdistributed.modelparallel import __version__ as SMP_VERSION
-
- IS_SAGEMAKER_MP_POST_1_10 = version.parse(SMP_VERSION) >= version.parse("1.10")
-
- from transformers.trainer_pt_utils import smp_forward_backward, smp_forward_only, smp_gather, smp_nested_concat
- else:
- IS_SAGEMAKER_MP_POST_1_10 = False
-
-
- if is_safetensors_available():
- import safetensors.torch
-
- if is_peft_available():
- from peft import PeftModel
-
-
- if is_accelerate_available():
- from accelerate import Accelerator, skip_first_batches
- from accelerate import __version__ as accelerate_version
- from accelerate.utils import (
- DistributedDataParallelKwargs,
- DistributedType,
- GradientAccumulationPlugin,
- load_fsdp_model,
- load_fsdp_optimizer,
- save_fsdp_model,
- save_fsdp_optimizer,
- )
-
- DATA_SAMPLERS = [RandomSampler]
- if version.parse(accelerate_version) > version.parse("0.23.0"):
- from accelerate.data_loader import SeedableRandomSampler
-
- DATA_SAMPLERS += [SeedableRandomSampler]
-
- if is_deepspeed_available():
- from accelerate.utils import DeepSpeedSchedulerWrapper
-
-
- def _is_peft_model(model):
- if is_peft_available():
- classes_to_check = (PeftModel,) if is_peft_available() else ()
- # Here we also check if the model is an instance of `PeftMixedModel` introduced in peft>=0.7.0: https://github.com/huggingface/transformers/pull/28321
- if version.parse(importlib.metadata.version("peft")) >= version.parse("0.7.0"):
- from peft import PeftMixedModel
-
- classes_to_check = (*classes_to_check, PeftMixedModel)
- return isinstance(model, classes_to_check)
- return False
-
-
- def _get_fsdp_ckpt_kwargs():
- # TODO: @AjayP13, @younesbelkada replace this check with version check at the next `accelerate` release
- if is_accelerate_available() and "adapter_only" in list(inspect.signature(save_fsdp_model).parameters):
- return {"adapter_only": True}
- else:
- return {}
-
-
- if TYPE_CHECKING:
- import optuna
-
-
- logger = logging.get_logger(__name__)
-
-
- # Name of the files used for checkpointing
- TRAINING_ARGS_NAME = "training_args.bin"
- TRAINER_STATE_NAME = "trainer_state.json"
- OPTIMIZER_NAME = "optimizer.pt"
- OPTIMIZER_NAME_BIN = "optimizer.bin"
- SCHEDULER_NAME = "scheduler.pt"
- SCALER_NAME = "scaler.pt"
- FSDP_MODEL_NAME = "pytorch_model_fsdp"
-
-
- class CustomTrainer(Trainer):
- def save_model(self, output_dir: Optional[str] = None, _internal_call: bool = False):
- """
- Will save the model, so you can reload it using `from_pretrained()`.
-
- Will only save from the main process.
- """
-
- if output_dir is None:
- output_dir = self.args.output_dir
-
- print(output_dir)
- if is_torch_xla_available():
- self._save_tpu(output_dir)
- elif is_sagemaker_mp_enabled():
- # Calling the state_dict needs to be done on the wrapped model and on all processes.
- os.makedirs(output_dir, exist_ok=True)
- state_dict = self.model_wrapped.state_dict()
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- if IS_SAGEMAKER_MP_POST_1_10:
- # 'user_content.pt' indicates model state_dict saved with smp >= 1.10
- Path(os.path.join(output_dir, "user_content.pt")).touch()
- elif self.is_fsdp_enabled:
- if ("FULL_STATE_DICT" in str(self.accelerator.state.fsdp_plugin.state_dict_type)) and (
- version.parse(accelerate_version) > version.parse("0.24.1")
- ):
- state_dict = self.accelerator.get_state_dict(self.model)
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- elif self.is_deepspeed_enabled:
- try:
- state_dict = self.accelerator.get_state_dict(self.deepspeed)
- if self.args.should_save:
- self._save(output_dir, state_dict=state_dict)
- except ValueError:
- logger.warning(
- " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead, use"
- " zero_to_fp32.py to recover weights"
- )
- if self.args.should_save:
- self._save(output_dir, state_dict={})
- # remove the dummy state_dict
- remove_dummy_checkpoint(self.args.should_save, output_dir, [WEIGHTS_NAME, SAFE_WEIGHTS_NAME])
- self.model_wrapped.save_checkpoint(output_dir)
-
- elif self.args.should_save:
- print("start save")
- self._save(output_dir)
-
- print("save end")
- # Push to the Hub when `save_model` is called by the user.
- if self.args.push_to_hub and not _internal_call:
- self.push_to_hub(commit_message="Model save")
|