# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json

# COMMAND ----------

dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")

# COMMAND ----------

def extract_table_info(input_message_str: str, delete_column:str = "id"):
    """
    Extract table name and data from input message
    Args:
        input_message_str (str): JSON string containing the message
    Returns:
        dict: Extracted information
    """
    try:
        message_data = json.loads(input_message_str)

        # Extract table name from data.type
        table_name = message_data['table_properties']['table_name']  # 'enterprise'
        project_id = message_data['project_id']    # 'project_787'
        delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
        table_properties = message_data['table_properties']        # Fetch table properties

        print(f"Extracted Info:")
        print(f"Table Name: {table_name}")
        print(f"Project ID: {project_id}")
        print(f'Deleting rows: {delete_values}')
        print(f"Table Prop Keys: {list(table_properties.keys())}")

        return {
            'table_name': table_name,
            'project_id': project_id,
            delete_column: delete_values,
            'raw_message': message_data,
            'table_properties': table_properties
        }

    except json.JSONDecodeError as e:
        raise ValueError(f"Invalid JSON in input_message: {str(e)}")
    except KeyError as e:
        raise ValueError(f"Missing required field in input_message: {str(e)}")


# COMMAND ----------

def delete_records_by_ids(table_name, ids, id_column="id"):
    """
    Delete records from external table (Delta or Parquet) using list of IDs
    Args:
        table_name (str): Full table name (catalog.schema.table)
        ids (list): List of IDs to delete
        id_column (str): Column name containing IDs (default: "id")
    Returns:
        bool: True if successful, False otherwise
    """
    try:
        if not ids:
            print("No IDs provided")
            return False

        # Format IDs for SQL IN clause
        if isinstance(ids[0], str):
            id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
        else:
            id_values = "(" + ",".join([str(id) for id in ids]) + ")"


        # Use Delta table DELETE operation
        delta_table = DeltaTable.forName(spark, table_name)
        condition = f"{id_column} IN {id_values}"
        delta_table.delete(condition=expr(condition))

        print(f"Successfully deleted {len(ids)} records from {table_name}")
        return True

    except Exception as e:
        print(f"Error: {str(e)}")
        return False

# COMMAND ----------

table_info = extract_table_info(input_message, delete_column=delete_column)

# COMMAND ----------

result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
