from dataclasses import dataclass
from datetime import timedelta
from typing import List, Optional, Tuple
import torch
import torch.distributed as dist
from accelerate import (
    Accelerator,
    DistributedDataParallelKwargs,
    InitProcessGroupKwargs,
)
from torch import nn
from torch.utils.data import Dataset
from recis.framework.checkpoint_manager import CheckpointManager, Saver
from recis.framework.metrics import get_global_metrics
from recis.hooks import Hook, LoggerHook
from recis.optim import sparse_optim
from recis.utils.data_utils import copy_data_to_device
from recis.utils.logger import Logger
logger = Logger(__name__)
[docs]
@dataclass
class TrainingArguments:
    """Configuration class for training parameters.
    This dataclass contains all the configuration parameters needed for training,
    including optimization settings, logging intervals, and checkpoint management.
    Attributes:
        gradient_accumulation_steps (int): Number of steps to accumulate gradients
                                         before performing an optimizer step. Defaults to 1.
        output_dir (str): Directory where checkpoints and logs will be saved.
                         Defaults to "output_dir".
        model_bank (Optional[list]): List of model bank paths for initialization.
                                   Defaults to None.
        log_steps (int): Number of training steps between logging. Defaults to 100.
        train_steps (Optional[int]): Maximum number of training steps. If None,
                                   will train for full epochs. Defaults to None.
        train_epoch (Optional[int]): Number of training epochs. Defaults to 1.
        eval_steps (Optional[int]): Number of evaluation steps. If None, evaluates
                                  on full dataset. Defaults to None.
        save_steps (Optional[int]): Number of steps between checkpoint saves.
                                  Defaults to 1000.
        max_to_keep (int): Maximum number of checkpoints to keep. Defaults to 5.
        save_concurrency_per_rank (int): Number of concurrent save operations per rank.
                                        Defaults to 4.
    """
    gradient_accumulation_steps: int = 1
    output_dir: str = "output_dir"
    model_bank: Optional[list] = None
    log_steps: int = 100
    train_steps: Optional[int] = None
    train_epoch: Optional[int] = 1
    eval_steps: Optional[int] = None
    save_steps: Optional[int] = 1000
    max_to_keep: int = 5
    save_concurrency_per_rank: int = 4 
[docs]
class Trainer:
    """Main training orchestrator with distributed training and checkpoint management.
    The Trainer class provides a comprehensive training framework that handles:
    - Distributed training coordination using Accelerate
    - Automatic checkpoint saving and loading
    - Training and evaluation loops with metrics tracking
    - Hook system for extensible training workflows
    - Support for both dense and sparse optimizers
    Attributes:
        args (TrainingArguments): Training configuration parameters.
        hooks (List[Hook]): List of training hooks for extensibility.
        train_dataset (Optional[Dataset]): Training dataset.
        eval_dataset (Optional[Dataset]): Evaluation dataset.
        model (nn.Module): The model to train.
        dense_optimizer (torch.optim.Optimizer): Dense parameter optimizer.
        dense_lr_scheduler: Learning rate scheduler for dense optimizer.
        sparse_optimizer (Optional[sparse_optim.SparseOptimizer]): Sparse parameter optimizer.
        data_to_cuda (bool): Whether to automatically move data to CUDA.
        accelerator (Accelerator): Accelerate instance for distributed training.
        checkpoint_manager (CheckpointManager): Handles checkpoint operations.
    Example:
    .. code-block:: python
        from recis.framework import Trainer, TrainingArguments
        from recis.optim import SparseAdamW
        from torch.optim import AdamW
        # Set training arguments
        training_args = TrainingArguments(
            output_dir="./checkpoints",
            train_steps=10000,
            eval_steps=1000,
            save_steps=2000,
            log_steps=100,
            gradient_accumulation_steps=4,
        )
        # split sparse params
        from recis.nn.modules.hashtable import filter_out_sparse_param
        sparse_params = filter_out_sparse_param(model)
        # create optimizers
        sparse_optimizer = SparseAdamW(sparse_params, lr=0.001)
        dense_optimizer = AdamW(model.parameters(), lr=0.001)
        # create trainer
        trainer = Trainer(
            model=model,
            args=training_args,
            train_dataset=train_dataset,
            eval_dataset=eval_dataset,
            dense_optimizers=(dense_optimizer, None),
            sparse_optimizer=sparse_optimizer,
            data_to_cuda=True,
        )
        # train the model
        trainer.train()
    """
