Commit 7845d0b6 authored by harshavardhan.c's avatar harshavardhan.c

feat: Functionality added for delete and add jobs in databricks based on the catalog.

parent c741f979
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
/.idea
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
.vscode
data
.env
assets
__pycache__
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: requirements-txt-fixer
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
args:
- --py3-plus
- --keep-runtime-typing
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.8
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.11
# model-managament-databricks # model-managament-databricks
...@@ -6,19 +6,33 @@ from scripts.config import KafkaConfig ...@@ -6,19 +6,33 @@ from scripts.config import KafkaConfig
from scripts.engines.agents.model_creator_agent import ModelCreatorAgent from scripts.engines.agents.model_creator_agent import ModelCreatorAgent
from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent") broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
@broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2) @broker.subscriber(
KafkaConfig.KAFKA_MODEL_CREATION_TOPIC,
group_id="databricks_model_creator_agent",
max_workers=2,
)
async def consume_stream_for_processing_dependencies(message: dict): async def consume_stream_for_processing_dependencies(message: dict):
try: try:
await ModelCreatorAgent.model_creator_agent(message=ModelCreatorSchema(meta=message)) await ModelCreatorAgent.model_creator_agent(
message=ModelCreatorSchema(meta=message)
)
return True return True
except Exception as e: except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}") logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False return False
@broker.subscriber(KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC, group_id="databricks_instance_agent", max_workers=2)
@broker.subscriber(
KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC,
group_id="databricks_instance_agent",
max_workers=2,
)
async def consume_stream_for_processing_instances(message: dict): async def consume_stream_for_processing_instances(message: dict):
try: try:
await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message)) await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message))
......
...@@ -6,11 +6,11 @@ import sys ...@@ -6,11 +6,11 @@ import sys
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
from agent_subscribers import broker
from faststream import FastStream from faststream import FastStream
from ut_dev_utils import configure_logger from ut_dev_utils import configure_logger
from agent_subscribers import broker
configure_logger() configure_logger()
# Create FastStream app # Create FastStream app
......
...@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings): ...@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public") DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public")
DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical") DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical")
DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET") DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET")
DATABRICKS_STORAGE_PATH: str = Field(default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087") DATABRICKS_STORAGE_PATH: str = Field(
default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@model_validator(mode="before") @model_validator(mode="before")
def prepare_databricks_uri(cls, values): def prepare_databricks_uri(cls, values):
values[ values["DATABRICKS_URI"] = (
'DATABRICKS_URI'] = (f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}" f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}") f"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
return values return values
...@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage() ...@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
KafkaConfig = _KafkaConfig() KafkaConfig = _KafkaConfig()
DatabricksConfig = _DatabricksConfig() DatabricksConfig = _DatabricksConfig()
__all__ = ["Services", "RedisConfig", "ExternalServices", "PathToStorage", "KafkaConfig", "DatabricksConfig"] __all__ = [
"Services",
"RedisConfig",
"ExternalServices",
"PathToStorage",
"KafkaConfig",
"DatabricksConfig",
]
class DatabricksConstants: class DatabricksConstants:
METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job" METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job"
METADATA_DELETION_JOB_NAME = "metadata_deletion_job"
METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook" METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook"
METADATA_DELETION_NOTEBOOK_NAME = "metadata_deletion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook" TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook"
class NotebookConstants:
METADATA_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_ingestion.txt"
)
METADATA_DELETION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_deletion.txt"
)
TIMESERIES_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/timeseries_ingestion.txt"
)
# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")
# COMMAND ----------
def extract_table_info(input_message_str: str, delete_column:str = "id"):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f'Deleting rows: {delete_values}')
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
delete_column: delete_values,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def delete_records_by_ids(table_name, ids, id_column="id"):
"""
Delete records from external table (Delta or Parquet) using list of IDs
Args:
table_name (str): Full table name (catalog.schema.table)
ids (list): List of IDs to delete
id_column (str): Column name containing IDs (default: "id")
Returns:
bool: True if successful, False otherwise
"""
try:
if not ids:
print("No IDs provided")
return False
# Format IDs for SQL IN clause
if isinstance(ids[0], str):
id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
else:
id_values = "(" + ",".join([str(id) for id in ids]) + ")"
# Use Delta table DELETE operation
delta_table = DeltaTable.forName(spark, table_name)
condition = f"{id_column} IN {id_values}"
delta_table.delete(condition=expr(condition))
print(f"Successfully deleted {len(ids)} records from {table_name}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
# COMMAND ----------
table_info = extract_table_info(input_message, delete_column=delete_column)
# COMMAND ----------
result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
...@@ -4,7 +4,7 @@ from pyspark.sql.functions import * ...@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
from pyspark.sql.types import * from pyspark.sql.types import *
import json import json
spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate() spark = SparkSession.builder.appName("StreamingTimeseriesPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN") spark.sparkContext.setLogLevel("WARN")
# COMMAND ---------- # COMMAND ----------
...@@ -146,4 +146,3 @@ transformed_df.writeStream \ ...@@ -146,4 +146,3 @@ transformed_df.writeStream \
# COMMAND ---------- # COMMAND ----------
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
from ut_dev_utils import get_db_name from ut_dev_utils import get_db_name
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants
from scripts.db.databricks.job_manager import DatabricksJobManager from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.redis.databricks_details import databricks_details_db from scripts.db.redis.databricks_details import databricks_details_db
from scripts.schemas import ModelInstanceSchema from scripts.schemas import ModelInstanceSchema
...@@ -12,29 +13,44 @@ class ModelInstanceHandler: ...@@ -12,29 +13,44 @@ class ModelInstanceHandler:
def __init__(self, project_id: str, payload: ModelInstanceSchema): def __init__(self, project_id: str, payload: ModelInstanceSchema):
self.project_id = project_id self.project_id = project_id
self.payload = payload self.payload = payload
self.catalog_name = get_db_name(project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME) self.catalog_name = get_db_name(
project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME
)
self.job_manager = DatabricksJobManager( self.job_manager = DatabricksJobManager(
databricks_host=payload.databricks_host, databricks_host=payload.databricks_host,
access_token=payload.databricks_access_token access_token=payload.databricks_access_token,
) )
def upload_instances_to_unity_catalog(self): async def upload_instances_to_unity_catalog(self):
job_id = databricks_details_db.hget(self.project_id, "metadata_ingestion_job") if self.payload.action_type == "delete":
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_DELETION_JOB_NAME
)
else:
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME
)
if not job_id: if not job_id:
raise ValueError("No job id found for metadata ingestion job, skipping upload to unity catalog") raise ValueError(
run_id = self.job_manager.run_job(job_id=job_id, f"No job id found for {self.payload.action_type}, skipping upload to unity catalog"
parameters={"input_message": json.dumps(self.get_job_trigger_payload())}) )
run_id = self.job_manager.run_job(
job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())},
)
if not run_id: if not run_id:
raise ValueError("Failed to run metadata ingestion job, skipping upload to unity catalog") raise ValueError(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def get_job_trigger_payload(self): def get_job_trigger_payload(self):
table_name = self.payload.data[0]['type'] table_name = self.payload.node_type
schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}" schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return { return {
"table_properties": { "table_properties": {
"table_name": f'{self.catalog_name}.{schema_table}', "table_name": f"{self.catalog_name}.{schema_table}",
"table_path": f'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}', "table_path": f"{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}",
}, },
"project_id": self.project_id, "project_id": self.project_id,
"data": self.payload.data "data": self.payload.data,
} }
...@@ -5,7 +5,7 @@ from sqlalchemy.orm import declarative_base ...@@ -5,7 +5,7 @@ from sqlalchemy.orm import declarative_base
from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtils from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtils
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants from scripts.constants import DatabricksConstants, NotebookConstants
from scripts.db.databricks import DataBricksSQLLayer from scripts.db.databricks import DataBricksSQLLayer
from scripts.db.databricks.job_manager import DatabricksJobManager from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.databricks.notebook_manager import NotebookManager from scripts.db.databricks.notebook_manager import NotebookManager
...@@ -16,23 +16,25 @@ from scripts.utils.model_convertor_utils import ModelConverter ...@@ -16,23 +16,25 @@ from scripts.utils.model_convertor_utils import ModelConverter
class ModelCreatorHandler: class ModelCreatorHandler:
def __init__(self, message: ModelCreatorSchema, declarative_utils: DeclarativeUtils): def __init__(
self, message: ModelCreatorSchema, declarative_utils: DeclarativeUtils
):
self.declarative_utils = declarative_utils self.declarative_utils = declarative_utils
self.meta = message.meta self.meta = message.meta
self.message = message self.message = message
self.model_convertor = ModelConverter() self.model_convertor = ModelConverter()
self.job_manager = DatabricksJobManager( self.job_manager = DatabricksJobManager(
databricks_host=message.databricks_host, databricks_host=message.databricks_host,
access_token=message.databricks_access_token access_token=message.databricks_access_token,
) )
self.notebook_manager = NotebookManager( self.notebook_manager = NotebookManager(
databricks_host=message.databricks_host, databricks_host=message.databricks_host,
access_token=message.databricks_access_token access_token=message.databricks_access_token,
) )
self.databricks_sql_obj = DataBricksSQLLayer( self.databricks_sql_obj = DataBricksSQLLayer(
catalog_name=DatabricksConfig.DATABRICKS_CATALOG_NAME, catalog_name=DatabricksConfig.DATABRICKS_CATALOG_NAME,
project_id=self.meta.project_id, project_id=self.meta.project_id,
schema=message.schema schema=message.schema,
) )
self.external_location = self.message.databricks_storage_path self.external_location = self.message.databricks_storage_path
...@@ -47,32 +49,37 @@ class ModelCreatorHandler: ...@@ -47,32 +49,37 @@ class ModelCreatorHandler:
overall_tables = self.get_overall_tables() overall_tables = self.get_overall_tables()
project_levels = project_template_keys(self.meta.project_id, levels=True) project_levels = project_template_keys(self.meta.project_id, levels=True)
base = self.create_schema_base(schema_name=f'{self.databricks_sql_obj.catalog_name}.{self.message.schema}') base = self.create_schema_base(
schema_name=f"{self.databricks_sql_obj.catalog_name}.{self.message.schema}"
)
try: try:
# self.databricks_sql_obj.connect_to_databricks() # self.databricks_sql_obj.connect_to_databricks()
_ = self.setup_dependencies_for_unity_catalog() _ = self.setup_dependencies_for_unity_catalog()
table_properties = self.fetch_table_properties() table_properties = self.fetch_table_properties()
# for table in overall_tables: for table in overall_tables:
# table_class = self.declarative_utils.get_declarative_class(table) table_class = self.declarative_utils.get_declarative_class(table)
# if not table_class: if not table_class:
# logging.error(f"Table class not found for table: {table}") logging.error(f"Table class not found for table: {table}")
# return False return False
# new_model = self.model_convertor.convert_model( new_model = self.model_convertor.convert_model(
# table_class, table_class,
# base_class=base, base_class=base,
# new_schema=self.message.schema, new_schema=self.message.schema,
# ) )
#
# self.databricks_sql_obj.create_external_table_from_structure( self.databricks_sql_obj.create_external_table_from_structure(
# table=new_model.__table__, table=new_model.__table__,
# file_format="DELTA", file_format="DELTA",
# external_location=self.external_location, external_location=self.external_location,
# table_properties=table_properties table_properties=table_properties,
# ) )
ts_external_table = self.databricks_sql_obj.create_timeseries_table(columns=project_levels, ts_external_table = self.databricks_sql_obj.create_timeseries_table(
external_location=self.external_location) columns=project_levels, external_location=self.external_location
self.setup_notepads_and_jobs(timeseries_table_path=ts_external_table, project_levels=project_levels) )
self.setup_notepads_and_jobs(
timeseries_table_path=ts_external_table, project_levels=project_levels
)
return True return True
except Exception as e: except Exception as e:
logging.error(f"Error occurred while creating models in Unity Catalog: {e}") logging.error(f"Error occurred while creating models in Unity Catalog: {e}")
...@@ -95,20 +102,25 @@ class ModelCreatorHandler: ...@@ -95,20 +102,25 @@ class ModelCreatorHandler:
analytical (bool): Flag to indicate if the setup is for analytical or not analytical (bool): Flag to indicate if the setup is for analytical or not
""" """
logging.info( logging.info(
f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'") f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
self.databricks_sql_obj.connect_to_databricks() self.databricks_sql_obj.connect_to_databricks()
# Create catalog # Create catalog
catalog_success = self.databricks_sql_obj.create_catalog( catalog_success = self.databricks_sql_obj.create_catalog(
managed_location=f'{self.external_location}/{self.databricks_sql_obj.catalog_name}', managed_location=f"{self.external_location}/{self.databricks_sql_obj.catalog_name}",
) )
if not catalog_success: if not catalog_success:
return False return False
# Create schema # Create schema
schema_success = self.databricks_sql_obj.create_schema(DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME) schema_success = self.databricks_sql_obj.create_schema(
DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME
)
if not schema_success: if not schema_success:
return False return False
if analytical: if analytical:
schema_success = self.databricks_sql_obj.create_schema(DatabricksConfig.DATABRICKS_ANALYTICAL_SCHEMA_NAME) schema_success = self.databricks_sql_obj.create_schema(
DatabricksConfig.DATABRICKS_ANALYTICAL_SCHEMA_NAME
)
if not schema_success: if not schema_success:
return False return False
return True return True
...@@ -120,59 +132,112 @@ class ModelCreatorHandler: ...@@ -120,59 +132,112 @@ class ModelCreatorHandler:
project_levels: List of project levels project_levels: List of project levels
""" """
logging.info("Setting up notepads and jobs") logging.info("Setting up notepads and jobs")
with open(r"scripts/constants/notebooks/metadata_ingestion.txt", "r") as f: meta_ingestion_notebook_path = f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}"
notebook_code = f.read() meta_deletion_notebook_path = f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_NOTEBOOK_NAME}"
timeseries_notebook_path = f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}"
# # Notebook for metadata ingestion
# self.notebook_manager.create_notebook(
# notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}",
# notebook_code=notebook_code,
# overwrite=True
# )
# # Job for metadata ingestion used by model management
# job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
# job_name=f'{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}',
# notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
# ))
#
# databricks_details_db.hset(self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME, job_id)
# Timeseries DataPush Notebook
with open(r"scripts/constants/notebooks/timeseries_ingestion.txt", "r") as f:
notebook_code_for_timeseries = f.read()
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{timeseries_table_path}}",
f'"{timeseries_table_path}"')
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{project_levels}}", str(len(project_levels) - 1))
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{event_hub_connection_string}}", f'"{self.meta.project_id}"')
self.notebook_manager.create_notebook( # Setting up of Metadata Ingestion Notebook
notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}", existing_job_id = databricks_details_db.hget(
notebook_code=notebook_code_for_timeseries, self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME
overwrite=True )
if not existing_job_id:
self.create_notebook(
notebook_path=meta_ingestion_notebook_path,
source_notebook_path=NotebookConstants.METADATA_INGESTION_NOTEBOOK_PATH,
)
ingestion_job_id = self.create_job(
job_name=f"{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}",
notebook_path=meta_ingestion_notebook_path,
)
databricks_details_db.hset(
self.meta.project_id,
DatabricksConstants.METADATA_INGESTION_JOB_NAME,
ingestion_job_id,
)
existing_job_id = databricks_details_db.hget(
self.meta.project_id, DatabricksConstants.METADATA_DELETION_JOB_NAME
)
if not existing_job_id:
# Setting up of Metadata Deletion Notebook
self.create_notebook(
notebook_path=meta_deletion_notebook_path,
source_notebook_path=NotebookConstants.METADATA_DELETION_NOTEBOOK_PATH,
)
deletion_job_id = self.create_job(
job_name=f"{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_JOB_NAME}",
notebook_path=meta_deletion_notebook_path,
)
databricks_details_db.hset(
self.meta.project_id,
DatabricksConstants.METADATA_DELETION_JOB_NAME,
deletion_job_id,
)
# Setting up of Timeseries Ingestion Notebook
replace_mapping = {
"{{timeseries_table_path}}": f'"{timeseries_table_path}"',
"{{project_levels}}": str(len(project_levels) - 1),
"{{event_hub_connection_string}}": f'"{self.meta.project_id}"',
}
self.create_notebook(
notebook_path=timeseries_notebook_path,
source_notebook_path=NotebookConstants.TIMESERIES_INGESTION_NOTEBOOK_PATH,
replace_mapping=replace_mapping,
) )
@staticmethod @staticmethod
def fetch_table_properties(file_format: str = 'DELTA'): def fetch_table_properties(file_format: str = "DELTA"):
if file_format.lower() == 'delta': if file_format.lower() == "delta":
return { return {
# Performance optimization (Essential) # Performance optimization (Essential)
"delta.autoOptimize.optimizeWrite": "true", "delta.autoOptimize.optimizeWrite": "true",
"delta.autoOptimize.autoCompact": "true", "delta.autoOptimize.autoCompact": "true",
"delta.targetFileSize": "134217728", # 128MB "delta.targetFileSize": "134217728", # 128MB
'delta.enableChangeDataFeed': 'true', # If you need CDC "delta.enableChangeDataFeed": "true", # If you need CDC
# Checkpoint optimization (Performance boost) # Checkpoint optimization (Performance boost)
"delta.checkpoint.writeStatsAsStruct": "true", "delta.checkpoint.writeStatsAsStruct": "true",
"delta.checkpoint.writeStatsAsJson": "false" "delta.checkpoint.writeStatsAsJson": "false",
# Note: Retention properties removed - using defaults: # Note: Retention properties removed - using defaults:
# delta.deletedFileRetentionDuration = 7 days (default) # delta.deletedFileRetentionDuration = 7 days (default)
# delta.logRetentionDuration = 30 days (default) # delta.logRetentionDuration = 30 days (default)
} }
elif file_format.lower() == 'parquet': elif file_format.lower() == "parquet":
return {"parquet.compression": "snappy", return {
"parquet.compression": "snappy",
"parquet.page.size": "1048576", # 1MB - standard for mixed queries "parquet.page.size": "1048576", # 1MB - standard for mixed queries
"parquet.block.size": "134217728", # 128MB - balanced performance "parquet.block.size": "134217728", # 128MB - balanced performance
"serialization.format": "1"} "serialization.format": "1",
}
else: else:
return {} return {}
@staticmethod
def read_data_from_file(note_path: str):
with open(note_path) as f:
notebook_code = f.read()
return notebook_code
def create_notebook(
self,
notebook_path: str,
source_notebook_path: str,
replace_mapping: dict = None,
):
logging.info(f"Creating notebook {notebook_path}")
notebook_code = self.read_data_from_file(source_notebook_path)
if replace_mapping is not None:
for key, value in replace_mapping.items():
notebook_code = notebook_code.replace(key, value)
self.notebook_manager.create_notebook(
notebook_path=notebook_path, notebook_code=notebook_code, overwrite=True
)
return True
def create_job(self, job_name: str, notebook_path: str):
logging.info(f"Creating job {job_name}")
job_id = self.job_manager.create_job(
job_config=self.job_manager.create_job_config_for_serverless(
job_name=job_name,
notebook_path=notebook_path,
)
)
return job_id
from typing import Dict, List from typing import Dict, List
from sqlalchemy import Table, Column, String, BigInteger, DateTime, MetaData, Integer, Date from sqlalchemy import (
BigInteger,
Column,
Date,
DateTime,
Integer,
MetaData,
String,
Table,
)
from scripts.utils.databricks_utils import DatabricksSQLUtility from scripts.utils.databricks_utils import DatabricksSQLUtility
from scripts.utils.model_convertor_utils import TypeMapper from scripts.utils.model_convertor_utils import TypeMapper
...@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
super().__init__(catalog_name, project_id) super().__init__(catalog_name, project_id)
self.schema = schema self.schema = schema
def create_external_table_from_structure(self, table: Table, def create_external_table_from_structure(
self,
table: Table,
external_location: str, external_location: str,
file_format: str = "PARQUET", file_format: str = "PARQUET",
table_properties: Dict[str, str] = None, table_properties: Dict[str, str] = None,
partition_columns: list = None) -> str: partition_columns: list = None,
) -> str:
""" """
Create an external table from a model class. Create an external table from a model class.
...@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
""" """
schema_table = f"{table.schema}.{table.name}" if table.schema else table.name schema_table = f"{table.schema}.{table.name}" if table.schema else table.name
columns_sql = TypeMapper().extract_columns_without_constraints(table) columns_sql = TypeMapper().extract_columns_without_constraints(table)
external_location = f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}" external_location = (
f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
)
sql_parts = [ sql_parts = [
f"CREATE TABLE IF NOT EXISTS {schema_table}", f"CREATE TABLE IF NOT EXISTS {schema_table}",
f"({columns_sql})", f"({columns_sql})",
f"USING {file_format}", f"USING {file_format}",
f"LOCATION '{external_location}'" f"LOCATION '{external_location}'",
] ]
if partition_columns: if partition_columns:
partition_clause = ", ".join(partition_columns) partition_clause = ", ".join(partition_columns)
...@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
""" """
table_columns = [ table_columns = [
Column('timestamp', BigInteger, nullable=False), Column("timestamp", BigInteger, nullable=False),
Column('dt_timestamp', DateTime, nullable=False), Column("dt_timestamp", DateTime, nullable=False),
Column('dt_date', Date, nullable=False), Column("dt_date", Date, nullable=False),
Column('dt_hour', Integer, nullable=False), Column("dt_hour", Integer, nullable=False),
Column('value', String, nullable=False), Column("value", String, nullable=False),
Column('value_type', String, nullable=False, default='float'), Column("value_type", String, nullable=False, default="float"),
Column("c3", String, nullable=False) Column("c3", String, nullable=False),
] ]
default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns] default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns]
table_columns.extend([Column(col_name, String, nullable=True) for col_name in default_columns]) table_columns.extend(
partition_columns = ['dt_date', 'dt_hour', 'c3'] [Column(col_name, String, nullable=True) for col_name in default_columns]
)
partition_columns = ["dt_date", "dt_hour", "c3"]
table_properties = { table_properties = {
"parquet.compression": "snappy", # Fast decompression for frequent queries "parquet.compression": "snappy", # Fast decompression for frequent queries
"parquet.page.size": "524288", # 512KB - better time-range filtering "parquet.page.size": "524288", # 512KB - better time-range filtering
"parquet.block.size": "268435456", # 256MB - efficient sequential reads "parquet.block.size": "268435456", # 256MB - efficient sequential reads
"serialization.format": "1" # Support for arrays/complex types "serialization.format": "1", # Support for arrays/complex types
} }
table_obj = Table( table_obj = Table(
"timeseries_data", "timeseries_data", MetaData(), *table_columns, schema=self.schema
MetaData(),
*table_columns,
schema=self.schema
) )
self.create_external_table_from_structure( self.create_external_table_from_structure(
table=table_obj, table=table_obj,
external_location=external_location, external_location=external_location,
partition_columns=partition_columns, partition_columns=partition_columns,
table_properties=table_properties table_properties=table_properties,
) )
return external_location return external_location
...@@ -14,10 +14,14 @@ class DatabricksJobManager: ...@@ -14,10 +14,14 @@ class DatabricksJobManager:
databricks_host: Your Databricks workspace URL databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token access_token: Personal access token or service principal token
""" """
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}" self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = { self.headers = {
'Authorization': f'Bearer {access_token}', "Authorization": f"Bearer {access_token}",
'Content-Type': 'application/json' "Content-Type": "application/json",
} }
def create_job(self, job_config: dict): def create_job(self, job_config: dict):
...@@ -32,11 +36,13 @@ class DatabricksJobManager: ...@@ -32,11 +36,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config) response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config)
if response.status_code == 200: if response.status_code == 200:
job_id = response.json()['job_id'] job_id = response.json()["job_id"]
logging.info(f"Job created successfully with ID: {job_id}") logging.info(f"Job created successfully with ID: {job_id}")
return job_id return job_id
else: else:
logging.error(f"Error creating job: {response.status_code} - {response.text}") logging.error(
f"Error creating job: {response.status_code} - {response.text}"
)
return None return None
def run_job(self, job_id: str, parameters=None): def run_job(self, job_id: str, parameters=None):
...@@ -57,11 +63,13 @@ class DatabricksJobManager: ...@@ -57,11 +63,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload) response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload)
if response.status_code == 200: if response.status_code == 200:
run_id = response.json()['run_id'] run_id = response.json()["run_id"]
logging.info(f"Job run started with ID: {run_id}") logging.info(f"Job run started with ID: {run_id}")
return run_id return run_id
else: else:
logging.error(f"Error running job: {response.status_code} - {response.text}") logging.error(
f"Error running job: {response.status_code} - {response.text}"
)
return None return None
def get_run_status(self, run_id): def get_run_status(self, run_id):
...@@ -73,12 +81,16 @@ class DatabricksJobManager: ...@@ -73,12 +81,16 @@ class DatabricksJobManager:
url = f"{self.host}/api/2.1/jobs/runs/get" url = f"{self.host}/api/2.1/jobs/runs/get"
params = {"run_id": run_id} params = {"run_id": run_id}
response = HTTPXRequestHandler(url).get(url, headers=self.headers, params=params) response = HTTPXRequestHandler(url).get(
url, headers=self.headers, params=params
)
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
else: else:
logging.error(f"Error getting run status: {response.status_code} - {response.text}") logging.error(
f"Error getting run status: {response.status_code} - {response.text}"
)
return None return None
@staticmethod @staticmethod
...@@ -98,16 +110,18 @@ class DatabricksJobManager: ...@@ -98,16 +110,18 @@ class DatabricksJobManager:
"task_key": "table_update_task", "task_key": "table_update_task",
"notebook_task": { "notebook_task": {
"notebook_path": notebook_path, "notebook_path": notebook_path,
"base_parameters": { "base_parameters": {"input_message": "default_value"},
"input_message": "default_value"
}
}, },
"timeout_seconds": 3600 "timeout_seconds": 3600,
} }
], ],
"max_concurrent_runs": 10, "max_concurrent_runs": 10,
"tags": { "tags": {
"purpose": "metadata_ingestion", "purpose": (
"compute_type": "serverless" "metadata_ingestion"
} if "ingestion" in job_name
else "metadata_deletion"
),
"compute_type": "serverless",
},
} }
...@@ -13,13 +13,19 @@ class NotebookManager: ...@@ -13,13 +13,19 @@ class NotebookManager:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com') databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token access_token: Personal access token or service principal token
""" """
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}" self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = { self.headers = {
'Authorization': f'Bearer {access_token}', "Authorization": f"Bearer {access_token}",
'Content-Type': 'application/json' "Content-Type": "application/json",
} }
def create_notebook(self, notebook_path, notebook_code: str, language='PYTHON', overwrite=True): def create_notebook(
self, notebook_path, notebook_code: str, language="PYTHON", overwrite=True
):
""" """
Create a notebook in Databricks workspace Create a notebook in Databricks workspace
...@@ -31,18 +37,22 @@ class NotebookManager: ...@@ -31,18 +37,22 @@ class NotebookManager:
""" """
url = f"{self.host}/api/2.0/workspace/import" url = f"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64 # Encode the notebook content in base64
encoded_content = base64.b64encode(notebook_code.encode('utf-8')).decode('utf-8') encoded_content = base64.b64encode(notebook_code.encode("utf-8")).decode(
"utf-8"
)
payload = { payload = {
"path": notebook_path, "path": notebook_path,
"format": "SOURCE", "format": "SOURCE",
"language": language, "language": language,
"content": encoded_content, "content": encoded_content,
"overwrite": overwrite "overwrite": overwrite,
} }
response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers) response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers)
if response.status_code == 200: if response.status_code == 200:
logging.info(f"Notebook created successfully at: {notebook_path}") logging.info(f"Notebook created successfully at: {notebook_path}")
return True return True
else: else:
logging.error(f"Error creating notebook: {response.status_code} - {response.text}") logging.error(
f"Error creating notebook: {response.status_code} - {response.text}"
)
return False return False
import orjson
from scripts.config import RedisConfig from scripts.config import RedisConfig
from scripts.db.redis import redis_connector from scripts.db.redis import redis_connector
databricks_details_db = redis_connector.connect(db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True) databricks_details_db = redis_connector.connect(
\ No newline at end of file db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True
)
...@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig ...@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
from scripts.config import RedisConfig from scripts.config import RedisConfig
from scripts.db.redis import redis_connector from scripts.db.redis import redis_connector
graphql_details_db = redis_connector.connect(db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True) graphql_details_db = redis_connector.connect(
db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True
)
def get_models( def get_models(
...@@ -38,7 +40,9 @@ def get_models( ...@@ -38,7 +40,9 @@ def get_models(
""" """
tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper") tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper")
if tables_data is None: if tables_data is None:
raise ILensErrors(f"No GraphQL schema data found for project {info.data['project_id']}") raise ILensErrors(
f"No GraphQL schema data found for project {info.data['project_id']}"
)
tables: Dict[str, Any] = orjson.loads(tables_data) or {} tables: Dict[str, Any] = orjson.loads(tables_data) or {}
if ( if (
......
...@@ -3,7 +3,9 @@ import orjson ...@@ -3,7 +3,9 @@ import orjson
from scripts.config import RedisConfig from scripts.config import RedisConfig
from scripts.db.redis import redis_connector from scripts.db.redis import redis_connector
project_details_db = redis_connector.connect(db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True) project_details_db = redis_connector.connect(
db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True
)
def get_project_time_zone(project_id: str): def get_project_time_zone(project_id: str):
...@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str): ...@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
return "UTC" return "UTC"
def fetch_level_details(project_id: str, keys: bool = False, raw: bool = False) -> dict | list: def fetch_level_details(
project_id: str, keys: bool = False, raw: bool = False
) -> dict | list:
""" """
Function to fetch level details from project details Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details Uses redis project details cache db (db18) and fetches the level details
...@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str: ...@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
project_details = orjson.loads(project_details) project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {}) counter_levels = project_details.get("counter_levels", {})
asset_level = ( asset_level = (
counter_levels.get("asset", counter_levels.get("equipment")) if isinstance(counter_levels, dict) else None counter_levels.get("asset", counter_levels.get("equipment"))
if isinstance(counter_levels, dict)
else None
) )
if asset_level: if asset_level:
return asset_level return asset_level
...@@ -81,6 +87,7 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]: ...@@ -81,6 +87,7 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
swapped_dict = {v: k for k, v in counter_levels.items()} swapped_dict = {v: k for k, v in counter_levels.items()}
return swapped_dict.get("ast", ""), "ast" return swapped_dict.get("ast", ""), "ast"
def project_template_keys(project_id: str, levels=False): def project_template_keys(project_id: str, levels=False):
val = project_details_db.get(project_id) val = project_details_db.get(project_id)
if val is None: if val is None:
......
...@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker ...@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig from scripts.config import KafkaConfig
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent") broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
...@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema ...@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class ModelCreatorAgent: class ModelCreatorAgent:
def __init__(self): def __init__(self): ...
...
@staticmethod @staticmethod
async def model_creator_agent(message: ModelCreatorSchema): async def model_creator_agent(message: ModelCreatorSchema):
...@@ -18,10 +17,14 @@ class ModelCreatorAgent: ...@@ -18,10 +17,14 @@ class ModelCreatorAgent:
session_manager=session_manager, session_manager=session_manager,
schema=message.schema, schema=message.schema,
) )
model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils) model_cal_obj = ModelCreatorHandler(
message=message, declarative_utils=declarative_utils
)
await model_cal_obj.create_models_in_unity_catalog() await model_cal_obj.create_models_in_unity_catalog()
@staticmethod @staticmethod
async def model_instance_agent(message: ModelInstanceSchema): async def model_instance_agent(message: ModelInstanceSchema):
model_instance_obj = ModelInstanceHandler(project_id=message.project_id, payload=message) model_instance_obj = ModelInstanceHandler(
project_id=message.project_id, payload=message
)
await model_instance_obj.upload_instances_to_unity_catalog() await model_instance_obj.upload_instances_to_unity_catalog()
from typing import Optional, Union, Dict, Any, List from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, model_validator from pydantic import BaseModel, Field, model_validator
from ut_security_util import MetaInfoSchema from ut_security_util import MetaInfoSchema
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
...@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel): ...@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
class ModelInstanceSchema(BaseModel): class ModelInstanceSchema(BaseModel):
data: Union[Dict[str, Any], List[Dict[str, Any]]] data: Union[Dict[str, Any], List[Dict[str, Any]]]
project_id: str project_id: str
schema: Optional[str] = DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME action_type: str = "save"
node_type: str
sql_schema: Optional[str] = Field(
default=DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME, alias="schema"
)
databricks_host: str = DatabricksConfig.DATABRICKS_HOST databricks_host: str = DatabricksConfig.DATABRICKS_HOST
databricks_port: int = DatabricksConfig.DATABRICKS_PORT databricks_port: int = DatabricksConfig.DATABRICKS_PORT
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
...@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel): ...@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if 'data' in values and isinstance(values['data'], dict): if "data" in values and isinstance(values["data"], dict):
values['data'] = [values['data']] values["data"] = [values["data"]]
return values return values
...@@ -29,13 +29,17 @@ class DatabricksSQLUtility: ...@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
DatabricksConfig.DATABRICKS_URI, DatabricksConfig.DATABRICKS_URI,
pool_pre_ping=True, pool_pre_ping=True,
pool_recycle=3600, pool_recycle=3600,
echo=False echo=False,
) )
# Test connection # Test connection
with self.engine.connect() as conn: with self.engine.connect() as conn:
result = conn.execute(text("SELECT current_user() as user, current_catalog() as catalog")) result = conn.execute(
text("SELECT current_user() as user, current_catalog() as catalog")
)
user_info = result.fetchone() user_info = result.fetchone()
logger.info(f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}") logger.info(
f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger.info("Successfully connected to Databricks") logger.info("Successfully connected to Databricks")
return True return True
except Exception as e: except Exception as e:
...@@ -46,8 +50,12 @@ class DatabricksSQLUtility: ...@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
if self.engine: if self.engine:
self.engine.dispose() self.engine.dispose()
def create_catalog(self, managed_location: Optional[str] = None, comment: Optional[str] = None, def create_catalog(
properties: Optional[dict] = None): self,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
""" """
Create a new catalog in Unity Catalog Create a new catalog in Unity Catalog
Args: Args:
...@@ -77,8 +85,13 @@ class DatabricksSQLUtility: ...@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}") logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}")
raise raise
def create_schema(self, schema_name: str, managed_location: Optional[str] = None, comment: Optional[str] = None, def create_schema(
properties: Optional[dict] = None): self,
schema_name: str,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
""" """
Create a new schema within a catalog Create a new schema within a catalog
Args: Args:
...@@ -103,10 +116,14 @@ class DatabricksSQLUtility: ...@@ -103,10 +116,14 @@ class DatabricksSQLUtility:
ddl += f"\nWITH DBPROPERTIES ({props})" ddl += f"\nWITH DBPROPERTIES ({props})"
self.execute_sql_statement(ddl) self.execute_sql_statement(ddl)
logger.info(f"Schema '{self.catalog_name}.{schema_name}' created successfully") logger.info(
f"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
return full_schema_name return full_schema_name
except Exception as e: except Exception as e:
logger.error(f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}") logger.error(
f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
raise raise
def create_external_location( def create_external_location(
...@@ -114,7 +131,7 @@ class DatabricksSQLUtility: ...@@ -114,7 +131,7 @@ class DatabricksSQLUtility:
location_name: str, location_name: str,
storage_path: str, storage_path: str,
credential_name: str, credential_name: str,
comment: Optional[str] = None comment: Optional[str] = None,
) -> str: ) -> str:
""" """
Create an external location in Unity Catalog Create an external location in Unity Catalog
...@@ -138,7 +155,9 @@ class DatabricksSQLUtility: ...@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
logger.info(f"External location '{location_name}' created successfully") logger.info(f"External location '{location_name}' created successfully")
return location_name return location_name
except Exception as e: except Exception as e:
logger.error(f"Failed to create external location '{location_name}': {str(e)}") logger.error(
f"Failed to create external location '{location_name}': {str(e)}"
)
raise raise
def execute_sql_statement(self, query: str): def execute_sql_statement(self, query: str):
......
...@@ -18,7 +18,11 @@ class HTTPXRequestUtil: ...@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
def delete(self, path="", params=None, **kwargs) -> httpx.Response: def delete(self, path="", params=None, **kwargs) -> httpx.Response:
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.delete(url=url, params=params) response: httpx.Response = client.delete(url=url, params=params)
return response return response
...@@ -27,7 +31,11 @@ class HTTPXRequestUtil: ...@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.put(url=url, data=data, json=json) response: httpx.Response = client.put(url=url, data=data, json=json)
return response return response
...@@ -42,7 +50,11 @@ class HTTPXRequestUtil: ...@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
""" """
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.post(url=url, data=data, json=json) response: httpx.Response = client.post(url=url, data=data, json=json)
return response return response
...@@ -52,7 +64,11 @@ class HTTPXRequestUtil: ...@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.get(url=url, params=params) response: httpx.Response = client.get(url=url, params=params)
return response return response
......
import logging import logging
from typing import Any, Type, Optional, Dict, Union, Tuple from typing import Any, Dict, Optional, Tuple, Type, Union
from databricks.sqlalchemy import TIMESTAMP from databricks.sqlalchemy import TIMESTAMP
from sqlalchemy import ( from sqlalchemy import (
Integer, BigInteger, SmallInteger, String, Text, Boolean, CHAR,
Date, DateTime, Time, Numeric, Float, DECIMAL, CHAR, VARCHAR, DECIMAL,
LargeBinary, JSON, Column, PrimaryKeyConstraint, UniqueConstraint, ForeignKeyConstraint, Double, Index, Table JSON,
VARCHAR,
BigInteger,
Boolean,
Column,
Date,
DateTime,
Double,
Float,
ForeignKeyConstraint,
Index,
Integer,
LargeBinary,
Numeric,
PrimaryKeyConstraint,
SmallInteger,
String,
Table,
Text,
Time,
UniqueConstraint,
) )
from sqlalchemy.dialects import postgresql from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import class_mapper, mapped_column, Mapped from sqlalchemy.orm import Mapped, class_mapper, mapped_column
from sqlalchemy.types import UserDefinedType from sqlalchemy.types import UserDefinedType
...@@ -83,7 +103,6 @@ class TypeMapper: ...@@ -83,7 +103,6 @@ class TypeMapper:
VARCHAR: VARCHAR, VARCHAR: VARCHAR,
Text: String, Text: String,
String: String, String: String,
# DateTime types # DateTime types
postgresql.DATE: Date, postgresql.DATE: Date,
postgresql.TIME: String, postgresql.TIME: String,
...@@ -92,23 +111,18 @@ class TypeMapper: ...@@ -92,23 +111,18 @@ class TypeMapper:
Date: Date, Date: Date,
Time: String, Time: String,
DateTime: DateTime, DateTime: DateTime,
# Boolean # Boolean
postgresql.BOOLEAN: Boolean, postgresql.BOOLEAN: Boolean,
Boolean: Boolean, Boolean: Boolean,
# Binary # Binary
postgresql.BYTEA: LargeBinary, postgresql.BYTEA: LargeBinary,
LargeBinary: LargeBinary, LargeBinary: LargeBinary,
# JSON # JSON
postgresql.JSON: String, postgresql.JSON: String,
postgresql.JSONB: String, postgresql.JSONB: String,
JSON: String, JSON: String,
# Array # Array
postgresql.ARRAY: String, postgresql.ARRAY: String,
# PostgreSQL specific # PostgreSQL specific
postgresql.UUID: String, postgresql.UUID: String,
postgresql.INET: String, postgresql.INET: String,
...@@ -121,14 +135,14 @@ class TypeMapper: ...@@ -121,14 +135,14 @@ class TypeMapper:
} }
SQL_TO_DATABRICKS_MAPPING = { SQL_TO_DATABRICKS_MAPPING = {
'VARCHAR': 'STRING', "VARCHAR": "STRING",
'INTEGER': 'INT', "INTEGER": "INT",
'BIGINT': 'BIGINT', # Keep as is "BIGINT": "BIGINT", # Keep as is
'FLOAT': 'DOUBLE', "FLOAT": "DOUBLE",
'BOOLEAN': 'BOOLEAN', # Keep as is "BOOLEAN": "BOOLEAN", # Keep as is
'TIMESTAMP': 'TIMESTAMP', # Keep as is "TIMESTAMP": "TIMESTAMP", # Keep as is
'DATETIME': 'TIMESTAMP', # Change this mapping "DATETIME": "TIMESTAMP", # Change this mapping
'TEXT': 'STRING', "TEXT": "STRING",
# Arrays and complex types are already correct, no replacement needed # Arrays and complex types are already correct, no replacement needed
} }
...@@ -147,14 +161,14 @@ class TypeMapper: ...@@ -147,14 +161,14 @@ class TypeMapper:
base_type = type(sql_type) base_type = type(sql_type)
# Handle special cases first # Handle special cases first
if base_type == postgresql.ARRAY or 'ARRAY' in str(sql_type): if base_type == postgresql.ARRAY or "ARRAY" in str(sql_type):
return cls._convert_array_type_fallback(sql_type) return cls._convert_array_type_fallback(sql_type)
# Get the mapped type # Get the mapped type
if base_type in cls.POSTGRES_TO_DATABRICKS_MAPPING: if base_type in cls.POSTGRES_TO_DATABRICKS_MAPPING:
return cls.POSTGRES_TO_DATABRICKS_MAPPING[base_type]() return cls.POSTGRES_TO_DATABRICKS_MAPPING[base_type]()
logging.info(f'Defaulting to String() for type: {type(sql_type)}') logging.info(f"Defaulting to String() for type: {type(sql_type)}")
return String() return String()
@classmethod @classmethod
...@@ -171,21 +185,21 @@ class TypeMapper: ...@@ -171,21 +185,21 @@ class TypeMapper:
element_type_name = type(postgres_element_type).__name__.upper() element_type_name = type(postgres_element_type).__name__.upper()
# Map PostgreSQL types to Databricks array element types # Map PostgreSQL types to Databricks array element types
if any(t in element_type_name for t in ['VARCHAR', 'TEXT', 'STRING', 'CHAR']): if any(t in element_type_name for t in ["VARCHAR", "TEXT", "STRING", "CHAR"]):
return "STRING" return "STRING"
elif any(t in element_type_name for t in ['INTEGER', 'BIGINT', 'SMALLINT']): elif any(t in element_type_name for t in ["INTEGER", "BIGINT", "SMALLINT"]):
return "INT" if 'SMALLINT' not in element_type_name else "SMALLINT" return "INT" if "SMALLINT" not in element_type_name else "SMALLINT"
elif 'BIGINT' in element_type_name: elif "BIGINT" in element_type_name:
return "BIGINT" return "BIGINT"
elif any(t in element_type_name for t in ['BOOLEAN', 'BOOL']): elif any(t in element_type_name for t in ["BOOLEAN", "BOOL"]):
return "BOOLEAN" return "BOOLEAN"
elif any(t in element_type_name for t in ['FLOAT', 'REAL', 'DOUBLE']): elif any(t in element_type_name for t in ["FLOAT", "REAL", "DOUBLE"]):
return "DOUBLE" return "DOUBLE"
elif any(t in element_type_name for t in ['NUMERIC', 'DECIMAL']): elif any(t in element_type_name for t in ["NUMERIC", "DECIMAL"]):
return "DECIMAL" return "DECIMAL"
elif 'DATE' in element_type_name: elif "DATE" in element_type_name:
return "DATE" return "DATE"
elif 'TIMESTAMP' in element_type_name: elif "TIMESTAMP" in element_type_name:
return "TIMESTAMP" return "TIMESTAMP"
else: else:
return "STRING" # Default fallback return "STRING" # Default fallback
...@@ -225,7 +239,11 @@ class TypeMapper: ...@@ -225,7 +239,11 @@ class TypeMapper:
default_clause = "" default_clause = ""
if column.default is not None: if column.default is not None:
default_value = column.default.arg if hasattr(column.default, 'arg') else column.default default_value = (
column.default.arg
if hasattr(column.default, "arg")
else column.default
)
if isinstance(default_value, str): if isinstance(default_value, str):
default_clause = f" DEFAULT '{default_value}'" default_clause = f" DEFAULT '{default_value}'"
elif isinstance(default_value, bool): elif isinstance(default_value, bool):
...@@ -258,7 +276,7 @@ class ColumnConverter: ...@@ -258,7 +276,7 @@ class ColumnConverter:
columns_info = {} columns_info = {}
# # Check for modern SQLAlchemy with annotations and mapped_column # # Check for modern SQLAlchemy with annotations and mapped_column
if hasattr(model_class, '__annotations__'): if hasattr(model_class, "__annotations__"):
columns_info.update(self.extract_from_annotations(model_class)) columns_info.update(self.extract_from_annotations(model_class))
# Fallback: Try mapper approach for traditional models # Fallback: Try mapper approach for traditional models
...@@ -267,7 +285,9 @@ class ColumnConverter: ...@@ -267,7 +285,9 @@ class ColumnConverter:
mapper = class_mapper(model_class) mapper = class_mapper(model_class)
for column_name, column in mapper.columns.items(): for column_name, column in mapper.columns.items():
if column_name not in columns_info: if column_name not in columns_info:
columns_info[column_name] = self.extract_column_properties(column) columns_info[column_name] = self.extract_column_properties(
column
)
except Exception as e: except Exception as e:
logging.error(f"Failed to extract column info using mapper: {e}") logging.error(f"Failed to extract column info using mapper: {e}")
# Final fallback: inspect class attributes directly # Final fallback: inspect class attributes directly
...@@ -279,20 +299,22 @@ class ColumnConverter: ...@@ -279,20 +299,22 @@ class ColumnConverter:
def extract_column_properties(column: Any) -> Dict[str, Any]: def extract_column_properties(column: Any) -> Dict[str, Any]:
"""Extract properties from a column object.""" """Extract properties from a column object."""
return { return {
'type': getattr(column, 'type', None), "type": getattr(column, "type", None),
'primary_key': getattr(column, 'primary_key', False), "primary_key": getattr(column, "primary_key", False),
'nullable': getattr(column, 'nullable', True), "nullable": getattr(column, "nullable", True),
'default': getattr(column, 'default', None), "default": getattr(column, "default", None),
'server_default': getattr(column, 'server_default', None), "server_default": getattr(column, "server_default", None),
'uses_mapped_column': False, "uses_mapped_column": False,
} }
def _extract_from_class_attributes(self, model_class: type) -> Dict[str, Dict[str, Any]]: def _extract_from_class_attributes(
self, model_class: type
) -> Dict[str, Dict[str, Any]]:
"""Extract column info from class attributes.""" """Extract column info from class attributes."""
columns_info = {} columns_info = {}
for attr_name in dir(model_class): for attr_name in dir(model_class):
if attr_name.startswith('_'): if attr_name.startswith("_"):
continue continue
attr = getattr(model_class, attr_name, None) attr = getattr(model_class, attr_name, None)
...@@ -303,14 +325,14 @@ class ColumnConverter: ...@@ -303,14 +325,14 @@ class ColumnConverter:
if isinstance(attr, Column): if isinstance(attr, Column):
columns_info[attr_name] = self.extract_column_properties(attr) columns_info[attr_name] = self.extract_column_properties(attr)
# Check for mapped_column objects # Check for mapped_column objects
elif hasattr(attr, 'type') and hasattr(attr, 'nullable'): elif hasattr(attr, "type") and hasattr(attr, "nullable"):
columns_info[attr_name] = { columns_info[attr_name] = {
'type': getattr(attr, 'type', None), "type": getattr(attr, "type", None),
'primary_key': getattr(attr, 'primary_key', False), "primary_key": getattr(attr, "primary_key", False),
'nullable': getattr(attr, 'nullable', True), "nullable": getattr(attr, "nullable", True),
'default': getattr(attr, 'default', None), "default": getattr(attr, "default", None),
'server_default': getattr(attr, 'server_default', None), "server_default": getattr(attr, "server_default", None),
'uses_mapped_column': True, "uses_mapped_column": True,
} }
return columns_info return columns_info
...@@ -320,22 +342,22 @@ class ColumnConverter: ...@@ -320,22 +342,22 @@ class ColumnConverter:
"""Extract column info from type annotations (modern SQLAlchemy).""" """Extract column info from type annotations (modern SQLAlchemy)."""
columns_info = {} columns_info = {}
annotations = getattr(model_class, '__annotations__', {}) annotations = getattr(model_class, "__annotations__", {})
for attr_name, annotation in annotations.items(): for attr_name, annotation in annotations.items():
if hasattr(model_class, attr_name): if hasattr(model_class, attr_name):
attr = getattr(model_class, attr_name) attr = getattr(model_class, attr_name)
# Check if it's a mapped_column # Check if it's a mapped_column
if hasattr(attr, 'type'): if hasattr(attr, "type"):
columns_info[attr_name] = { columns_info[attr_name] = {
'type': attr.type, "type": attr.type,
'primary_key': getattr(attr, 'primary_key', False), "primary_key": getattr(attr, "primary_key", False),
'nullable': getattr(attr, 'nullable', True), "nullable": getattr(attr, "nullable", True),
'default': getattr(attr, 'default', None), "default": getattr(attr, "default", None),
'server_default': getattr(attr, 'server_default', None), "server_default": getattr(attr, "server_default", None),
'annotation': annotation, "annotation": annotation,
'uses_mapped_column': True, "uses_mapped_column": True,
} }
return columns_info return columns_info
...@@ -351,36 +373,38 @@ class ColumnConverter: ...@@ -351,36 +373,38 @@ class ColumnConverter:
Tuple of (column_object, annotation_if_any) Tuple of (column_object, annotation_if_any)
""" """
# Convert the column type # Convert the column type
new_type = self.type_mapper.get_databricks_type(column_info['type']) new_type = self.type_mapper.get_databricks_type(column_info["type"])
# Check if this uses type annotations (modern approach) # Check if this uses type annotations (modern approach)
if column_info.get('uses_mapped_column', False): if column_info.get("uses_mapped_column", False):
new_column = mapped_column( new_column = mapped_column(
new_type, new_type,
primary_key=column_info.get('primary_key', False), primary_key=column_info.get("primary_key", False),
nullable=column_info.get('nullable', True), nullable=column_info.get("nullable", True),
default=column_info.get('default'), default=column_info.get("default"),
server_default=column_info.get('server_default'), server_default=column_info.get("server_default"),
) )
# Convert annotation # Convert annotation
annotation = self.convert_annotation(column_info.get('annotation'), new_type) annotation = self.convert_annotation(
column_info.get("annotation"), new_type
)
return new_column, annotation return new_column, annotation
else: else:
# Traditional Column approach # Traditional Column approach
new_column = Column( new_column = Column(
new_type, new_type,
primary_key=column_info.get('primary_key', False), primary_key=column_info.get("primary_key", False),
nullable=column_info.get('nullable', True), nullable=column_info.get("nullable", True),
default=column_info.get('default'), default=column_info.get("default"),
server_default=column_info.get('server_default'), server_default=column_info.get("server_default"),
) )
return new_column, None return new_column, None
@staticmethod @staticmethod
def convert_annotation(annotation: Any, databricks_type: Any = None) -> Any: def convert_annotation(annotation: Any, databricks_type: Any = None) -> Any:
"""Convert type annotations for Databricks compatibility.""" """Convert type annotations for Databricks compatibility."""
from typing import Optional, List from typing import List, Optional
if annotation is None: if annotation is None:
return Mapped[Optional[str]] return Mapped[Optional[str]]
...@@ -388,15 +412,17 @@ class ColumnConverter: ...@@ -388,15 +412,17 @@ class ColumnConverter:
annotation_str = str(annotation) annotation_str = str(annotation)
# Handle array/list types -> convert to List annotation # Handle array/list types -> convert to List annotation
if any(keyword in annotation_str for keyword in ['list', 'List']): if any(keyword in annotation_str for keyword in ["list", "List"]):
if 'Optional' in annotation_str or 'Union' in annotation_str: if "Optional" in annotation_str or "Union" in annotation_str:
return Mapped[Optional[List[str]]] # Default to List[str] return Mapped[Optional[List[str]]] # Default to List[str]
else: else:
return Mapped[List[str]] return Mapped[List[str]]
# Handle JSON/JSONB/dict types -> convert to string # Handle JSON/JSONB/dict types -> convert to string
if any(keyword in annotation_str for keyword in ['dict', 'Dict', 'json', 'Json']): if any(
if 'Optional' in annotation_str or 'Union' in annotation_str: keyword in annotation_str for keyword in ["dict", "Dict", "json", "Json"]
):
if "Optional" in annotation_str or "Union" in annotation_str:
return Mapped[Optional[str]] return Mapped[Optional[str]]
else: else:
return Mapped[str] return Mapped[str]
...@@ -405,31 +431,38 @@ class ColumnConverter: ...@@ -405,31 +431,38 @@ class ColumnConverter:
if databricks_type: if databricks_type:
type_str = str(type(databricks_type).__name__).lower() type_str = str(type(databricks_type).__name__).lower()
if 'array' in type_str: if "array" in type_str:
# For array types, use List annotation # For array types, use List annotation
if 'Optional' in annotation_str: if "Optional" in annotation_str:
return Mapped[Optional[List[str]]] # Could be more specific based on element type return Mapped[
Optional[List[str]]
] # Could be more specific based on element type
return Mapped[List[str]] return Mapped[List[str]]
elif 'integer' in type_str or 'biginteger' in type_str or 'smallinteger' in type_str: elif (
if 'Optional' in annotation_str: "integer" in type_str
or "biginteger" in type_str
or "smallinteger" in type_str
):
if "Optional" in annotation_str:
return Mapped[Optional[int]] return Mapped[Optional[int]]
return Mapped[int] return Mapped[int]
elif 'boolean' in type_str: elif "boolean" in type_str:
if 'Optional' in annotation_str: if "Optional" in annotation_str:
return Mapped[Optional[bool]] return Mapped[Optional[bool]]
return Mapped[bool] return Mapped[bool]
elif 'float' in type_str or 'numeric' in type_str: elif "float" in type_str or "numeric" in type_str:
if 'Optional' in annotation_str: if "Optional" in annotation_str:
return Mapped[Optional[float]] return Mapped[Optional[float]]
return Mapped[float] return Mapped[float]
elif 'datetime' in type_str: elif "datetime" in type_str:
from datetime import datetime from datetime import datetime
if 'Optional' in annotation_str:
if "Optional" in annotation_str:
return Mapped[Optional[datetime]] return Mapped[Optional[datetime]]
return Mapped[datetime] return Mapped[datetime]
# Default to string # Default to string
if 'Optional' in annotation_str or 'Union' in annotation_str: if "Optional" in annotation_str or "Union" in annotation_str:
return Mapped[Optional[str]] return Mapped[Optional[str]]
return Mapped[str] return Mapped[str]
...@@ -442,8 +475,7 @@ class SchemaProcessor: ...@@ -442,8 +475,7 @@ class SchemaProcessor:
@staticmethod @staticmethod
def process_table_args( def process_table_args(
original_table_args: Any, original_table_args: Any, new_schema: Optional[str] = None
new_schema: Optional[str] = None
) -> Union[Tuple, Dict, None]: ) -> Union[Tuple, Dict, None]:
""" """
Process table arguments, handling constraints and schema conversion. Process table arguments, handling constraints and schema conversion.
...@@ -457,7 +489,7 @@ class SchemaProcessor: ...@@ -457,7 +489,7 @@ class SchemaProcessor:
""" """
if not original_table_args: if not original_table_args:
if new_schema: if new_schema:
return {'schema': new_schema} return {"schema": new_schema}
return None return None
new_table_args = [] new_table_args = []
...@@ -468,7 +500,9 @@ class SchemaProcessor: ...@@ -468,7 +500,9 @@ class SchemaProcessor:
for arg in original_table_args: for arg in original_table_args:
if isinstance(arg, dict): if isinstance(arg, dict):
# Process dictionary part # Process dictionary part
processed_kwargs = SchemaProcessor._process_table_kwargs(arg, new_schema) processed_kwargs = SchemaProcessor._process_table_kwargs(
arg, new_schema
)
table_kwargs.update(processed_kwargs) table_kwargs.update(processed_kwargs)
elif isinstance(arg, (Index, ForeignKeyConstraint)): elif isinstance(arg, (Index, ForeignKeyConstraint)):
continue continue
...@@ -481,11 +515,13 @@ class SchemaProcessor: ...@@ -481,11 +515,13 @@ class SchemaProcessor:
# Handle dictionary format: {'schema': 'public', 'extend_existing': True} # Handle dictionary format: {'schema': 'public', 'extend_existing': True}
elif isinstance(original_table_args, dict): elif isinstance(original_table_args, dict):
table_kwargs = SchemaProcessor._process_table_kwargs(original_table_args, new_schema) table_kwargs = SchemaProcessor._process_table_kwargs(
original_table_args, new_schema
)
# Add new schema if specified and not already set # Add new schema if specified and not already set
if new_schema is not None and 'schema' not in table_kwargs: if new_schema is not None and "schema" not in table_kwargs:
table_kwargs['schema'] = new_schema table_kwargs["schema"] = new_schema
# Construct result # Construct result
if new_table_args and table_kwargs: if new_table_args and table_kwargs:
...@@ -499,16 +535,18 @@ class SchemaProcessor: ...@@ -499,16 +535,18 @@ class SchemaProcessor:
return None return None
@staticmethod @staticmethod
def _process_table_kwargs(kwargs: Dict[str, Any], new_schema: Optional[str]) -> Dict[str, Any]: def _process_table_kwargs(
kwargs: Dict[str, Any], new_schema: Optional[str]
) -> Dict[str, Any]:
"""Process table keyword arguments.""" """Process table keyword arguments."""
processed = {} processed = {}
for key, value in kwargs.items(): for key, value in kwargs.items():
if key == 'schema': if key == "schema":
# Use new_schema if provided, otherwise keep original unless it's 'public' # Use new_schema if provided, otherwise keep original unless it's 'public'
if new_schema is not None: if new_schema is not None:
processed[key] = new_schema processed[key] = new_schema
elif value != 'public': elif value != "public":
processed[key] = value processed[key] = value
# Skip 'public' schema (default) # Skip 'public' schema (default)
else: else:
...@@ -532,7 +570,8 @@ class ModelConverter: ...@@ -532,7 +570,8 @@ class ModelConverter:
self.type_mapper = TypeMapper() self.type_mapper = TypeMapper()
self.column_converter = ColumnConverter(self.type_mapper) self.column_converter = ColumnConverter(self.type_mapper)
def convert_model(self, def convert_model(
self,
postgres_model_class: Type, postgres_model_class: Type,
base_class: Type, base_class: Type,
new_table_name: Optional[str] = None, new_table_name: Optional[str] = None,
...@@ -552,27 +591,28 @@ class ModelConverter: ...@@ -552,27 +591,28 @@ class ModelConverter:
""" """
# Create base class if not provided # Create base class if not provided
# Get table information # Get table information
original_table_name = getattr(postgres_model_class, '__tablename__', 'unknown_table') original_table_name = getattr(
postgres_model_class, "__tablename__", "unknown_table"
)
table_name = new_table_name or original_table_name table_name = new_table_name or original_table_name
table_name = f'{table_name}' table_name = f"{table_name}"
schema_processor = SchemaProcessor() schema_processor = SchemaProcessor()
# Create new model attributes # Create new model attributes
new_attrs = { new_attrs = {
'__tablename__': table_name, "__tablename__": table_name,
'__module__': postgres_model_class.__module__, "__module__": postgres_model_class.__module__,
} }
# Process table arguments # Process table arguments
if hasattr(postgres_model_class, '__table_args__'): if hasattr(postgres_model_class, "__table_args__"):
processed_table_args = schema_processor.process_table_args( processed_table_args = schema_processor.process_table_args(
postgres_model_class.__table_args__, postgres_model_class.__table_args__, new_schema
new_schema
) )
if processed_table_args: if processed_table_args:
new_attrs['__table_args__'] = processed_table_args new_attrs["__table_args__"] = processed_table_args
elif new_schema: elif new_schema:
new_attrs['__table_args__'] = {'schema': new_schema} new_attrs["__table_args__"] = {"schema": new_schema}
# Extract and convert columns # Extract and convert columns
columns_info = self.column_converter.extract_column_info(postgres_model_class) columns_info = self.column_converter.extract_column_info(postgres_model_class)
...@@ -587,7 +627,7 @@ class ModelConverter: ...@@ -587,7 +627,7 @@ class ModelConverter:
# Add annotations if any # Add annotations if any
if annotations: if annotations:
new_attrs['__annotations__'] = annotations new_attrs["__annotations__"] = annotations
# Create the new model class # Create the new model class
new_class_name = f"{postgres_model_class.__name__}Databricks" new_class_name = f"{postgres_model_class.__name__}Databricks"
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment