# Databricks notebook source
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import json

spark = SparkSession.builder.appName("StreamingTimeseriesPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")

# COMMAND ----------

print("🚀 Applying Spark optimizations for high-volume streaming...")

# Adaptive Query Execution
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.coalescePartitions.enabled", "true")
spark.conf.set("spark.sql.adaptive.advisoryPartitionSizeInBytes", "128MB")

# Streaming Backpressure
spark.conf.set("spark.sql.streaming.backpressure.enabled", "true")
spark.conf.set("spark.sql.streaming.backpressure.pid.minRate", "5000")

# Delta Lake Optimizations
spark.conf.set("spark.databricks.delta.autoCompact.enabled", "true")
spark.conf.set("spark.databricks.delta.optimizeWrite.enabled", "true")
spark.conf.set("spark.databricks.delta.merge.repartitionBeforeWrite.enabled", "true")

# Streaming State Management
spark.conf.set("spark.sql.streaming.stateStore.maintenanceInterval", "300s")
spark.conf.set("spark.sql.streaming.ui.retainedBatches", "200")
print("✅ Spark optimizations applied")

# COMMAND ----------

# Parameters - will be set when job runs
dbutils.widgets.text("eventhub_connection_string", "", "Event Hub Connection String")
dbutils.widgets.text("output_table", "catalog.schema.sensor_data", "Output Table")
dbutils.widgets.text("consumer_group", "$Default", "Consumer Group")
dbutils.widgets.text("checkpoint_location", "", "Checkpoint Location")
dbutils.widgets.text("batch_interval", "10 seconds", "Batch Processing Interval")
dbutils.widgets.text("project_levels", "4", "Project Template Levels")

# COMMAND ----------

# Get parameters
eventhub_conn_str = dbutils.widgets.get("eventhub_connection_string")
output_table = dbutils.widgets.get("output_table")
consumer_group = dbutils.widgets.get("consumer_group")
checkpoint_location = dbutils.widgets.get("checkpoint_location")
batch_interval = dbutils.widgets.get("batch_interval")
project_levels = int(dbutils.widgets.get("project_levels"))

# COMMAND ----------

message_schema = StructType([
    StructField("data", MapType(StringType(), StructType([
        StructField("dq", IntegerType(), True),
        StructField("ta", StringType(), True),
        StructField("val", StringType(), True)
    ])), True),
    StructField("a_id", StringType(), True),
    StructField("d_id", StringType(), True),
    StructField("gw_id", StringType(), True),
    StructField("msg_id", IntegerType(), True),
    StructField("p_id", StringType(), True),
    StructField("pd_id", StringType(), True),
    StructField("retain_flag", BooleanType(), True),
    StructField("site_id", StringType(), True),
    StructField("timestamp", LongType(), False),
    StructField("ver", DoubleType(), True)
])

# COMMAND ----------

def safe_get_item(array_col, index):
        return when(size(array_col) > index, array_col.getItem(index)).otherwise(lit(None))


