数据处理流程

简单示例

def get_odps_dataset():
    worker_idx = int(os.environ.get("RANK", 0))
    worker_num = int(os.environ.get("WORLD_SIZE", 1))
    transform_fn = [lambda x: x[0]]
    # 定义dataset
    dataset = OdpsDataset(
      batch_size=1024,
      worker_idx=worker_idx,
      worker_num=worker_num,
      read_threads_num=2,
      prefetch=1,
      is_compressed=False,
      drop_remainder=True,
      transform_fn=transform_fn,
      dtype=torch.float32,
      device="cuda",
      save_interval=100,
    )
    # 添加path
    dataset.add_path(
      "odps://xxx/tables/xxx/ds=xxx",
    )
    # 添加特征
    dataset.varlen_feature(
      name="item_id",
      trans_int8=True,
    )
    dataset.varlen_feature(
      name="user_id",
      trans_int8=True,
    )
    dataset.fixedlen_feature(
      name="rate",
      default_value=[0.0]
    )
    dataset.fixedlen_feature(
      name="label",
      default_value=[0.0]
    )
    return dataset