This tutorial extends Getting started with Databricks. Currently spark does not support recursion like you can use in SQL via “Common Table Expression“.
Step 1: Login to Databricks notebook:
https://community.cloud.databricks.com/login.html.
Step 2: Create a CLUSTER and it will take a few minutes to come up. This cluster will go down after 2 hours.
Step 3: Create simple hierarchical data with 3 levels as shown below: level-0, level-1 & level-2. The level-0 is the top parent.
Hierarchy Example
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 | from pyspark.sql.types import * node_list = ( (1, None), (2, 1), (3, 1), (4, 2), (5, 2), (6, 3), (7, 3) ) data_schema = [StructField('node_id', IntegerType(), False), StructField('parent_node_id', IntegerType(), True)] final_struc = StructType(fields=data_schema) df = spark.sparkContext.parallelize(node_list); node_df = spark.createDataFrame(df, final_struc) node_df.show() |
Outputs:
1 2 3 4 5 6 7 8 9 10 11 12 | +-------+--------------+ |node_id|parent_node_id| +-------+--------------+ | 1| null| | 2| 1| | 3| 1| | 4| 2| | 5| 2| | 6| 3| | 7| 3| +-------+--------------+ |
Spark SQL does not support recursive CTE as discussed later in this post. In most of hierarchical data, depth is unknown, hence you could identify the top level hierarchy of one column from another column using WHILE loop and recursively joining DataFrame as shown below.
Step 4: Loop through the levels breadth first (i.e. left to right) for each level as shown below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 | node_df.createOrReplaceTempView("node_rec") #create a temp view records_cnt = 1 level = 1 node_rec_df = node_df.filter("node_id = 1") # initial node node_rec_df.show() node_rec_df.createOrReplaceTempView("vt_level_0") # create a temp view while (records_cnt != 0): curr_lvl_tbl = "vt_level_" + str(level - 1) next_lvl_tbl = "vt_level_" + str(level) query = "SELECT b.node_id, b.parent_node_id FROM {} a INNER JOIN node_rec b ON a.node_id = b.parent_node_id".format(curr_lvl_tbl) print ("level " + str(level) + " " + query) df_node_rec = spark.sql(query) records_cnt = df_node_rec.count() print ("level " + str(level) + " has "+ str(records_cnt) + " records") if(records_cnt != 0): df_node_rec.createOrReplaceTempView("{}".format(next_lvl_tbl)) level = level + 1; |
Outputs:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | +-------+--------------+ |node_id|parent_node_id| +-------+--------------+ | 1| null| +-------+--------------+ level 1 SELECT b.node_id, b.parent_node_id FROM vt_level_0 a INNER JOIN node_rec b ON a.node_id = b.parent_node_id level 1 has 2 records level 2 SELECT b.node_id, b.parent_node_id FROM vt_level_1 a INNER JOIN node_rec b ON a.node_id = b.parent_node_id level 2 has 4 records level 3 SELECT b.node_id, b.parent_node_id FROM vt_level_2 a INNER JOIN node_rec b ON a.node_id = b.parent_node_id level 3 has 0 records |
Step 5: Combine the above 3 levels of dataframes vt_level_0, vt_level_1 and vt_level_2.
1 2 3 4 5 6 7 8 9 10 11 12 | query = "" for x in range(level-1): if(x == 0): query = "SELECT node_id, parent_node_id from vt_level_{}".format(x) else: query = query + " union select node_id, parent_node_id from vt_level_{}".format(x) print (query) df_result = spark.sql(query) df_result.show() |
Outputs:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 | SELECT node_id, parent_node_id from vt_level_0 union select node_id, parent_node_id from vt_level_1 union select node_id, parent_node_id from vt_level_2 +-------+--------------+ |node_id|parent_node_id| +-------+--------------+ | 1| null| | 3| 1| | 2| 1| | 4| 2| | 5| 2| | 6| 3| | 7| 3| +-------+--------------+ |
Spark SQL does not support recursive CTE
Spark SQL does not support recursive CTE (i.e. Common Table Expression) as shown below. It gives an error on the “RECURSIVE” word.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | %sql WITH RECURSIVE cte_nodes (node_id) as ( SELECT root.node_id FROM node_rec root WHERE root.node_id = 1 union SELECT a.node_id FROM node_rec a inner join node_rec b WHERE a.node_id = b.parent_node_id ) SELECT * FROM cte_nodes ORDER BY node_id; |
If you run without the “RECURSIVE” key word you will only get one level down from the root as the output as shown below.
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 | %sql WITH cte_nodes (node_id) as ( SELECT root.node_id FROM node_rec root WHERE root.node_id = 1 union SELECT a.node_id FROM node_rec a inner join node_rec b WHERE a.node_id = b.parent_node_id ) SELECT * FROM cte_nodes ORDER BY node_id; |
Outputs:
1 2 3 4 5 6 7 | node_id -------- 1 2 3 |