[docs]
    def __init__(
        self,
        model: Optional[nn.Module] = None,
        args: TrainingArguments = None,
        train_dataset: Optional[Dataset] = None,
        eval_dataset: Optional[Dataset] = None,
        hooks: Optional[List[Hook]] = None,
        dense_optimizers: Tuple[
            torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR
        ] = (None, None),
        sparse_optimizer: Optional[sparse_optim.SparseOptimizer] = None,
        data_to_cuda: bool = False,
        **kwargs,
    ) -> None:
        """Initialize the Trainer with model, datasets, and training configuration.
        Args:
            model (Optional[nn.Module]): The model to train.
            args (TrainingArguments): Training configuration. If None, uses default.
            train_dataset (Optional[Dataset]): Training dataset.
            eval_dataset (Optional[Dataset]): Evaluation dataset.
            hooks (Optional[List[Hook]]): List of training hooks for extensibility.
            dense_optimizers (Tuple): Tuple of (optimizer, lr_scheduler) for dense parameters.
            sparse_optimizer (Optional[sparse_optim.SparseOptimizer]): Optimizer for sparse parameters.
            data_to_cuda (bool): Whether to automatically move data to CUDA. Defaults to False.
            **kwargs: Additional arguments passed to Accelerator.
        """
        if hooks is None:
            hooks = []
        if args is None:
            args = TrainingArguments()
        self.args = args
        self.hooks = hooks
        self.train_dataset = train_dataset
        self.eval_dataset = eval_dataset
        self.model = model
        self.dense_optimizer = dense_optimizers[0]
        self.dense_lr_scheduler = dense_optimizers[1]
        self.sparse_optimizer = sparse_optimizer
        self.data_to_cuda = data_to_cuda
        ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
        init_kwargs = InitProcessGroupKwargs(timeout=timedelta(seconds=1800))
        self.accelerator = Accelerator(
            kwargs_handlers=[ddp_kwargs, init_kwargs],
            gradient_accumulation_steps=args.gradient_accumulation_steps,
            **kwargs,
        )
        self.gradient_accumulation_steps = args.gradient_accumulation_steps
        (
            self.model,
            self.dense_optimizer,
            self.dense_lr_scheduler,
        ) = self.accelerator.prepare(
            self.model, self.dense_optimizer, self.dense_lr_scheduler
        )
        if self.sparse_optimizer is not None:
            self.sparse_optimizer.set_grad_accum_steps(self.gradient_accumulation_steps)
        self._global_step = torch.scalar_tensor(0, dtype=torch.int64)
        self.build_checkpoint_manager(model, args)
        self.has_restore = False
        self.hooks.append(LoggerHook(self.args.log_steps))
        self.stop_state = torch.scalar_tensor(0, dtype=torch.int64).cuda() 
    def build_checkpoint_manager(self, model, args):
        saver = self.build_saver(model, args)
        self.checkpoint_manager = CheckpointManager(
            saver=saver, save_interval=args.save_steps
        )
        if self.train_dataset is not None:
            saver.register_io_state("train_io", self.train_dataset)
            if hasattr(self.train_dataset, "_window_paths"):
                saver.register_for_checkpointing("train_window_io", self.train_dataset)
        if self.eval_dataset is not None and hasattr(
            self.eval_dataset, "_window_paths"
        ):
            saver.register_io_state("eval_io", self.eval_dataset)
            if hasattr(self.eval_dataset, "_window_paths"):
                saver.register_for_checkpointing("eval_window_io", self.eval_dataset)
        if self.dense_optimizer is not None:
            saver.register_for_checkpointing("dense_optimizer", self.dense_optimizer)
    def build_saver(self, model, args):
        saver = Saver(
            model,
            self.sparse_optimizer,
            output_dir=args.output_dir,
            max_keep=args.max_to_keep,
            concurrency=args.save_concurrency_per_rank,
        )
        return saver
[docs]
    def add_hooks(self, hooks: List[Hook]):
        """Add multiple hooks to the trainer.
        Args:
            hooks (List[Hook]): List of hooks to add.
        """
        for hook in hooks:
            self.add_hook(hook) 
[docs]
    def add_hook(self, hook: Hook):
        """Add a single hook to the trainer.
        Args:
            hook (Hook): The hook to add.
        """
        self.hooks.append(hook) 
[docs]
    def restore(self):
        """Restore model and training state from checkpoints.
        This method loads the latest checkpoint if available and restores
        the model state, optimizer state, and training progress.
        """
        if not self.has_restore:
            self.checkpoint_manager.load_model_bank(self.args.model_bank)
            self._global_step = self.checkpoint_manager.restore()
            self.has_restore = True 
