温馨提示: 豌豆仅提供国内节点,不提供境外节点,不能用于任何非法用途,不能访问境外网站及跨境联网。

免费领取1万IP!

时间序列预测(一)—— 数据预处理

发布时间:

时间序列预测(一)—— 数据预处理

  最近在做时间序列的预测问题,这里就稍微总结回顾一下,便于以后查阅,也希望能给大家提供到帮助,有什么问题欢迎多多交流。
  这是一个系列的文章,主要从代码的角度分析问题(争取做到代码片段的随用随取),不涉及太多的模型原理(我会尽可能讲一下自己的理解),本系列文章包含了数据预处理和基本时间序列分析预测模型:

(一)数据预处理

(二)AR模型(自回归模型)

(三)Xgboost模型

(四)LSTM模型

(五)Prophet模型(自回归模型)


数据预处理(pre-processing)

数据预处理在数据分析中占据了重要的地位,这里主要介绍几种常见的预处理方法

1. 归一化(反归一化)

  归一化可以说是数据预处理里面最常用的方法之一了,在模型训练中不同的数据取值范围假如相差过大很容易造成模型错误的分配权重,所以很多时候归一化必不可少。

def normalize(data, method="MinMax", feature_range=(0, 1)):
    """
    normalize the data
    :param data: list of data
    :param method: support MinMax scaler or Z-Score scaler
    :param feature_range: use in MinMax scaler
    :return: normalized data(list), scaler
    """
    data = np.array(data)
    if len(data.shape) == 1 or data.shape[1] != 1:
        # reshape(-1, 1) --> reshape to a one column n rows matrix(-1 means not sure how many row)
        data = data.reshape(-1, 1)
    if method == "MinMax":
        scaler = MinMaxScaler(feature_range=feature_range)
    elif method == "Z-Score":
        scaler = StandardScaler()
    else:
        raise ValueError("only support MinMax scaler and Z-Score scaler")
    scaler.fit(data)
    # scaler transform apply to each column respectively
    # (which means that if we want to transform a 1-D data, we must reshape it to n x 1 matrix)
    return scaler.transform(data).reshape(-1), scaler


def denormalize(data, scaler):
    """
    denormalize data by scaler
    :param data:
    :param scaler:
    :return: denormalized data
    """
    data = np.array(data)
    if len(data.shape) == 1 or data.shape[1] != 1:
        data = data.reshape(-1, 1)
    # max, min, mean, variance are all store in scaler, so we need it to perform inverse transform
    return scaler.inverse_transform(data).reshape(-1)

2.重采样

  有时候原始数据的采样频率可能太高,导致噪声比较大,重采样可以在一定程度上降低噪声,同时数据量较大的时候还可以起到减小数据量提高模型的迭代速度。

def resample(data, period="W"):
    """
    resample the original data to reduce noise
    :param data:
    :param period: the period of data e.g. B - business day, D - calendar day, W - weekly, Y - yearly etc.
    (reference: pandas DateOffset Objects'http://pandas.pydata.org/pandas-docs/stable/user_guide/timeseries.html')
    :return:
    """
    data = data.set_index(pd.DatetimeIndex(data['ds']))
    return data.resample(period, label="right").mean().reset_index()

3.数据划分

  用于将原始数据划分为训练机,测试集和验证集,可以自行调整分配权重。

def split_data(observed_data, split_ratio=(8, 2, 1)):
    """
    split the observed data into train-evaluation-test three part
    :param observed_data:
    :param split_ratio: relative proportion among train,evaluation,test
    :return: train, evaluation, test data
    """
    total = split_ratio[0] + split_ratio[1] + split_ratio[2]
    length = len(observed_data)
    train_cnt = int((split_ratio[0] / total) * length)
    test_cnt = int((split_ratio[2] / total) * length)
    return observed_data[:train_cnt], observed_data[train_cnt:-test_cnt], observed_data[-test_cnt:]

