Source code for recis.framework.model_bank

import fnmatch
import os
import re
from copy import deepcopy
from dataclasses import dataclass, field, fields
from typing import Any, Dict, Optional, Set

import torch
from safetensors.torch import load_file

from recis.framework.filesystem import get_file_system
from recis.serialize.checkpoint_reader import CheckpointReader
from recis.utils.logger import Logger
from recis.utils.mos import Mos


logger = Logger(__name__)
tag = "[ModelBank]"
for level in ("info", "warning", "error"):
    old_func = getattr(logger, level)
    setattr(
        logger,
        level,
        lambda msg, *args, _old=old_func, **kwargs: _old(
            f"{tag} {msg}", *args, **kwargs
        ),
    )


@dataclass
class ModelBankEntry:
    path: str = field(default="")
    load: Set[str] = field(default_factory=lambda: {"*"})
    exclude: Set[str] = field(default_factory=set)

    is_dynamic: bool = False
    hashtable_clear: bool = True
    ignore_error: bool = False
    skip: bool = False
    oname: list[dict] = field(default_factory=list)

    @classmethod
    def from_dict(cls, d: Dict[str, Any]) -> "ModelBankEntry":
        if "path" not in d:
            raise ValueError("Missing required field: 'path'")

        allowed_keys = {f.name for f in fields(cls)}
        filtered_data = {k: v for k, v in d.items() if k in allowed_keys}

        return cls(**filtered_data)

    def __post_init__(self):
        if self.skip:
            logger.warning(f"'skip' is True, skip this model bank: {self.path}.")
            return

        if not isinstance(self.path, str):
            raise TypeError(f"'path' must be a string, got {type(self.path).__name__}")
        if not self.path.strip():
            raise RuntimeError("'path' is empty, not load any model.")

        if isinstance(self.load, list):
            object.__setattr__(self, "load", set(self.load))
        if isinstance(self.exclude, list):
            object.__setattr__(self, "exclude", set(self.exclude))

        if not isinstance(self.load, set):
            raise TypeError(f"'load' must be a set, got {type(self.load).__name__}")
        if not isinstance(self.exclude, set):
            raise TypeError(
                f"'exclude' must be a set, got {type(self.exclude).__name__}"
            )

        if not isinstance(self.hashtable_clear, bool):
            raise TypeError(
                f"'hashtable_clear' must be a bool, got {type(self.hashtable_clear).__name__}"
            )

        if not isinstance(self.is_dynamic, bool):
            raise TypeError(
                f"'is_dynamic' must be a bool, got {type(self.is_dynamic).__name__}"
            )

        if not isinstance(self.ignore_error, bool):
            raise TypeError(
                f"'ignore_error' must be a bool, got {type(self.ignore_error).__name__}"
            )

        if not isinstance(self.oname, list):
            raise TypeError(f"'oname' must be a list, got {type(self.oname).__name__}")
        if not all(isinstance(item, dict) for item in self.oname):
            raise TypeError(
                f"'oname' must be a list of dictionaries, got {type(self.oname).__name__}"
            )


class DensePatternMatcher:
    def __init__(self):
        self.regex_cache = {}

    def _get_regex(self, pattern: str):
        if pattern not in self.regex_cache:
            escaped_pattern = pattern.replace(".", r"\.").replace("?", r"\?")
            regex_pattern = "^" + escaped_pattern.replace("*", "(.*)") + "$"
            self.regex_cache[pattern] = re.compile(regex_pattern)
        return self.regex_cache[pattern]

    def apply_mapping(self, key: str, oname_rules: list) -> Optional[str]:
        """
        mapping key from source to target

        Args:
            key: source key
            oname_rules: oname rules list, each rule is a dictionary {pattern: replacement}

        Returns:
            mapped key, if not mapped, return None
        """
        for rule in oname_rules:
            for pattern, replacement in rule.items():
                if fnmatch.fnmatch(key, pattern):
                    if "*" in pattern and "*" in replacement:
                        regex = self._get_regex(pattern)
                        match = regex.match(key)
                        if match:
                            captured_groups = match.groups()
                            result = replacement
                            for group in captured_groups:
                                result = result.replace("*", group, 1)
                            return result
                    elif "*" in pattern:
                        regex = self._get_regex(pattern)
                        match = regex.match(key)
                        if match:
                            captured_groups = match.groups()
                            if "*" in replacement:
                                pattern_prefix = pattern.split("*")[0]
                                replacement_prefix = replacement.split("*")[0]
                                if (
                                    pattern_prefix
                                    and replacement_prefix
                                    and key.startswith(pattern_prefix)
                                ):
                                    suffix = key[len(pattern_prefix) :]
                                    return replacement_prefix + suffix
                            elif pattern.endswith("*") and not replacement.endswith(
                                "*"
                            ):
                                prefix = pattern.replace("*", "")
                                if key.startswith(prefix):
                                    suffix = key[len(prefix) :]
                                    return replacement + suffix
                            return replacement
                    else:
                        if key == pattern:
                            return replacement

        return None


class MBC:
    PATH = "path"
    LOAD = "load"
    EXCLUDE = "exclude"
    IS_DYNAMIC = "is_dynamic"
    HASHTABLE_CLEAR = "hashtable_clear"
    ONAME = "oname"
    SYMBOL_ALL = "*"
    SYMBOL_EMPTY = ""
    SPECIFIC = "specific"
    COMMON = "common"
    FINAL = "final"
    VARIABLE = "variable"
    IGNORE_ERROR = "ignore_error"


def maybe_get_latest_version(path, force_sub_version=False):
    ckpt_id = None
    fs = get_file_system(path)
    if fs.isdir(path):
        if fs.exists(os.path.join(path, "checkpoint")):
            content = fs.open(os.path.join(path, "checkpoint"), "r").read()
            versions = content.split("\n")[::-1]
            for version in versions:
                if len(version) == 0:
                    continue
                ckpt_id = version.strip()
                break
        logger.warning(f"Get latest checkpoint version {ckpt_id} from path {path}.")
    else:
        logger.warning(f"Checkpoint not found in path: {path}")
    if ckpt_id is not None:
        real_path = os.path.join(path, ckpt_id)
    else:
        real_path = path
        if force_sub_version:
            real_path = ""
    logger.warning(f"Get real ckpt path {real_path} from {path}")
    return real_path


def get_update_path(path, is_bank=True) -> str:
    if len(path) == 0:
        logger.warning("get_update_path: path is empty")
        return ""

    if path.startswith("model."):
        mos = Mos(path, is_bank)
        path = mos.real_physical_path
    path = maybe_get_latest_version(path, (not is_bank))
    return path


def show_model_bank_format(name: str, model_bank):
    if len(model_bank) == 0:
        logger.warning(f"No {name} model bank to show")
        return

    res = f"============= {name} =============\n"
    all_names = []
    all_dyn_strs = []
    all_clear_strs = []
    all_oname_strs = []
    for tensors in model_bank.values():
        for name, meta in tensors.items():
            all_names.append(name)
            all_dyn_strs.append(str(meta.get("is_dynamic", "")))
            all_clear_strs.append(str(meta.get("hashtable_clear", "")))
            all_oname_strs.append(str(meta.get("oname", "")))

    name_width = max([len(n) for n in all_names] + [len("Tensor Name")])
    dyn_width = max([len(s) for s in all_dyn_strs] + [len("is_dynamic")])
    clear_width = max([len(s) for s in all_clear_strs] + [len("hashtable_clear")])
    oname_width = max([len(s) for s in all_oname_strs] + [len("oname")])
    header = (
        f"{'Tensor Name'.ljust(name_width)}  "
        f"{'is_dynamic'.ljust(dyn_width)}  "
        f"{'hashtable_clear'.ljust(clear_width)}  "
        f"{'oname'.ljust(oname_width)}"
    )
    sep_line = "-" * len(header)

    for path, tensors in model_bank.items():
        res += f"Checkpoint: {path}\n"
        res += "=" * len(header) + "\n"
        res += header + "\n"
        res += sep_line + "\n"
        for name in sorted(tensors, key=lambda x: ("@" not in x, x)):
            meta = tensors[name]
            dyn = str(meta.get("is_dynamic", ""))
            clear = str(meta.get("hashtable_clear", ""))
            oname = str(meta.get("oname", ""))
            res += f"{name.ljust(name_width)}  {dyn.ljust(dyn_width)}  {clear.ljust(clear_width)}  {oname.ljust(oname_width)}"
            res += "\n"
        res += "=" * len(header) + "\n"
        res += "\n"
    logger.info(res)


def raise_error(core_text: str, message: str, ignore_error: bool):
    if "*" in core_text:
        logger.warning(message)
    else:
        if not ignore_error:
            raise ValueError(message)
        else:
            logger.warning(message)


def get_match_by_pattern(pattern: str, var_list: set[str]):
    """
    pattern:
        * -> all variables
        model.var_* -> variables starting with model.var_
        model.var_1, model.var_2 -> model.var_1 and model.var_2
    """
    if pattern == MBC.SYMBOL_ALL:
        return var_list
    elif MBC.SYMBOL_EMPTY in pattern and len(pattern) > 1:
        return {var for var in var_list if fnmatch.fnmatch(var, pattern)}
    elif pattern in var_list:
        return {pattern}
    raise ValueError(f"Bad pattern: {pattern} couldn't match any variable")


def load_pt_file(ckpt_dir: str, file_name: str):
    pt_path = os.path.join(ckpt_dir, file_name + ".pt")
    safe_path = os.path.join(ckpt_dir, file_name + ".safetensors")
    fs = get_file_system(os.path.join(ckpt_dir, "index"))
    data = {}
    if fs.exists(pt_path):
        with fs.open(pt_path, "rb") as f:
            data = torch.load(f=f)
    elif fs.exists(safe_path):
        data = load_file(safe_path)
    return data


def parse_sparse_oname(
    onames: list,
    src_names: set[str],
    dst_names: set[str],
    ignore_error: bool,
    oname_success: list,
) -> dict:
    sparse_oname = {}
    for idx, oname in enumerate(onames):
        src_table = next(iter(oname.keys()))
        dst_table = next(iter(oname.values()))

        matched_src_names = get_match_by_pattern(src_table, src_names)
        if not matched_src_names:
            raise_error(
                src_table,
                f"[sparse_oname] Bad oname, src table {src_table} not found in src_names",
                True,
            )
            continue

        matched_dst_names = get_match_by_pattern(dst_table, dst_names)
        if not matched_dst_names:
            raise_error(
                dst_table,
                f"[sparse_oname] Bad oname, dst table {dst_table} not found in dst_names",
                True,
            )
            continue

        if len(matched_dst_names) != len(matched_src_names):
            raise_error(
                "",
                f"[sparse_oname] Bad oname, Dst table {matched_dst_names} has different number of variables than src table {matched_src_names}",
                False,
            )
            continue

        src_table_name = src_table.split("@")[0].rsplit("*", 1)[0]
        dst_table_name = dst_table.split("@")[0].rsplit("*", 1)[0]

        oname_success[idx] = 1
        for src_name in matched_src_names:
            dst_name = src_name.replace(src_table_name, dst_table_name)
            if dst_name not in dst_names:
                raise_error(
                    src_table,
                    f"[sparse_oname] Bad oname, Dst name {dst_name} not found in dst_names",
                    ignore_error,
                )
                continue
            sparse_oname[src_name] = dst_name

    return sparse_oname


def apply_oname_mapping(
    pattern_matcher: DensePatternMatcher, key: str, oname_rules: list
) -> Optional[str]:
    """
    mapping key from source to target (use cached PatternMatcher)

    Args:
        key: source key
        oname_rules: oname rules list, each rule is a dictionary {pattern: replacement}

    Returns:
        mapped key, if not mapped, return None
    """
    return pattern_matcher.apply_mapping(key, oname_rules)


def parse_dense_oname(
    pattern_matcher: DensePatternMatcher,
    oname: list,
    src_keys: set[str],
    dst_keys: set[str],
    ignore_error: bool,
    oname_success: list,
) -> dict:
    """
    mapping key from source to target model

    Optimize:
    - convert dst_keys to set, make lookup from O(n) to O(1)
    - use PatternMatcher to cache regex, avoid duplicate compilation

    Args:
        src_keys: source model state dict keys
        dst_keys: target model state dict keys
        oname: oname rules dict, format: {"oname": [{"pattern": "replacement"}, ...]}

    """
    dense_oname = {}
    oname_rules = oname
    dst_keys_set = set(dst_keys)

    for key in src_keys:
        if "@" in key:
            continue
        mapped_key = None
        for idx, rule in enumerate(oname_rules):
            for pattern in rule.keys():
                if fnmatch.fnmatch(key, pattern):
                    candidate = apply_oname_mapping(pattern_matcher, key, [rule])
                    if candidate and candidate in dst_keys_set:  # O(1)
                        mapped_key = candidate
                        break
            if mapped_key:
                oname_success[idx] = 1
                break

        if not mapped_key:
            mapped_key = apply_oname_mapping(pattern_matcher, key, oname_rules)

        if mapped_key:
            if mapped_key in dst_keys_set:
                dense_oname[key] = mapped_key
                logger.warning(f"[dense_oname] T {key} <- {mapped_key} (from dst_sd)")
            else:
                raise_error(
                    key,
                    f"[dense_oname] F {key} -> {mapped_key} (not found in dst_sd)",
                    True,
                )

    return dense_oname


def parse_oname(
    dense_pattern_matcher: DensePatternMatcher,
    oname: list,
    src_sparse_names: set[str],
    dst_sparse_names: set[str],
    src_dense_names: set[str],
    dst_dense_names: set[str],
    ignore_error: bool,
):
    oname_success = [0 for _ in range(len(oname))]
    dense_oname = parse_dense_oname(
        dense_pattern_matcher,
        oname,
        src_dense_names,
        dst_dense_names,
        ignore_error,
        oname_success,
    )
    sparse_oname = parse_sparse_oname(
        oname,
        src_sparse_names,
        dst_sparse_names,
        ignore_error,
        oname_success,
    )

    for idx, success in enumerate(oname_success):
        if success == 0:
            raise_error(
                next(iter(oname[idx].keys())),
                f"Oname {oname[idx]} failed",
                ignore_error,
            )

    return dense_oname, sparse_oname


[docs] class ModelBankParser:
[docs] def __init__( self, output_dir: str, model_bank_content: list[Dict[str, Any]], model_names: set[str], sparse_model_names: set[str], sparse_tables: set[str], dense_model_names: set[str], extra_fields, ): self._output_dir = output_dir self._model_bank_content = model_bank_content self._extra_fields = extra_fields self._original_model_names = deepcopy(model_names) self._original_dense_model_names = deepcopy(dense_model_names) self._original_sparse_model_names = deepcopy(sparse_model_names) self._original_sparse_tables = deepcopy(sparse_tables) self._dense_oname = {} self._sparse_oname = {} self._dense_pattern_matcher = DensePatternMatcher() self._reset_work_state() logger.warning("checking model bank...") self._is_model_bank_valid()
def _reset_work_state(self): self._model_names = deepcopy(self._original_model_names) self._dense_model_names = deepcopy(self._original_dense_model_names) self._sparse_model_names = deepcopy(self._original_sparse_model_names) self._sparse_tables = deepcopy(self._original_sparse_tables) self._dense_oname = {} self._sparse_oname = {} def _is_load_valid(self): for bank in self._model_bank_content: for name in bank[MBC.LOAD]: if "*" not in name and name not in self._model_names: raise_error( name, f"Variable {name} not found in model names", False, ) def has_bank(self): return len(self._model_bank) > 0 def _is_model_bank_valid(self): self._complete_model_bank() self._is_load_valid() self._model_bank = [ ModelBankEntry.from_dict(bank) for bank in self._model_bank_content if not bank.get("skip", False) ] self._complete_sparse_name() self._replace_io_fields() def _replace_io_fields(self): for bank in self._model_bank: if self._extra_fields.io_state in bank.load: bank.load.discard(self._extra_fields.io_state) bank.load.update(self._extra_fields.get_io_fields()) if self._extra_fields.io_state in bank.exclude: bank.exclude.discard(self._extra_fields.io_state) bank.exclude.update(self._extra_fields.get_io_fields()) def _get_dst_names(self, path: str): """read index file, model file, extra file to get dst vars""" sparse_names = set() dense_names = set() extra_names = set() ckpt_path = path ckpt_path = get_update_path(path) if ckpt_path == "": raise RuntimeError(f"No update path found in {path}") logger.info(f"get ckpt names from ckpt_path: {ckpt_path}") fs = get_file_system(os.path.join(ckpt_path, "index")) reader = CheckpointReader(ckpt_path) sparse_names.update(reader.tensor_names()) if fs.exists(os.path.join(ckpt_path, "model.pt")) or fs.exists( os.path.join(ckpt_path, "model.safetensors") ): data = load_pt_file(ckpt_path, "model") dense_names.update(data.keys()) else: logger.warning(f"Dense model file not found in {ckpt_path}") if fs.exists(os.path.join(ckpt_path, "extra.pt")) or fs.exists( os.path.join(ckpt_path, "extra.safetensors") ): data = load_pt_file(ckpt_path, "extra") extra_names.update(data.keys()) if self._extra_fields.prev_optim in data: extra_names.discard(self._extra_fields.prev_optim) extra_names.add(self._extra_fields.recis_dense_optim) else: logger.warning(f"Extra model file not found in {ckpt_path}") if fs.exists(os.path.join(ckpt_path, "io_state_0.pt")): extra_names.update(self._extra_fields.get_io_fields()) return sparse_names, dense_names, extra_names def get_sparse_oname(self) -> dict: return self._sparse_oname def get_dense_oname(self) -> dict: return self._dense_oname def _check_dst_valid( self, name: str, bank_load: set[str], dst_names: set[str], sparse_oname: dict, dense_oname: dict, path: str, ignore_error: bool, ): cond_1 = name in dst_names cond_2 = sparse_oname.get(name, name) in dst_names cond_3 = dense_oname.get(name, name) in dst_names if not (cond_1 or cond_2 or cond_3): if name in bank_load: raise_error( name, f"No var {name} found in dst_names, ckpt path: {path}", ignore_error, ) else: raise_error( name, f"No var {name} found in dst_names, ckpt path: {path}", True, ) return cond_1 or cond_2 or cond_3 def _get_names_set(self, names: Set[str]) -> set[str]: data = set() for name in names: data.update(get_match_by_pattern(name, self._model_names)) return data def _add_dense_optim_names(self, names: set[str]): """ if add dense modules, add recis.dense.optim to names automatically """ has_dense_module = False for name in names: if name in self._dense_model_names: has_dense_module = True break if has_dense_module: names.add(self._extra_fields.recis_dense_optim) def _travel_model_bank_reversely(self, model_bank: list[ModelBankEntry]): var_dict = {} for bank in reversed(model_bank): if len(self._model_names) == 0: logger.warning("all variables are loaded, break parse model bank.") break path = bank.path dst_sparse_names, dst_dense_names, extra_names = self._get_dst_names(path) dst_names = dst_sparse_names | dst_dense_names | extra_names if len(dst_names) == 0: logger.warning(f"No dst vars found in ckpt: {path}") continue exclude_names_set = self._get_names_set(bank.exclude) load_names_set = self._get_names_set(bank.load) self._add_dense_optim_names(load_names_set) need_load_names = load_names_set - exclude_names_set if len(need_load_names) == 0: logger.warning( f"No need to load vars in {path} because all vars are excluded" ) continue # parse oname oname = bank.oname dense_oname, sparse_oname = parse_oname( self._dense_pattern_matcher, oname, {k for k in self._sparse_model_names if k in need_load_names}, dst_sparse_names, {k for k in self._dense_model_names if k in need_load_names}, dst_dense_names, bank.ignore_error, ) for name in need_load_names: # check if the variable is in the ckpt list add_var = self._check_dst_valid( name, bank.load, dst_names, sparse_oname, dense_oname, path, bank.ignore_error, ) if add_var: var_dict.setdefault(name, {}).update( { MBC.LOAD: path, MBC.IS_DYNAMIC: bank.is_dynamic, MBC.HASHTABLE_CLEAR: bank.hashtable_clear, MBC.IGNORE_ERROR: bank.ignore_error, } ) self._model_names.discard(name) self._dense_oname.setdefault(path, {}).update(dense_oname) self._sparse_oname.setdefault(path, {}).update(sparse_oname) return var_dict def parse_all_model_bank(self): logger.info("parse all model bank") self._reset_work_state() return self._get_parse_result(self._model_bank) def parse_dynamic_model_bank(self): logger.info("parse dynamic model bank") self._reset_work_state() dynamic_model_bank = [] for bank in self._model_bank: if bank.is_dynamic is True: dynamic_model_bank.append(bank) return self._get_parse_result(dynamic_model_bank) def _get_parse_result(self, model_bank: list[ModelBankEntry]): var_dict = self._travel_model_bank_reversely(model_bank) return self._combine_bank_by_path(var_dict) def _combine_bank_by_path(self, var_dict: dict): path_dict = {} for var in var_dict: path = var_dict[var][MBC.LOAD] if path not in path_dict: path_dict[path] = {} path_dict[path][var] = { MBC.IS_DYNAMIC: var_dict[var][MBC.IS_DYNAMIC], MBC.IGNORE_ERROR: var_dict[var][MBC.IGNORE_ERROR], } if var in self._sparse_model_names: path_dict[path][var][MBC.HASHTABLE_CLEAR] = var_dict[var][ MBC.HASHTABLE_CLEAR ] if var in self._dense_oname[path]: path_dict[path][var][MBC.ONAME] = self._dense_oname[path][var] if var in self._sparse_oname[path]: path_dict[path][var][MBC.ONAME] = self._sparse_oname[path][var] return path_dict def _complete_sparse_name(self): for bank in self._model_bank: remove_vars = set() added_vars = set() if bank.load: for var in bank.load: if var in self._sparse_tables: remove_vars.add(var) added_vars.add(var + "*") for remove_var in remove_vars: bank.load.discard(remove_var) for add_var in added_vars: bank.load.add(add_var) onames = [] for oname in bank.oname: src, dst = next(iter(oname.items())) if ("*" in src and "*" not in dst) or ("*" not in src and "*" in dst): raise ValueError( f"Bad oname, src {src} and dst {dst} must have the same number of *" ) if "*" not in src and src in self._sparse_tables: onames.append({src + "@*": dst + "@*"}) else: onames.append(oname) bank.oname = onames def _complete_model_bank(self): path = get_update_path(self._output_dir, False) if path != "": self._model_bank_content.append( { MBC.PATH: path, MBC.LOAD: {"*"}, MBC.EXCLUDE: set(), MBC.IS_DYNAMIC: False, MBC.HASHTABLE_CLEAR: True, MBC.IGNORE_ERROR: False, MBC.ONAME: [], } )