Linear Regression using Dask Data Frames
• 10 min read
This post includes code from Scalable-Data-Analysis-in-Python-with-Dask and coiled-examples.
import numpy as np
import dask.array as da
import pandas as pd
import sqlalchemy as db
from sqlalchemy import create_engine
import sqlite3
import pandas as pd
engine = db.create_engine("sqlite:///fiscal.db")
connection = engine.connect()
metadata = db.MetaData()
#engine.execute("SELECT * FROM fiscal_data LIMIT 1").fetchall()
sql = """
SELECT year
, region
, province
, gdp
, fdi
, it
, specific
FROM fiscal_table
"""
cnxn = connection
df = pd.read_sql(sql, cnxn)
from dask.distributed import Client
client = Client(processes=False, threads_per_worker=2,
n_workers=3, memory_limit='4GB')
client
client.restart()
Client
|
Cluster
|
from dask import dataframe as dd
ddf = dd.from_pandas(df, npartitions=5)
print(ddf)
Dask DataFrame Structure: year region province gdp fdi it specific npartitions=5 0 int64 object object float64 int64 int64 float64 72 ... ... ... ... ... ... ... ... ... ... ... ... ... ... ... 288 ... ... ... ... ... ... ... 359 ... ... ... ... ... ... ... Dask Name: from_pandas, 5 tasks
ddf.head()
year | region | province | gdp | fdi | it | specific | |
---|---|---|---|---|---|---|---|
0 | 1996 | East China | Anhui | 2093.30 | 50661 | 631930 | 147002.0 |
1 | 1997 | East China | Anhui | 2347.32 | 43443 | 657860 | 151981.0 |
2 | 1998 | East China | Anhui | 2542.96 | 27673 | 889463 | 174930.0 |
3 | 1999 | East China | Anhui | 2712.34 | 26131 | 1227364 | 285324.0 |
4 | 2000 | East China | Anhui | 2902.09 | 31847 | 1499110 | 195580.0 |
client.id
'Client-4c6edbde-0e23-11eb-a5c8-4b14c8b4f4db'
feat_list = ["year", "fdi"]
cat_feat_list = ["region", "province"]
target = ["gdp"]
ddf["year"] = ddf["year"].astype(int)
ddf["fdi"] = ddf["fdi"].astype(float)
ddf["gdp"] = ddf["gdp"].astype(float)
ddf["it"] = ddf["it"].astype(float)
#OHE
from dask_ml.preprocessing import OneHotEncoder
ddf = ddf.categorize(cat_feat_list)
ohe = OneHotEncoder(sparse=False)
ohe_ddf = ohe.fit_transform(ddf[cat_feat_list])
feat_list = feat_list + ohe_ddf.columns.tolist()
feat_list = [f for f in feat_list if f not in cat_feat_list]
ddf_processed = (dd.concat([ddf,ohe_ddf], axis=1) [feat_list + target])
ddf_processed.compute()
year | fdi | region_East China | region_North China | region_Southwest China | region_Northwest China | region_South Central China | region_Northeast China | province_Anhui | province_Beijing | ... | province_Shandong | province_Shanghai | province_Shanxi | province_Sichuan | province_Tianjin | province_Tibet | province_Xinjiang | province_Yunnan | province_Zhejiang | gdp | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1996 | 50661.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2093.30 |
1 | 1997 | 43443.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2347.32 |
2 | 1998 | 27673.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2542.96 |
3 | 1999 | 26131.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2712.34 |
4 | 2000 | 31847.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 2902.09 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
355 | 2003 | 498055.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 9705.02 |
356 | 2004 | 668128.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 11648.70 |
357 | 2005 | 772000.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 13417.68 |
358 | 2006 | 888935.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 15718.47 |
359 | 2007 | 1036576.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 18753.73 |
360 rows × 75 columns
feat_list
['year', 'fdi', 'region_East China', 'region_North China', 'region_Southwest China', 'region_Northwest China', 'region_South Central China', 'region_Northeast China', 'province_Anhui', 'province_Beijing', 'province_Chongqing', 'province_Fujian', 'province_Gansu', 'province_Guangdong', 'province_Guangxi', 'province_Guizhou', 'province_Hainan', 'province_Hebei', 'province_Heilongjiang', 'province_Henan', 'province_Hubei', 'province_Hunan', 'province_Jiangsu', 'province_Jiangxi', 'province_Jilin', 'province_Liaoning', 'province_Ningxia', 'province_Qinghai', 'province_Shaanxi', 'province_Shandong', 'province_Shanghai', 'province_Shanxi', 'province_Sichuan', 'province_Tianjin', 'province_Tibet', 'province_Xinjiang', 'province_Yunnan', 'province_Zhejiang', 'region_East China', 'region_North China', 'region_Southwest China', 'region_Northwest China', 'region_South Central China', 'region_Northeast China', 'province_Anhui', 'province_Beijing', 'province_Chongqing', 'province_Fujian', 'province_Gansu', 'province_Guangdong', 'province_Guangxi', 'province_Guizhou', 'province_Hainan', 'province_Hebei', 'province_Heilongjiang', 'province_Henan', 'province_Hubei', 'province_Hunan', 'province_Jiangsu', 'province_Jiangxi', 'province_Jilin', 'province_Liaoning', 'province_Ningxia', 'province_Qinghai', 'province_Shaanxi', 'province_Shandong', 'province_Shanghai', 'province_Shanxi', 'province_Sichuan', 'province_Tianjin', 'province_Tibet', 'province_Xinjiang', 'province_Yunnan', 'province_Zhejiang']
X=ddf_processed[feat_list].persist()
y=ddf_processed[target].persist()
from sklearn.datasets import make_regression
from sklearn.linear_model import LinearRegression, Ridge
X
Dask DataFrame Structure:
year | fdi | region_East China | region_East China | region_North China | region_North China | region_Southwest China | region_Southwest China | region_Northwest China | region_Northwest China | region_South Central China | region_South Central China | region_Northeast China | region_Northeast China | province_Anhui | province_Anhui | province_Beijing | province_Beijing | province_Chongqing | province_Chongqing | province_Fujian | province_Fujian | province_Gansu | province_Gansu | province_Guangdong | province_Guangdong | province_Guangxi | province_Guangxi | province_Guizhou | province_Guizhou | province_Hainan | province_Hainan | province_Hebei | province_Hebei | province_Heilongjiang | province_Heilongjiang | province_Henan | province_Henan | province_Hubei | province_Hubei | province_Hunan | province_Hunan | province_Jiangsu | province_Jiangsu | province_Jiangxi | province_Jiangxi | province_Jilin | province_Jilin | province_Liaoning | province_Liaoning | province_Ningxia | province_Ningxia | province_Qinghai | province_Qinghai | province_Shaanxi | province_Shaanxi | province_Shandong | province_Shandong | province_Shanghai | province_Shanghai | province_Shanxi | province_Shanxi | province_Sichuan | province_Sichuan | province_Tianjin | province_Tianjin | province_Tibet | province_Tibet | province_Xinjiang | province_Xinjiang | province_Yunnan | province_Yunnan | province_Zhejiang | province_Zhejiang | region_East China | region_East China | region_North China | region_North China | region_Southwest China | region_Southwest China | region_Northwest China | region_Northwest China | region_South Central China | region_South Central China | region_Northeast China | region_Northeast China | province_Anhui | province_Anhui | province_Beijing | province_Beijing | province_Chongqing | province_Chongqing | province_Fujian | province_Fujian | province_Gansu | province_Gansu | province_Guangdong | province_Guangdong | province_Guangxi | province_Guangxi | province_Guizhou | province_Guizhou | province_Hainan | province_Hainan | province_Hebei | province_Hebei | province_Heilongjiang | province_Heilongjiang | province_Henan | province_Henan | province_Hubei | province_Hubei | province_Hunan | province_Hunan | province_Jiangsu | province_Jiangsu | province_Jiangxi | province_Jiangxi | province_Jilin | province_Jilin | province_Liaoning | province_Liaoning | province_Ningxia | province_Ningxia | province_Qinghai | province_Qinghai | province_Shaanxi | province_Shaanxi | province_Shandong | province_Shandong | province_Shanghai | province_Shanghai | province_Shanxi | province_Shanxi | province_Sichuan | province_Sichuan | province_Tianjin | province_Tianjin | province_Tibet | province_Tibet | province_Xinjiang | province_Xinjiang | province_Yunnan | province_Yunnan | province_Zhejiang | province_Zhejiang | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
npartitions=5 | ||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||||
0 | int64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 | float64 |
72 | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
288 | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
359 | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
Dask Name: getitem, 5 tasks
X.compute()
year | fdi | region_East China | region_East China | region_North China | region_North China | region_Southwest China | region_Southwest China | region_Northwest China | region_Northwest China | ... | province_Tianjin | province_Tianjin | province_Tibet | province_Tibet | province_Xinjiang | province_Xinjiang | province_Yunnan | province_Yunnan | province_Zhejiang | province_Zhejiang | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | 1996 | 50661.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
1 | 1997 | 43443.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
2 | 1998 | 27673.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
3 | 1999 | 26131.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
4 | 2000 | 31847.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 |
... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... | ... |
355 | 2003 | 498055.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
356 | 2004 | 668128.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
357 | 2005 | 772000.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
358 | 2006 | 888935.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
359 | 2007 | 1036576.0 | 1.0 | 1.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | ... | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 0.0 | 1.0 | 1.0 |
360 rows × 146 columns
y
Dask DataFrame Structure:
gdp | |
---|---|
npartitions=5 | |
0 | float64 |
72 | ... |
... | ... |
288 | ... |
359 | ... |
Dask Name: getitem, 5 tasks
LinReg = LinearRegression()
LinReg.fit(X, y)
LinearRegression()
RidgeReg = Ridge()
RidgeReg.fit(x, y)
Ridge()
LinReg.predict(x)[:5]
array([[1830.87851079], [2076.99855135], [2220.28956053], [2534.65768132], [2936.29581027]])
RidgeReg.predict(x)[:5]
array([[1804.41754025], [2053.19939587], [2200.05297844], [2516.48507702], [2919.42271884]])
client.restart()
Client
|
Cluster
|
client.close()