Source code for recis.hooks.trace_to_odps_hook

import os
import time
from functools import wraps
from multiprocessing import Process, Queue
from typing import Dict, List, Union

import numpy as np
import pandas as pd
import pyarrow as pa
import torch

from recis.hooks import Hook
from recis.utils.logger import Logger


if not os.environ.get("BUILD_DOCUMENT", None) == "1":
    from odps import ODPS
    from odps.models import Schema
    from odps.tunnel.io.writer import ArrowWriter
    from odps.tunnel.tabletunnel import TableTunnel


logger = Logger(__name__)

TRACE_MAP = {}

rank = int(os.environ.get("RANK", 0))


def retry(retry_count, interval):
    """Decorator for adding retry logic to functions.

    This decorator automatically retries a function if it raises an exception,
    with configurable retry count and interval between attempts.

    Args:
        retry_count (int): Maximum number of retry attempts.
        interval (float): Time interval (in seconds) between retry attempts.

    Returns:
        callable: Decorated function with retry logic.

    Example:
        >>> @retry(retry_count=3, interval=1.0)
        ... def unreliable_function():
        ...     # Function that might fail
        ...     pass
    """

    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            for i in range(retry_count):
                try:
                    return func(*args, **kwargs)
                except Exception as e:
                    if i == retry_count - 1:
                        raise e
                    time.sleep(interval)

        return wrapper

    return decorator


[docs] def add_to_trace(name: str, tensor: Union[torch.Tensor, np.ndarray, list] = None): """Adds data to the trace map for ODPS logging. This function adds training data to the global trace map that will be uploaded to ODPS tables. Supports tensors, numpy arrays, and lists. Args: name (str): Name/key for the data being traced. tensor (Union[torch.Tensor, np.ndarray, list]): Data to be traced. Must be one of the supported types. Raises: ValueError: If the tensor type is not supported. Example: >>> import torch >>> import numpy as np >>> # Add tensor data >>> embeddings = torch.randn(100, 64) >>> add_to_trace("user_embeddings", embeddings) >>> # Add numpy array >>> features = np.random.rand(100, 32) >>> add_to_trace("item_features", features) >>> # Add list data >>> user_ids = [1, 2, 3, 4, 5] >>> add_to_trace("user_ids", user_ids) Note: Tensor data is automatically converted to numpy arrays for compatibility with ODPS. A warning is logged if data with the same name already exists. """ if not isinstance(tensor, (torch.Tensor, np.ndarray, list)): raise ValueError( f"Trace data must be torch.Tensor or np.ndarray or list, not now {type(tensor)}" ) global TRACE_MAP if name in TRACE_MAP: logger.warning(f"Trace data {name} already exists") if isinstance(tensor, torch.Tensor): TRACE_MAP[name] = tensor.detach().cpu().numpy() else: TRACE_MAP[name] = tensor
def get_trace_map(): """Gets the global trace map containing data to be uploaded to ODPS. Returns: Dict: Global trace map containing key-value pairs of data to be traced. """ global TRACE_MAP return TRACE_MAP def clear_trace_map(): """Clears the global trace map. This function is typically called after uploading data to ODPS to prepare for the next batch of trace data. """ global TRACE_MAP TRACE_MAP = {} def patch_flush(self): """Patches the flush method for optimized ODPS uploads. This function modifies the default flush behavior to support chunked uploads for large data transfers to ODPS. Args: self: The writer instance to patch. """ checksum = self._crccrc.getvalue() self._write_unint32(checksum) self._crccrc.reset() chunk_size = 1 << 27 def gen(): # synchronize chunk upload data = self._out.getvalue() while data: to_send = data[:chunk_size] data = data[chunk_size:] yield to_send self._request_callback(gen()) class TraceWriter(Process): """Multiprocess writer for uploading trace data to ODPS tables. The TraceWriter runs as a separate process to handle ODPS uploads without blocking the main training process. It supports buffering, batching, and automatic retry mechanisms for reliable data transfer. Args: config (Dict): ODPS configuration containing access credentials and table info. Required keys: access_id, access_key, project, end_point, table_name. Optional keys: partition. fields (List[str]): List of field names for the ODPS table schema. types (List[str]): List of field types corresponding to the fields. writer_id (int): Unique identifier for this writer process. queue (Queue): Multiprocessing queue for receiving data to write. size_threshold (int): Buffer size threshold in bytes for triggering flushes. Defaults to 50 MiB. Attributes: table_name (str): Name of the ODPS table to write to. fields (List[str]): Field names for the table schema. types (List[str]): Field types for the table schema. partition (str): Partition specification for the table. write_count (int): Total number of rows written. buffer (List[Dict]): Internal buffer for batching data. buffered_size (int): Current size of buffered data in bytes. Example: >>> config = { ... "access_id": "your_access_id", ... "access_key": "your_access_key", ... "project": "your_project", ... "end_point": "your_endpoint", ... "table_name": "training_traces", ... "partition": "dt=20231201", ... } >>> fields = ["user_id", "item_id", "score"] >>> types = ["bigint", "bigint", "double"] >>> queue = Queue() >>> writer = TraceWriter(config, fields, types, 0, queue) >>> writer.start() """ def __init__( self, config: Dict, fields: List[str], types: List[str], writer_id: int, queue: Queue, size_threshold: int = 50 * 1024 * 1024, # 50 MiB ): super().__init__() self._block_id = 0 odps = ODPS( config["access_id"], config["access_key"], config["project"], config["end_point"], ) self.table_name = config["table_name"] self.fields = fields self.types = types self.partition = config.get("partition", None) partitions = [] part_types = [] for s in self.partition.split(","): partitions.append(s.split("=")[0]) part_types.append("string") table = odps.create_table( self.table_name, schema=Schema.from_lists(fields, types, partitions, part_types), if_not_exists=True, lifecycle=365, table_properties={"columnar.nested.type": "true"}, ) table.create_partition(self.partition, if_not_exists=True) self._tunnel_client = TableTunnel(odps) self._writer_session = self._tunnel_client.create_upload_session( table.name, partition_spec=self.partition ) self.write_count = 0 self.write_id = writer_id self._block_id = 0 self.daemon = True self.queue = queue # Buffering self.buffer = [] # list of dicts self.buffered_size = 0 self.size_threshold = size_threshold def run(self) -> None: """Main process loop for handling data writes. Continuously processes data from the queue until a None sentinel value is received, indicating shutdown. """ while True: data = self.queue.get() if data is None: self.flush(force=True) break self.write(data) @retry(retry_count=3, interval=10) def write(self, data: Dict[str, np.ndarray]): """Writes data to the internal buffer and flushes when threshold is reached. Args: data (Dict[str, np.ndarray]): Dictionary mapping field names to data arrays. Note: Data is automatically converted to lists for consistency and buffered until the size threshold is reached, at which point it's flushed to ODPS. """ # Convert all values to lists for consistency for key in data.keys(): if not isinstance(data[key], list): data[key] = data[key].tolist() self.buffer.append(data) # Estimate size using DataFrame df = pd.DataFrame(data) self.buffered_size += df.memory_usage(deep=True).sum() # Check threshold if self.buffered_size >= self.size_threshold: self._flush_buffer() def _flush_buffer(self): """Flushes the internal buffer to ODPS. Merges all buffered data into a single batch and uploads it to ODPS using Arrow format for efficient transfer. """ if not self.buffer: return # Concatenate all dicts in buffer merged = {k: [] for k in self.buffer[0].keys()} for d in self.buffer: for k in merged: merged[k].extend(d[k]) ArrowWriter._flush = patch_flush df = pd.DataFrame(merged) row_num = len(df) write_data = pa.RecordBatch.from_pandas(df) writer = self._writer_session.open_arrow_writer(self._block_id) writer.write(write_data) writer.close() self.write_count += row_num self._block_id += 1 self.buffer = [] self.buffered_size = 0 # flush ODPS every 2 blocks if self._block_id == 2: self.flush() @retry(retry_count=3, interval=10) def flush(self, force=False): """Flushes data to ODPS and commits the upload session. Args: force (bool): If True, forces flushing of any remaining buffer data. Defaults to False. Note: This method commits the current upload session and creates a new one for subsequent writes. It's called automatically every 2 blocks or when forced during shutdown. """ # Flush any remaining buffer first if force: self._flush_buffer() # update writer_session if self._block_id > 0: self._writer_session.commit(list(range(self._block_id))) self._writer_session = self._tunnel_client.create_upload_session( self.table_name, partition_spec=self.partition ) self._block_id = 0 def __del__(self): """Cleanup method called when the writer process is destroyed. Logs the total write count and ensures any remaining data is flushed. """ logger.info( f"[rank-{rank}] [writer-{self.write_id}] write_count = {self.write_count}" ) if self._block_id > 0 or self.buffer: self.flush(force=True)
[docs] class TraceToOdpsHook(Hook): """Hook for tracing training data to ODPS tables. The TraceToOdpsHook provides high-performance data collection and upload capabilities for training traces. It uses multiprocessing to avoid blocking the main training process and supports configurable batching and buffering. Args: config (Dict): ODPS configuration dictionary containing connection details. Required keys: access_id, access_key, project, end_point, table_name. Optional keys: partition. fields (List[str]): List of field names for the ODPS table schema. types (List[str]): List of field types corresponding to the fields. worker_num (int): Number of worker processes for parallel uploads. Defaults to 1. size_threshold (int): Buffer size threshold in bytes for triggering flushes. Defaults to 50 MiB. Attributes: queue (Queue): Multiprocessing queue for data transfer. writer_num (int): Number of writer processes. writers (List[TraceWriter]): List of writer process instances. Example: >>> from recis.hooks import TraceToOdpsHook, add_to_trace >>> # Configure ODPS connection >>> config = { ... "access_id": "your_access_id", ... "access_key": "your_access_key", ... "project": "your_project", ... "end_point": "your_endpoint", ... "table_name": "training_traces", ... "partition": "dt=20231201", ... } >>> # Define table schema >>> fields = ["user_id", "item_id", "embedding", "score"] >>> types = ["bigint", "bigint", "string", "double"] >>> # Create hook >>> odps_hook = TraceToOdpsHook( ... config=config, fields=fields, types=types, worker_num=2 ... ) >>> trainer.add_hook(odps_hook) >>> # During training, add data to be traced >>> add_to_trace("user_embeddings", user_embeddings) >>> add_to_trace("item_scores", item_scores) >>> # The hook will automatically upload data after each step Note: This hook is only available in internal environments where ODPS access is configured. Use add_to_trace() to add data that should be uploaded to ODPS tables. """
[docs] def __init__( self, config: Dict, fields: List[str], types: List[str], worker_num: int = 1, size_threshold: int = 50 * 1024 * 1024, ) -> None: super().__init__() self.queue = Queue(maxsize=worker_num) self.writer_num = worker_num self.writers = [] for i in range(self.writer_num): self.writers.append( TraceWriter( config, fields, types, i, self.queue, size_threshold=size_threshold ) ) for writer in self.writers: writer.start()
def check_alive(self): """Checks if all writer processes are still alive. Returns: bool: True if all writer processes are alive, False otherwise. Raises: ValueError: If any writer process has died unexpectedly. """ alive = True for writer in self.writers: if not writer.is_alive(): alive = False return alive def after_step(self, is_train=True, *args, **kw): """Called after each training step to upload accumulated trace data. This method retrieves all data from the trace map and sends it to the writer processes for upload to ODPS. After sending, the trace map is cleared to prepare for the next step. Args: *args: Variable length argument list (unused). **kw: Arbitrary keyword arguments (unused). Raises: ValueError: If any writer subprocess has encountered an error. Note: The method checks that all writer processes are still alive before sending data. If any process has died, an error is raised. """ if not self.check_alive(): raise ValueError("TraceToOdpsHook sub-process raise error") data = get_trace_map() self.queue.put(data) clear_trace_map() def end(self, is_train=True, *args, **kwargs): """Called at the end of training to properly shutdown writer processes. This method sends shutdown signals to all writer processes and waits for them to complete their work and terminate gracefully. """ for writer in self.writers: self.queue.put(None) for writer in self.writers: writer.join()