

 从补丁 198 开始，Amazon Redshift 将不再支持创建新的 Python UDF。现有的 Python UDF 将继续正常运行至 2026 年 6 月 30 日。有关更多信息，请参阅[博客文章](https://aws.amazon.com/blogs/big-data/amazon-redshift-python-user-defined-functions-will-reach-end-of-support-after-june-30-2026/)。

# 教程：构建回归模型
<a name="tutorial_regression"></a>

在本教程中，您使用 Amazon Redshift ML 创建机器学习回归模型，并对该模型运行预测查询。回归模型允许您预测数值结果，例如房屋价格，或有多少人将使用城市的自行车租赁服务。您可以在 Amazon Redshift 中使用 CREATE MODEL 命令来处理您的训练数据。然后，Amazon Redshift ML 编译模型，将经过训练的模型导入到 Redshift 中，并准备一个 SQL 预测函数。您可以在 Amazon Redshift 的 SQL 查询中使用预测函数。

在本教程中，您将使用 Amazon Redshift ML 构建回归模型，该模型可预测在一天中任何给定小时内使用多伦多市自行车共享服务的人数。模型的输入包括节假日和天气状况。您将使用回归模型，因为您需要此问题的数值结果。

您可以使用 CREATE MODEL 命令导出训练数据，训练模型，并使模型在 Amazon Redshift 中可用作 SQL 函数。使用 CREATE MODEL 操作将训练数据指定为表或 SELECT 语句。

## 使用案例示例
<a name="tutorial_regression_tasks"></a>

您可以使用 Amazon Redshift ML 解决其他回归问题，例如预测客户的生命周期价值。还可以使用 Redshift ML 来预测商品的最能赢利的价格和相应的收入。

**任务**：
+ 先决条件
+ 步骤 1：将数据从 Amazon S3 加载到 Amazon Redshift
+ 步骤 2：创建机器学习模型
+ 步骤 3：验证模型

## 先决条件
<a name="tutorial_regression_prereqs"></a>

要完成此教程，必须完成 Amazon Redshift ML 的[管理设置](https://docs.aws.amazon.com/redshift/latest/dg/admin-setup.html)。

## 步骤 1：将数据从 Amazon S3 加载到 Amazon Redshift
<a name="tutorial_regression_step_load"></a>

使用 [Amazon Redshift 查询器 v2](https://docs.aws.amazon.com/redshift/latest/mgmt/query-editor-v2-using.html) 运行以下查询。

1. 您必须创建三个表才能将三个公有数据集加载到 Amazon Redshift 中。这些数据集是[多伦多自行车骑行数据](https://open.toronto.ca/dataset/bike-share-toronto-ridership-data/)、[历史天气数据](https://climate.weather.gc.ca/historical_data/search_historic_data_e.html)和[历史节假日数据](https://github.com/uWaterloo/Datasets/blob/master/Holidays/holidays.csv)。在 Amazon Redshift 查询编辑器中运行以下查询，以便创建名为 `ridership`、`weather` 和 `holiday` 的表。

   ```
   CREATE TABLE IF NOT EXISTS ridership (
       trip_id INT,
       trip_duration_seconds INT,
       trip_start_time timestamp,
       trip_stop_time timestamp,
       from_station_name VARCHAR(50),
       to_station_name VARCHAR(50),
       from_station_id SMALLINT,
       to_station_id SMALLINT,
       user_type VARCHAR(20)
   );
   
   CREATE TABLE IF NOT EXISTS weather (
       longitude_x DECIMAL(5, 2),
       latitude_y DECIMAL(5, 2),
       station_name VARCHAR(20),
       climate_id BIGINT,
       datetime_utc TIMESTAMP,
       weather_year SMALLINT,
       weather_month SMALLINT,
       weather_day SMALLINT,
       time_utc VARCHAR(5),
       temp_c DECIMAL(5, 2),
       temp_flag VARCHAR(1),
       dew_point_temp_c DECIMAL(5, 2),
       dew_point_temp_flag VARCHAR(1),
       rel_hum SMALLINT,
       rel_hum_flag VARCHAR(1),
       precip_amount_mm DECIMAL(5, 2),
       precip_amount_flag VARCHAR(1),
       wind_dir_10s_deg VARCHAR(10),
       wind_dir_flag VARCHAR(1),
       wind_spd_kmh VARCHAR(10),
       wind_spd_flag VARCHAR(1),
       visibility_km VARCHAR(10),
       visibility_flag VARCHAR(1),
       stn_press_kpa DECIMAL(5, 2),
       stn_press_flag VARCHAR(1),
       hmdx SMALLINT,
       hmdx_flag VARCHAR(1),
       wind_chill VARCHAR(10),
       wind_chill_flag VARCHAR(1),
       weather VARCHAR(10)
   );
   
   CREATE TABLE IF NOT EXISTS holiday (holiday_date DATE, description VARCHAR(100));
   ```

1. 以下查询将示例数据加载到您在上一步中创建的各表中。

   ```
   COPY ridership
   FROM
       's3://redshift-ml-bikesharing-data/bike-sharing-data/ridership/' 
       IAM_ROLE default 
       FORMAT CSV 
       IGNOREHEADER 1 
       DATEFORMAT 'auto' 
       TIMEFORMAT 'auto' 
       REGION 'us-west-2' 
       gzip;
   
   COPY weather
   FROM
       's3://redshift-ml-bikesharing-data/bike-sharing-data/weather/' 
       IAM_ROLE default 
       FORMAT csv 
       IGNOREHEADER 1 
       DATEFORMAT 'auto' 
       TIMEFORMAT 'auto' 
       REGION 'us-west-2' 
       gzip;
   
   COPY holiday
   FROM
       's3://redshift-ml-bikesharing-data/bike-sharing-data/holiday/' 
       IAM_ROLE default 
       FORMAT csv 
       IGNOREHEADER 1 
       DATEFORMAT 'auto' 
       TIMEFORMAT 'auto' 
       REGION 'us-west-2' 
       gzip;
   ```

1. 以下查询对 `ridership` 和 `weather` 数据集执行转换以消除偏差或异常。消除偏差和异常可提高模型准确性。查询通过创建两个名为 `ridership_view` 和 `weather_view` 的新视图来简化表。

   ```
   CREATE
   OR REPLACE VIEW ridership_view AS
   SELECT
       trip_time,
       trip_count,
       TO_CHAR(trip_time, 'hh24') :: INT trip_hour,
       TO_CHAR(trip_time, 'dd') :: INT trip_day,
       TO_CHAR(trip_time, 'mm') :: INT trip_month,
       TO_CHAR(trip_time, 'yy') :: INT trip_year,
       TO_CHAR(trip_time, 'q') :: INT trip_quarter,
       TO_CHAR(trip_time, 'w') :: INT trip_month_week,
       TO_CHAR(trip_time, 'd') :: INT trip_week_day
   FROM
       (
           SELECT
               CASE
                   WHEN TRUNC(r.trip_start_time) < '2017-07-01' :: DATE THEN CONVERT_TIMEZONE(
                       'US/Eastern',
                       DATE_TRUNC('hour', r.trip_start_time)
                   )
                   ELSE DATE_TRUNC('hour', r.trip_start_time)
               END trip_time,
               COUNT(1) trip_count
           FROM
               ridership r
           WHERE
               r.trip_duration_seconds BETWEEN 60
               AND 60 * 60 * 24
           GROUP BY
               1
       );
   
   CREATE
   OR REPLACE VIEW weather_view AS
   SELECT
       CONVERT_TIMEZONE(
           'US/Eastern',
           DATE_TRUNC('hour', datetime_utc)
       ) daytime,
       ROUND(AVG(temp_c)) temp_c,
       ROUND(AVG(precip_amount_mm)) precip_amount_mm
   FROM
       weather
   GROUP BY
       1;
   ```

1. 以下查询创建了一个表，该表将来自 `ridership_view` 和 `weather_view` 的所有相关输入属性组合到 `trip_data` 表中。

   ```
   CREATE TABLE trip_data AS
   SELECT
       r.trip_time,
       r.trip_count,
       r.trip_hour,
       r.trip_day,
       r.trip_month,
       r.trip_year,
       r.trip_quarter,
       r.trip_month_week,
       r.trip_week_day,
       w.temp_c,
       w.precip_amount_mm,CASE
           WHEN h.holiday_date IS NOT NULL THEN 1
           WHEN TO_CHAR(r.trip_time, 'D') :: INT IN (1, 7) THEN 1
           ELSE 0
       END is_holiday,
       ROW_NUMBER() OVER (
           ORDER BY
               RANDOM()
       ) serial_number
   FROM
       ridership_view r
       JOIN weather_view w ON (r.trip_time = w.daytime)
       LEFT OUTER JOIN holiday h ON (TRUNC(r.trip_time) = h.holiday_date);
   ```

### 查看示例数据（可选）
<a name="tutorial_regression_view_data"></a>

以下查询显示表中的条目。您可以运行此操作以确保正确创建了此表。

```
SELECT * 
FROM trip_data 
LIMIT 5;
```

以下是上一个操作的输出示例。

```
+---------------------+------------+-----------+----------+------------+-----------+--------------+-----------------+---------------+--------+------------------+------------+---------------+
|      trip_time      | trip_count | trip_hour | trip_day | trip_month | trip_year | trip_quarter | trip_month_week | trip_week_day | temp_c | precip_amount_mm | is_holiday | serial_number |
+---------------------+------------+-----------+----------+------------+-----------+--------------+-----------------+---------------+--------+------------------+------------+---------------+
| 2017-03-21 22:00:00 |         47 |        22 |       21 |          3 |        17 |            1 |               3 |             3 |      1 |                0 |          0 |             1 |
| 2018-05-04 01:00:00 |         19 |         1 |        4 |          5 |        18 |            2 |               1 |             6 |     12 |                0 |          0 |             3 |
| 2018-01-11 10:00:00 |         93 |        10 |       11 |          1 |        18 |            1 |               2 |             5 |      9 |                0 |          0 |             5 |
| 2017-10-28 04:00:00 |         20 |         4 |       28 |         10 |        17 |            4 |               4 |             7 |     11 |                0 |          1 |             7 |
| 2017-12-31 21:00:00 |         11 |        21 |       31 |         12 |        17 |            4 |               5 |             1 |    -15 |                0 |          1 |             9 |
+---------------------+------------+-----------+----------+------------+-----------+--------------+-----------------+---------------+--------+------------------+------------+---------------+
```

### 显示属性之间的相关性（可选）
<a name="tutorial_regression_show_correlation"></a>

确定相关性有助于衡量属性之间的关联强度。关联级别可以帮助您确定影响目标输出的因素。在本教程中，目标输出为 `trip_count`。

以下查询将创建或替换 `sp_correlation` 过程。您可以使用名为 `sp_correlation` 的存储过程，以显示 Amazon Redshift 的表中某个属性与其他属性之间的相关性。

```
CREATE OR REPLACE PROCEDURE sp_correlation(source_schema_name in varchar(255), source_table_name in varchar(255), target_column_name in varchar(255), output_temp_table_name inout varchar(255)) AS $$
DECLARE
  v_sql varchar(max);
  v_generated_sql varchar(max);
  v_source_schema_name varchar(255)=lower(source_schema_name);
  v_source_table_name varchar(255)=lower(source_table_name);
  v_target_column_name varchar(255)=lower(target_column_name);
BEGIN
  EXECUTE 'DROP TABLE IF EXISTS ' || output_temp_table_name;
  v_sql = '
SELECT
  ''CREATE temp table '|| output_temp_table_name||' AS SELECT ''|| outer_calculation||
  '' FROM (SELECT COUNT(1) number_of_items, SUM('||v_target_column_name||') sum_target, SUM(POW('||v_target_column_name||',2)) sum_square_target, POW(SUM('||v_target_column_name||'),2) square_sum_target,''||
  inner_calculation||
  '' FROM (SELECT ''||
  column_name||
  '' FROM '||v_source_table_name||'))''
FROM
  (
  SELECT
    DISTINCT
    LISTAGG(outer_calculation,'','') OVER () outer_calculation
    ,LISTAGG(inner_calculation,'','') OVER () inner_calculation
    ,LISTAGG(column_name,'','') OVER () column_name
  FROM
    (
    SELECT
      CASE WHEN atttypid=16 THEN ''DECODE(''||column_name||'',true,1,0)'' ELSE column_name END column_name
      ,atttypid
      ,''CAST(DECODE(number_of_items * sum_square_''||rn||'' - square_sum_''||rn||'',0,null,(number_of_items*sum_target_''||rn||'' - sum_target * sum_''||rn||
        '')/SQRT((number_of_items * sum_square_target - square_sum_target) * (number_of_items * sum_square_''||rn||
        '' - square_sum_''||rn||''))) AS numeric(5,2)) ''||column_name outer_calculation
      ,''sum(''||column_name||'') sum_''||rn||'',''||
            ''SUM(trip_count*''||column_name||'') sum_target_''||rn||'',''||
            ''SUM(POW(''||column_name||'',2)) sum_square_''||rn||'',''||
            ''POW(SUM(''||column_name||''),2) square_sum_''||rn inner_calculation
    FROM
      (
      SELECT
        row_number() OVER (order by a.attnum) rn
        ,a.attname::VARCHAR column_name
        ,a.atttypid
      FROM pg_namespace AS n
        INNER JOIN pg_class AS c ON n.oid = c.relnamespace
        INNER JOIN pg_attribute AS a ON c.oid = a.attrelid
      WHERE a.attnum > 0
        AND n.nspname = '''||v_source_schema_name||'''
        AND c.relname = '''||v_source_table_name||'''
        AND a.atttypid IN (16,20,21,23,700,701,1700)
      )
    )
)';
  EXECUTE v_sql INTO v_generated_sql;
  EXECUTE  v_generated_sql;
END;
$$ LANGUAGE plpgsql;
```

下面的查询显示数据集中的目标列、`trip_count` 和其他数值属性之间的相关性。

```
call sp_correlation(
    'public',
    'trip_data',
    'trip_count',
    'tmp_corr_table'
);

SELECT
    *
FROM
    tmp_corr_table;
```

下例显示了上一个 `sp_correlation` 操作的输出。

```
+------------+-----------+----------+------------+-----------+--------------+-----------------+---------------+--------+------------------+------------+---------------+
| trip_count | trip_hour | trip_day | trip_month | trip_year | trip_quarter | trip_month_week | trip_week_day | temp_c | precip_amount_mm | is_holiday | serial_number |
+------------+-----------+----------+------------+-----------+--------------+-----------------+---------------+--------+------------------+------------+---------------+
|          1 |      0.32 |     0.01 |       0.18 |      0.12 |         0.18 |               0 |          0.02 |   0.53 |            -0.07 |      -0.13 |             0 |
+------------+-----------+----------+------------+-----------+--------------+-----------------+---------------+--------+------------------+------------+---------------+
```

## 步骤 2：创建机器学习模型
<a name="tutorial_regression_create_model"></a>

1. 以下查询通过指定 80% 的数据集用于训练和 20% 的数据集用于验证，将数据拆分为训练集和验证集。训练集是 ML 模型的输入，用于确定模型的最佳可能算法。创建模型后，您可以使用验证集来验证模型的准确性。

   ```
   CREATE TABLE training_data AS
   SELECT
       trip_count,
       trip_hour,
       trip_day,
       trip_month,
       trip_year,
       trip_quarter,
       trip_month_week,
       trip_week_day,
       temp_c,
       precip_amount_mm,
       is_holiday
   FROM
       trip_data
   WHERE
       serial_number > (
           SELECT
               COUNT(1) * 0.2
           FROM
               trip_data
       );
   
   CREATE TABLE validation_data AS
   SELECT
       trip_count,
       trip_hour,
       trip_day,
       trip_month,
       trip_year,
       trip_quarter,
       trip_month_week,
       trip_week_day,
       temp_c,
       precip_amount_mm,
       is_holiday,
       trip_time
   FROM
       trip_data
   WHERE
       serial_number <= (
           SELECT
               COUNT(1) * 0.2
           FROM
               trip_data
       );
   ```

1. 以下查询创建一个回归模型来预测任何输入日期和时间的 `trip_count` 值。在以下示例中，将 amzn-s3-demo-bucket 替换为您自己的 S3 存储桶。

   ```
   CREATE MODEL predict_rental_count
   FROM
       training_data TARGET trip_count FUNCTION predict_rental_count 
       IAM_ROLE default 
       PROBLEM_TYPE regression 
       OBJECTIVE 'mse' 
       SETTINGS (
           s3_bucket 'amzn-s3-demo-bucket',
           s3_garbage_collect off,
           max_runtime 5000
       );
   ```

## 步骤 3：验证模型
<a name="tutorial_regression_step_validate"></a>

1. 使用以下查询输出模型的多个方面，并在输出中找出均方误差指标。均方误差是回归问题的典型准确性指标。

   ```
   show model predict_rental_count;
   ```

1. 根据验证数据运行以下预测查询，以将预测的行程计数与实际行程计数进行比较。

   ```
   SELECT
       trip_time,
       actual_count,
       predicted_count,
       (actual_count - predicted_count) difference
   FROM
       (
           SELECT
               trip_time,
               trip_count AS actual_count,
               PREDICT_RENTAL_COUNT (
                   trip_hour,
                   trip_day,
                   trip_month,
                   trip_year,
                   trip_quarter,
                   trip_month_week,
                   trip_week_day,
                   temp_c,
                   precip_amount_mm,
                   is_holiday
               ) predicted_count
           FROM
               validation_data
       )
   LIMIT
       5;
   ```

1. 以下查询根据您的验证数据计算均方误差和均方根误差。您可以使用均方误差和均方根误差，来测量预测的数值目标与实际数值答案之间的差距。一个好的模型在这两个指标中的分数都很低。下面的查询返回这两个指标的值。

   ```
   SELECT
       ROUND(
           AVG(POWER((actual_count - predicted_count), 2)),
           2
       ) mse,
       ROUND(
           SQRT(AVG(POWER((actual_count - predicted_count), 2))),
           2
       ) rmse
   FROM
       (
           SELECT
               trip_time,
               trip_count AS actual_count,
               PREDICT_RENTAL_COUNT (
                   trip_hour,
                   trip_day,
                   trip_month,
                   trip_year,
                   trip_quarter,
                   trip_month_week,
                   trip_week_day,
                   temp_c,
                   precip_amount_mm,
                   is_holiday
               ) predicted_count
           FROM
               validation_data
       );
   ```

1. 以下查询计算 2017 年 1 月 1 日每次行程时间的行程计数误差百分比。该查询对行程时间进行排序，采用的顺序为从误差百分比最低的时间到误差百分比最高的时间。

   ```
   SELECT
       trip_time,
       CAST(ABS(((actual_count - predicted_count) / actual_count)) * 100 AS DECIMAL (7,2)) AS pct_error
   FROM
       (
           SELECT
               trip_time,
               trip_count AS actual_count,
               PREDICT_RENTAL_COUNT (
                   trip_hour,
                   trip_day,
                   trip_month,
                   trip_year,
                   trip_quarter,
                   trip_month_week,
                   trip_week_day,
                   temp_c,
                   precip_amount_mm,
                   is_holiday
               ) predicted_count
           FROM
               validation_data
       )
   WHERE
      trip_time LIKE '2017-01-01 %%:%%:%%'
   ORDER BY
      2 ASC;
   ```

## 相关主题
<a name="tutorial_regression_related_topics"></a>

有关 Amazon Redshift ML 的更多信息，请参阅以下文档：
+ [使用 Amazon Redshift ML 的成本](https://docs.aws.amazon.com/redshift/latest/dg/cost.html)
+ [CREATE MODEL 操作](https://docs.aws.amazon.com/redshift/latest/dg/r_CREATE_MODEL.html)
+ [EXPLAIN\$1MODEL 函数](https://docs.aws.amazon.com/redshift/latest/dg/r_explain_model_function.html)

有关机器学习的更多信息，请参阅以下文档：
+ [机器学习概览](https://docs.aws.amazon.com/redshift/latest/dg/machine_learning_overview.html)
+ [面向新手和专家的机器学习](https://docs.aws.amazon.com/redshift/latest/dg/novice_expert.html)
+ [机器学习预测的公平性和模型可解释性是什么？](https://docs.aws.amazon.com/sagemaker/latest/dg/clarify-fairness-and-explainability.html)