[docs]
    def train(self, train_steps=None):
        """Execute the training loop.
        Args:
            train_steps (Optional[int]): Override for number of training steps.
                                       If None, uses args.train_steps.
        """
        self.restore()
        for epoch in range(self.args.train_epoch):
            if hasattr(self.train_dataset, "_window_paths"):
                self.train_dataset.reset()
            self._train_loop(
                self.args.train_steps if train_steps is None else train_steps,
                epoch=epoch,
            )
        for hook in self.hooks:
            hook.end() 
    def train_and_evaluate(self, epochs=1, train_steps=None, eval_steps=None):
        """Execute alternating training and evaluation loops.
        Args:
            epochs (int): Number of epochs to train. Defaults to 1.
            train_steps (Optional[int]): Override for number of training steps per epoch.
            eval_steps (Optional[int]): Override for number of evaluation steps.
        """
        self.restore()
        for epoch in range(epochs):
            self._train_loop(
                self.args.train_steps if train_steps is None else train_steps,
                epoch=epoch,
            )
            self.evaluate(eval_steps=eval_steps)
        for hook in self.hooks:
            hook.end()
[docs]
    def evaluate(self, eval_steps=None):
        """Execute the evaluation loop.
        Args:
            eval_steps (Optional[int]): Override for number of evaluation steps.
                                      If None, evaluates on full dataset.
        """
        self.restore()
        if hasattr(self.eval_dataset, "_window_paths"):
            iterator = self.get_new_window_iter(self.eval_dataset)
        else:
            iterator = iter(self.eval_dataset)
        lstep = 0
        while True:
            if eval_steps is not None and lstep >= eval_steps:
                break
            need_break = False
            stop_flag, data = next(iterator)
            if self.data_to_cuda:
                data = copy_data_to_device(data, "cuda")
            if stop_flag:
                if hasattr(self.eval_dataset, "_window_paths"):
                    iterator = self.get_new_window_iter(self.eval_dataset)
                    if iterator is None:
                        need_break = True
                else:
                    need_break = True
            need_break = self.sync_exit_flag(need_break)
            if need_break:
                break
            metrics = {}
            self.model.eval()
            with torch.no_grad():
                self.model(data)
            metrics.update(get_global_metrics())
            for hook in self.hooks:
                hook.after_step(metrics, self._global_step)
            lstep += 1
        for hook in self.hooks:
            hook.end() 
    def get_new_window_iter(self, dataset):
        if not hasattr(dataset, "_window_paths"):
            raise TypeError("dataset must be window_io")
        while True:
            try:
                need_skip = dataset.next_window()
            except StopIteration:
                logger.info("Window IO Finish")
                return None
            except Exception as e:
                raise e
            read_offset = int(dataset._read_offset[0])
            if need_skip:
                logger.info(f"Skip for window, offset = {read_offset}")
            else:
                logger.info(f"Next window, offset = {read_offset}")
                break
        return iter(dataset)
    def sync_exit_flag(self, flag: bool):
        self.stop_state.fill_(int(flag))
        dist.all_reduce(self.stop_state, op=dist.ReduceOp.MAX)
        return bool(self.stop_state.item())
    def _train_loop(self, max_steps=None, epoch=1):
        self.model.train()
        if hasattr(self.train_dataset, "_window_paths"):
            iterator = self.get_new_window_iter(self.train_dataset)
        else:
            iterator = iter(self.train_dataset)
        lstep = 0
        while True:
            if max_steps is not None and lstep >= max_steps - 1:
                break
            stop_flag, data = next(iterator)
            if self.data_to_cuda:
                data = copy_data_to_device(data, "cuda")
            need_break = False
            if stop_flag:
                if hasattr(self.train_dataset, "_window_paths"):
                    iterator = self.get_new_window_iter(self.train_dataset)
                    if iterator is None:
                        need_break = True
                else:
                    need_break = True
            need_break = self.sync_exit_flag(need_break)
            if need_break:
                break
            metrics = {}
            with self.accelerator.accumulate(self.model):
                self._train_step(data, epoch, metrics)
            for hook in self.hooks:
                hook.after_step(metrics, self._global_step)
            self.checkpoint_manager.step()
            lstep += 1
        self.checkpoint_manager.save()
    def _train_step(self, data, epoch, metrics):
        self.dense_optimizer.zero_grad()
        if self.sparse_optimizer is not None:
            self.sparse_optimizer.zero_grad()
        loss = self.model(data)
        metrics.update(epoch=epoch)
        metrics.update(loss=loss)
        metrics.update(get_global_metrics())
        loss.backward()
        self.dense_optimizer.step()
        if self.sparse_optimizer is not None:
            self.sparse_optimizer.step()
        if self.dense_lr_scheduler is not None:
            self.dense_lr_scheduler.step()