# Databricks notebook source
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import json

spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")

# COMMAND ----------
# Input Parameters
event_hub_connection_string = {{event_hub_connection_string}}
timeseries_table_path = {{timeseries_table_path}}
project_levels = {{project_levels}}

# COMMAND ----------
event_hub_conf = {
    'eventhubs.connectionString': spark._jvm.org.apache.spark.eventhubs.EventHubsUtils.encrypt(event_hub_connection_string),
    'eventhubs.consumerGroup': '$Default'
}


# COMMAND ----------

message_schema = StructType([
    StructField("data", StructType([
        StructField("tag", StringType(), False),
        StructField("dq", IntegerType(), True),
        StructField("ta", StringType(), True),
        StructField("val", DoubleType(), False)
    ]), True),
    StructField("a_id", StringType(), True),
    StructField("d_id", StringType(), True),
    StructField("gw_id", StringType(), True), 
    StructField("msg_id", IntegerType(), True),
    StructField("p_id", StringType(), True),
    StructField("pd_id", StringType(), True),
    StructField("retain_flag", BooleanType(), True),
    StructField("site_id", StringType(), True),
    StructField("timestamp", LongType(), False),
    StructField("ver", DoubleType(), True)
])


# COMMAND ----------

def safe_get_item(array_col, index):
    return when(size(array_col) > index, array_col.getItem(index)).otherwise(lit(None))


def transform_timeseries_data_fully_dynamic(df, max_tag_parts=4):
    print(f"Transforming to target schema with up to {max_tag_parts} tag parts...")
    
    df_with_split = df.withColumn("tag_parts", split(col("data.tag"), "\\$"))
    df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts")))
    df_with_split = df_with_split.withColumn("hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1))
    df_with_split = df_with_split.withColumn("levels_without_ast", expr("filter(hierarchy_levels, x -> NOT x LIKE '%ast%')"))
    df_with_split = df_with_split.withColumn("ast", expr("filter(hierarchy_levels, x -> x LIKE '%ast%')[0]"))

    value_type_logic = when(
        col("data.val").cast("float").isNotNull() & ~isnan(col("data.val").cast("float")),
        lit("float")
    ).otherwise(lit("string"))

    select_columns = [
        col("timestamp").alias("timestamp"),
        from_unixtime(col("timestamp") / 1000).cast("timestamp").alias("dt_timestamp"),
        to_date(from_unixtime(col("timestamp") / 1000)).alias("dt_date"),
        hour(from_unixtime(col("timestamp") / 1000)).alias("dt_hour"),
        col("data.val").cast("string").alias("value"),
        value_type_logic.alias("value_type"),
        col("data.tag").alias("c3"),
        safe_get_item(col("tag_parts"), 0).alias("c1"),
        when(col("tag_parts_count") > 0, col("tag_parts").getItem(col("tag_parts_count") - 1)).otherwise(lit(None)).alias("c5"),
        col("data.dq").cast("string").alias("Q"),
        col("data.ta").alias("T"),
        col("d_id").alias("D"),
        col("p_id").alias("P"),
        col("a_id").alias("A"),
        lit(None).cast("string").alias("B")
    ]

    select_columns += [
        safe_get_item(col("levels_without_ast"), i).alias(f"l{i+1}")
        for i in range(max_tag_parts)
    ] + [col("ast").alias("ast")]

    return df_with_split.select(*select_columns)


# COMMAND ----------

raw_stream_df = spark.readStream \
    .format("eventhubs") \
    .options(**event_hub_conf) \
    .load()

#Binary -> String
json_df = raw_stream_df.withColumn("json_string", col("body").cast("string"))

#JSON -> Struct
parsed_stream_df = json_df.select(
    from_json(col("json_string"), message_schema).alias("parsed_data")
).select("parsed_data.*")

#Explode
df_exploded = parsed_stream_df.select(
    explode(col("data")).alias("tag", "value"),
    col("a_id"),
    col("d_id"),
    col("gw_id"),
    col("msg_id"),
    col("p_id"),
    col("pd_id"),
    col("retain_flag"),
    col("site_id"),
    col("timestamp"),
    col("ver")
)

# COMMAND ----------

transformed_df = transform_timeseries_data_fully_dynamic(df_exploded, max_tag_parts=projects_levels)


# COMMAND ----------

# Option A: Write to Delta
transformed_df.writeStream \
    .format("delta") \
    .outputMode("append") \
    .partitionBy("dt_date", "dt_hour", "c3") \
    .option("checkpointLocation", "/mnt/checkpoints/timeseries_data") \
    .start(timeseries_data_path)

# COMMAND ----------

# # Option B: Write to Parquet (same as your batch)
# transformed_df.writeStream \
#     .format("parquet") \
#     .outputMode("append") \
#     .partitionBy("dt_date", "dt_hour", "c3") \
#     .option("checkpointLocation", "/mnt/checkpoints/timeseries_data") \
#     .start(timeseries_table_path)



# COMMAND ----------

