iobjectspy.ml.spacetime package

Module contents

class iobjectspy.ml.spacetime.DataPreparation

基类:object

图时空预测数据准备流程入口

static create_adj_mx(input_pts_coords, output_dir, id_col='0', long_col='1', lat_col='2', dist_file=None)

邻接矩阵生成

参数:
  • input_pts_coords (str) -- 输入点数据的坐标文件
  • output_dir (str) -- 生成邻接矩阵的输出目录
  • id_col (id序号的列序号,默认为"0") -- id序号的列序号,默认为"0"
  • long_col (str) -- 经度列的列序号,默认为"1"
  • lat_col (str) -- 纬度列的列序号,默认为"2"
  • dist_file (str) -- 给定的距离信息文件路径,包含每条结点对 (from结点和to结点)的id,以及节点对之间的距离
static create_training_data(input_data, output_dir=None, train_rate=0.7, test_rate=0.2, index_col='0', period_len=3, step_rows=12, period_steps=24, period_units='D', add_time_in_period=True)

训练数据生成

参数:
  • input_data (str) -- 输入的原始表格数据路径,支持csv格式。表格中每列表示各位置点的数据,每行表示一个时间时刻。
  • output_dir (str) -- 生成训练集,测试集等数据的输出目录。
  • train_rate (float) -- 用于生成训练集的比例,默认0.7
  • test_rate (float) -- 用于生成训练集的比例,默认0.2
  • index_col (str or int) -- 时间索引列的序号,默认为"0"
  • period_len (int) -- 时间周期特征的长度,默认为3
  • step_rows (int) -- 一个时间步长包括的行数,默认为12
  • period_steps (int) -- 一个时间周期包括的时间步长,默认为24
  • period_units (str) -- 时间周期的单位,默认为"D",表示天
  • add_time_in_period -- 是否增加周期性特征,默认为True
  • add_time_in_period -- bool
class iobjectspy.ml.spacetime.Trainer(train_data_path, config, epochs=5, batch_size=1, lr=0.01, output_model_path=None, checkpoint_filename=None, **kwargs)

基类:object

图时空深度学习训练功能入口

参数:
  • train_data_path (str) -- 训练数据路径
  • config (str) -- 配置文件路径
graphst_regression_train()

图时空深度学习训练功能

生成模型将存储在输入的 ‘output_model_path’ 路径下

返回:None
class iobjectspy.ml.spacetime.Inference(input_data_dir, model_path, out_data, location_data_path=None, out_dataset_name='graph_st_predictions', add_index_before=False, fields_as_point=['longitude', 'latitude'])

基类:object

图时空回归模型推理功能入口

参数:
  • input_data_dir (str) -- 待推理的数据所在目录
  • model_path (str) -- 模型存储路径
  • out_data (str) -- 输出文件路径
  • add_index_before (bool) -- 当生成预测结果矢量文件时,是否要自动增加索引字段在location_data表格第一列.默认为False,表示location_data的第一列已有索引信息。
graph_st_regress_infer(**kwargs)

基于图时空回归的交通时空预测

输入和输出文件为numpy二进制序列化文件( *.npz )

参数:result_type (List) -- 结果返回类型
返回:若提供location_data_path,返回矢量数据集的预测结果,否则返回预测结果与GroundTruth的列表数据