10: Databricks – Spark ML – Linear Regression

Prerequisite: Extends Databricks getting started – Spark, Shell, SQL.

Spark ML (Machine Learning)

You can try these tutorials in Scala using Databricks Notebook. There are Scala tutorials covered in Spark using Scala on Zeppelin Notebook.

Problem statement: Predict the land prices by land area in square feet based on a given set of known prices.

Create train & predict/test data

The “train_data” is used to train & build the model, and then the “predict_data” is used to predict the prices based on trained model.


Linear graph

And create a “Line” graph as shown below:

Databricks – ML Linear Regression

train_data to build the model

Q: What is a predictive modelling?
A:Predictive modelling means we are predicting values of a variable in this case “land_price” in dollars with the help of independent variable that is “land_area” in square feet.

Linear Regression model is built on the “train_data”. There are two terms which you often come across in ML. “Feature” and “Label”.

Features are all the independent variables (E.g. land_area) that we think will help us to predict the values of dependent variable. In many a cases there are one or more independent variable or say “Feature”. VectorAssembler combines all feature into single vector.

source: https://pt.slideshare.net/dredmonds/an-overview-of-simple-linear-regression?smtNoRedir=1

Label is dependent variable whose value our model predicts.

LinearRegression from ml package of PySpark is used as our algorithm for predictive modelling to create a model using a fit( ) method to pass our train_data.


Model coefficients & intercept


model is used to predict prices

Our trained model will predict the values for predict_data.


Verify the model

Verify the model as to how good our model is in predicting the values of label and whether selecting linear regression as our algorithm for our model was good choice or not.

R-Squared, RMSE, and MAE

“r2“, “rmse”, and “mae” are different metrics. R-squared (i.e. r2) is a statistical measure of how close the data are to the fitted regression line. r2 value will be 0 to 1 (aka 0% to 100%)

Mean Absolute Error (MAE) and Root mean squared error (RMSE) are two of the most common metrics used to measure accuracy for continuous variables. Both MAE and RMSE express average model prediction error in units of the variable (E.g. house price) of interest. Both metrics can range from 0 to ∞.

rmse is Root Mean Square Error is a quadratic scoring rule that also measures the average magnitude of the error. It’s the square root of the average of squared differences between prediction and actual observation.

mae is mean absolute error measures the average magnitude of the errors in a set of predictions, without considering their direction. It’s the average over the test sample of the absolute differences between prediction and actual observation where all individual differences have equal weight.


In “r2” metrics ~1 (aka 100%) means a good fit. The rmse result will always be larger or equal to the mae. If all of the errors have the same magnitude, then rmse=mae.

Note: Instead of using two separate sets of data, you can split the same input date into two separate sets for training & predicting as shown below.

Q: What if you feature is string value like a “city” instead of a numeric value like “land_area”?
A: We will need to put through those non numeric features to numeric values using a “StringIndexer” and a “OneHotEncoder” in an ML pipeline stages before VectorAssembler.

The “StringIndexer” assigns a numeric value based on a number of occurrences of particular String value. For example, the value that occurs most number of times get a value of “0.0”, the second most gets a value of “1.0”, and so on.


Residual = Observed – Predicted


Residuals plot

A residual plot is a graph that shows the residuals on the Y-axis and the independent variable (E.g. land_area) on the X-axis. If the points in a residual plot are randomly dispersed around the horizontal axis, a linear regression model is appropriate for the data. If NOT randomly disperse, and clustered around certain areas then nonlinear models like Neural networks need to be used.

Let’s use matplotlib to scatter plot the residuals.

Residuals Scatter Plot

What is LinearRegressionTrainingSummary?

The output of a model will be captured in “pyspark.ml.regression.LinearRegressionTrainingSummary“. Refer: Pyspark ML API.


Note: The concepts are explained at Big Data – Data Science.

Categories Menu - Q&As, FAQs & Tutorials