def transform_timeseries_data_fully_dynamic(df, project_levels=4):
    """
    Fully dynamic version where you can specify max number of project_levels
    """
    from pyspark.sql.functions import col, lit, from_unixtime, to_date, hour, split, size, when, isnan, isnull, filter as spark_filter
    from pyspark.sql.types import FloatType
    print(f"Transforming to target schema with up to {project_levels} project levels...")
    # First, let's create a column to split the tag and get the size
    df_with_split = df.withColumn("tag_parts", split(col("tag"), "\\$"))
    df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts")))
    #Remove last index
    df_with_split = df_with_split.withColumn(
    "hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1))

    # Remove parts containing "ast" from hierarchy for l1,l2,l3 columns
    df_with_split = df_with_split.withColumn(
    "levels_without_ast", spark_filter(col("hierarchy_levels"), lambda x: ~x.contains("ast")))

    # Find the part containing "ast" for ast column
    df_with_split = df_with_split.withColumn(
    "ast",
    expr("filter(hierarchy_levels, x -> x like '%ast%')[0]")
    )

    # Determine value_type based on data.val content
    value_type_logic = when(
        col("value.val").cast(FloatType()).isNotNull() &
        ~isnan(col("value.val").cast(FloatType())),
        lit("float")
    ).otherwise(lit("string"))
    # Build the select columns list dynamically
    select_columns = []

    # Fixed columns first
    fixed_columns = [
        col("timestamp").alias("timestamp"),
        from_unixtime(col("timestamp") / 1000).cast("timestamp").alias("dt_timestamp"),
        to_date(from_unixtime(col("timestamp") / 1000)).alias("dt_date"),
        hour(from_unixtime(col("timestamp") / 1000)).alias("dt_hour"),
        col("value.val").cast("string").alias("value"),
        value_type_logic.alias("value_type"),
        col("tag").alias("c3"),
        safe_get_item(col("tag_parts"), 0).alias("c1"),
        when(col("tag_parts_count") > 0,
             col("tag_parts").getItem(col("tag_parts_count") - 1)
        ).otherwise(lit(None)).alias("c5"),
        col("value.dq").cast("string").alias("Q"),
        col("value.ta").alias("T"),
        col("d_id").alias("D"),
        col("p_id").alias("P"),
        col("a_id").alias("A"),
        lit(None).cast("string").alias("B")
    ]

    # Add fixed columns
    select_columns.extend(fixed_columns)

    # Dynamically create l1, l2, l3, ... ln columns
    tag_part_columns = [
    safe_get_item(col("levels_without_ast"), i).alias(f"l{i+1}")
    for i in range(project_levels)
] + [col("ast").alias("ast")]
    select_columns.extend(tag_part_columns)
    # Apply the transformation
    transformed_df = df_with_split.select(*select_columns)

    return transformed_df

# COMMAND ----------

# Event Hub configuration
eventhub_config = {
    "eventhubs.connectionString": spark._jvm.org.apache.spark.eventhubs.EventHubsUtils.encrypt(eventhub_conn_str),
    "eventhubs.consumerGroup": consumer_group,
    "eventhubs.maxEventsPerTrigger": "10000" #Processing records based on the batch Size
}

# COMMAND ----------

print("📡 Connecting to Event Hub...")
try:
    raw_stream_df = spark.readStream.format("eventhubs").options(**eventhub_config).load()
    print("Successfully connected to Event Hub stream")
except Exception as e:
    print(f"Failed to connect to Event Hub: {e}")
    dbutils.notebook.exit(f"FAILED: Event Hub connection error - {e}")

# COMMAND ----------

#Binary -> String
json_df = raw_stream_df.withColumn("json_string", col("body").cast("string"))

# #JSON -> Struct
parsed_stream_df = json_df.select(
     from_json(col("json_string"), message_schema).alias("parsed_data")
).select("parsed_data.*")

#Explode
df_exploded = parsed_stream_df.select(
    explode(col("data")).alias("tag", "value"),
    col("a_id"),
    col("d_id"),
    col("gw_id"),
    col("msg_id"),
    col("p_id"),
    col("pd_id"),
    col("retain_flag"),
    col("site_id"),
    col("timestamp"),
    col("ver")
)
#display(df_exploded)

# COMMAND ----------

transformed_df = transform_timeseries_data_fully_dynamic(df_exploded, project_levels=project_levels)


# COMMAND ----------

# CRITICAL: Start the CONTINUOUS streaming query
print("STARTING CONTINUOUS STREAMING QUERY...")
print("This will run INDEFINITELY until manually stopped!")

try:
    streaming_query = transformed_df.writeStream \
        .format("parquet") \
        .outputMode("append") \
        .option("checkpointLocation", checkpoint_location) \
        .option("mergeSchema", "true") \
        .trigger(processingTime=batch_interval) \
        .table(output_table)

    print("STREAMING QUERY STARTED SUCCESSFULLY!")
    print(f"Processing Event Hub → {output_table}")
    print(f"Batch interval: {batch_interval}")
    print(f"Checkpoint: {checkpoint_location}")

except Exception as e:
    print(f"Failed to start streaming: {e}")
    dbutils.notebook.exit(f"FAILED: Streaming start error - {e}")

# COMMAND ----------

# Monitor the streaming query continuously
print("📊 Streaming pipeline is now running continuously...")
print("🔄 Processing Event Hub messages in real-time...")
print("⏹️  To stop: Cancel this notebook or stop the job")

try:
    # This will run indefinitely until the notebook is cancelled
    streaming_query.awaitTermination()

except Exception as e:
    print(f"❌ Streaming pipeline error: {e}")
    if streaming_query.isActive:
        streaming_query.stop()
    raise e
