Commit 36df546c authored by harshavardhan.c's avatar harshavardhan.c

feat: initial commit

parent 1b6e4cf1
BASE_PATH=data
MOUNT_DIR=batch-process-app
#POSTGRES_URI=oracledb://admin:UtAdm%23Post271486@192.168.0.220:30178
LOG_LEVEL=DEBUG
SECURE_ACCESS=True
SECURE_COOKIE=False
SW_DOCS_URL=/docs
SW_REDOC_URL=/redoc
SW_OPENAPI_URL=/openapi.json
MODULE_NAME=batch-process-app
DEFER_GEN_REFRESH=True
VERIFY_SIGNATURE=False
##DEV
#MONGO_URI=mongodb://ilens:ilens4321@192.168.0.220:31589?directConnection=true
#REDIS_URI=redis://admin:iLensDevRedis@192.168.0.220:32642
#POSTGRES_URI=postgresql://admin:UtAdm%23Post271486@192.168.0.220:30178
#METADATA_SERVICES_URL=https://dev.unifytwin.com/ilens_api
#HIERARCHY_SERVICES_URL=https://dev.unifytwin.com/hry
####DMPC-DEV
MONGO_URI=mongodb://admin:iLens#DMPCDEVv765@4.213.201.118:27887/?authSource=admin&directConnection=true
POSTGRES_URI=postgresql://admin:UtiLens%23DmpcDev0824@135.235.212.79:5421
REDIS_URI=redis://20.235.217.124:6345
METADATA_SERVICES_URL=https://dmpc-dev.unifytwin.com/ilens_api
HIERARCHY_SERVICES_URL=https://dmpc-dev.unifytwin.com/hry
KAIROS_URI=http://4.224.167.17:8097
MODEL_MANAGEMENT_URL=https://dmpc-dev.unifytwin.com/model_mgmt
BATCH_PROCESS_APP_URL=https://dmpc-dev.unifytwin.com/bpa_app
ARANGO_URI=http://root:root123@14.213.204.58:8886
RULES_ALERTS_SERVICES_URL=https://dmpc-dev.unifytwin.com/awb
KAFKA_URI=kafka://135.235.228.243:9094
DIGITAL_TWIN_SERVICE_URL=https://dmpc-dev.unifytwin.com/scada_dt
MQTT_HOST=192.168.0.220
MQTT_PORT=1883
DATABRICKS_HOST=adb-416418955412087.7.azuredatabricks.net
DATABRICKS_PORT=443
DATABRICKS_HTTP_PATH=sql/protocolv1/o/416418955412087/0702-121224-xsstn0c7
DATABRICKS_ACCESS_TOKEN=dapi72a54657606877a3f7a6d92dd573df28
# #DMPC-QA
#MONGO_URI=mongodb://admin:UtAdm%23Mong539608@20.235.201.97:4623/?directConnection=true
#POSTGRES_URI=postgresql://writeuser:writeuser890@135.235.147.99:5438
#KAIROS_URI=http://20.235.212.64:8076
#REDIS_URI=redis://admin:iLensProdRedis@135.235.226.201:6388
#METADATA_SERVICES_URL=http://192.168.0.221:8712
#HIERARCHY_SERVICES_URL=http://192.168.0.221:8711
#MODEL_MANAGEMENT_URL=http://192.168.0.221:8713
#BATCH_PROCESS_APP_URL=http://localhost:7879
#RULES_ALERTS_SERVICES_URL=https://dmpc-qa.unifytwin.com/awb
#KAFKA_URI=192.168.0.220:9094
# ##DMPC-prod
# MONGO_URI=mongodb://global_read_user:read%23456@192.168.0.207:8098/?directConnection=true
# # POSTGRES_URI=postgresql://readwrite_user:readwrite%23456@192.168.0.207:5349
# POSTGRES_URI=postgresql://global_read_user:read%23956@192.168.0.207:5349
# KAIROS_URI=http://192.168.0.207:5334
# REDIS_URI=redis://admin:iLensProdRedis@192.168.0.207:8213
# METADATA_SERVICES_URL=http://192.168.0.207:9712
# HIERARCHY_SERVICES_URL=http://192.168.0.207:9713
# MODEL_MANAGEMENT_URL=http://192.168.0.207:9714
# RULES_ALERTS_SERVICES_URL=https://dmpc-qa.unifytwin.com/awb
# BATCH_PROCESS_APP_URL=http://localhost:7879a
# KAFKA_URI=192.168.0.220:9094
# # readwrite_user/readwrite#456
##DMPC-Pentest
#MONGO_URI=mongodb://global_read_user:read#456@192.168.0.221:2719/?authSource=admin&directConnection=true
#POSTGRES_URI=postgresql://global_read_user:read#956@192.168.0.221:5433
#REDIS_URI=redis://admin:LOKe1d63eOwN@192.168.0.221:6789
#METADATA_SERVICES_URL=http://192.168.0.221:7111
#HIERARCHY_SERVICES_URL=http://192.168.0.221:7112
#MODEL_MANAGEMENT_URL=http://192.168.0.221:7113
#BATCH_PROCESS_APP_URL=http://localhost:7879
\ No newline at end of file
import logging
from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig
from scripts.engines.agents.model_creator_agent import ModelCreatorAgent
from scripts.schemas import ModelCreatorSchema
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
@broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2)
async def consume_stream_for_processing_dependencies(message: dict):
try:
await ModelCreatorAgent.model_trigger_agent(message=ModelCreatorSchema(meta=message))
return True
except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False
# app.py
import asyncio
import logging as logger
import sys
from dotenv import load_dotenv
load_dotenv()
from agent_subscribers import broker
from faststream import FastStream
from ut_dev_utils import configure_logger
configure_logger()
# Create FastStream app
app = FastStream(broker)
async def run_app():
try:
logger.info("Starting FastStream application...")
await app.run()
except KeyboardInterrupt:
logger.info("Application interrupted by user")
except Exception as e:
logger.error(f"Application error: {e}")
raise
finally:
logger.info("Application shutdown complete")
# Main execution
if __name__ == "__main__":
try:
# For better performance on Linux/Mac, use uvloop if available
if sys.platform != "win32":
try:
import uvloop
asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
logger.info("Using uvloop for better performance")
except ImportError:
logger.info("uvloop not available, using default event loop")
# Run the application
asyncio.run(run_app())
except KeyboardInterrupt:
print("\nApplication stopped by user")
except Exception as e:
logger.error(f"Failed to start application: {e}")
sys.exit(1)
faststream[confluent]==0.5.48
ut-dev-utils[sql,essentials]==1.2
uvloop==0.21.0
\ No newline at end of file
import pathlib
from typing import Optional
from urllib.parse import urlparse
from pydantic import Field, model_validator
from pydantic_settings import BaseSettings
from typing_extensions import Annotated
PROJECT_NAME = "model_management_databricks"
from pydantic.functional_validators import BeforeValidator
def no_trailing_slash(v: str):
return v.rstrip("/")
SafeURLType = Annotated[str, BeforeValidator(no_trailing_slash)]
class _Services(BaseSettings):
# Module Required
MODULE_NAME: str = Field(default=PROJECT_NAME)
PORT: int = Field(default=7124, validation_alias="service_port")
HOST: str = Field(default="0.0.0.0", validation_alias="service_host")
class _RedisConfig(BaseSettings):
REDIS_URI: str
REDIS_PROJECT_TAGS_DB: int = 18
REDIS_DATABRICKS_DB: int = 73
REDIS_GRAPHQL_DB: int = 37
class _ExternalServices(BaseSettings):
METADATA_SERVICES_URL: SafeURLType
HIERARCHY_SERVICES_URL: SafeURLType
class _PathToStorage(BaseSettings):
BASE_PATH: pathlib.Path = Field("/code/data", validation_alias="BASE_PATH")
class _KafkaConfig(BaseSettings):
KAFKA_HOST: Optional[str] = None
KAFKA_PORT: Optional[int] = None
KAFKA_URI: str
KAFKA_MODEL_CREATION_TOPIC: str = Field(default="model_creator")
KAFKA_MODEL_INSTANCE_TOPIC: str = Field(default="model_instance")
ENABLE_KAFKA_PARTITION: bool = Field(default=True)
ROUND_ROBIN_ENABLE: bool = Field(default=False)
PARTITION_DB: int = Field(default=13)
@model_validator(mode="before")
def parse_kafka_uri(cls, values):
if values["KAFKA_URI"]:
streaming_url = urlparse(url=values["KAFKA_URI"])
values["KAFKA_HOST"] = streaming_url.hostname
values["KAFKA_PORT"] = streaming_url.port
return values
class _DatabricksConfig(BaseSettings):
DATABRICKS_HOST: str
DATABRICKS_PORT: int = Field(default=443)
DATABRICKS_URI: str
DATABRICKS_HTTP_PATH: str
DATABRICKS_ACCESS_TOKEN: str
DATABRICKS_CATALOG_NAME: str = Field(default="unified_model")
DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public")
DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical")
DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET")
@model_validator(mode="before")
def prepare_databricks_uri(cls, values):
values[
'DATABRICKS_URI'] = (f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}")
return values
Services = _Services()
RedisConfig = _RedisConfig()
ExternalServices = _ExternalServices()
PathToStorage = _PathToStorage()
KafkaConfig = _KafkaConfig()
DatabricksConfig = _DatabricksConfig()
__all__ = ["Services", "RedisConfig", "ExternalServices", "PathToStorage", "KafkaConfig", "DatabricksConfig"]
# Databricks notebook source
import json
from datetime import datetime
from pyspark.sql.functions import *
from pyspark.sql.types import *
# COMMAND ----------
# Sample Input
#input_message = json.dumps({'data': [{'id': 'l1_100', 'name': 'HM 1', 'description': 'HM 1', 'meta': {'created_by': 'user_099', 'created_on': 1747054186650, 'last_updated_by': 'user_099', 'last_updated_on': 1750163411541}, 'project_id': 'project_787', 'type': 'enterprise', 'latitude': None, 'parameters': [], 'longitude': None, 'is_child': None, 'multi_select_dependent_length': None, 'schema': 'public', 'offline': {'timestamp': 1753769799940}, 'project_type': 'graph_model', 'tz': 'Asia/Kolkata', 'resolution': 'lg', 'language': 'en', 'user_id': 'user_099', 'action_type': 'save'},
#{'id': 'l1_100', 'name': 'HM 2', 'description': 'HM 2', 'meta': {'created_by': 'user_099', 'created_on': 1747054186650, 'last_updated_by': 'user_099', 'last_updated_on': 1750163411541}, 'project_id': 'project_787', 'type': 'enterprise', 'latitude': None, 'parameters': [], 'longitude': None, 'is_child': None, 'multi_select_dependent_length': None, 'schema': 'public', 'offline': {'timestamp': 1753769799940}, 'project_type': 'graph_model', 'tz': 'Asia/Kolkata', 'resolution': 'lg', 'language': 'en', 'user_id': 'user_099', 'action_type': 'save'}], 'project_id': 'project_787',
#'table_properties': {
# 'table_name': 'unified_model.public.enterprise',
# 'table_path': 'abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087/unified_model/public.enterprise'
# }
#}
#)
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
input_message = dbutils.widgets.get("input_message")
# COMMAND ----------
def extract_table_info(input_message_str: str):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
data_payload = message_data['data'] # Full data object
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
'data_payload': data_payload,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def detect_external_table_schema(table_name):
"""
Detect schema of external Delta or Parquet table
Args:
table_name (str): Name of the table (e.g., 'enterprise')
Returns:
pyspark.sql.types.StructType: Schema of the table
"""
try:
# Try to get schema from catalog
table_df = spark.table(table_name)
schema = table_df.schema
print(f"✓ Schema found in metastore for table: {table_name}")
return schema
except Exception as e:
print(f"X Failed to get schema from metastore for table: {table_name}")
return None
# COMMAND ----------
table_info = extract_table_info(input_message)
# COMMAND ----------
schema = detect_external_table_schema(table_info['table_name'])
if schema is None:
raise ValueError(f"Schema not found for table: {table_name}")
# COMMAND ----------
data_df = spark.createDataFrame(table_info['data_payload'], schema=schema)
#data_df.show()
# COMMAND ----------
data_df.write \
.mode("append") \
.format("parquet") \
.save(table_info['table_properties']['table_path'])
\ No newline at end of file
import json
import logging
from sqlalchemy import MetaData
from sqlalchemy.orm import declarative_base
from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtils
from scripts.config import DatabricksConfig
from scripts.db.databricks import DataBricksSQLLayer
from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.databricks.notebook_manager import NotebookManager
from scripts.db.redis.databricks_details import databricks_details_db
from scripts.db.redis.project_details import fetch_level_details, project_template_keys
from scripts.schemas import ModelCreatorSchema
from scripts.utils.model_convertor_utils import ModelConverter
class ModelCreatorHandler:
def __init__(self, message: ModelCreatorSchema, declarative_utils: DeclarativeUtils):
self.declarative_utils = declarative_utils
self.meta = message.meta
self.message = message
self.model_convertor = ModelConverter()
self.job_manager = DatabricksJobManager(
databricks_host=message.databricks_host,
access_token=message.databricks_access_token
)
self.notebook_manager = NotebookManager(
databricks_host=message.databricks_host,
access_token=message.databricks_access_token
)
self.databricks_sql_obj = DataBricksSQLLayer(
catalog_name=DatabricksConfig.DATABRICKS_CATALOG_NAME,
project_id=self.meta.project_id,
schema=message.schema
)
self.external_location = f"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
@staticmethod
def create_schema_base(schema_name: str):
"""Create a Base class for a specific schema"""
metadata = MetaData(schema=schema_name)
return declarative_base(metadata=metadata)
async def create_models_in_unity_catalog(self):
overall_tables = self.get_overall_tables()
project_levels = project_template_keys(self.meta.project_id, levels=True)
schema = f'{self.databricks_sql_obj.catalog_name}.{self.message.schema}'
base = self.create_schema_base(schema)
try:
self.databricks_sql_obj.connect_to_databricks()
# _ = self.setup_dependencies_for_unity_catalog()
table_properties = self.fetch_table_properties()
for table in overall_tables:
table_class = self.declarative_utils.get_declarative_class(table)
if not table_class:
logging.error(f"Table class not found for table: {table}")
return False
new_model = self.model_convertor.convert_model(
table_class,
base_class=base,
new_schema=self.message.schema,
)
self.databricks_sql_obj.create_external_table_from_structure(
table=new_model.__table__,
file_format="DELTA",
external_location=self.external_location,
table_properties=table_properties
)
self.databricks_sql_obj.create_timeseries_table(columns=project_levels,
external_location=self.external_location)
self.setup_notepads_and_jobs()
except Exception as e:
logging.error(f"Error occurred while creating models in Unity Catalog: {e}")
return False
finally:
self.databricks_sql_obj.__del__()
def get_physical_tables(self):
return fetch_level_details(project_id=self.meta.project_id, keys=True)
def get_overall_tables(self):
tables = self.get_physical_tables()
return tables
def setup_dependencies_for_unity_catalog(self, analytical=False):
"""
Complete setup of catalog and schema
Args:
analytical (bool): Flag to indicate if the setup is for analytical or not
"""
logging.info(f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}'")
self.databricks_sql_obj.connect_to_databricks()
# Create catalog
catalog_success = self.databricks_sql_obj.create_catalog(
managed_location=f'{self.external_location}/{self.databricks_sql_obj.catalog_name}',
)
if not catalog_success:
return False
# Create schema
schema_success = self.databricks_sql_obj.create_schema(DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME)
if not schema_success:
return False
if analytical:
schema_success = self.databricks_sql_obj.create_schema(DatabricksConfig.DATABRICKS_ANALYTICAL_SCHEMA_NAME)
if not schema_success:
return False
return True
def setup_notepads_and_jobs(self):
logging.info("Setting up notepads and jobs")
with open(r"scripts/constants/notebooks/metadata_ingestion.txt", "r") as f:
notebook_code = f.read()
self.notebook_manager.create_notebook(
notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
notebook_code=notebook_code,
overwrite=True
)
#Job for metadata ingestion used by model management
job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
job_name="metadata_ingestion_job",
notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
))
redis_dict = {"metadata_ingestion_job": job_id}
databricks_details_db.hset(self.meta.project_id, json.dumps(redis_dict))
@staticmethod
def fetch_table_properties(file_format: str = 'DELTA'):
if file_format.lower() == 'delta':
return {
# Performance optimization (Essential)
"delta.autoOptimize.optimizeWrite": "true",
"delta.autoOptimize.autoCompact": "true",
"delta.targetFileSize": "134217728", # 128MB
'delta.enableChangeDataFeed': 'true', # If you need CDC
# Checkpoint optimization (Performance boost)
"delta.checkpoint.writeStatsAsStruct": "true",
"delta.checkpoint.writeStatsAsJson": "false"
# Note: Retention properties removed - using defaults:
# delta.deletedFileRetentionDuration = 7 days (default)
# delta.logRetentionDuration = 30 days (default)
}
elif file_format.lower() == 'parquet':
return {"parquet.compression": "snappy",
"parquet.page.size": "1048576", # 1MB - standard for mixed queries
"parquet.block.size": "134217728", # 128MB - balanced performance
"serialization.format": "1"}
else:
return {}
from typing import Dict, List
from sqlalchemy import Table, Column, String, BigInteger, DateTime, MetaData, Integer, Date
from scripts.utils.databricks_utils import DatabricksSQLUtility
from scripts.utils.model_convertor_utils import TypeMapper
class DataBricksSQLLayer(DatabricksSQLUtility):
def __init__(self, catalog_name: str, project_id: str, schema: str):
super().__init__(catalog_name, project_id)
self.catalog_name = catalog_name
self.schema = schema
def create_external_table_from_structure(self, table: Table,
external_location: str,
file_format: str = "PARQUET",
table_properties: Dict[str, str] = None,
partition_columns: list = None):
"""
Create an external table from a model class.
Args:
table: The model class to create the external table from.
external_location: The external location path.
file_format: The file format of the data files.
table_properties: Additional table properties.
partition_columns: List of columns to partition the table by.
"""
schema_table = f"{table.schema}.{table.name}" if table.schema else table.name
columns_sql = TypeMapper().extract_columns_without_constraints(table)
external_location = f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
sql_parts = [
f"CREATE TABLE IF NOT EXISTS {schema_table}",
f"({columns_sql})",
f"USING {file_format}",
f"LOCATION '{external_location}'"
]
if partition_columns:
partition_clause = ", ".join(partition_columns)
sql_parts.append(f"PARTITIONED BY ({partition_clause})")
if table_properties:
props = [f"'{k}' = '{v}'" for k, v in table_properties.items()]
props_sql = ",\n ".join(props)
sql_parts.append(f"TBLPROPERTIES (\n {props_sql}\n)")
create_sql = "\n".join(sql_parts)
self.execute_sql_statement(create_sql)
return True
def create_timeseries_table(self, columns: List[str], external_location: str):
"""
Create a timeseries table model and all columns will be of type String
Args:
columns: List of columns in the table
external_location: The external location path
Example:
columns = [l1,l2,enterprise]
Returns:
Timeseries Table model
"""
table_columns = [
Column('timestamp', BigInteger, nullable=False),
Column('dt_timestamp', DateTime, nullable=False),
Column('dt_date', Date, nullable=False),
Column('dt_hour', Integer, nullable=False),
Column('value', String, nullable=False),
Column('value_type', String, nullable=False, default='float'),
Column("c3", String, nullable=False)
]
default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns]
table_columns.extend([Column(col_name, String, nullable=True) for col_name in default_columns])
partition_columns = ['dt_date', 'dt_hour', 'c3']
table_properties = {
"parquet.compression": "snappy", # Fast decompression for frequent queries
"parquet.page.size": "524288", # 512KB - better time-range filtering
"parquet.block.size": "268435456", # 256MB - efficient sequential reads
"serialization.format": "1" # Support for arrays/complex types
}
table_obj = Table(
"timeseries_data",
MetaData(),
*table_columns,
schema=self.schema
)
self.create_external_table_from_structure(
table=table_obj,
external_location=external_location,
partition_columns=partition_columns,
table_properties=table_properties
)
return table_obj
import logging
from ut_security_util.security_tools.auth_util import HTTPXRequestHandler
from scripts.utils.httpx_util import HTTPXRequestUtil
class DatabricksJobManager:
def __init__(self, databricks_host: str, access_token: str):
"""
Initialize Databricks job manager
Args:
databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token
"""
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}"
self.headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
def create_job(self, job_config: dict):
"""
Create a new job in Databricks
Args:
job_config: Dictionary containing job configuration
"""
url = f"{self.host}/api/2.1/jobs/create"
response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config)
if response.status_code == 200:
job_id = response.json()['job_id']
logging.info(f"Job created successfully with ID: {job_id}")
return job_id
else:
logging.error(f"Error creating job: {response.status_code} - {response.text}")
return None
def run_job(self, job_id: str, parameters=None):
"""
Run a job with optional parameters
Args:
job_id: The ID of the job to run
parameters: Dictionary of parameters to pass to the job
"""
url = f"{self.host}/api/2.1/jobs/run-now"
payload = {"job_id": job_id}
if parameters:
payload["notebook_params"] = parameters
response = HTTPXRequestHandler(url).post(headers=self.headers, json=payload)
if response.status_code == 200:
run_id = response.json()['run_id']
logging.info(f"Job run started with ID: {run_id}")
return run_id
else:
logging.error(f"Error running job: {response.status_code} - {response.text}")
return None
def get_run_status(self, run_id):
"""
Get the status of a job run
Args:
run_id: The ID of the job run
"""
url = f"{self.host}/api/2.1/jobs/runs/get"
params = {"run_id": run_id}
response = HTTPXRequestHandler(url).get(url, headers=self.headers, params=params)
if response.status_code == 200:
return response.json()
else:
logging.error(f"Error getting run status: {response.status_code} - {response.text}")
return None
@staticmethod
def create_job_config_for_serverless(notebook_path: str, job_name: str):
"""
Create job configuration for a parameterized notebook
Args:
notebook_path: Path to the notebook in Databricks workspace
job_name: Name of the job
"""
return {
"name": job_name,
"tasks": [
{
"task_key": "table_update_task",
"notebook_task": {
"notebook_path": notebook_path,
"base_parameters": {
"input_message": "default_value"
}
},
"timeout_seconds": 3600
}
],
"max_concurrent_runs": 10,
"tags": {
"purpose": "metadata_ingestion",
"compute_type": "serverless"
}
}
import base64
import logging
from scripts.utils.httpx_util import HTTPXRequestUtil
class NotebookManager:
def __init__(self, databricks_host, access_token):
"""
Initialize Databricks connection
Args:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token
"""
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}"
self.headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
}
def create_notebook(self, notebook_path, notebook_code: str, language='PYTHON', overwrite=True):
"""
Create a notebook in Databricks workspace
Args:
notebook_path: Path where notebook will be created (e.g., '/Users/your-email/my-notebook')
notebook_code: Python code as string
language: Notebook language ('PYTHON', 'SQL', 'SCALA', 'R')
overwrite: Whether to overwrite existing notebook
"""
url = f"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64
encoded_content = base64.b64encode(notebook_code.encode('utf-8')).decode('utf-8')
payload = {
"path": notebook_path,
"format": "SOURCE",
"language": language,
"content": encoded_content,
"overwrite": overwrite
}
response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers)
if response.status_code == 200:
logging.info(f"Notebook created successfully at: {notebook_path}")
return True
else:
logging.error(f"Error creating notebook: {response.status_code} - {response.text}")
return False
from ut_sql_utils.asyncio import SQLSessionManager
from scripts.db.redis.project_details import project_details_db
session_manager = SQLSessionManager(project_details_db)
from ut_redis_connector import RedisConnector
from scripts.config import RedisConfig
redis_connector = RedisConnector(redis_uri=RedisConfig.REDIS_URI)
import orjson
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
databricks_details_db = redis_connector.connect(db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True)
\ No newline at end of file
from typing import Annotated, Any, Dict
import orjson
from fastapi import Query
from pydantic import ValidationInfo
from ut_dev_utils import ILensErrors
from ut_sql_utils.config import PostgresConfig
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
graphql_details_db = redis_connector.connect(db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True)
def get_models(
model: str,
info: ValidationInfo,
schema: Annotated[str, Query()] = PostgresConfig.PG_DEFAULT_SCHEMA,
no_error: bool = False,
*__args__,
**__kwargs__,
) -> str | None:
"""
Fetches the GraphQL details of the specific project and checks if a specific model/table is present in the schema.
Args:
model (str): The name of the table/model to check.
info (ValidationInfo): Validation information containing project details.
schema (str, optional): The schema name, default is PostgresConfig.PG_DEFAULT_SCHEMA.
__args__ (tuple): Additional positional arguments (not used).
__kwargs__ (dict): Additional keyword arguments (not used).
Returns:
str: The model name if found in the schema.
Raises:
ILensErrors: If the model is not found in the schema or if project data is unavailable.
"""
tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper")
if tables_data is None:
raise ILensErrors(f"No GraphQL schema data found for project {info.data['project_id']}")
tables: Dict[str, Any] = orjson.loads(tables_data) or {}
if (
"overall_mapping" in tables
and isinstance(tables["overall_mapping"], dict)
and tables["overall_mapping"].get(schema)
):
tables = tables["overall_mapping"][schema]
if model not in tables:
if no_error:
return None
raise ILensErrors(f"Model '{model}' not found in schema '{schema}'")
return model
def fetch_cache_info_for_graph_model(project_id: str, expected_key: str = "hierarchy"):
data = graphql_details_db.hgetall(project_id)
if not data:
return {}
return orjson.loads(data.get(expected_key, {}))
import orjson
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
project_details_db = redis_connector.connect(db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True)
def get_project_time_zone(project_id: str):
"""
Function to get project time zone
Uses redis project details cache db (db18) and fetches the time zone
"""
project_details = project_details_db.get(project_id)
if project_details:
project_details = orjson.loads(project_details)
return project_details.get("time_zone")
else:
return "UTC"
def fetch_level_details(project_id: str, keys: bool = False, raw: bool = False) -> dict | list:
"""
Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details
Returns a dictionary of level details or a list of keys if requested.
"""
project_details = project_details_db.get(project_id)
if not project_details:
return {}
project_details = orjson.loads(project_details)
if raw:
return project_details
if keys:
return list(project_details.get("counter_levels", {}).keys())
return project_details.get("counter_levels", {})
def fetch_ast_level(project_id: str):
"""
Function to fetch ast level from project details
Uses redis project details cache db (db18) and fetches the ast level
Returns the ast level
"""
project_details = project_details_db.get(project_id)
if not project_details:
return {}
project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {})
for k, v in counter_levels.items():
if v == "ast":
return k
return ""
def fetch_asset_level(project_id: str) -> str:
project_details = project_details_db.get(project_id)
if not project_details:
return ""
project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {})
asset_level = (
counter_levels.get("asset", counter_levels.get("equipment")) if isinstance(counter_levels, dict) else None
)
if asset_level:
return asset_level
key_list = project_details.get("key_list")
if key_list and isinstance(key_list, list):
return key_list[-1]
return ""
def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
project_details = project_details_db.get(project_id)
if not project_details:
return "", ""
project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {})
if not counter_levels:
return "", ""
swapped_dict = {v: k for k, v in counter_levels.items()}
return swapped_dict.get("ast", ""), "ast"
def project_template_keys(project_id: str, levels=False):
val = project_details_db.get(project_id)
if val is None:
raise ValueError(f"Unknown Project, Project ID:{project_id}Not Found!!!")
val = orjson.loads(val)
return val.get("levels", {}) if levels else list(val.get("levels", {}).keys())
\ No newline at end of file
from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtilsFactory
from scripts.core.handlers.model_creator_handler import ModelCreatorHandler
from scripts.db.psql import session_manager
from scripts.schemas import ModelCreatorSchema
class ModelCreatorAgent:
def __init__(self):
...
@staticmethod
async def model_trigger_agent(message: ModelCreatorSchema):
declarative_utils = await DeclarativeUtilsFactory.get_declarative_utils(
raw_database="unified_model",
project_id=message.meta.project_id,
session_manager=session_manager,
schema=message.schema,
)
model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils)
await model_cal_obj.create_models_in_unity_catalog()
from typing import Optional
from pydantic import BaseModel
from ut_security_util import MetaInfoSchema
from scripts.config import DatabricksConfig
class ModelCreatorSchema(BaseModel):
meta: MetaInfoSchema
schema: Optional[str] = DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME
databricks_host: str = DatabricksConfig.DATABRICKS_HOST
databricks_port: int = DatabricksConfig.DATABRICKS_PORT
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
databricks_http_path: str = DatabricksConfig.DATABRICKS_HTTP_PATH
databricks_user_email: str = "aniket.dhale@ilenscloud.onmicrosoft.com"
import logging as logger
from typing import Optional
from sqlalchemy import create_engine, text
from ut_dev_utils import get_db_name
from scripts.config import DatabricksConfig
class DatabricksSQLUtility:
def __init__(self, catalog_name: str, project_id: str):
"""
Initialize Databricks connectivity setup
Args:
catalog_name: Name of the catalog to create
project_id: Project ID
"""
# self.catalog_name = get_db_name(project_id=project_id, database=catalog_name)
self.catalog_name = catalog_name
# self.catalog_name = catalog_name
self.engine = None
def connect_to_databricks(self):
"""
Connect to Databricks using sqlalchemy-databricks
"""
try:
# Build connection string for sqlalchemy-databricks
self.engine = create_engine(
DatabricksConfig.DATABRICKS_URI,
pool_pre_ping=True,
pool_recycle=3600,
echo=False
)
# Test connection
with self.engine.connect() as conn:
result = conn.execute(text("SELECT current_user() as user, current_catalog() as catalog"))
user_info = result.fetchone()
logger.info(f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}")
logger.info("Successfully connected to Databricks")
return True
except Exception as e:
logger.error(f"Failed to connect to Databricks: {str(e)}")
return False
def __del__(self):
if self.engine:
self.engine.dispose()
def create_catalog(self, managed_location: Optional[str] = None, comment: Optional[str] = None,
properties: Optional[dict] = None):
"""
Create a new catalog in Unity Catalog
Args:
managed_location: Optional managed storage location path
comment: Optional description
properties: Optional catalog properties
Returns:
Name of the created catalog
"""
try:
ddl = f"CREATE CATALOG IF NOT EXISTS `{self.catalog_name}`"
if managed_location:
ddl += f"\nMANAGED LOCATION '{managed_location}'"
if comment:
ddl += f"\nCOMMENT '{comment}'"
if properties:
props = ", ".join([f"'{k}' = '{v}'" for k, v in properties.items()])
ddl += f"\nWITH DBPROPERTIES ({props})"
self.execute_sql_statement(ddl)
logger.info(f"Catalog '{self.catalog_name}' created successfully")
use_catalog = f"USE CATALOG `{self.catalog_name}`"
self.execute_sql_statement(use_catalog)
logger.info(f"Switched to catalog '{self.catalog_name}'")
return self.catalog_name
except Exception as e:
logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}")
raise
def create_schema(self, schema_name: str, managed_location: Optional[str] = None, comment: Optional[str] = None,
properties: Optional[dict] = None):
"""
Create a new schema within a catalog
Args:
schema_name: Name of the schema to create
managed_location: Optional managed storage location path
comment: Optional description
properties: Optional schema properties
Returns:
Name of the created schema
"""
try:
full_schema_name = f"`{self.catalog_name}`.`{schema_name}`"
ddl = f"CREATE SCHEMA IF NOT EXISTS {full_schema_name}"
if managed_location:
ddl += f"\nMANAGED LOCATION '{managed_location}'"
if comment:
ddl += f"\nCOMMENT '{comment}'"
if properties:
props = ", ".join([f"'{k}' = '{v}'" for k, v in properties.items()])
ddl += f"\nWITH DBPROPERTIES ({props})"
self.execute_sql_statement(ddl)
logger.info(f"Schema '{self.catalog_name}.{schema_name}' created successfully")
return full_schema_name
except Exception as e:
logger.error(f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}")
raise
def create_external_location(
self,
location_name: str,
storage_path: str,
credential_name: str,
comment: Optional[str] = None
) -> str:
"""
Create an external location in Unity Catalog
Args:
location_name: Name for the external location
storage_path: Cloud storage path (e.g., 'abfss://container@storageaccount.dfs.core.windows.net/path')
credential_name: Name of the storage credential to use
comment: Optional description
Returns:
Name of the created external location
"""
# Build the CREATE EXTERNAL LOCATION statement
ddl = f"CREATE EXTERNAL LOCATION IF NOT EXISTS `{location_name}`"
ddl += f"\nURL '{storage_path}'"
ddl += f"\nWITH (CREDENTIAL `{credential_name}`)"
if comment:
ddl += f"\nCOMMENT '{comment}'"
try:
self.execute_sql_statement(ddl)
logger.info(f"External location '{location_name}' created successfully")
return location_name
except Exception as e:
logger.error(f"Failed to create external location '{location_name}': {str(e)}")
raise
def execute_sql_statement(self, query: str):
try:
with self.engine.connect() as conn:
conn.execute(text(query))
conn.commit()
logger.info(f"Query '{query}' executed successfully")
except Exception as e:
logger.error(f"Failed to execute query '{query}': {str(e)}")
raise
import logging
from urllib.parse import urlparse
import httpx
class HTTPXRequestUtil:
def __init__(self, url, time_out=None) -> None:
self.time_out = time_out
self.url = url
self.verify = False
self.verify_request()
@property
def get_timeout(self):
return self.time_out
def delete(self, path="", params=None, **kwargs) -> httpx.Response:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
response: httpx.Response = client.delete(url=url, params=params)
return response
def put(self, path="", json=None, data=None, **kwargs) -> httpx.Response:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
response: httpx.Response = client.put(url=url, data=data, json=json)
return response
def post(self, path="", json=None, data=None, **kwargs) -> httpx.Response:
"""
:param path:
:param json:
:param data:
:param kwargs:
:return:
"""
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
response: httpx.Response = client.post(url=url, data=data, json=json)
return response
def get(self, path="", params=None, **kwargs) -> httpx.Response:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
response: httpx.Response = client.get(url=url, params=params)
return response
def get_url(self, path):
if path:
return f"{self.url.rstrip('/')}/{path.lstrip('/').rstrip('/')}"
return self.url.rstrip("/")
def verify_request(self):
if self.url_scheme(self.url) == "https":
self.verify = True
return self.verify
@staticmethod
def url_scheme(url):
return urlparse(url).scheme
import logging
from typing import Any, Type, Optional, Dict, Union, Tuple
from databricks.sqlalchemy import TIMESTAMP
from sqlalchemy import (
Integer, BigInteger, SmallInteger, String, Text, Boolean,
Date, DateTime, Time, Numeric, Float, DECIMAL, CHAR, VARCHAR,
LargeBinary, JSON, Column, PrimaryKeyConstraint, UniqueConstraint, ForeignKeyConstraint, Double, Index, Table
)
from sqlalchemy.dialects import postgresql
from sqlalchemy.orm import class_mapper, mapped_column, Mapped
from sqlalchemy.types import UserDefinedType
class DatabricksArrayType(UserDefinedType):
"""
Custom SQLAlchemy type for Databricks ARRAY<type> columns.
This type generates the correct DDL for Databricks array columns
and handles Python list serialization/deserialization.
"""
def __init__(self, element_type: str = "STRING"):
"""
Initialize Databricks array type.
Args:
element_type: Databricks element type (STRING, INT, BOOLEAN, etc.)
"""
self.element_type = element_type.upper()
def get_col_spec(self):
"""Return the DDL type specification for Databricks."""
return f"ARRAY<{self.element_type}>"
def bind_processor(self, dialect):
"""Process values before sending to database."""
def process(value):
if value is not None:
# Ensure it's a list
if not isinstance(value, list):
return [value]
return value
return value
return process
def result_processor(self, dialect, coltype):
"""Process values when retrieving from database."""
def process(value):
# Databricks returns arrays as lists already
return value
return process
def __repr__(self):
return f"DatabricksArrayType({self.element_type})"
class TypeMapper:
# Core type mapping dictionary
POSTGRES_TO_DATABRICKS_MAPPING = {
postgresql.SMALLINT: SmallInteger,
postgresql.INTEGER: Integer,
postgresql.BIGINT: BigInteger,
postgresql.NUMERIC: Numeric,
postgresql.REAL: Float,
postgresql.DOUBLE_PRECISION: Float,
SmallInteger: SmallInteger,
Integer: Integer,
BigInteger: BigInteger,
Numeric: Numeric,
DECIMAL: DECIMAL,
Float: Float,
Double: Float,
# String types
postgresql.CHAR: CHAR,
postgresql.VARCHAR: VARCHAR,
postgresql.TEXT: String,
CHAR: CHAR,
VARCHAR: VARCHAR,
Text: String,
String: String,
# DateTime types
postgresql.DATE: Date,
postgresql.TIME: String,
postgresql.TIMESTAMP: TIMESTAMP,
postgresql.INTERVAL: String,
Date: Date,
Time: String,
DateTime: DateTime,
# Boolean
postgresql.BOOLEAN: Boolean,
Boolean: Boolean,
# Binary
postgresql.BYTEA: LargeBinary,
LargeBinary: LargeBinary,
# JSON
postgresql.JSON: String,
postgresql.JSONB: String,
JSON: String,
# Array
postgresql.ARRAY: String,
# PostgreSQL specific
postgresql.UUID: String,
postgresql.INET: String,
postgresql.CIDR: String,
postgresql.MACADDR: String,
postgresql.TSVECTOR: String,
postgresql.TSQUERY: String,
postgresql.BIT: String,
postgresql.HSTORE: String,
}
SQL_TO_DATABRICKS_MAPPING = {
'VARCHAR': 'STRING',
'INTEGER': 'INT',
'BIGINT': 'BIGINT', # Keep as is
'FLOAT': 'DOUBLE',
'BOOLEAN': 'BOOLEAN', # Keep as is
'TIMESTAMP': 'TIMESTAMP', # Keep as is
'DATETIME': 'TIMESTAMP', # Change this mapping
'TEXT': 'STRING',
# Arrays and complex types are already correct, no replacement needed
}
@classmethod
def get_databricks_type(cls, sql_type: Any) -> Any:
"""
Convert SQLAlchemy type to Databricks equivalent.
Args:
sql_type: SQLAlchemy type instance
Returns:
Databricks SQLAlchemy type instance
"""
if sql_type is None:
return String()
base_type = type(sql_type)
# Handle special cases first
if base_type == postgresql.ARRAY or 'ARRAY' in str(sql_type):
return cls._convert_array_type_fallback(sql_type)
# Get the mapped type
if base_type in cls.POSTGRES_TO_DATABRICKS_MAPPING:
return cls.POSTGRES_TO_DATABRICKS_MAPPING[base_type]()
logging.info(f'Defaulting to String() for type: {type(sql_type)}')
return String()
@classmethod
def _get_databricks_array_element_type(cls, postgres_element_type: Any) -> str:
"""
Map PostgreSQL array element type to Databricks array element type string.
Args:
postgres_element_type: PostgreSQL element type
Returns:
Databricks type string for array elements
"""
element_type_name = type(postgres_element_type).__name__.upper()
# Map PostgreSQL types to Databricks array element types
if any(t in element_type_name for t in ['VARCHAR', 'TEXT', 'STRING', 'CHAR']):
return "STRING"
elif any(t in element_type_name for t in ['INTEGER', 'BIGINT', 'SMALLINT']):
return "INT" if 'SMALLINT' not in element_type_name else "SMALLINT"
elif 'BIGINT' in element_type_name:
return "BIGINT"
elif any(t in element_type_name for t in ['BOOLEAN', 'BOOL']):
return "BOOLEAN"
elif any(t in element_type_name for t in ['FLOAT', 'REAL', 'DOUBLE']):
return "DOUBLE"
elif any(t in element_type_name for t in ['NUMERIC', 'DECIMAL']):
return "DECIMAL"
elif 'DATE' in element_type_name:
return "DATE"
elif 'TIMESTAMP' in element_type_name:
return "TIMESTAMP"
else:
return "STRING" # Default fallback
@classmethod
def _convert_array_type_fallback(cls, array_type: postgresql.ARRAY):
"""
Fallback method for array conversion when DatabricksArray is not available.
Args:
array_type: PostgreSQL ARRAY type instance
Returns:
String representation of Databricks array type (e.g., "ARRAY<INT>")
"""
# Get the item type from the array
item_type = array_type.item_type
databricks_element_type = cls._get_databricks_array_element_type(item_type)
return DatabricksArrayType(databricks_element_type)
@classmethod
def extract_columns_without_constraints(cls, table: Table) -> str:
"""
Extract column definitions without any constraints from SQLAlchemy Table
"""
column_definitions = []
for column in table.columns:
# Get column name
col_name = column.name
# Get column type and convert using mapping
col_type = str(column.type).upper()
databricks_type = cls.SQL_TO_DATABRICKS_MAPPING.get(col_type, col_type)
# Handle nullable (skip primary key constraint)
nullable_clause = "" if column.nullable else " NOT NULL"
default_clause = ""
if column.default is not None:
default_value = column.default.arg if hasattr(column.default, 'arg') else column.default
if isinstance(default_value, str):
default_clause = f" DEFAULT '{default_value}'"
elif isinstance(default_value, bool):
default_clause = f" DEFAULT {str(default_value).upper()}"
else:
default_clause = f" DEFAULT {default_value}"
# Build column definition (without primary key constraint)
col_def = f"\t{col_name} {databricks_type}{nullable_clause}{default_clause}"
column_definitions.append(col_def)
return ",\n".join(column_definitions)
class ColumnConverter:
"""Handles individual column conversion from PostgreSQL to Databricks."""
def __init__(self, type_mapper: Optional[TypeMapper] = None):
"""Initialize with optional custom type mapper."""
self.type_mapper = type_mapper or TypeMapper()
def extract_column_info(self, model_class) -> Dict[str, Dict[str, Any]]:
"""
Extract column information from a SQLAlchemy model class.
Args:
model_class: SQLAlchemy model class
Returns:
Dictionary with column information
"""
columns_info = {}
# # Check for modern SQLAlchemy with annotations and mapped_column
if hasattr(model_class, '__annotations__'):
columns_info.update(self.extract_from_annotations(model_class))
# Fallback: Try mapper approach for traditional models
if not columns_info:
try:
mapper = class_mapper(model_class)
for column_name, column in mapper.columns.items():
if column_name not in columns_info:
columns_info[column_name] = self.extract_column_properties(column)
except Exception as e:
logging.error(f"Failed to extract column info using mapper: {e}")
# Final fallback: inspect class attributes directly
columns_info.update(self._extract_from_class_attributes(model_class))
return columns_info
@staticmethod
def extract_column_properties(column: Any) -> Dict[str, Any]:
"""Extract properties from a column object."""
return {
'type': getattr(column, 'type', None),
'primary_key': getattr(column, 'primary_key', False),
'nullable': getattr(column, 'nullable', True),
'default': getattr(column, 'default', None),
'server_default': getattr(column, 'server_default', None),
'uses_mapped_column': False,
}
def _extract_from_class_attributes(self, model_class: type) -> Dict[str, Dict[str, Any]]:
"""Extract column info from class attributes."""
columns_info = {}
for attr_name in dir(model_class):
if attr_name.startswith('_'):
continue
attr = getattr(model_class, attr_name, None)
if attr is None:
continue
# Check for Column objects
if isinstance(attr, Column):
columns_info[attr_name] = self.extract_column_properties(attr)
# Check for mapped_column objects
elif hasattr(attr, 'type') and hasattr(attr, 'nullable'):
columns_info[attr_name] = {
'type': getattr(attr, 'type', None),
'primary_key': getattr(attr, 'primary_key', False),
'nullable': getattr(attr, 'nullable', True),
'default': getattr(attr, 'default', None),
'server_default': getattr(attr, 'server_default', None),
'uses_mapped_column': True,
}
return columns_info
@staticmethod
def extract_from_annotations(model_class: type) -> Dict[str, Dict[str, Any]]:
"""Extract column info from type annotations (modern SQLAlchemy)."""
columns_info = {}
annotations = getattr(model_class, '__annotations__', {})
for attr_name, annotation in annotations.items():
if hasattr(model_class, attr_name):
attr = getattr(model_class, attr_name)
# Check if it's a mapped_column
if hasattr(attr, 'type'):
columns_info[attr_name] = {
'type': attr.type,
'primary_key': getattr(attr, 'primary_key', False),
'nullable': getattr(attr, 'nullable', True),
'default': getattr(attr, 'default', None),
'server_default': getattr(attr, 'server_default', None),
'annotation': annotation,
'uses_mapped_column': True,
}
return columns_info
def convert_column(self, column_info: Dict[str, Any]) -> tuple:
"""
Convert a single column from PostgreSQL to Databricks format.
Args:
column_info: Column information dictionary
Returns:
Tuple of (column_object, annotation_if_any)
"""
# Convert the column type
new_type = self.type_mapper.get_databricks_type(column_info['type'])
# Check if this uses type annotations (modern approach)
if column_info.get('uses_mapped_column', False):
new_column = mapped_column(
new_type,
primary_key=column_info.get('primary_key', False),
nullable=column_info.get('nullable', True),
default=column_info.get('default'),
server_default=column_info.get('server_default'),
)
# Convert annotation
annotation = self.convert_annotation(column_info.get('annotation'), new_type)
return new_column, annotation
else:
# Traditional Column approach
new_column = Column(
new_type,
primary_key=column_info.get('primary_key', False),
nullable=column_info.get('nullable', True),
default=column_info.get('default'),
server_default=column_info.get('server_default'),
)
return new_column, None
@staticmethod
def convert_annotation(annotation: Any, databricks_type: Any = None) -> Any:
"""Convert type annotations for Databricks compatibility."""
from typing import Optional, List
if annotation is None:
return Mapped[Optional[str]]
annotation_str = str(annotation)
# Handle array/list types -> convert to List annotation
if any(keyword in annotation_str for keyword in ['list', 'List']):
if 'Optional' in annotation_str or 'Union' in annotation_str:
return Mapped[Optional[List[str]]] # Default to List[str]
else:
return Mapped[List[str]]
# Handle JSON/JSONB/dict types -> convert to string
if any(keyword in annotation_str for keyword in ['dict', 'Dict', 'json', 'Json']):
if 'Optional' in annotation_str or 'Union' in annotation_str:
return Mapped[Optional[str]]
else:
return Mapped[str]
# Handle basic types based on the databricks type
if databricks_type:
type_str = str(type(databricks_type).__name__).lower()
if 'array' in type_str:
# For array types, use List annotation
if 'Optional' in annotation_str:
return Mapped[Optional[List[str]]] # Could be more specific based on element type
return Mapped[List[str]]
elif 'integer' in type_str or 'biginteger' in type_str or 'smallinteger' in type_str:
if 'Optional' in annotation_str:
return Mapped[Optional[int]]
return Mapped[int]
elif 'boolean' in type_str:
if 'Optional' in annotation_str:
return Mapped[Optional[bool]]
return Mapped[bool]
elif 'float' in type_str or 'numeric' in type_str:
if 'Optional' in annotation_str:
return Mapped[Optional[float]]
return Mapped[float]
elif 'datetime' in type_str:
from datetime import datetime
if 'Optional' in annotation_str:
return Mapped[Optional[datetime]]
return Mapped[datetime]
# Default to string
if 'Optional' in annotation_str or 'Union' in annotation_str:
return Mapped[Optional[str]]
return Mapped[str]
class SchemaProcessor:
"""Handles schema and table arguments processing for model conversion."""
def __init__(self):
self.type_mapper = TypeMapper()
@staticmethod
def process_table_args(
original_table_args: Any,
new_schema: Optional[str] = None
) -> Union[Tuple, Dict, None]:
"""
Process table arguments, handling constraints and schema conversion.
Args:
original_table_args: Original __table_args__ from PostgreSQL model
new_schema: Optional new schema for Unity Catalog
Returns:
Processed table arguments
"""
if not original_table_args:
if new_schema:
return {'schema': new_schema}
return None
new_table_args = []
table_kwargs = {}
# Handle tuple/list format: (constraint1, constraint2, {...})
if isinstance(original_table_args, (tuple, list)):
for arg in original_table_args:
if isinstance(arg, dict):
# Process dictionary part
processed_kwargs = SchemaProcessor._process_table_kwargs(arg, new_schema)
table_kwargs.update(processed_kwargs)
elif isinstance(arg, (Index, ForeignKeyConstraint)):
continue
elif isinstance(arg, (PrimaryKeyConstraint, UniqueConstraint)):
# Keep constraints (though they may need adjustment)
new_table_args.append(arg)
else:
# Keep other arguments
new_table_args.append(arg)
# Handle dictionary format: {'schema': 'public', 'extend_existing': True}
elif isinstance(original_table_args, dict):
table_kwargs = SchemaProcessor._process_table_kwargs(original_table_args, new_schema)
# Add new schema if specified and not already set
if new_schema is not None and 'schema' not in table_kwargs:
table_kwargs['schema'] = new_schema
# Construct result
if new_table_args and table_kwargs:
new_table_args.append(table_kwargs)
return tuple(new_table_args)
elif table_kwargs:
return table_kwargs
elif new_table_args:
return tuple(new_table_args)
return None
@staticmethod
def _process_table_kwargs(kwargs: Dict[str, Any], new_schema: Optional[str]) -> Dict[str, Any]:
"""Process table keyword arguments."""
processed = {}
for key, value in kwargs.items():
if key == 'schema':
# Use new_schema if provided, otherwise keep original unless it's 'public'
if new_schema is not None:
processed[key] = new_schema
elif value != 'public':
processed[key] = value
# Skip 'public' schema (default)
else:
# Keep other kwargs (like extend_existing)
processed[key] = value
return processed
class ModelConverter:
"""Main class for converting PostgreSQL models to Databricks models."""
def __init__(self):
"""
Initialize ModelConverter with external location, table location, and storage format.
:param self:
"""
"""Initialize with optional custom components."""
self.type_mapper = TypeMapper()
self.column_converter = ColumnConverter(self.type_mapper)
def convert_model(self,
postgres_model_class: Type,
base_class: Type,
new_table_name: Optional[str] = None,
new_schema: Optional[str] = None,
) -> Type:
"""
Convert a PostgreSQL SQLAlchemy model to a Databricks SQLAlchemy model.
Args:
postgres_model_class: The PostgreSQL SQLAlchemy model class
new_table_name: Optional new table name
new_schema: Optional new schema (Unity Catalog schema)
base_class: Optional base class
Returns:
New Databricks SQLAlchemy model class
"""
# Create base class if not provided
# Get table information
original_table_name = getattr(postgres_model_class, '__tablename__', 'unknown_table')
table_name = new_table_name or original_table_name
table_name = f'{table_name}'
schema_processor = SchemaProcessor()
# Create new model attributes
new_attrs = {
'__tablename__': table_name,
'__module__': postgres_model_class.__module__,
}
# Process table arguments
if hasattr(postgres_model_class, '__table_args__'):
processed_table_args = schema_processor.process_table_args(
postgres_model_class.__table_args__,
new_schema
)
if processed_table_args:
new_attrs['__table_args__'] = processed_table_args
elif new_schema:
new_attrs['__table_args__'] = {'schema': new_schema}
# Extract and convert columns
columns_info = self.column_converter.extract_column_info(postgres_model_class)
annotations = {}
for column_name, column_info in columns_info.items():
new_column, annotation = self.column_converter.convert_column(column_info)
new_attrs[column_name] = new_column
if annotation is not None:
annotations[column_name] = annotation
# Add annotations if any
if annotations:
new_attrs['__annotations__'] = annotations
# Create the new model class
new_class_name = f"{postgres_model_class.__name__}Databricks"
new_model_class = type(new_class_name, (base_class,), new_attrs)
return new_model_class
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment