Commit 7845d0b6 authored by harshavardhan.c's avatar harshavardhan.c

feat: Functionality added for delete and add jobs in databricks based on the catalog.

parent c741f979
......@@ -72,4 +72,4 @@ DATABRICKS_ACCESS_TOKEN=dapi72a54657606877a3f7a6d92dd573df28
#METADATA_SERVICES_URL=http://192.168.0.221:7111
#HIERARCHY_SERVICES_URL=http://192.168.0.221:7112
#MODEL_MANAGEMENT_URL=http://192.168.0.221:7113
#BATCH_PROCESS_APP_URL=http://localhost:7879
\ No newline at end of file
#BATCH_PROCESS_APP_URL=http://localhost:7879
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
/.idea
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
.vscode
data
.env
assets
__pycache__
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: requirements-txt-fixer
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
args:
- --py3-plus
- --keep-runtime-typing
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.8
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.11
# model-managament-databricks
......@@ -6,23 +6,37 @@ from scripts.config import KafkaConfig
from scripts.engines.agents.model_creator_agent import ModelCreatorAgent
from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
@broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2)
@broker.subscriber(
KafkaConfig.KAFKA_MODEL_CREATION_TOPIC,
group_id="databricks_model_creator_agent",
max_workers=2,
)
async def consume_stream_for_processing_dependencies(message: dict):
try:
await ModelCreatorAgent.model_creator_agent(message=ModelCreatorSchema(meta=message))
await ModelCreatorAgent.model_creator_agent(
message=ModelCreatorSchema(meta=message)
)
return True
except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False
@broker.subscriber(KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC, group_id="databricks_instance_agent", max_workers=2)
@broker.subscriber(
KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC,
group_id="databricks_instance_agent",
max_workers=2,
)
async def consume_stream_for_processing_instances(message: dict):
try:
await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message))
return True
except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False
\ No newline at end of file
return False
......@@ -6,11 +6,11 @@ import sys
from dotenv import load_dotenv
load_dotenv()
from agent_subscribers import broker
from faststream import FastStream
from ut_dev_utils import configure_logger
from agent_subscribers import broker
configure_logger()
# Create FastStream app
......
faststream[confluent]==0.5.48
ut-dev-utils[sql,essentials]==1.2
uvloop==0.21.0
\ No newline at end of file
uvloop==0.21.0
......@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public")
DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical")
DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET")
DATABRICKS_STORAGE_PATH: str = Field(default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087")
DATABRICKS_STORAGE_PATH: str = Field(
default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@model_validator(mode="before")
def prepare_databricks_uri(cls, values):
values[
'DATABRICKS_URI'] = (f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}")
values["DATABRICKS_URI"] = (
f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
return values
......@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
KafkaConfig = _KafkaConfig()
DatabricksConfig = _DatabricksConfig()
__all__ = ["Services", "RedisConfig", "ExternalServices", "PathToStorage", "KafkaConfig", "DatabricksConfig"]
__all__ = [
"Services",
"RedisConfig",
"ExternalServices",
"PathToStorage",
"KafkaConfig",
"DatabricksConfig",
]
class DatabricksConstants:
METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job"
METADATA_DELETION_JOB_NAME = "metadata_deletion_job"
METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook"
\ No newline at end of file
METADATA_DELETION_NOTEBOOK_NAME = "metadata_deletion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook"
class NotebookConstants:
METADATA_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_ingestion.txt"
)
METADATA_DELETION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_deletion.txt"
)
TIMESERIES_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/timeseries_ingestion.txt"
)
# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")
# COMMAND ----------
def extract_table_info(input_message_str: str, delete_column:str = "id"):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f'Deleting rows: {delete_values}')
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
delete_column: delete_values,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def delete_records_by_ids(table_name, ids, id_column="id"):
"""
Delete records from external table (Delta or Parquet) using list of IDs
Args:
table_name (str): Full table name (catalog.schema.table)
ids (list): List of IDs to delete
id_column (str): Column name containing IDs (default: "id")
Returns:
bool: True if successful, False otherwise
"""
try:
if not ids:
print("No IDs provided")
return False
# Format IDs for SQL IN clause
if isinstance(ids[0], str):
id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
else:
id_values = "(" + ",".join([str(id) for id in ids]) + ")"
# Use Delta table DELETE operation
delta_table = DeltaTable.forName(spark, table_name)
condition = f"{id_column} IN {id_values}"
delta_table.delete(condition=expr(condition))
print(f"Successfully deleted {len(ids)} records from {table_name}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
# COMMAND ----------
table_info = extract_table_info(input_message, delete_column=delete_column)
# COMMAND ----------
result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
......@@ -31,18 +31,18 @@ def extract_table_info(input_message_str: str):
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
data_payload = message_data['data'] # Full data object
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
......@@ -50,7 +50,7 @@ def extract_table_info(input_message_str: str):
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
......@@ -62,13 +62,13 @@ def extract_table_info(input_message_str: str):
def detect_external_table_schema(table_name):
"""
Detect schema of external Delta or Parquet table
Args:
table_name (str): Name of the table (e.g., 'enterprise')
Returns:
pyspark.sql.types.StructType: Schema of the table
"""
try:
# Try to get schema from catalog
table_df = spark.table(table_name)
......@@ -96,4 +96,4 @@ data_df = spark.createDataFrame(table_info['data_payload'], schema=schema)
# COMMAND ----------
data_df.write.mode("append").saveAsTable(table_info['table_name'])
\ No newline at end of file
data_df.write.mode("append").saveAsTable(table_info['table_name'])
......@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
from pyspark.sql.types import *
import json
spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate()
spark = SparkSession.builder.appName("StreamingTimeseriesPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# COMMAND ----------
......@@ -31,7 +31,7 @@ message_schema = StructType([
]), True),
StructField("a_id", StringType(), True),
StructField("d_id", StringType(), True),
StructField("gw_id", StringType(), True),
StructField("gw_id", StringType(), True),
StructField("msg_id", IntegerType(), True),
StructField("p_id", StringType(), True),
StructField("pd_id", StringType(), True),
......@@ -50,7 +50,7 @@ def safe_get_item(array_col, index):
def transform_timeseries_data_fully_dynamic(df, max_tag_parts=4):
print(f"Transforming to target schema with up to {max_tag_parts} tag parts...")
df_with_split = df.withColumn("tag_parts", split(col("data.tag"), "\\$"))
df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts")))
df_with_split = df_with_split.withColumn("hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1))
......@@ -146,4 +146,3 @@ transformed_df.writeStream \
# COMMAND ----------
......@@ -3,6 +3,7 @@ import json
from ut_dev_utils import get_db_name
from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants
from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.redis.databricks_details import databricks_details_db
from scripts.schemas import ModelInstanceSchema
......@@ -12,29 +13,44 @@ class ModelInstanceHandler:
def __init__(self, project_id: str, payload: ModelInstanceSchema):
self.project_id = project_id
self.payload = payload
self.catalog_name = get_db_name(project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME)
self.catalog_name = get_db_name(
project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME
)
self.job_manager = DatabricksJobManager(
databricks_host=payload.databricks_host,
access_token=payload.databricks_access_token
access_token=payload.databricks_access_token,
)
def upload_instances_to_unity_catalog(self):
job_id = databricks_details_db.hget(self.project_id, "metadata_ingestion_job")
async def upload_instances_to_unity_catalog(self):
if self.payload.action_type == "delete":
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_DELETION_JOB_NAME
)
else:
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME
)
if not job_id:
raise ValueError("No job id found for metadata ingestion job, skipping upload to unity catalog")
run_id = self.job_manager.run_job(job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())})
raise ValueError(
f"No job id found for {self.payload.action_type}, skipping upload to unity catalog"
)
run_id = self.job_manager.run_job(
job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())},
)
if not run_id:
raise ValueError("Failed to run metadata ingestion job, skipping upload to unity catalog")
raise ValueError(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def get_job_trigger_payload(self):
table_name = self.payload.data[0]['type']
table_name = self.payload.node_type
schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return {
"table_properties": {
"table_name": f'{self.catalog_name}.{schema_table}',
"table_path": f'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}',
},
"table_name": f"{self.catalog_name}.{schema_table}",
"table_path": f"{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}",
},
"project_id": self.project_id,
"data": self.payload.data
"data": self.payload.data,
}
......@@ -5,7 +5,7 @@ from sqlalchemy.orm import declarative_base
from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtils
from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants
from scripts.constants import DatabricksConstants, NotebookConstants
from scripts.db.databricks import DataBricksSQLLayer
from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.databricks.notebook_manager import NotebookManager
......@@ -16,23 +16,25 @@ from scripts.utils.model_convertor_utils import ModelConverter
class ModelCreatorHandler:
def __init__(self, message: ModelCreatorSchema, declarative_utils: DeclarativeUtils):
def __init__(
self, message: ModelCreatorSchema, declarative_utils: DeclarativeUtils
):
self.declarative_utils = declarative_utils
self.meta = message.meta
self.message = message
self.model_convertor = ModelConverter()
self.job_manager = DatabricksJobManager(
databricks_host=message.databricks_host,
access_token=message.databricks_access_token
access_token=message.databricks_access_token,
)
self.notebook_manager = NotebookManager(
databricks_host=message.databricks_host,
access_token=message.databricks_access_token
access_token=message.databricks_access_token,
)
self.databricks_sql_obj = DataBricksSQLLayer(
catalog_name=DatabricksConfig.DATABRICKS_CATALOG_NAME,
project_id=self.meta.project_id,
schema=message.schema
schema=message.schema,
)
self.external_location = self.message.databricks_storage_path
......@@ -47,32 +49,37 @@ class ModelCreatorHandler:
overall_tables = self.get_overall_tables()
project_levels = project_template_keys(self.meta.project_id, levels=True)
base = self.create_schema_base(schema_name=f'{self.databricks_sql_obj.catalog_name}.{self.message.schema}')
base = self.create_schema_base(
schema_name=f"{self.databricks_sql_obj.catalog_name}.{self.message.schema}"
)
try:
# self.databricks_sql_obj.connect_to_databricks()
_ = self.setup_dependencies_for_unity_catalog()
table_properties = self.fetch_table_properties()
# for table in overall_tables:
# table_class = self.declarative_utils.get_declarative_class(table)
# if not table_class:
# logging.error(f"Table class not found for table: {table}")
# return False
# new_model = self.model_convertor.convert_model(
# table_class,
# base_class=base,
# new_schema=self.message.schema,
# )
#
# self.databricks_sql_obj.create_external_table_from_structure(
# table=new_model.__table__,
# file_format="DELTA",
# external_location=self.external_location,
# table_properties=table_properties
# )
ts_external_table = self.databricks_sql_obj.create_timeseries_table(columns=project_levels,
external_location=self.external_location)
self.setup_notepads_and_jobs(timeseries_table_path=ts_external_table, project_levels=project_levels)
for table in overall_tables:
table_class = self.declarative_utils.get_declarative_class(table)
if not table_class:
logging.error(f"Table class not found for table: {table}")
return False
new_model = self.model_convertor.convert_model(
table_class,
base_class=base,
new_schema=self.message.schema,
)
self.databricks_sql_obj.create_external_table_from_structure(
table=new_model.__table__,
file_format="DELTA",
external_location=self.external_location,
table_properties=table_properties,
)
ts_external_table = self.databricks_sql_obj.create_timeseries_table(
columns=project_levels, external_location=self.external_location
)
self.setup_notepads_and_jobs(
timeseries_table_path=ts_external_table, project_levels=project_levels
)
return True
except Exception as e:
logging.error(f"Error occurred while creating models in Unity Catalog: {e}")
......@@ -95,20 +102,25 @@ class ModelCreatorHandler:
analytical (bool): Flag to indicate if the setup is for analytical or not
"""
logging.info(
f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'")
f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
self.databricks_sql_obj.connect_to_databricks()
# Create catalog
catalog_success = self.databricks_sql_obj.create_catalog(
managed_location=f'{self.external_location}/{self.databricks_sql_obj.catalog_name}',
managed_location=f"{self.external_location}/{self.databricks_sql_obj.catalog_name}",
)
if not catalog_success:
return False
# Create schema
schema_success = self.databricks_sql_obj.create_schema(DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME)
schema_success = self.databricks_sql_obj.create_schema(
DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME
)
if not schema_success:
return False
if analytical:
schema_success = self.databricks_sql_obj.create_schema(DatabricksConfig.DATABRICKS_ANALYTICAL_SCHEMA_NAME)
schema_success = self.databricks_sql_obj.create_schema(
DatabricksConfig.DATABRICKS_ANALYTICAL_SCHEMA_NAME
)
if not schema_success:
return False
return True
......@@ -120,59 +132,112 @@ class ModelCreatorHandler:
project_levels: List of project levels
"""
logging.info("Setting up notepads and jobs")
with open(r"scripts/constants/notebooks/metadata_ingestion.txt", "r") as f:
notebook_code = f.read()
# # Notebook for metadata ingestion
# self.notebook_manager.create_notebook(
# notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}",
# notebook_code=notebook_code,
# overwrite=True
# )
# # Job for metadata ingestion used by model management
# job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
# job_name=f'{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}',
# notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
# ))
#
# databricks_details_db.hset(self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME, job_id)
# Timeseries DataPush Notebook
with open(r"scripts/constants/notebooks/timeseries_ingestion.txt", "r") as f:
notebook_code_for_timeseries = f.read()
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{timeseries_table_path}}",
f'"{timeseries_table_path}"')
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{project_levels}}", str(len(project_levels) - 1))
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{event_hub_connection_string}}", f'"{self.meta.project_id}"')
meta_ingestion_notebook_path = f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}"
meta_deletion_notebook_path = f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_NOTEBOOK_NAME}"
timeseries_notebook_path = f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}"
self.notebook_manager.create_notebook(
notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}",
notebook_code=notebook_code_for_timeseries,
overwrite=True
# Setting up of Metadata Ingestion Notebook
existing_job_id = databricks_details_db.hget(
self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME
)
if not existing_job_id:
self.create_notebook(
notebook_path=meta_ingestion_notebook_path,
source_notebook_path=NotebookConstants.METADATA_INGESTION_NOTEBOOK_PATH,
)
ingestion_job_id = self.create_job(
job_name=f"{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}",
notebook_path=meta_ingestion_notebook_path,
)
databricks_details_db.hset(
self.meta.project_id,
DatabricksConstants.METADATA_INGESTION_JOB_NAME,
ingestion_job_id,
)
existing_job_id = databricks_details_db.hget(
self.meta.project_id, DatabricksConstants.METADATA_DELETION_JOB_NAME
)
if not existing_job_id:
# Setting up of Metadata Deletion Notebook
self.create_notebook(
notebook_path=meta_deletion_notebook_path,
source_notebook_path=NotebookConstants.METADATA_DELETION_NOTEBOOK_PATH,
)
deletion_job_id = self.create_job(
job_name=f"{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_JOB_NAME}",
notebook_path=meta_deletion_notebook_path,
)
databricks_details_db.hset(
self.meta.project_id,
DatabricksConstants.METADATA_DELETION_JOB_NAME,
deletion_job_id,
)
# Setting up of Timeseries Ingestion Notebook
replace_mapping = {
"{{timeseries_table_path}}": f'"{timeseries_table_path}"',
"{{project_levels}}": str(len(project_levels) - 1),
"{{event_hub_connection_string}}": f'"{self.meta.project_id}"',
}
self.create_notebook(
notebook_path=timeseries_notebook_path,
source_notebook_path=NotebookConstants.TIMESERIES_INGESTION_NOTEBOOK_PATH,
replace_mapping=replace_mapping,
)
@staticmethod
def fetch_table_properties(file_format: str = 'DELTA'):
if file_format.lower() == 'delta':
def fetch_table_properties(file_format: str = "DELTA"):
if file_format.lower() == "delta":
return {
# Performance optimization (Essential)
"delta.autoOptimize.optimizeWrite": "true",
"delta.autoOptimize.autoCompact": "true",
"delta.targetFileSize": "134217728", # 128MB
'delta.enableChangeDataFeed': 'true', # If you need CDC
"delta.enableChangeDataFeed": "true", # If you need CDC
# Checkpoint optimization (Performance boost)
"delta.checkpoint.writeStatsAsStruct": "true",
"delta.checkpoint.writeStatsAsJson": "false"
"delta.checkpoint.writeStatsAsJson": "false",
# Note: Retention properties removed - using defaults:
# delta.deletedFileRetentionDuration = 7 days (default)
# delta.logRetentionDuration = 30 days (default)
}
elif file_format.lower() == 'parquet':
return {"parquet.compression": "snappy",
"parquet.page.size": "1048576", # 1MB - standard for mixed queries
"parquet.block.size": "134217728", # 128MB - balanced performance
"serialization.format": "1"}
elif file_format.lower() == "parquet":
return {
"parquet.compression": "snappy",
"parquet.page.size": "1048576", # 1MB - standard for mixed queries
"parquet.block.size": "134217728", # 128MB - balanced performance
"serialization.format": "1",
}
else:
return {}
@staticmethod
def read_data_from_file(note_path: str):
with open(note_path) as f:
notebook_code = f.read()
return notebook_code
def create_notebook(
self,
notebook_path: str,
source_notebook_path: str,
replace_mapping: dict = None,
):
logging.info(f"Creating notebook {notebook_path}")
notebook_code = self.read_data_from_file(source_notebook_path)
if replace_mapping is not None:
for key, value in replace_mapping.items():
notebook_code = notebook_code.replace(key, value)
self.notebook_manager.create_notebook(
notebook_path=notebook_path, notebook_code=notebook_code, overwrite=True
)
return True
def create_job(self, job_name: str, notebook_path: str):
logging.info(f"Creating job {job_name}")
job_id = self.job_manager.create_job(
job_config=self.job_manager.create_job_config_for_serverless(
job_name=job_name,
notebook_path=notebook_path,
)
)
return job_id
from typing import Dict, List
from sqlalchemy import Table, Column, String, BigInteger, DateTime, MetaData, Integer, Date
from sqlalchemy import (
BigInteger,
Column,
Date,
DateTime,
Integer,
MetaData,
String,
Table,
)
from scripts.utils.databricks_utils import DatabricksSQLUtility
from scripts.utils.model_convertor_utils import TypeMapper
......@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
super().__init__(catalog_name, project_id)
self.schema = schema
def create_external_table_from_structure(self, table: Table,
external_location: str,
file_format: str = "PARQUET",
table_properties: Dict[str, str] = None,
partition_columns: list = None) -> str:
def create_external_table_from_structure(
self,
table: Table,
external_location: str,
file_format: str = "PARQUET",
table_properties: Dict[str, str] = None,
partition_columns: list = None,
) -> str:
"""
Create an external table from a model class.
......@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
schema_table = f"{table.schema}.{table.name}" if table.schema else table.name
columns_sql = TypeMapper().extract_columns_without_constraints(table)
external_location = f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
external_location = (
f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
)
sql_parts = [
f"CREATE TABLE IF NOT EXISTS {schema_table}",
f"({columns_sql})",
f"USING {file_format}",
f"LOCATION '{external_location}'"
f"LOCATION '{external_location}'",
]
if partition_columns:
partition_clause = ", ".join(partition_columns)
......@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
table_columns = [
Column('timestamp', BigInteger, nullable=False),
Column('dt_timestamp', DateTime, nullable=False),
Column('dt_date', Date, nullable=False),
Column('dt_hour', Integer, nullable=False),
Column('value', String, nullable=False),
Column('value_type', String, nullable=False, default='float'),
Column("c3", String, nullable=False)
Column("timestamp", BigInteger, nullable=False),
Column("dt_timestamp", DateTime, nullable=False),
Column("dt_date", Date, nullable=False),
Column("dt_hour", Integer, nullable=False),
Column("value", String, nullable=False),
Column("value_type", String, nullable=False, default="float"),
Column("c3", String, nullable=False),
]
default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns]
table_columns.extend([Column(col_name, String, nullable=True) for col_name in default_columns])
partition_columns = ['dt_date', 'dt_hour', 'c3']
table_columns.extend(
[Column(col_name, String, nullable=True) for col_name in default_columns]
)
partition_columns = ["dt_date", "dt_hour", "c3"]
table_properties = {
"parquet.compression": "snappy", # Fast decompression for frequent queries
"parquet.page.size": "524288", # 512KB - better time-range filtering
"parquet.block.size": "268435456", # 256MB - efficient sequential reads
"serialization.format": "1" # Support for arrays/complex types
"serialization.format": "1", # Support for arrays/complex types
}
table_obj = Table(
"timeseries_data",
MetaData(),
*table_columns,
schema=self.schema
"timeseries_data", MetaData(), *table_columns, schema=self.schema
)
self.create_external_table_from_structure(
table=table_obj,
external_location=external_location,
partition_columns=partition_columns,
table_properties=table_properties
table_properties=table_properties,
)
return external_location
......@@ -14,10 +14,14 @@ class DatabricksJobManager:
databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token
"""
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}"
self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def create_job(self, job_config: dict):
......@@ -32,11 +36,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config)
if response.status_code == 200:
job_id = response.json()['job_id']
job_id = response.json()["job_id"]
logging.info(f"Job created successfully with ID: {job_id}")
return job_id
else:
logging.error(f"Error creating job: {response.status_code} - {response.text}")
logging.error(
f"Error creating job: {response.status_code} - {response.text}"
)
return None
def run_job(self, job_id: str, parameters=None):
......@@ -57,11 +63,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload)
if response.status_code == 200:
run_id = response.json()['run_id']
run_id = response.json()["run_id"]
logging.info(f"Job run started with ID: {run_id}")
return run_id
else:
logging.error(f"Error running job: {response.status_code} - {response.text}")
logging.error(
f"Error running job: {response.status_code} - {response.text}"
)
return None
def get_run_status(self, run_id):
......@@ -73,12 +81,16 @@ class DatabricksJobManager:
url = f"{self.host}/api/2.1/jobs/runs/get"
params = {"run_id": run_id}
response = HTTPXRequestHandler(url).get(url, headers=self.headers, params=params)
response = HTTPXRequestHandler(url).get(
url, headers=self.headers, params=params
)
if response.status_code == 200:
return response.json()
else:
logging.error(f"Error getting run status: {response.status_code} - {response.text}")
logging.error(
f"Error getting run status: {response.status_code} - {response.text}"
)
return None
@staticmethod
......@@ -98,16 +110,18 @@ class DatabricksJobManager:
"task_key": "table_update_task",
"notebook_task": {
"notebook_path": notebook_path,
"base_parameters": {
"input_message": "default_value"
}
"base_parameters": {"input_message": "default_value"},
},
"timeout_seconds": 3600
"timeout_seconds": 3600,
}
],
"max_concurrent_runs": 10,
"tags": {
"purpose": "metadata_ingestion",
"compute_type": "serverless"
}
"purpose": (
"metadata_ingestion"
if "ingestion" in job_name
else "metadata_deletion"
),
"compute_type": "serverless",
},
}
......@@ -13,13 +13,19 @@ class NotebookManager:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token
"""
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}"
self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def create_notebook(self, notebook_path, notebook_code: str, language='PYTHON', overwrite=True):
def create_notebook(
self, notebook_path, notebook_code: str, language="PYTHON", overwrite=True
):
"""
Create a notebook in Databricks workspace
......@@ -31,18 +37,22 @@ class NotebookManager:
"""
url = f"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64
encoded_content = base64.b64encode(notebook_code.encode('utf-8')).decode('utf-8')
encoded_content = base64.b64encode(notebook_code.encode("utf-8")).decode(
"utf-8"
)
payload = {
"path": notebook_path,
"format": "SOURCE",
"language": language,
"content": encoded_content,
"overwrite": overwrite
"overwrite": overwrite,
}
response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers)
if response.status_code == 200:
logging.info(f"Notebook created successfully at: {notebook_path}")
return True
else:
logging.error(f"Error creating notebook: {response.status_code} - {response.text}")
logging.error(
f"Error creating notebook: {response.status_code} - {response.text}"
)
return False
import orjson
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
databricks_details_db = redis_connector.connect(db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True)
\ No newline at end of file
databricks_details_db = redis_connector.connect(
db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True
)
......@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
graphql_details_db = redis_connector.connect(db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True)
graphql_details_db = redis_connector.connect(
db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True
)
def get_models(
......@@ -38,7 +40,9 @@ def get_models(
"""
tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper")
if tables_data is None:
raise ILensErrors(f"No GraphQL schema data found for project {info.data['project_id']}")
raise ILensErrors(
f"No GraphQL schema data found for project {info.data['project_id']}"
)
tables: Dict[str, Any] = orjson.loads(tables_data) or {}
if (
......
......@@ -3,7 +3,9 @@ import orjson
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
project_details_db = redis_connector.connect(db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True)
project_details_db = redis_connector.connect(
db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True
)
def get_project_time_zone(project_id: str):
......@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
return "UTC"
def fetch_level_details(project_id: str, keys: bool = False, raw: bool = False) -> dict | list:
def fetch_level_details(
project_id: str, keys: bool = False, raw: bool = False
) -> dict | list:
"""
Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details
......@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {})
asset_level = (
counter_levels.get("asset", counter_levels.get("equipment")) if isinstance(counter_levels, dict) else None
counter_levels.get("asset", counter_levels.get("equipment"))
if isinstance(counter_levels, dict)
else None
)
if asset_level:
return asset_level
......@@ -81,9 +87,10 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
swapped_dict = {v: k for k, v in counter_levels.items()}
return swapped_dict.get("ast", ""), "ast"
def project_template_keys(project_id: str, levels=False):
val = project_details_db.get(project_id)
if val is None:
raise ValueError(f"Unknown Project, Project ID:{project_id}Not Found!!!")
val = orjson.loads(val)
return val.get("levels", {}) if levels else list(val.get("levels", {}).keys())
\ No newline at end of file
return val.get("levels", {}) if levels else list(val.get("levels", {}).keys())
......@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
......@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class ModelCreatorAgent:
def __init__(self):
...
def __init__(self): ...
@staticmethod
async def model_creator_agent(message: ModelCreatorSchema):
......@@ -18,10 +17,14 @@ class ModelCreatorAgent:
session_manager=session_manager,
schema=message.schema,
)
model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils)
model_cal_obj = ModelCreatorHandler(
message=message, declarative_utils=declarative_utils
)
await model_cal_obj.create_models_in_unity_catalog()
@staticmethod
async def model_instance_agent(message: ModelInstanceSchema):
model_instance_obj = ModelInstanceHandler(project_id=message.project_id, payload=message)
await model_instance_obj.upload_instances_to_unity_catalog()
\ No newline at end of file
model_instance_obj = ModelInstanceHandler(
project_id=message.project_id, payload=message
)
await model_instance_obj.upload_instances_to_unity_catalog()
from typing import Optional, Union, Dict, Any, List
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator
from ut_security_util import MetaInfoSchema
from scripts.config import DatabricksConfig
......@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
class ModelInstanceSchema(BaseModel):
data: Union[Dict[str, Any], List[Dict[str, Any]]]
project_id: str
schema: Optional[str] = DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME
action_type: str = "save"
node_type: str
sql_schema: Optional[str] = Field(
default=DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME, alias="schema"
)
databricks_host: str = DatabricksConfig.DATABRICKS_HOST
databricks_port: int = DatabricksConfig.DATABRICKS_PORT
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
......@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
@model_validator(mode="before")
def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if 'data' in values and isinstance(values['data'], dict):
values['data'] = [values['data']]
if "data" in values and isinstance(values["data"], dict):
values["data"] = [values["data"]]
return values
......@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
DatabricksConfig.DATABRICKS_URI,
pool_pre_ping=True,
pool_recycle=3600,
echo=False
echo=False,
)
# Test connection
with self.engine.connect() as conn:
result = conn.execute(text("SELECT current_user() as user, current_catalog() as catalog"))
result = conn.execute(
text("SELECT current_user() as user, current_catalog() as catalog")
)
user_info = result.fetchone()
logger.info(f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}")
logger.info(
f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger.info("Successfully connected to Databricks")
return True
except Exception as e:
......@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
if self.engine:
self.engine.dispose()
def create_catalog(self, managed_location: Optional[str] = None, comment: Optional[str] = None,
properties: Optional[dict] = None):
def create_catalog(
self,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
"""
Create a new catalog in Unity Catalog
Args:
......@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}")
raise
def create_schema(self, schema_name: str, managed_location: Optional[str] = None, comment: Optional[str] = None,
properties: Optional[dict] = None):
def create_schema(
self,
schema_name: str,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
"""
Create a new schema within a catalog
Args:
......@@ -103,18 +116,22 @@ class DatabricksSQLUtility:
ddl += f"\nWITH DBPROPERTIES ({props})"
self.execute_sql_statement(ddl)
logger.info(f"Schema '{self.catalog_name}.{schema_name}' created successfully")
logger.info(
f"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
return full_schema_name
except Exception as e:
logger.error(f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}")
logger.error(
f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
raise
def create_external_location(
self,
location_name: str,
storage_path: str,
credential_name: str,
comment: Optional[str] = None
self,
location_name: str,
storage_path: str,
credential_name: str,
comment: Optional[str] = None,
) -> str:
"""
Create an external location in Unity Catalog
......@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
logger.info(f"External location '{location_name}' created successfully")
return location_name
except Exception as e:
logger.error(f"Failed to create external location '{location_name}': {str(e)}")
logger.error(
f"Failed to create external location '{location_name}': {str(e)}"
)
raise
def execute_sql_statement(self, query: str):
......
......@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
def delete(self, path="", params=None, **kwargs) -> httpx.Response:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.delete(url=url, params=params)
return response
......@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.put(url=url, data=data, json=json)
return response
......@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
"""
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.post(url=url, data=data, json=json)
return response
......@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.get(url=url, params=params)
return response
......
import logging
from typing import Any, Type, Optional, Dict, Union, Tuple
from typing import Any, Dict, Optional, Tuple, Type, Union
from databricks.sqlalchemy import TIMESTAMP
from sqlalchemy import (
Integer, BigInteger, SmallInteger, String, Text, Boolean,
Date, DateTime, Time, Numeric, Float, DECIMAL, CHAR, VARCHAR,
LargeBinary, JSON, Column, PrimaryKeyConstraint, UniqueConstraint, ForeignKeyConstraint, Double, Index, Table
CHAR,
DECIMAL,
JSON,
VARCHAR,
BigInteger,
Boolean,
Column,
Date,
DateTime,
Double,
Float,
ForeignKeyConstraint,
Index,
Integer,
LargeBinary,
Numeric,
PrimaryKeyConstraint,
SmallInteger,
String,
Table,
Text,
Time,
UniqueConstraint,
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import class_mapper, mapped_column, Mapped
from sqlalchemy.orm import Mapped, class_mapper, mapped_column
from sqlalchemy.types import UserDefinedType
......@@ -83,7 +103,6 @@ class TypeMapper:
VARCHAR: VARCHAR,
Text: String,
String: String,
# DateTime types
postgresql.DATE: Date,
postgresql.TIME: String,
......@@ -92,23 +111,18 @@ class TypeMapper:
Date: Date,
Time: String,
DateTime: DateTime,
# Boolean
postgresql.BOOLEAN: Boolean,
Boolean: Boolean,
# Binary
postgresql.BYTEA: LargeBinary,
LargeBinary: LargeBinary,
# JSON
postgresql.JSON: String,
postgresql.JSONB: String,
JSON: String,
# Array
postgresql.ARRAY: String,
# PostgreSQL specific
postgresql.UUID: String,
postgresql.INET: String,
......@@ -121,14 +135,14 @@ class TypeMapper:
}
SQL_TO_DATABRICKS_MAPPING = {
'VARCHAR': 'STRING',
'INTEGER': 'INT',
'BIGINT': 'BIGINT', # Keep as is
'FLOAT': 'DOUBLE',
'BOOLEAN': 'BOOLEAN', # Keep as is
'TIMESTAMP': 'TIMESTAMP', # Keep as is
'DATETIME': 'TIMESTAMP', # Change this mapping
'TEXT': 'STRING',
"VARCHAR": "STRING",
"INTEGER": "INT",
"BIGINT": "BIGINT", # Keep as is
"FLOAT": "DOUBLE",
"BOOLEAN": "BOOLEAN", # Keep as is
"TIMESTAMP": "TIMESTAMP", # Keep as is
"DATETIME": "TIMESTAMP", # Change this mapping
"TEXT": "STRING",
# Arrays and complex types are already correct, no replacement needed
}
......@@ -147,14 +161,14 @@ class TypeMapper:
base_type = type(sql_type)
# Handle special cases first
if base_type == postgresql.ARRAY or 'ARRAY' in str(sql_type):
if base_type == postgresql.ARRAY or "ARRAY" in str(sql_type):
return cls._convert_array_type_fallback(sql_type)
# Get the mapped type
if base_type in cls.POSTGRES_TO_DATABRICKS_MAPPING:
return cls.POSTGRES_TO_DATABRICKS_MAPPING[base_type]()
logging.info(f'Defaulting to String() for type: {type(sql_type)}')
logging.info(f"Defaulting to String() for type: {type(sql_type)}")
return String()
@classmethod
......@@ -171,21 +185,21 @@ class TypeMapper:
element_type_name = type(postgres_element_type).__name__.upper()
# Map PostgreSQL types to Databricks array element types
if any(t in element_type_name for t in ['VARCHAR', 'TEXT', 'STRING', 'CHAR']):
if any(t in element_type_name for t in ["VARCHAR", "TEXT", "STRING", "CHAR"]):
return "STRING"
elif any(t in element_type_name for t in ['INTEGER', 'BIGINT', 'SMALLINT']):
return "INT" if 'SMALLINT' not in element_type_name else "SMALLINT"
elif 'BIGINT' in element_type_name:
elif any(t in element_type_name for t in ["INTEGER", "BIGINT", "SMALLINT"]):
return "INT" if "SMALLINT" not in element_type_name else "SMALLINT"
elif "BIGINT" in element_type_name:
return "BIGINT"
elif any(t in element_type_name for t in ['BOOLEAN', 'BOOL']):
elif any(t in element_type_name for t in ["BOOLEAN", "BOOL"]):
return "BOOLEAN"
elif any(t in element_type_name for t in ['FLOAT', 'REAL', 'DOUBLE']):
elif any(t in element_type_name for t in ["FLOAT", "REAL", "DOUBLE"]):
return "DOUBLE"
elif any(t in element_type_name for t in ['NUMERIC', 'DECIMAL']):
elif any(t in element_type_name for t in ["NUMERIC", "DECIMAL"]):
return "DECIMAL"
elif 'DATE' in element_type_name:
elif "DATE" in element_type_name:
return "DATE"
elif 'TIMESTAMP' in element_type_name:
elif "TIMESTAMP" in element_type_name:
return "TIMESTAMP"
else:
return "STRING" # Default fallback
......@@ -225,7 +239,11 @@ class TypeMapper:
default_clause = ""
if column.default is not None:
default_value = column.default.arg if hasattr(column.default, 'arg') else column.default
default_value = (
column.default.arg
if hasattr(column.default, "arg")
else column.default
)
if isinstance(default_value, str):
default_clause = f" DEFAULT '{default_value}'"
elif isinstance(default_value, bool):
......@@ -258,7 +276,7 @@ class ColumnConverter:
columns_info = {}
# # Check for modern SQLAlchemy with annotations and mapped_column
if hasattr(model_class, '__annotations__'):
if hasattr(model_class, "__annotations__"):
columns_info.update(self.extract_from_annotations(model_class))
# Fallback: Try mapper approach for traditional models
......@@ -267,7 +285,9 @@ class ColumnConverter:
mapper = class_mapper(model_class)
for column_name, column in mapper.columns.items():
if column_name not in columns_info:
columns_info[column_name] = self.extract_column_properties(column)
columns_info[column_name] = self.extract_column_properties(
column
)
except Exception as e:
logging.error(f"Failed to extract column info using mapper: {e}")
# Final fallback: inspect class attributes directly
......@@ -279,20 +299,22 @@ class ColumnConverter:
def extract_column_properties(column: Any) -> Dict[str, Any]:
"""Extract properties from a column object."""
return {
'type': getattr(column, 'type', None),
'primary_key': getattr(column, 'primary_key', False),
'nullable': getattr(column, 'nullable', True),
'default': getattr(column, 'default', None),
'server_default': getattr(column, 'server_default', None),
'uses_mapped_column': False,
"type": getattr(column, "type", None),
"primary_key": getattr(column, "primary_key", False),
"nullable": getattr(column, "nullable", True),
"default": getattr(column, "default", None),
"server_default": getattr(column, "server_default", None),
"uses_mapped_column": False,
}
def _extract_from_class_attributes(self, model_class: type) -> Dict[str, Dict[str, Any]]:
def _extract_from_class_attributes(
self, model_class: type
) -> Dict[str, Dict[str, Any]]:
"""Extract column info from class attributes."""
columns_info = {}
for attr_name in dir(model_class):
if attr_name.startswith('_'):
if attr_name.startswith("_"):
continue
attr = getattr(model_class, attr_name, None)
......@@ -303,14 +325,14 @@ class ColumnConverter:
if isinstance(attr, Column):
columns_info[attr_name] = self.extract_column_properties(attr)
# Check for mapped_column objects
elif hasattr(attr, 'type') and hasattr(attr, 'nullable'):
elif hasattr(attr, "type") and hasattr(attr, "nullable"):
columns_info[attr_name] = {
'type': getattr(attr, 'type', None),
'primary_key': getattr(attr, 'primary_key', False),
'nullable': getattr(attr, 'nullable', True),
'default': getattr(attr, 'default', None),
'server_default': getattr(attr, 'server_default', None),
'uses_mapped_column': True,
"type": getattr(attr, "type", None),
"primary_key": getattr(attr, "primary_key", False),
"nullable": getattr(attr, "nullable", True),
"default": getattr(attr, "default", None),
"server_default": getattr(attr, "server_default", None),
"uses_mapped_column": True,
}
return columns_info
......@@ -320,22 +342,22 @@ class ColumnConverter:
"""Extract column info from type annotations (modern SQLAlchemy)."""
columns_info = {}
annotations = getattr(model_class, '__annotations__', {})
annotations = getattr(model_class, "__annotations__", {})
for attr_name, annotation in annotations.items():
if hasattr(model_class, attr_name):
attr = getattr(model_class, attr_name)
# Check if it's a mapped_column
if hasattr(attr, 'type'):
if hasattr(attr, "type"):
columns_info[attr_name] = {
'type': attr.type,
'primary_key': getattr(attr, 'primary_key', False),
'nullable': getattr(attr, 'nullable', True),
'default': getattr(attr, 'default', None),
'server_default': getattr(attr, 'server_default', None),
'annotation': annotation,
'uses_mapped_column': True,
"type": attr.type,
"primary_key": getattr(attr, "primary_key", False),
"nullable": getattr(attr, "nullable", True),
"default": getattr(attr, "default", None),
"server_default": getattr(attr, "server_default", None),
"annotation": annotation,
"uses_mapped_column": True,
}
return columns_info
......@@ -351,36 +373,38 @@ class ColumnConverter:
Tuple of (column_object, annotation_if_any)
"""
# Convert the column type
new_type = self.type_mapper.get_databricks_type(column_info['type'])
new_type = self.type_mapper.get_databricks_type(column_info["type"])
# Check if this uses type annotations (modern approach)
if column_info.get('uses_mapped_column', False):
if column_info.get("uses_mapped_column", False):
new_column = mapped_column(
new_type,
primary_key=column_info.get('primary_key', False),
nullable=column_info.get('nullable', True),
default=column_info.get('default'),
server_default=column_info.get('server_default'),
primary_key=column_info.get("primary_key", False),
nullable=column_info.get("nullable", True),
default=column_info.get("default"),
server_default=column_info.get("server_default"),
)
# Convert annotation
annotation = self.convert_annotation(column_info.get('annotation'), new_type)
annotation = self.convert_annotation(
column_info.get("annotation"), new_type
)
return new_column, annotation
else:
# Traditional Column approach
new_column = Column(
new_type,
primary_key=column_info.get('primary_key', False),
nullable=column_info.get('nullable', True),
default=column_info.get('default'),
server_default=column_info.get('server_default'),
primary_key=column_info.get("primary_key", False),
nullable=column_info.get("nullable", True),
default=column_info.get("default"),
server_default=column_info.get("server_default"),
)
return new_column, None
@staticmethod
def convert_annotation(annotation: Any, databricks_type: Any = None) -> Any:
"""Convert type annotations for Databricks compatibility."""
from typing import Optional, List
from typing import List, Optional
if annotation is None:
return Mapped[Optional[str]]
......@@ -388,15 +412,17 @@ class ColumnConverter:
annotation_str = str(annotation)
# Handle array/list types -> convert to List annotation
if any(keyword in annotation_str for keyword in ['list', 'List']):
if 'Optional' in annotation_str or 'Union' in annotation_str:
if any(keyword in annotation_str for keyword in ["list", "List"]):
if "Optional" in annotation_str or "Union" in annotation_str:
return Mapped[Optional[List[str]]] # Default to List[str]
else:
return Mapped[List[str]]
# Handle JSON/JSONB/dict types -> convert to string
if any(keyword in annotation_str for keyword in ['dict', 'Dict', 'json', 'Json']):
if 'Optional' in annotation_str or 'Union' in annotation_str:
if any(
keyword in annotation_str for keyword in ["dict", "Dict", "json", "Json"]
):
if "Optional" in annotation_str or "Union" in annotation_str:
return Mapped[Optional[str]]
else:
return Mapped[str]
......@@ -405,31 +431,38 @@ class ColumnConverter:
if databricks_type:
type_str = str(type(databricks_type).__name__).lower()
if 'array' in type_str:
if "array" in type_str:
# For array types, use List annotation
if 'Optional' in annotation_str:
return Mapped[Optional[List[str]]] # Could be more specific based on element type
if "Optional" in annotation_str:
return Mapped[
Optional[List[str]]
] # Could be more specific based on element type
return Mapped[List[str]]
elif 'integer' in type_str or 'biginteger' in type_str or 'smallinteger' in type_str:
if 'Optional' in annotation_str:
elif (
"integer" in type_str
or "biginteger" in type_str
or "smallinteger" in type_str
):
if "Optional" in annotation_str:
return Mapped[Optional[int]]
return Mapped[int]
elif 'boolean' in type_str:
if 'Optional' in annotation_str:
elif "boolean" in type_str:
if "Optional" in annotation_str:
return Mapped[Optional[bool]]
return Mapped[bool]
elif 'float' in type_str or 'numeric' in type_str:
if 'Optional' in annotation_str:
elif "float" in type_str or "numeric" in type_str:
if "Optional" in annotation_str:
return Mapped[Optional[float]]
return Mapped[float]
elif 'datetime' in type_str:
elif "datetime" in type_str:
from datetime import datetime
if 'Optional' in annotation_str:
if "Optional" in annotation_str:
return Mapped[Optional[datetime]]
return Mapped[datetime]
# Default to string
if 'Optional' in annotation_str or 'Union' in annotation_str:
if "Optional" in annotation_str or "Union" in annotation_str:
return Mapped[Optional[str]]
return Mapped[str]
......@@ -442,8 +475,7 @@ class SchemaProcessor:
@staticmethod
def process_table_args(
original_table_args: Any,
new_schema: Optional[str] = None
original_table_args: Any, new_schema: Optional[str] = None
) -> Union[Tuple, Dict, None]:
"""
Process table arguments, handling constraints and schema conversion.
......@@ -457,7 +489,7 @@ class SchemaProcessor:
"""
if not original_table_args:
if new_schema:
return {'schema': new_schema}
return {"schema": new_schema}
return None
new_table_args = []
......@@ -468,7 +500,9 @@ class SchemaProcessor:
for arg in original_table_args:
if isinstance(arg, dict):
# Process dictionary part
processed_kwargs = SchemaProcessor._process_table_kwargs(arg, new_schema)
processed_kwargs = SchemaProcessor._process_table_kwargs(
arg, new_schema
)
table_kwargs.update(processed_kwargs)
elif isinstance(arg, (Index, ForeignKeyConstraint)):
continue
......@@ -481,11 +515,13 @@ class SchemaProcessor:
# Handle dictionary format: {'schema': 'public', 'extend_existing': True}
elif isinstance(original_table_args, dict):
table_kwargs = SchemaProcessor._process_table_kwargs(original_table_args, new_schema)
table_kwargs = SchemaProcessor._process_table_kwargs(
original_table_args, new_schema
)
# Add new schema if specified and not already set
if new_schema is not None and 'schema' not in table_kwargs:
table_kwargs['schema'] = new_schema
if new_schema is not None and "schema" not in table_kwargs:
table_kwargs["schema"] = new_schema
# Construct result
if new_table_args and table_kwargs:
......@@ -499,16 +535,18 @@ class SchemaProcessor:
return None
@staticmethod
def _process_table_kwargs(kwargs: Dict[str, Any], new_schema: Optional[str]) -> Dict[str, Any]:
def _process_table_kwargs(
kwargs: Dict[str, Any], new_schema: Optional[str]
) -> Dict[str, Any]:
"""Process table keyword arguments."""
processed = {}
for key, value in kwargs.items():
if key == 'schema':
if key == "schema":
# Use new_schema if provided, otherwise keep original unless it's 'public'
if new_schema is not None:
processed[key] = new_schema
elif value != 'public':
elif value != "public":
processed[key] = value
# Skip 'public' schema (default)
else:
......@@ -532,12 +570,13 @@ class ModelConverter:
self.type_mapper = TypeMapper()
self.column_converter = ColumnConverter(self.type_mapper)
def convert_model(self,
postgres_model_class: Type,
base_class: Type,
new_table_name: Optional[str] = None,
new_schema: Optional[str] = None,
) -> Type:
def convert_model(
self,
postgres_model_class: Type,
base_class: Type,
new_table_name: Optional[str] = None,
new_schema: Optional[str] = None,
) -> Type:
"""
Convert a PostgreSQL SQLAlchemy model to a Databricks SQLAlchemy model.
......@@ -552,27 +591,28 @@ class ModelConverter:
"""
# Create base class if not provided
# Get table information
original_table_name = getattr(postgres_model_class, '__tablename__', 'unknown_table')
original_table_name = getattr(
postgres_model_class, "__tablename__", "unknown_table"
)
table_name = new_table_name or original_table_name
table_name = f'{table_name}'
table_name = f"{table_name}"
schema_processor = SchemaProcessor()
# Create new model attributes
new_attrs = {
'__tablename__': table_name,
'__module__': postgres_model_class.__module__,
"__tablename__": table_name,
"__module__": postgres_model_class.__module__,
}
# Process table arguments
if hasattr(postgres_model_class, '__table_args__'):
if hasattr(postgres_model_class, "__table_args__"):
processed_table_args = schema_processor.process_table_args(
postgres_model_class.__table_args__,
new_schema
postgres_model_class.__table_args__, new_schema
)
if processed_table_args:
new_attrs['__table_args__'] = processed_table_args
new_attrs["__table_args__"] = processed_table_args
elif new_schema:
new_attrs['__table_args__'] = {'schema': new_schema}
new_attrs["__table_args__"] = {"schema": new_schema}
# Extract and convert columns
columns_info = self.column_converter.extract_column_info(postgres_model_class)
......@@ -587,7 +627,7 @@ class ModelConverter:
# Add annotations if any
if annotations:
new_attrs['__annotations__'] = annotations
new_attrs["__annotations__"] = annotations
# Create the new model class
new_class_name = f"{postgres_model_class.__name__}Databricks"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment