import torch
[docs]
class CheckpointReader:
"""Provides read access to checkpoint files and their metadata.
This class serves as a wrapper around the low-level torch.classes.recis.CheckpointReader,
offering a convenient interface to inspect and read tensors from checkpoint files.
Examples:
Typical usage example for reading checkpoint contents:
>>> reader = CheckpointReader("/path/to/checkpoint")
>>> tensor_names = reader.tensor_names()
>>> for name in tensor_names:
... shape = reader.tensor_shape(name)
... dtype = reader.tensor_dtype(name)
... tensor_data = reader.read_tensor(name)
Attributes:
reader: The underlying implementation object handling low-level checkpoint reading.
"""
[docs]
def __init__(self, path):
"""Initializes the CheckpointReader with a path to checkpoint files.
Args:
path: The directory path containing checkpoint files to read.
Note:
The reader initialization may involve loading metadata and preparing
for subsequent read operations.
"""
self.reader = torch.classes.recis.CheckpointReader(path)
self.reader.init()
[docs]
def tensor_names(self):
"""Retrieves the names of all tensors available in the checkpoint.
Returns:
A list of string identifiers for all tensors stored in the checkpoint.
"""
return self.reader.list_tensor_names()
[docs]
def read_tensor(self, name):
"""Reads and returns the tensor data for the specified tensor name.
Args:
name: The identifier of the tensor to read.
Returns:
The tensor data as an appropriate array or tensor object.
Raises:
KeyError: If the specified tensor name does not exist in the checkpoint.
"""
return self.reader.read_tensor(name)
[docs]
def tensor_shape(self, name):
"""Retrieves the shape/dimensions of the specified tensor.
Args:
name: The identifier of the tensor.
Returns:
A tuple representing the shape of the tensor.
Raises:
KeyError: If the specified tensor name does not exist in the checkpoint.
"""
return self.reader.tensor_shape(name)
[docs]
def tensor_dtype(self, name):
"""Retrieves the data type of the specified tensor.
Args:
name: The identifier of the tensor.
Returns:
The data type object representing the tensor's element type.
Raises:
KeyError: If the specified tensor name does not exist in the checkpoint.
"""
return self.reader.tensor_dtype(name).dtype