4.各种滤波(中值,均值,巴特沃斯)

  这个主要用于信号处理,一般的数据分析可能用的不多。

def median_filter(datas, length=3):
    """
    median filter, length must be odd number
    :param datas:
    :param length:
    :return:
    """
    return signal.medfilt(datas, length).tolist()


def average_filter(datas, length=3):
    """
    average filter, length should not greater than 5
    :param datas:
    :param length:
    :return:
    """
    if isinstance(datas, np.ndarray):
        datas = datas.tolist()
    updated = []
    if length == 2:
        for d1, d2 in zip(datas[:-1], datas[1:]):
            updated.append(np.average([d1, d2]))
    elif length == 3:
        for d1, d2, d3 in zip(datas[:-2], datas[1:-1], datas[2:]):
            updated.append(np.average([d1, d2, d3]))
    elif length == 4:
        for d1, d2, d3, d4 in zip(datas[:-3], datas[1:-2], datas[2:-1], datas[3:]):
            updated.append(np.average([d1, d2, d3, d4]))
    else:
        for d1, d2, d3, d4, d5 in zip(datas[:-4], datas[1:-3], datas[2:-2], datas[3:-1], datas[4:]):
            updated.append(np.average([d1, d2, d3, d4, d5]))
    return updated


# noinspection PyTupleAssignmentBalance
def butter_filter(datas, hc=0.1):
    b, a = signal.butter(5, hc, btype="low")
    return signal.filtfilt(b, a, datas)

5.数据可视化

  数据可视化也是预处理中很重要的一个手段,可以用于分析数据的各项属性。

  • 折线图:用于分析数据的变化趋势
  • 箱型图:用于分析数据异常值
  • 正态分布图:用于分析数据是否符合正态分布
  • 散点图:用于分析数据的分布情况(线性度,聚合度等)
  • 热力图:用于分析多组数据间的相关系数(注意,一定是多组数据)
  • 柱状图:用于分析数据的相对大小/占比
  • 饼状图:用于分析数据的占比
def plot_data(y_val, x_val=None, legend="", x_label="", y_label="", title="", file_name=""):
    if x_val is None:
        x_val = range(len(y_val))
    plt.figure()
    plt.plot(x_val, y_val, label=legend)
    if legend:
        # show legend(line label)
        plt.legend()
    # show x label
    plt.xlabel(x_label)
    # show y label
    plt.ylabel(y_label)
    # show title
    plt.title(title)
    if file_name:
        plt.savefig('./result/%s.png' % file_name, bbox_inches='tight')
    plt.tight_layout()
    plt.show()


def box_plot(data, file_name=""):
    """
    box plot —— use to see the distribution of the datas
    :param data: list of data
    :param file_name: file name of the plot figure
    :return: None
    """
    plt.boxplot(np.array(data), sym="o")
    if file_name:
        plt.savefig('./result/%s.png' % file_name, bbox_inches='tight')
    plt.show()


def distribution_plot(data, file_name=""):
    """
    distribution plot —— use to see the distribution of the data
    :param data: list of data
    :param file_name: file name of the plot figure
    :return: None
    """
    plt.figure(figsize=(8, 5))
    sns.set_style('whitegrid')
    sns.distplot(np.array(data), rug=True, color='b')
    plt.title("distribution")
    if file_name:
        plt.savefig("./result/%s.png" % file_name, bbox_inches='tight')
    plt.show()


def scatter_plot(y_val, x_val=None, x_label="", y_label="", title="", file_name=""):
    """
    scatter plot —— use to see the distribution of the data
    :param y_val: y value
    :param x_val: x value, if None will use 0 ~ range(len(y_val))
    :param x_label:
    :param y_label:
    :param title:
    :param file_name: file name of the plot figure
    :return: None
    """
    if x_val is None:
        x_val = range(len(y_val))
    plt.scatter(x_val, y_val, marker='o', color='black', s=10)
    plt.xlabel(x_label)
    plt.ylabel(y_label)
    plt.title(title)
    if file_name:
        plt.savefig('./result/%s.png' % file_name, bbox_inches='tight')
    plt.show()


def heatmap(data, file_name="", method="pearson"):
    """
    draw the heat map of sets of data
    :param data: DataFrame format
    :param file_name:
    :param method: method used to calculate the correlation
                   pearson - range -1 ~ 1(only two variable in perfect linear relation, it will be ±1)
                   spearman - range -1 ~ 1(it will be ±1 when
                   the relation between two variable can be described by a monotonic function)
    reference
    < 0.1 : no relation
    0.10 ~ 0.29: weak relation
    0.30 ~ 0.49: medium relation
    > 0.5: strong relation
    :return: None
    """
    sns.heatmap(
        data.corr(method=method),
        xticklabels=data.corr(method=method).columns,
        yticklabels=data.corr(method=method).columns,
        annot=True, annot_kws={'weight': 'bold'},
        vmin=-0.5, vmax=1, cmap="YlGnBu"
    )
    plt.tight_layout()
    if file_name:
        file_name = file_name.split(".")[0]
        plt.savefig("./result/heatmap/%s.png" % file_name)
    plt.show()
    plt.close()


def bar_plot(y_val, x_val=None, x_label=None, horizontal=False, file_name=""):
    if x_val is None:
        x_val = range(len(y_val))
    if x_label is None:
        x_label = x_val
    if horizontal:
        plt.barh(y=x_val, width=y_val, tick_label=x_label)
        for a, b in zip(y_val, x_val):
            plt.text(a + 0.01, b, '%.0f' % a, ha='left', va='center', fontsize=11)
    else:
        plt.bar(x=x_val, height=y_val, tick_label=x_label)
        for a, b in zip(x_val, y_val):
            plt.text(a, b + 0.01, '%.0f' % b, ha='center', va='bottom', fontsize=11)
    plt.tight_layout()
    plt.savefig('./result/%s.png' % file_name, bbox_inches='tight')
    # if file_name:
    #     plt.savefig('./result/%s.png' % file_name, bbox_inches='tight')
    plt.show()


def pie_plot(data, labels=None):
    """
    draw data in pie figure
    :param data:list(pure data) or dict(use keys as labels and values as data)
    :param labels: None if data is a dict
    :return:
    """
    X = []
    if isinstance(data, dict):
        labels = []
        for k in data.keys():
            labels.append(k)
            X.append(data[k])

    if labels is None:
        raise ValueError("labels should be specify")

    plt.pie(X, labels=labels, autopct='%1.2f%%')  # 画饼图(数据,数据对应的标签,百分数保留两位小数点)
    plt.title("Pie chart")

    plt.show()

模型结果衡量

  当有多个模型的时候,我们需要有一个指标衡量模型建表现的好坏,针对时间序列(连续),可以选取均方误差(mse)和均方根误差(rmse)。

def rmse(predictions, targets):
    """
    root-mean-square error
    :param predictions:
    :param targets:
    :return:
    """
    predictions = np.array(predictions)
    targets = np.array(targets)
    return np.sqrt(((predictions - targets) ** 2).mean())


def mse(predictions, targets):
    """
    mean-square error
    :param predictions:
    :param targets:
    :return:
    """
    predictions = np.array(predictions)
    targets = np.array(targets)
    return ((predictions - targets) ** 2).mean()

相关文章


解决Python模块报错:ModuleNotFoundError: No module name ' 时间序列预测(五)—— Prophet模型 Java代码 - 基于广度优先遍历实现迷宫路径搜索(搜索迷宫的最短路径信息) 便宜的ssl数字证书,可靠的ssl证书,低价的ssl证书原因 Flutter——UI布局 《Spark The Definitive Guide》Chapter 4:结构化API预览 使用expect实现自动2步登录 demo10 关于JS Tree Shaking

上一篇:透视相机(PerspectiveCamera)
下一篇:xcodebuild自动打包,发布应用

咨询·合作