from pyspark.sql import SparkSession
from pyspark.sql.types import StructField,StringType,IntegerType,StructType, DoubleType, FloatType

data_schema = [
StructField("_c0", IntegerType(), True)
,StructField("province", StringType(), True)
,StructField("specific", DoubleType(), True)
,StructField("general", DoubleType(), True)
,StructField("year", IntegerType(), True)
,StructField("gdp", FloatType(), True)
,StructField("fdi", FloatType(), True)
,StructField("rnr", DoubleType(), True)
,StructField("rr", FloatType(), True)
,StructField("i", FloatType(), True)
,StructField("fr", IntegerType(), True)
,StructField("reg", StringType(), True)
,StructField("it", IntegerType(), True)
]

final_struc = StructType(fields=data_schema)

file_location = "/FileStore/tables/df_panel_fix.csv"
df = spark.read.format("CSV").schema(final_struc).option("header", True).load(file_location)

#df.printSchema()

df.show()
+---+--------+---------+--------+----+-------+--------+----+---------+---------+-------+-----------+-------+ _c0|province| specific| general|year| gdp| fdi| rnr| rr| i| fr| reg| it| +---+--------+---------+--------+----+-------+--------+----+---------+---------+-------+-----------+-------+ 0| Anhui| 147002.0| null|1996| 2093.3| 50661.0| 0.0| 0.0| 0.0|1128873| East China| 631930| 1| Anhui| 151981.0| null|1997|2347.32| 43443.0| 0.0| 0.0| 0.0|1356287| East China| 657860| 2| Anhui| 174930.0| null|1998|2542.96| 27673.0| 0.0| 0.0| 0.0|1518236| East China| 889463| 3| Anhui| 285324.0| null|1999|2712.34| 26131.0|null| null| null|1646891| East China|1227364| 4| Anhui| 195580.0| 32100.0|2000|2902.09| 31847.0| 0.0| 0.0| 0.0|1601508| East China|1499110| 5| Anhui| 250898.0| null|2001|3246.71| 33672.0| 0.0| 0.0| 0.0|1672445| East China|2165189| 6| Anhui| 434149.0| 66529.0|2002|3519.72| 38375.0| 0.0| 0.0| 0.0|1677840| East China|2404936| 7| Anhui| 619201.0| 52108.0|2003|3923.11| 36720.0| 0.0| 0.0| 0.0|1896479| East China|2815820| 8| Anhui| 898441.0|349699.0|2004| 4759.3| 54669.0| 0.0| 0.0| 0.0| null| East China|3422176| 9| Anhui| 898441.0| null|2005|5350.17| 69000.0| 0.0| 0.0|0.3243243| null| East China|3874846| 10| Anhui|1457872.0|279052.0|2006| 6112.5|139354.0| 0.0| 0.0|0.3243243|3434548| East China|5167300| 11| Anhui|2213991.0|178705.0|2007|7360.92|299892.0| 0.0| 0.0|0.3243243|4468640| East China|7040099| 12| Beijing| 165957.0| null|1996| 1789.2|155290.0|null| null| null| 634562|North China| 508135| 13| Beijing| 165957.0| null|1997|2077.09|159286.0| 0.0| 0.0| 0.6| 634562|North China| 569283| 14| Beijing| 245198.0| null|1998|2377.18|216800.0| 0.0| 0.0| 0.53| 938788|North China| 695528| 15| Beijing| 388083.0| null|1999|2678.82|197525.0| 0.0| 0.0| 0.53| null|North China| 944047| 16| Beijing| 281769.0|188633.0|2000|3161.66|168368.0| 0.0| 0.0| 0.53|1667114|North China| 757990| 17| Beijing| 441923.0| null|2001|3707.96|176818.0| 0.0| 0.0| 0.53|2093925|North China|1194728| 18| Beijing| 558569.0|280277.0|2002| 4315.0|172464.0| 0.0| 0.0| 0.53|2511249|North China|1078754| 19| Beijing| 642581.0|269596.0|2003|5007.21|219126.0| 0.0|0.7948718| 0.0|2823366|North China|1426600| +---+--------+---------+--------+----+-------+--------+----+---------+---------+-------+-----------+-------+ only showing top 20 rows
from pyspark.ml.linalg import Vectors
from pyspark.ml.feature import VectorAssembler
feat_cols = ['year', 'gdp', 'fdi','fr']
feat_cols = ['gdp']
vec_assembler = VectorAssembler(inputCols = feat_cols, outputCol='features')
final_df = vec_assembler.transform(df)

Using the StandardScaler

from pyspark.ml.feature import StandardScaler
scaler = StandardScaler(inputCol="features", outputCol="scaledFeatures", withStd=True, withMean=False)

Fitting the StandardScaler

# Compute summary statistics by fitting the StandardScaler
scalerModel = scaler.fit(final_df)
# Normalize each feature to have unit standard deviation.
cluster_final_data = scalerModel.transform(final_df)
kmeans3 = KMeans(featuresCol='scaledFeatures',k=3)
kmeans2 = KMeans(featuresCol='scaledFeatures',k=2)
model_k3 = kmeans3.fit(cluster_final_data)
model_k2 = kmeans2.fit(cluster_final_data)
model_k3.transform(cluster_final_data).groupBy('prediction').count().show()
+----------+-----+ prediction|count| +----------+-----+ 1| 15| 2| 86| 0| 259| +----------+-----+
model_k2.transform(cluster_final_data).groupBy('prediction').count().show()
+----------+-----+ prediction|count| +----------+-----+ 1| 308| 0| 52| +----------+-----+