Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
M
model-managament-databricks
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
CI / CD Analytics
Repository Analytics
Value Stream Analytics
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
harshavardhan.c
model-managament-databricks
Commits
7845d0b6
Commit
7845d0b6
authored
Aug 04, 2025
by
harshavardhan.c
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: Functionality added for delete and add jobs in databricks based on the catalog.
parent
c741f979
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
26 changed files
with
826 additions
and
306 deletions
+826
-306
.env
.env
+1
-1
.gitignore
.gitignore
+134
-0
.pre-commit-config.yaml
.pre-commit-config.yaml
+41
-0
README.md
README.md
+0
-1
agent_subscribers.py
agent_subscribers.py
+19
-5
app.py
app.py
+2
-2
requirements.txt
requirements.txt
+1
-1
scripts/config/__init__.py
scripts/config/__init__.py
+15
-5
scripts/constants/__init__.py
scripts/constants/__init__.py
+15
-1
scripts/constants/notebooks/metadata_deletion.txt
scripts/constants/notebooks/metadata_deletion.txt
+95
-0
scripts/constants/notebooks/metadata_ingestion.txt
scripts/constants/notebooks/metadata_ingestion.txt
+7
-7
scripts/constants/notebooks/timeseries_ingestion.txt
scripts/constants/notebooks/timeseries_ingestion.txt
+3
-4
scripts/core/handlers/instance_handler.py
scripts/core/handlers/instance_handler.py
+29
-13
scripts/core/handlers/model_creator_handler.py
scripts/core/handlers/model_creator_handler.py
+135
-70
scripts/db/databricks/__init__.py
scripts/db/databricks/__init__.py
+36
-23
scripts/db/databricks/job_manager.py
scripts/db/databricks/job_manager.py
+30
-16
scripts/db/databricks/notebook_manager.py
scripts/db/databricks/notebook_manager.py
+17
-7
scripts/db/redis/databricks_details.py
scripts/db/redis/databricks_details.py
+3
-3
scripts/db/redis/graphql.py
scripts/db/redis/graphql.py
+6
-2
scripts/db/redis/project_details.py
scripts/db/redis/project_details.py
+11
-4
scripts/engines/agents/__init__.py
scripts/engines/agents/__init__.py
+4
-1
scripts/engines/agents/model_creator_agent.py
scripts/engines/agents/model_creator_agent.py
+8
-5
scripts/schemas/__init__.py
scripts/schemas/__init__.py
+9
-5
scripts/utils/databricks_utils.py
scripts/utils/databricks_utils.py
+34
-15
scripts/utils/httpx_util.py
scripts/utils/httpx_util.py
+20
-4
scripts/utils/model_convertor_utils.py
scripts/utils/model_convertor_utils.py
+151
-111
No files found.
.env
View file @
7845d0b6
.gitignore
View file @
7845d0b6
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
/.idea
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
.vscode
data
.env
assets
__pycache__
.pre-commit-config.yaml
0 → 100644
View file @
7845d0b6
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.6.0
hooks
:
-
id
:
end-of-file-fixer
-
id
:
trailing-whitespace
-
id
:
requirements-txt-fixer
-
repo
:
https://github.com/asottile/pyupgrade
rev
:
v3.16.0
hooks
:
-
id
:
pyupgrade
args
:
-
--py3-plus
-
--keep-runtime-typing
-
repo
:
https://github.com/charliermarsh/ruff-pre-commit
rev
:
v0.4.8
hooks
:
-
id
:
ruff
args
:
-
--fix
-
repo
:
https://github.com/pycqa/isort
rev
:
5.13.2
hooks
:
-
id
:
isort
name
:
isort (python)
-
id
:
isort
name
:
isort (cython)
types
:
[
cython
]
-
id
:
isort
name
:
isort (pyi)
types
:
[
pyi
]
-
repo
:
https://github.com/psf/black
rev
:
24.4.2
hooks
:
-
id
:
black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version
:
python3.11
README.md
View file @
7845d0b6
# model-managament-databricks
agent_subscribers.py
View file @
7845d0b6
...
...
@@ -6,19 +6,33 @@ from scripts.config import KafkaConfig
from
scripts.engines.agents.model_creator_agent
import
ModelCreatorAgent
from
scripts.schemas
import
ModelCreatorSchema
,
ModelInstanceSchema
broker
=
KafkaBroker
(
f
'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}'
,
client_id
=
"model_creator_agent"
)
broker
=
KafkaBroker
(
f
"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}"
,
client_id
=
"model_creator_agent"
,
)
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_CREATION_TOPIC
,
group_id
=
"databricks_model_creator_agent"
,
max_workers
=
2
)
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_CREATION_TOPIC
,
group_id
=
"databricks_model_creator_agent"
,
max_workers
=
2
,
)
async
def
consume_stream_for_processing_dependencies
(
message
:
dict
):
try
:
await
ModelCreatorAgent
.
model_creator_agent
(
message
=
ModelCreatorSchema
(
meta
=
message
))
await
ModelCreatorAgent
.
model_creator_agent
(
message
=
ModelCreatorSchema
(
meta
=
message
)
)
return
True
except
Exception
as
e
:
logging
.
error
(
f
"Exception occurred while creating model in Databricks: {e}"
)
return
False
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_INSTANCE_TOPIC
,
group_id
=
"databricks_instance_agent"
,
max_workers
=
2
)
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_INSTANCE_TOPIC
,
group_id
=
"databricks_instance_agent"
,
max_workers
=
2
,
)
async
def
consume_stream_for_processing_instances
(
message
:
dict
):
try
:
await
ModelCreatorAgent
.
model_instance_agent
(
ModelInstanceSchema
(
**
message
))
...
...
app.py
View file @
7845d0b6
...
...
@@ -6,11 +6,11 @@ import sys
from
dotenv
import
load_dotenv
load_dotenv
()
from
agent_subscribers
import
broker
from
faststream
import
FastStream
from
ut_dev_utils
import
configure_logger
from
agent_subscribers
import
broker
configure_logger
()
# Create FastStream app
...
...
requirements.txt
View file @
7845d0b6
scripts/config/__init__.py
View file @
7845d0b6
...
...
@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME
:
str
=
Field
(
default
=
"public"
)
DATABRICKS_ANALYTICAL_SCHEMA_NAME
:
str
=
Field
(
default
=
"analytical"
)
DATABRICKS_STORAGE_FORMAT
:
str
=
Field
(
default
=
"PARQUET"
)
DATABRICKS_STORAGE_PATH
:
str
=
Field
(
default
=
"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
DATABRICKS_STORAGE_PATH
:
str
=
Field
(
default
=
"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@
model_validator
(
mode
=
"before"
)
def
prepare_databricks_uri
(
cls
,
values
):
values
[
'DATABRICKS_URI'
]
=
(
f
"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f
"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
values
[
"DATABRICKS_URI"
]
=
(
f
"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f
"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
return
values
...
...
@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
KafkaConfig
=
_KafkaConfig
()
DatabricksConfig
=
_DatabricksConfig
()
__all__
=
[
"Services"
,
"RedisConfig"
,
"ExternalServices"
,
"PathToStorage"
,
"KafkaConfig"
,
"DatabricksConfig"
]
__all__
=
[
"Services"
,
"RedisConfig"
,
"ExternalServices"
,
"PathToStorage"
,
"KafkaConfig"
,
"DatabricksConfig"
,
]
scripts/constants/__init__.py
View file @
7845d0b6
class
DatabricksConstants
:
METADATA_INGESTION_JOB_NAME
=
"metadata_ingestion_job"
METADATA_DELETION_JOB_NAME
=
"metadata_deletion_job"
METADATA_INGESTION_NOTEBOOK_NAME
=
"metadata_ingestion_notebook"
METADATA_DELETION_NOTEBOOK_NAME
=
"metadata_deletion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME
=
"timeseries_ingestion_notebook"
class
NotebookConstants
:
METADATA_INGESTION_NOTEBOOK_PATH
=
(
"scripts/constants/notebooks/metadata_ingestion.txt"
)
METADATA_DELETION_NOTEBOOK_PATH
=
(
"scripts/constants/notebooks/metadata_deletion.txt"
)
TIMESERIES_INGESTION_NOTEBOOK_PATH
=
(
"scripts/constants/notebooks/timeseries_ingestion.txt"
)
scripts/constants/notebooks/metadata_deletion.txt
0 → 100644
View file @
7845d0b6
# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")
# COMMAND ----------
def extract_table_info(input_message_str: str, delete_column:str = "id"):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f'Deleting rows: {delete_values}')
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
delete_column: delete_values,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def delete_records_by_ids(table_name, ids, id_column="id"):
"""
Delete records from external table (Delta or Parquet) using list of IDs
Args:
table_name (str): Full table name (catalog.schema.table)
ids (list): List of IDs to delete
id_column (str): Column name containing IDs (default: "id")
Returns:
bool: True if successful, False otherwise
"""
try:
if not ids:
print("No IDs provided")
return False
# Format IDs for SQL IN clause
if isinstance(ids[0], str):
id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
else:
id_values = "(" + ",".join([str(id) for id in ids]) + ")"
# Use Delta table DELETE operation
delta_table = DeltaTable.forName(spark, table_name)
condition = f"{id_column} IN {id_values}"
delta_table.delete(condition=expr(condition))
print(f"Successfully deleted {len(ids)} records from {table_name}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
# COMMAND ----------
table_info = extract_table_info(input_message, delete_column=delete_column)
# COMMAND ----------
result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
scripts/constants/notebooks/metadata_ingestion.txt
View file @
7845d0b6
scripts/constants/notebooks/timeseries_ingestion.txt
View file @
7845d0b6
...
...
@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
from pyspark.sql.types import *
import json
spark = SparkSession.builder.appName("Streaming
IoT
Pipeline").getOrCreate()
spark = SparkSession.builder.appName("Streaming
Timeseries
Pipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# COMMAND ----------
...
...
@@ -146,4 +146,3 @@ transformed_df.writeStream \
# COMMAND ----------
scripts/core/handlers/instance_handler.py
View file @
7845d0b6
...
...
@@ -3,6 +3,7 @@ import json
from
ut_dev_utils
import
get_db_name
from
scripts.config
import
DatabricksConfig
from
scripts.constants
import
DatabricksConstants
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.redis.databricks_details
import
databricks_details_db
from
scripts.schemas
import
ModelInstanceSchema
...
...
@@ -12,29 +13,44 @@ class ModelInstanceHandler:
def
__init__
(
self
,
project_id
:
str
,
payload
:
ModelInstanceSchema
):
self
.
project_id
=
project_id
self
.
payload
=
payload
self
.
catalog_name
=
get_db_name
(
project_id
=
project_id
,
database
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
)
self
.
catalog_name
=
get_db_name
(
project_id
=
project_id
,
database
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
)
self
.
job_manager
=
DatabricksJobManager
(
databricks_host
=
payload
.
databricks_host
,
access_token
=
payload
.
databricks_access_token
access_token
=
payload
.
databricks_access_token
,
)
def
upload_instances_to_unity_catalog
(
self
):
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
"metadata_ingestion_job"
)
async
def
upload_instances_to_unity_catalog
(
self
):
if
self
.
payload
.
action_type
==
"delete"
:
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
DatabricksConstants
.
METADATA_DELETION_JOB_NAME
)
else
:
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
DatabricksConstants
.
METADATA_INGESTION_JOB_NAME
)
if
not
job_id
:
raise
ValueError
(
"No job id found for metadata ingestion job, skipping upload to unity catalog"
)
run_id
=
self
.
job_manager
.
run_job
(
job_id
=
job_id
,
parameters
=
{
"input_message"
:
json
.
dumps
(
self
.
get_job_trigger_payload
())})
raise
ValueError
(
f
"No job id found for {self.payload.action_type}, skipping upload to unity catalog"
)
run_id
=
self
.
job_manager
.
run_job
(
job_id
=
job_id
,
parameters
=
{
"input_message"
:
json
.
dumps
(
self
.
get_job_trigger_payload
())},
)
if
not
run_id
:
raise
ValueError
(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
raise
ValueError
(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def
get_job_trigger_payload
(
self
):
table_name
=
self
.
payload
.
data
[
0
][
'type'
]
table_name
=
self
.
payload
.
node_type
schema_table
=
f
"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return
{
"table_properties"
:
{
"table_name"
:
f
'{self.catalog_name}.{schema_table}'
,
"table_path"
:
f
'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}'
,
"table_name"
:
f
"{self.catalog_name}.{schema_table}"
,
"table_path"
:
f
"{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}"
,
},
"project_id"
:
self
.
project_id
,
"data"
:
self
.
payload
.
data
"data"
:
self
.
payload
.
data
,
}
scripts/core/handlers/model_creator_handler.py
View file @
7845d0b6
...
...
@@ -5,7 +5,7 @@ from sqlalchemy.orm import declarative_base
from
ut_sql_utils.asyncio.declarative_utils
import
DeclarativeUtils
from
scripts.config
import
DatabricksConfig
from
scripts.constants
import
DatabricksConstants
from
scripts.constants
import
DatabricksConstants
,
NotebookConstants
from
scripts.db.databricks
import
DataBricksSQLLayer
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.databricks.notebook_manager
import
NotebookManager
...
...
@@ -16,23 +16,25 @@ from scripts.utils.model_convertor_utils import ModelConverter
class
ModelCreatorHandler
:
def
__init__
(
self
,
message
:
ModelCreatorSchema
,
declarative_utils
:
DeclarativeUtils
):
def
__init__
(
self
,
message
:
ModelCreatorSchema
,
declarative_utils
:
DeclarativeUtils
):
self
.
declarative_utils
=
declarative_utils
self
.
meta
=
message
.
meta
self
.
message
=
message
self
.
model_convertor
=
ModelConverter
()
self
.
job_manager
=
DatabricksJobManager
(
databricks_host
=
message
.
databricks_host
,
access_token
=
message
.
databricks_access_token
access_token
=
message
.
databricks_access_token
,
)
self
.
notebook_manager
=
NotebookManager
(
databricks_host
=
message
.
databricks_host
,
access_token
=
message
.
databricks_access_token
access_token
=
message
.
databricks_access_token
,
)
self
.
databricks_sql_obj
=
DataBricksSQLLayer
(
catalog_name
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
,
project_id
=
self
.
meta
.
project_id
,
schema
=
message
.
schema
schema
=
message
.
schema
,
)
self
.
external_location
=
self
.
message
.
databricks_storage_path
...
...
@@ -47,32 +49,37 @@ class ModelCreatorHandler:
overall_tables
=
self
.
get_overall_tables
()
project_levels
=
project_template_keys
(
self
.
meta
.
project_id
,
levels
=
True
)
base
=
self
.
create_schema_base
(
schema_name
=
f
'{self.databricks_sql_obj.catalog_name}.{self.message.schema}'
)
base
=
self
.
create_schema_base
(
schema_name
=
f
"{self.databricks_sql_obj.catalog_name}.{self.message.schema}"
)
try
:
# self.databricks_sql_obj.connect_to_databricks()
_
=
self
.
setup_dependencies_for_unity_catalog
()
table_properties
=
self
.
fetch_table_properties
()
# for table in overall_tables:
# table_class = self.declarative_utils.get_declarative_class(table)
# if not table_class:
# logging.error(f"Table class not found for table: {table}")
# return False
# new_model = self.model_convertor.convert_model(
# table_class,
# base_class=base,
# new_schema=self.message.schema,
# )
#
# self.databricks_sql_obj.create_external_table_from_structure(
# table=new_model.__table__,
# file_format="DELTA",
# external_location=self.external_location,
# table_properties=table_properties
# )
ts_external_table
=
self
.
databricks_sql_obj
.
create_timeseries_table
(
columns
=
project_levels
,
external_location
=
self
.
external_location
)
self
.
setup_notepads_and_jobs
(
timeseries_table_path
=
ts_external_table
,
project_levels
=
project_levels
)
for
table
in
overall_tables
:
table_class
=
self
.
declarative_utils
.
get_declarative_class
(
table
)
if
not
table_class
:
logging
.
error
(
f
"Table class not found for table: {table}"
)
return
False
new_model
=
self
.
model_convertor
.
convert_model
(
table_class
,
base_class
=
base
,
new_schema
=
self
.
message
.
schema
,
)
self
.
databricks_sql_obj
.
create_external_table_from_structure
(
table
=
new_model
.
__table__
,
file_format
=
"DELTA"
,
external_location
=
self
.
external_location
,
table_properties
=
table_properties
,
)
ts_external_table
=
self
.
databricks_sql_obj
.
create_timeseries_table
(
columns
=
project_levels
,
external_location
=
self
.
external_location
)
self
.
setup_notepads_and_jobs
(
timeseries_table_path
=
ts_external_table
,
project_levels
=
project_levels
)
return
True
except
Exception
as
e
:
logging
.
error
(
f
"Error occurred while creating models in Unity Catalog: {e}"
)
...
...
@@ -95,20 +102,25 @@ class ModelCreatorHandler:
analytical (bool): Flag to indicate if the setup is for analytical or not
"""
logging
.
info
(
f
"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
f
"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
self
.
databricks_sql_obj
.
connect_to_databricks
()
# Create catalog
catalog_success
=
self
.
databricks_sql_obj
.
create_catalog
(
managed_location
=
f
'{self.external_location}/{self.databricks_sql_obj.catalog_name}'
,
managed_location
=
f
"{self.external_location}/{self.databricks_sql_obj.catalog_name}"
,
)
if
not
catalog_success
:
return
False
# Create schema
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
)
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
)
if
not
schema_success
:
return
False
if
analytical
:
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_ANALYTICAL_SCHEMA_NAME
)
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_ANALYTICAL_SCHEMA_NAME
)
if
not
schema_success
:
return
False
return
True
...
...
@@ -120,59 +132,112 @@ class ModelCreatorHandler:
project_levels: List of project levels
"""
logging
.
info
(
"Setting up notepads and jobs"
)
with
open
(
r"scripts/constants/notebooks/metadata_ingestion.txt"
,
"r"
)
as
f
:
notebook_code
=
f
.
read
()
# # Notebook for metadata ingestion
# self.notebook_manager.create_notebook(
# notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}",
# notebook_code=notebook_code,
# overwrite=True
# )
# # Job for metadata ingestion used by model management
# job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
# job_name=f'{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}',
# notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
# ))
#
# databricks_details_db.hset(self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME, job_id)
# Timeseries DataPush Notebook
with
open
(
r"scripts/constants/notebooks/timeseries_ingestion.txt"
,
"r"
)
as
f
:
notebook_code_for_timeseries
=
f
.
read
()
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{timeseries_table_path}}"
,
f
'"{timeseries_table_path}"'
)
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{project_levels}}"
,
str
(
len
(
project_levels
)
-
1
))
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{event_hub_connection_string}}"
,
f
'"{self.meta.project_id}"'
)
meta_ingestion_notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}"
meta_deletion_notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_NOTEBOOK_NAME}"
timeseries_notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}"
self
.
notebook_manager
.
create_notebook
(
notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}"
,
notebook_code
=
notebook_code_for_timeseries
,
overwrite
=
True
# Setting up of Metadata Ingestion Notebook
existing_job_id
=
databricks_details_db
.
hget
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_INGESTION_JOB_NAME
)
if
not
existing_job_id
:
self
.
create_notebook
(
notebook_path
=
meta_ingestion_notebook_path
,
source_notebook_path
=
NotebookConstants
.
METADATA_INGESTION_NOTEBOOK_PATH
,
)
ingestion_job_id
=
self
.
create_job
(
job_name
=
f
"{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}"
,
notebook_path
=
meta_ingestion_notebook_path
,
)
databricks_details_db
.
hset
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_INGESTION_JOB_NAME
,
ingestion_job_id
,
)
existing_job_id
=
databricks_details_db
.
hget
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_DELETION_JOB_NAME
)
if
not
existing_job_id
:
# Setting up of Metadata Deletion Notebook
self
.
create_notebook
(
notebook_path
=
meta_deletion_notebook_path
,
source_notebook_path
=
NotebookConstants
.
METADATA_DELETION_NOTEBOOK_PATH
,
)
deletion_job_id
=
self
.
create_job
(
job_name
=
f
"{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_JOB_NAME}"
,
notebook_path
=
meta_deletion_notebook_path
,
)
databricks_details_db
.
hset
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_DELETION_JOB_NAME
,
deletion_job_id
,
)
# Setting up of Timeseries Ingestion Notebook
replace_mapping
=
{
"{{timeseries_table_path}}"
:
f
'"{timeseries_table_path}"'
,
"{{project_levels}}"
:
str
(
len
(
project_levels
)
-
1
),
"{{event_hub_connection_string}}"
:
f
'"{self.meta.project_id}"'
,
}
self
.
create_notebook
(
notebook_path
=
timeseries_notebook_path
,
source_notebook_path
=
NotebookConstants
.
TIMESERIES_INGESTION_NOTEBOOK_PATH
,
replace_mapping
=
replace_mapping
,
)
@
staticmethod
def
fetch_table_properties
(
file_format
:
str
=
'DELTA'
):
if
file_format
.
lower
()
==
'delta'
:
def
fetch_table_properties
(
file_format
:
str
=
"DELTA"
):
if
file_format
.
lower
()
==
"delta"
:
return
{
# Performance optimization (Essential)
"delta.autoOptimize.optimizeWrite"
:
"true"
,
"delta.autoOptimize.autoCompact"
:
"true"
,
"delta.targetFileSize"
:
"134217728"
,
# 128MB
'delta.enableChangeDataFeed'
:
'true'
,
# If you need CDC
"delta.enableChangeDataFeed"
:
"true"
,
# If you need CDC
# Checkpoint optimization (Performance boost)
"delta.checkpoint.writeStatsAsStruct"
:
"true"
,
"delta.checkpoint.writeStatsAsJson"
:
"false"
"delta.checkpoint.writeStatsAsJson"
:
"false"
,
# Note: Retention properties removed - using defaults:
# delta.deletedFileRetentionDuration = 7 days (default)
# delta.logRetentionDuration = 30 days (default)
}
elif
file_format
.
lower
()
==
'parquet'
:
return
{
"parquet.compression"
:
"snappy"
,
elif
file_format
.
lower
()
==
"parquet"
:
return
{
"parquet.compression"
:
"snappy"
,
"parquet.page.size"
:
"1048576"
,
# 1MB - standard for mixed queries
"parquet.block.size"
:
"134217728"
,
# 128MB - balanced performance
"serialization.format"
:
"1"
}
"serialization.format"
:
"1"
,
}
else
:
return
{}
@
staticmethod
def
read_data_from_file
(
note_path
:
str
):
with
open
(
note_path
)
as
f
:
notebook_code
=
f
.
read
()
return
notebook_code
def
create_notebook
(
self
,
notebook_path
:
str
,
source_notebook_path
:
str
,
replace_mapping
:
dict
=
None
,
):
logging
.
info
(
f
"Creating notebook {notebook_path}"
)
notebook_code
=
self
.
read_data_from_file
(
source_notebook_path
)
if
replace_mapping
is
not
None
:
for
key
,
value
in
replace_mapping
.
items
():
notebook_code
=
notebook_code
.
replace
(
key
,
value
)
self
.
notebook_manager
.
create_notebook
(
notebook_path
=
notebook_path
,
notebook_code
=
notebook_code
,
overwrite
=
True
)
return
True
def
create_job
(
self
,
job_name
:
str
,
notebook_path
:
str
):
logging
.
info
(
f
"Creating job {job_name}"
)
job_id
=
self
.
job_manager
.
create_job
(
job_config
=
self
.
job_manager
.
create_job_config_for_serverless
(
job_name
=
job_name
,
notebook_path
=
notebook_path
,
)
)
return
job_id
scripts/db/databricks/__init__.py
View file @
7845d0b6
from
typing
import
Dict
,
List
from
sqlalchemy
import
Table
,
Column
,
String
,
BigInteger
,
DateTime
,
MetaData
,
Integer
,
Date
from
sqlalchemy
import
(
BigInteger
,
Column
,
Date
,
DateTime
,
Integer
,
MetaData
,
String
,
Table
,
)
from
scripts.utils.databricks_utils
import
DatabricksSQLUtility
from
scripts.utils.model_convertor_utils
import
TypeMapper
...
...
@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
super
()
.
__init__
(
catalog_name
,
project_id
)
self
.
schema
=
schema
def
create_external_table_from_structure
(
self
,
table
:
Table
,
def
create_external_table_from_structure
(
self
,
table
:
Table
,
external_location
:
str
,
file_format
:
str
=
"PARQUET"
,
table_properties
:
Dict
[
str
,
str
]
=
None
,
partition_columns
:
list
=
None
)
->
str
:
partition_columns
:
list
=
None
,
)
->
str
:
"""
Create an external table from a model class.
...
...
@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
schema_table
=
f
"{table.schema}.{table.name}"
if
table
.
schema
else
table
.
name
columns_sql
=
TypeMapper
()
.
extract_columns_without_constraints
(
table
)
external_location
=
f
"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
external_location
=
(
f
"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
)
sql_parts
=
[
f
"CREATE TABLE IF NOT EXISTS {schema_table}"
,
f
"({columns_sql})"
,
f
"USING {file_format}"
,
f
"LOCATION '{external_location}'"
f
"LOCATION '{external_location}'"
,
]
if
partition_columns
:
partition_clause
=
", "
.
join
(
partition_columns
)
...
...
@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
table_columns
=
[
Column
(
'timestamp'
,
BigInteger
,
nullable
=
False
),
Column
(
'dt_timestamp'
,
DateTime
,
nullable
=
False
),
Column
(
'dt_date'
,
Date
,
nullable
=
False
),
Column
(
'dt_hour'
,
Integer
,
nullable
=
False
),
Column
(
'value'
,
String
,
nullable
=
False
),
Column
(
'value_type'
,
String
,
nullable
=
False
,
default
=
'float'
),
Column
(
"c3"
,
String
,
nullable
=
False
)
Column
(
"timestamp"
,
BigInteger
,
nullable
=
False
),
Column
(
"dt_timestamp"
,
DateTime
,
nullable
=
False
),
Column
(
"dt_date"
,
Date
,
nullable
=
False
),
Column
(
"dt_hour"
,
Integer
,
nullable
=
False
),
Column
(
"value"
,
String
,
nullable
=
False
),
Column
(
"value_type"
,
String
,
nullable
=
False
,
default
=
"float"
),
Column
(
"c3"
,
String
,
nullable
=
False
)
,
]
default_columns
=
[
"c1"
,
"c5"
,
"Q"
,
"T"
,
"D"
,
"P"
,
"A"
,
"B"
,
*
columns
]
table_columns
.
extend
([
Column
(
col_name
,
String
,
nullable
=
True
)
for
col_name
in
default_columns
])
partition_columns
=
[
'dt_date'
,
'dt_hour'
,
'c3'
]
table_columns
.
extend
(
[
Column
(
col_name
,
String
,
nullable
=
True
)
for
col_name
in
default_columns
]
)
partition_columns
=
[
"dt_date"
,
"dt_hour"
,
"c3"
]
table_properties
=
{
"parquet.compression"
:
"snappy"
,
# Fast decompression for frequent queries
"parquet.page.size"
:
"524288"
,
# 512KB - better time-range filtering
"parquet.block.size"
:
"268435456"
,
# 256MB - efficient sequential reads
"serialization.format"
:
"1"
# Support for arrays/complex types
"serialization.format"
:
"1"
,
# Support for arrays/complex types
}
table_obj
=
Table
(
"timeseries_data"
,
MetaData
(),
*
table_columns
,
schema
=
self
.
schema
"timeseries_data"
,
MetaData
(),
*
table_columns
,
schema
=
self
.
schema
)
self
.
create_external_table_from_structure
(
table
=
table_obj
,
external_location
=
external_location
,
partition_columns
=
partition_columns
,
table_properties
=
table_properties
table_properties
=
table_properties
,
)
return
external_location
scripts/db/databricks/job_manager.py
View file @
7845d0b6
...
...
@@ -14,10 +14,14 @@ class DatabricksJobManager:
databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token
"""
self
.
host
=
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
self
.
host
=
(
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
)
self
.
headers
=
{
'Authorization'
:
f
'Bearer {access_token}'
,
'Content-Type'
:
'application/json'
"Authorization"
:
f
"Bearer {access_token}"
,
"Content-Type"
:
"application/json"
,
}
def
create_job
(
self
,
job_config
:
dict
):
...
...
@@ -32,11 +36,13 @@ class DatabricksJobManager:
response
=
HTTPXRequestUtil
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
job_config
)
if
response
.
status_code
==
200
:
job_id
=
response
.
json
()[
'job_id'
]
job_id
=
response
.
json
()[
"job_id"
]
logging
.
info
(
f
"Job created successfully with ID: {job_id}"
)
return
job_id
else
:
logging
.
error
(
f
"Error creating job: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error creating job: {response.status_code} - {response.text}"
)
return
None
def
run_job
(
self
,
job_id
:
str
,
parameters
=
None
):
...
...
@@ -57,11 +63,13 @@ class DatabricksJobManager:
response
=
HTTPXRequestUtil
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
payload
)
if
response
.
status_code
==
200
:
run_id
=
response
.
json
()[
'run_id'
]
run_id
=
response
.
json
()[
"run_id"
]
logging
.
info
(
f
"Job run started with ID: {run_id}"
)
return
run_id
else
:
logging
.
error
(
f
"Error running job: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error running job: {response.status_code} - {response.text}"
)
return
None
def
get_run_status
(
self
,
run_id
):
...
...
@@ -73,12 +81,16 @@ class DatabricksJobManager:
url
=
f
"{self.host}/api/2.1/jobs/runs/get"
params
=
{
"run_id"
:
run_id
}
response
=
HTTPXRequestHandler
(
url
)
.
get
(
url
,
headers
=
self
.
headers
,
params
=
params
)
response
=
HTTPXRequestHandler
(
url
)
.
get
(
url
,
headers
=
self
.
headers
,
params
=
params
)
if
response
.
status_code
==
200
:
return
response
.
json
()
else
:
logging
.
error
(
f
"Error getting run status: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error getting run status: {response.status_code} - {response.text}"
)
return
None
@
staticmethod
...
...
@@ -98,16 +110,18 @@ class DatabricksJobManager:
"task_key"
:
"table_update_task"
,
"notebook_task"
:
{
"notebook_path"
:
notebook_path
,
"base_parameters"
:
{
"input_message"
:
"default_value"
}
"base_parameters"
:
{
"input_message"
:
"default_value"
},
},
"timeout_seconds"
:
3600
"timeout_seconds"
:
3600
,
}
],
"max_concurrent_runs"
:
10
,
"tags"
:
{
"purpose"
:
"metadata_ingestion"
,
"compute_type"
:
"serverless"
}
"purpose"
:
(
"metadata_ingestion"
if
"ingestion"
in
job_name
else
"metadata_deletion"
),
"compute_type"
:
"serverless"
,
},
}
scripts/db/databricks/notebook_manager.py
View file @
7845d0b6
...
...
@@ -13,13 +13,19 @@ class NotebookManager:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token
"""
self
.
host
=
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
self
.
host
=
(
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
)
self
.
headers
=
{
'Authorization'
:
f
'Bearer {access_token}'
,
'Content-Type'
:
'application/json'
"Authorization"
:
f
"Bearer {access_token}"
,
"Content-Type"
:
"application/json"
,
}
def
create_notebook
(
self
,
notebook_path
,
notebook_code
:
str
,
language
=
'PYTHON'
,
overwrite
=
True
):
def
create_notebook
(
self
,
notebook_path
,
notebook_code
:
str
,
language
=
"PYTHON"
,
overwrite
=
True
):
"""
Create a notebook in Databricks workspace
...
...
@@ -31,18 +37,22 @@ class NotebookManager:
"""
url
=
f
"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64
encoded_content
=
base64
.
b64encode
(
notebook_code
.
encode
(
'utf-8'
))
.
decode
(
'utf-8'
)
encoded_content
=
base64
.
b64encode
(
notebook_code
.
encode
(
"utf-8"
))
.
decode
(
"utf-8"
)
payload
=
{
"path"
:
notebook_path
,
"format"
:
"SOURCE"
,
"language"
:
language
,
"content"
:
encoded_content
,
"overwrite"
:
overwrite
"overwrite"
:
overwrite
,
}
response
=
HTTPXRequestUtil
(
url
=
url
)
.
post
(
json
=
payload
,
headers
=
self
.
headers
)
if
response
.
status_code
==
200
:
logging
.
info
(
f
"Notebook created successfully at: {notebook_path}"
)
return
True
else
:
logging
.
error
(
f
"Error creating notebook: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error creating notebook: {response.status_code} - {response.text}"
)
return
False
scripts/db/redis/databricks_details.py
View file @
7845d0b6
import
orjson
from
scripts.config
import
RedisConfig
from
scripts.db.redis
import
redis_connector
databricks_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_DATABRICKS_DB
,
decode_responses
=
True
)
\ No newline at end of file
databricks_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_DATABRICKS_DB
,
decode_responses
=
True
)
scripts/db/redis/graphql.py
View file @
7845d0b6
...
...
@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
from
scripts.config
import
RedisConfig
from
scripts.db.redis
import
redis_connector
graphql_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_GRAPHQL_DB
,
decode_responses
=
True
)
graphql_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_GRAPHQL_DB
,
decode_responses
=
True
)
def
get_models
(
...
...
@@ -38,7 +40,9 @@ def get_models(
"""
tables_data
=
graphql_details_db
.
hget
(
info
.
data
[
"project_id"
],
"schema_mapper"
)
if
tables_data
is
None
:
raise
ILensErrors
(
f
"No GraphQL schema data found for project {info.data['project_id']}"
)
raise
ILensErrors
(
f
"No GraphQL schema data found for project {info.data['project_id']}"
)
tables
:
Dict
[
str
,
Any
]
=
orjson
.
loads
(
tables_data
)
or
{}
if
(
...
...
scripts/db/redis/project_details.py
View file @
7845d0b6
...
...
@@ -3,7 +3,9 @@ import orjson
from
scripts.config
import
RedisConfig
from
scripts.db.redis
import
redis_connector
project_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_PROJECT_TAGS_DB
,
decode_responses
=
True
)
project_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_PROJECT_TAGS_DB
,
decode_responses
=
True
)
def
get_project_time_zone
(
project_id
:
str
):
...
...
@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
return
"UTC"
def
fetch_level_details
(
project_id
:
str
,
keys
:
bool
=
False
,
raw
:
bool
=
False
)
->
dict
|
list
:
def
fetch_level_details
(
project_id
:
str
,
keys
:
bool
=
False
,
raw
:
bool
=
False
)
->
dict
|
list
:
"""
Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details
...
...
@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
project_details
=
orjson
.
loads
(
project_details
)
counter_levels
=
project_details
.
get
(
"counter_levels"
,
{})
asset_level
=
(
counter_levels
.
get
(
"asset"
,
counter_levels
.
get
(
"equipment"
))
if
isinstance
(
counter_levels
,
dict
)
else
None
counter_levels
.
get
(
"asset"
,
counter_levels
.
get
(
"equipment"
))
if
isinstance
(
counter_levels
,
dict
)
else
None
)
if
asset_level
:
return
asset_level
...
...
@@ -81,6 +87,7 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
swapped_dict
=
{
v
:
k
for
k
,
v
in
counter_levels
.
items
()}
return
swapped_dict
.
get
(
"ast"
,
""
),
"ast"
def
project_template_keys
(
project_id
:
str
,
levels
=
False
):
val
=
project_details_db
.
get
(
project_id
)
if
val
is
None
:
...
...
scripts/engines/agents/__init__.py
View file @
7845d0b6
...
...
@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
from
scripts.config
import
KafkaConfig
broker
=
KafkaBroker
(
f
'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}'
,
client_id
=
"model_creator_agent"
)
broker
=
KafkaBroker
(
f
"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}"
,
client_id
=
"model_creator_agent"
,
)
scripts/engines/agents/model_creator_agent.py
View file @
7845d0b6
...
...
@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class
ModelCreatorAgent
:
def
__init__
(
self
):
...
def
__init__
(
self
):
...
@
staticmethod
async
def
model_creator_agent
(
message
:
ModelCreatorSchema
):
...
...
@@ -18,10 +17,14 @@ class ModelCreatorAgent:
session_manager
=
session_manager
,
schema
=
message
.
schema
,
)
model_cal_obj
=
ModelCreatorHandler
(
message
=
message
,
declarative_utils
=
declarative_utils
)
model_cal_obj
=
ModelCreatorHandler
(
message
=
message
,
declarative_utils
=
declarative_utils
)
await
model_cal_obj
.
create_models_in_unity_catalog
()
@
staticmethod
async
def
model_instance_agent
(
message
:
ModelInstanceSchema
):
model_instance_obj
=
ModelInstanceHandler
(
project_id
=
message
.
project_id
,
payload
=
message
)
model_instance_obj
=
ModelInstanceHandler
(
project_id
=
message
.
project_id
,
payload
=
message
)
await
model_instance_obj
.
upload_instances_to_unity_catalog
()
scripts/schemas/__init__.py
View file @
7845d0b6
from
typing
import
Optional
,
Union
,
Dict
,
Any
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
model_validator
from
pydantic
import
BaseModel
,
Field
,
model_validator
from
ut_security_util
import
MetaInfoSchema
from
scripts.config
import
DatabricksConfig
...
...
@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
class
ModelInstanceSchema
(
BaseModel
):
data
:
Union
[
Dict
[
str
,
Any
],
List
[
Dict
[
str
,
Any
]]]
project_id
:
str
schema
:
Optional
[
str
]
=
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
action_type
:
str
=
"save"
node_type
:
str
sql_schema
:
Optional
[
str
]
=
Field
(
default
=
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
,
alias
=
"schema"
)
databricks_host
:
str
=
DatabricksConfig
.
DATABRICKS_HOST
databricks_port
:
int
=
DatabricksConfig
.
DATABRICKS_PORT
databricks_access_token
:
str
=
DatabricksConfig
.
DATABRICKS_ACCESS_TOKEN
...
...
@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
@
model_validator
(
mode
=
"before"
)
def
validate_data
(
cls
,
values
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
if
'data'
in
values
and
isinstance
(
values
[
'data'
],
dict
):
values
[
'data'
]
=
[
values
[
'data'
]]
if
"data"
in
values
and
isinstance
(
values
[
"data"
],
dict
):
values
[
"data"
]
=
[
values
[
"data"
]]
return
values
scripts/utils/databricks_utils.py
View file @
7845d0b6
...
...
@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
DatabricksConfig
.
DATABRICKS_URI
,
pool_pre_ping
=
True
,
pool_recycle
=
3600
,
echo
=
False
echo
=
False
,
)
# Test connection
with
self
.
engine
.
connect
()
as
conn
:
result
=
conn
.
execute
(
text
(
"SELECT current_user() as user, current_catalog() as catalog"
))
result
=
conn
.
execute
(
text
(
"SELECT current_user() as user, current_catalog() as catalog"
)
)
user_info
=
result
.
fetchone
()
logger
.
info
(
f
"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger
.
info
(
f
"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger
.
info
(
"Successfully connected to Databricks"
)
return
True
except
Exception
as
e
:
...
...
@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
if
self
.
engine
:
self
.
engine
.
dispose
()
def
create_catalog
(
self
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
properties
:
Optional
[
dict
]
=
None
):
def
create_catalog
(
self
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
properties
:
Optional
[
dict
]
=
None
,
):
"""
Create a new catalog in Unity Catalog
Args:
...
...
@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
logger
.
error
(
f
"Failed to create catalog '{self.catalog_name}': {str(e)}"
)
raise
def
create_schema
(
self
,
schema_name
:
str
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
properties
:
Optional
[
dict
]
=
None
):
def
create_schema
(
self
,
schema_name
:
str
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
properties
:
Optional
[
dict
]
=
None
,
):
"""
Create a new schema within a catalog
Args:
...
...
@@ -103,10 +116,14 @@ class DatabricksSQLUtility:
ddl
+=
f
"
\n
WITH DBPROPERTIES ({props})"
self
.
execute_sql_statement
(
ddl
)
logger
.
info
(
f
"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
logger
.
info
(
f
"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
return
full_schema_name
except
Exception
as
e
:
logger
.
error
(
f
"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
logger
.
error
(
f
"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
raise
def
create_external_location
(
...
...
@@ -114,7 +131,7 @@ class DatabricksSQLUtility:
location_name
:
str
,
storage_path
:
str
,
credential_name
:
str
,
comment
:
Optional
[
str
]
=
None
comment
:
Optional
[
str
]
=
None
,
)
->
str
:
"""
Create an external location in Unity Catalog
...
...
@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
logger
.
info
(
f
"External location '{location_name}' created successfully"
)
return
location_name
except
Exception
as
e
:
logger
.
error
(
f
"Failed to create external location '{location_name}': {str(e)}"
)
logger
.
error
(
f
"Failed to create external location '{location_name}': {str(e)}"
)
raise
def
execute_sql_statement
(
self
,
query
:
str
):
...
...
scripts/utils/httpx_util.py
View file @
7845d0b6
...
...
@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
def
delete
(
self
,
path
=
""
,
params
=
None
,
**
kwargs
)
->
httpx
.
Response
:
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
delete
(
url
=
url
,
params
=
params
)
return
response
...
...
@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
put
(
url
=
url
,
data
=
data
,
json
=
json
)
return
response
...
...
@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
"""
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
post
(
url
=
url
,
data
=
data
,
json
=
json
)
return
response
...
...
@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
get
(
url
=
url
,
params
=
params
)
return
response
...
...
scripts/utils/model_convertor_utils.py
View file @
7845d0b6
import
logging
from
typing
import
Any
,
Type
,
Optional
,
Dict
,
Union
,
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Type
,
Union
from
databricks.sqlalchemy
import
TIMESTAMP
from
sqlalchemy
import
(
Integer
,
BigInteger
,
SmallInteger
,
String
,
Text
,
Boolean
,
Date
,
DateTime
,
Time
,
Numeric
,
Float
,
DECIMAL
,
CHAR
,
VARCHAR
,
LargeBinary
,
JSON
,
Column
,
PrimaryKeyConstraint
,
UniqueConstraint
,
ForeignKeyConstraint
,
Double
,
Index
,
Table
CHAR
,
DECIMAL
,
JSON
,
VARCHAR
,
BigInteger
,
Boolean
,
Column
,
Date
,
DateTime
,
Double
,
Float
,
ForeignKeyConstraint
,
Index
,
Integer
,
LargeBinary
,
Numeric
,
PrimaryKeyConstraint
,
SmallInteger
,
String
,
Table
,
Text
,
Time
,
UniqueConstraint
,
)
from
sqlalchemy.dialects
import
postgresql
from
sqlalchemy.orm
import
class_mapper
,
mapped_column
,
Mapped
from
sqlalchemy.orm
import
Mapped
,
class_mapper
,
mapped_column
from
sqlalchemy.types
import
UserDefinedType
...
...
@@ -83,7 +103,6 @@ class TypeMapper:
VARCHAR
:
VARCHAR
,
Text
:
String
,
String
:
String
,
# DateTime types
postgresql
.
DATE
:
Date
,
postgresql
.
TIME
:
String
,
...
...
@@ -92,23 +111,18 @@ class TypeMapper:
Date
:
Date
,
Time
:
String
,
DateTime
:
DateTime
,
# Boolean
postgresql
.
BOOLEAN
:
Boolean
,
Boolean
:
Boolean
,
# Binary
postgresql
.
BYTEA
:
LargeBinary
,
LargeBinary
:
LargeBinary
,
# JSON
postgresql
.
JSON
:
String
,
postgresql
.
JSONB
:
String
,
JSON
:
String
,
# Array
postgresql
.
ARRAY
:
String
,
# PostgreSQL specific
postgresql
.
UUID
:
String
,
postgresql
.
INET
:
String
,
...
...
@@ -121,14 +135,14 @@ class TypeMapper:
}
SQL_TO_DATABRICKS_MAPPING
=
{
'VARCHAR'
:
'STRING'
,
'INTEGER'
:
'INT'
,
'BIGINT'
:
'BIGINT'
,
# Keep as is
'FLOAT'
:
'DOUBLE'
,
'BOOLEAN'
:
'BOOLEAN'
,
# Keep as is
'TIMESTAMP'
:
'TIMESTAMP'
,
# Keep as is
'DATETIME'
:
'TIMESTAMP'
,
# Change this mapping
'TEXT'
:
'STRING'
,
"VARCHAR"
:
"STRING"
,
"INTEGER"
:
"INT"
,
"BIGINT"
:
"BIGINT"
,
# Keep as is
"FLOAT"
:
"DOUBLE"
,
"BOOLEAN"
:
"BOOLEAN"
,
# Keep as is
"TIMESTAMP"
:
"TIMESTAMP"
,
# Keep as is
"DATETIME"
:
"TIMESTAMP"
,
# Change this mapping
"TEXT"
:
"STRING"
,
# Arrays and complex types are already correct, no replacement needed
}
...
...
@@ -147,14 +161,14 @@ class TypeMapper:
base_type
=
type
(
sql_type
)
# Handle special cases first
if
base_type
==
postgresql
.
ARRAY
or
'ARRAY'
in
str
(
sql_type
):
if
base_type
==
postgresql
.
ARRAY
or
"ARRAY"
in
str
(
sql_type
):
return
cls
.
_convert_array_type_fallback
(
sql_type
)
# Get the mapped type
if
base_type
in
cls
.
POSTGRES_TO_DATABRICKS_MAPPING
:
return
cls
.
POSTGRES_TO_DATABRICKS_MAPPING
[
base_type
]()
logging
.
info
(
f
'Defaulting to String() for type: {type(sql_type)}'
)
logging
.
info
(
f
"Defaulting to String() for type: {type(sql_type)}"
)
return
String
()
@
classmethod
...
...
@@ -171,21 +185,21 @@ class TypeMapper:
element_type_name
=
type
(
postgres_element_type
)
.
__name__
.
upper
()
# Map PostgreSQL types to Databricks array element types
if
any
(
t
in
element_type_name
for
t
in
[
'VARCHAR'
,
'TEXT'
,
'STRING'
,
'CHAR'
]):
if
any
(
t
in
element_type_name
for
t
in
[
"VARCHAR"
,
"TEXT"
,
"STRING"
,
"CHAR"
]):
return
"STRING"
elif
any
(
t
in
element_type_name
for
t
in
[
'INTEGER'
,
'BIGINT'
,
'SMALLINT'
]):
return
"INT"
if
'SMALLINT'
not
in
element_type_name
else
"SMALLINT"
elif
'BIGINT'
in
element_type_name
:
elif
any
(
t
in
element_type_name
for
t
in
[
"INTEGER"
,
"BIGINT"
,
"SMALLINT"
]):
return
"INT"
if
"SMALLINT"
not
in
element_type_name
else
"SMALLINT"
elif
"BIGINT"
in
element_type_name
:
return
"BIGINT"
elif
any
(
t
in
element_type_name
for
t
in
[
'BOOLEAN'
,
'BOOL'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"BOOLEAN"
,
"BOOL"
]):
return
"BOOLEAN"
elif
any
(
t
in
element_type_name
for
t
in
[
'FLOAT'
,
'REAL'
,
'DOUBLE'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"FLOAT"
,
"REAL"
,
"DOUBLE"
]):
return
"DOUBLE"
elif
any
(
t
in
element_type_name
for
t
in
[
'NUMERIC'
,
'DECIMAL'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"NUMERIC"
,
"DECIMAL"
]):
return
"DECIMAL"
elif
'DATE'
in
element_type_name
:
elif
"DATE"
in
element_type_name
:
return
"DATE"
elif
'TIMESTAMP'
in
element_type_name
:
elif
"TIMESTAMP"
in
element_type_name
:
return
"TIMESTAMP"
else
:
return
"STRING"
# Default fallback
...
...
@@ -225,7 +239,11 @@ class TypeMapper:
default_clause
=
""
if
column
.
default
is
not
None
:
default_value
=
column
.
default
.
arg
if
hasattr
(
column
.
default
,
'arg'
)
else
column
.
default
default_value
=
(
column
.
default
.
arg
if
hasattr
(
column
.
default
,
"arg"
)
else
column
.
default
)
if
isinstance
(
default_value
,
str
):
default_clause
=
f
" DEFAULT '{default_value}'"
elif
isinstance
(
default_value
,
bool
):
...
...
@@ -258,7 +276,7 @@ class ColumnConverter:
columns_info
=
{}
# # Check for modern SQLAlchemy with annotations and mapped_column
if
hasattr
(
model_class
,
'__annotations__'
):
if
hasattr
(
model_class
,
"__annotations__"
):
columns_info
.
update
(
self
.
extract_from_annotations
(
model_class
))
# Fallback: Try mapper approach for traditional models
...
...
@@ -267,7 +285,9 @@ class ColumnConverter:
mapper
=
class_mapper
(
model_class
)
for
column_name
,
column
in
mapper
.
columns
.
items
():
if
column_name
not
in
columns_info
:
columns_info
[
column_name
]
=
self
.
extract_column_properties
(
column
)
columns_info
[
column_name
]
=
self
.
extract_column_properties
(
column
)
except
Exception
as
e
:
logging
.
error
(
f
"Failed to extract column info using mapper: {e}"
)
# Final fallback: inspect class attributes directly
...
...
@@ -279,20 +299,22 @@ class ColumnConverter:
def
extract_column_properties
(
column
:
Any
)
->
Dict
[
str
,
Any
]:
"""Extract properties from a column object."""
return
{
'type'
:
getattr
(
column
,
'type'
,
None
),
'primary_key'
:
getattr
(
column
,
'primary_key'
,
False
),
'nullable'
:
getattr
(
column
,
'nullable'
,
True
),
'default'
:
getattr
(
column
,
'default'
,
None
),
'server_default'
:
getattr
(
column
,
'server_default'
,
None
),
'uses_mapped_column'
:
False
,
"type"
:
getattr
(
column
,
"type"
,
None
),
"primary_key"
:
getattr
(
column
,
"primary_key"
,
False
),
"nullable"
:
getattr
(
column
,
"nullable"
,
True
),
"default"
:
getattr
(
column
,
"default"
,
None
),
"server_default"
:
getattr
(
column
,
"server_default"
,
None
),
"uses_mapped_column"
:
False
,
}
def
_extract_from_class_attributes
(
self
,
model_class
:
type
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
def
_extract_from_class_attributes
(
self
,
model_class
:
type
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
"""Extract column info from class attributes."""
columns_info
=
{}
for
attr_name
in
dir
(
model_class
):
if
attr_name
.
startswith
(
'_'
):
if
attr_name
.
startswith
(
"_"
):
continue
attr
=
getattr
(
model_class
,
attr_name
,
None
)
...
...
@@ -303,14 +325,14 @@ class ColumnConverter:
if
isinstance
(
attr
,
Column
):
columns_info
[
attr_name
]
=
self
.
extract_column_properties
(
attr
)
# Check for mapped_column objects
elif
hasattr
(
attr
,
'type'
)
and
hasattr
(
attr
,
'nullable'
):
elif
hasattr
(
attr
,
"type"
)
and
hasattr
(
attr
,
"nullable"
):
columns_info
[
attr_name
]
=
{
'type'
:
getattr
(
attr
,
'type'
,
None
),
'primary_key'
:
getattr
(
attr
,
'primary_key'
,
False
),
'nullable'
:
getattr
(
attr
,
'nullable'
,
True
),
'default'
:
getattr
(
attr
,
'default'
,
None
),
'server_default'
:
getattr
(
attr
,
'server_default'
,
None
),
'uses_mapped_column'
:
True
,
"type"
:
getattr
(
attr
,
"type"
,
None
),
"primary_key"
:
getattr
(
attr
,
"primary_key"
,
False
),
"nullable"
:
getattr
(
attr
,
"nullable"
,
True
),
"default"
:
getattr
(
attr
,
"default"
,
None
),
"server_default"
:
getattr
(
attr
,
"server_default"
,
None
),
"uses_mapped_column"
:
True
,
}
return
columns_info
...
...
@@ -320,22 +342,22 @@ class ColumnConverter:
"""Extract column info from type annotations (modern SQLAlchemy)."""
columns_info
=
{}
annotations
=
getattr
(
model_class
,
'__annotations__'
,
{})
annotations
=
getattr
(
model_class
,
"__annotations__"
,
{})
for
attr_name
,
annotation
in
annotations
.
items
():
if
hasattr
(
model_class
,
attr_name
):
attr
=
getattr
(
model_class
,
attr_name
)
# Check if it's a mapped_column
if
hasattr
(
attr
,
'type'
):
if
hasattr
(
attr
,
"type"
):
columns_info
[
attr_name
]
=
{
'type'
:
attr
.
type
,
'primary_key'
:
getattr
(
attr
,
'primary_key'
,
False
),
'nullable'
:
getattr
(
attr
,
'nullable'
,
True
),
'default'
:
getattr
(
attr
,
'default'
,
None
),
'server_default'
:
getattr
(
attr
,
'server_default'
,
None
),
'annotation'
:
annotation
,
'uses_mapped_column'
:
True
,
"type"
:
attr
.
type
,
"primary_key"
:
getattr
(
attr
,
"primary_key"
,
False
),
"nullable"
:
getattr
(
attr
,
"nullable"
,
True
),
"default"
:
getattr
(
attr
,
"default"
,
None
),
"server_default"
:
getattr
(
attr
,
"server_default"
,
None
),
"annotation"
:
annotation
,
"uses_mapped_column"
:
True
,
}
return
columns_info
...
...
@@ -351,36 +373,38 @@ class ColumnConverter:
Tuple of (column_object, annotation_if_any)
"""
# Convert the column type
new_type
=
self
.
type_mapper
.
get_databricks_type
(
column_info
[
'type'
])
new_type
=
self
.
type_mapper
.
get_databricks_type
(
column_info
[
"type"
])
# Check if this uses type annotations (modern approach)
if
column_info
.
get
(
'uses_mapped_column'
,
False
):
if
column_info
.
get
(
"uses_mapped_column"
,
False
):
new_column
=
mapped_column
(
new_type
,
primary_key
=
column_info
.
get
(
'primary_key'
,
False
),
nullable
=
column_info
.
get
(
'nullable'
,
True
),
default
=
column_info
.
get
(
'default'
),
server_default
=
column_info
.
get
(
'server_default'
),
primary_key
=
column_info
.
get
(
"primary_key"
,
False
),
nullable
=
column_info
.
get
(
"nullable"
,
True
),
default
=
column_info
.
get
(
"default"
),
server_default
=
column_info
.
get
(
"server_default"
),
)
# Convert annotation
annotation
=
self
.
convert_annotation
(
column_info
.
get
(
'annotation'
),
new_type
)
annotation
=
self
.
convert_annotation
(
column_info
.
get
(
"annotation"
),
new_type
)
return
new_column
,
annotation
else
:
# Traditional Column approach
new_column
=
Column
(
new_type
,
primary_key
=
column_info
.
get
(
'primary_key'
,
False
),
nullable
=
column_info
.
get
(
'nullable'
,
True
),
default
=
column_info
.
get
(
'default'
),
server_default
=
column_info
.
get
(
'server_default'
),
primary_key
=
column_info
.
get
(
"primary_key"
,
False
),
nullable
=
column_info
.
get
(
"nullable"
,
True
),
default
=
column_info
.
get
(
"default"
),
server_default
=
column_info
.
get
(
"server_default"
),
)
return
new_column
,
None
@
staticmethod
def
convert_annotation
(
annotation
:
Any
,
databricks_type
:
Any
=
None
)
->
Any
:
"""Convert type annotations for Databricks compatibility."""
from
typing
import
Optional
,
List
from
typing
import
List
,
Optional
if
annotation
is
None
:
return
Mapped
[
Optional
[
str
]]
...
...
@@ -388,15 +412,17 @@ class ColumnConverter:
annotation_str
=
str
(
annotation
)
# Handle array/list types -> convert to List annotation
if
any
(
keyword
in
annotation_str
for
keyword
in
[
'list'
,
'List'
]):
if
'Optional'
in
annotation_str
or
'Union'
in
annotation_str
:
if
any
(
keyword
in
annotation_str
for
keyword
in
[
"list"
,
"List"
]):
if
"Optional"
in
annotation_str
or
"Union"
in
annotation_str
:
return
Mapped
[
Optional
[
List
[
str
]]]
# Default to List[str]
else
:
return
Mapped
[
List
[
str
]]
# Handle JSON/JSONB/dict types -> convert to string
if
any
(
keyword
in
annotation_str
for
keyword
in
[
'dict'
,
'Dict'
,
'json'
,
'Json'
]):
if
'Optional'
in
annotation_str
or
'Union'
in
annotation_str
:
if
any
(
keyword
in
annotation_str
for
keyword
in
[
"dict"
,
"Dict"
,
"json"
,
"Json"
]
):
if
"Optional"
in
annotation_str
or
"Union"
in
annotation_str
:
return
Mapped
[
Optional
[
str
]]
else
:
return
Mapped
[
str
]
...
...
@@ -405,31 +431,38 @@ class ColumnConverter:
if
databricks_type
:
type_str
=
str
(
type
(
databricks_type
)
.
__name__
)
.
lower
()
if
'array'
in
type_str
:
if
"array"
in
type_str
:
# For array types, use List annotation
if
'Optional'
in
annotation_str
:
return
Mapped
[
Optional
[
List
[
str
]]]
# Could be more specific based on element type
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
List
[
str
]]
]
# Could be more specific based on element type
return
Mapped
[
List
[
str
]]
elif
'integer'
in
type_str
or
'biginteger'
in
type_str
or
'smallinteger'
in
type_str
:
if
'Optional'
in
annotation_str
:
elif
(
"integer"
in
type_str
or
"biginteger"
in
type_str
or
"smallinteger"
in
type_str
):
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
int
]]
return
Mapped
[
int
]
elif
'boolean'
in
type_str
:
if
'Optional'
in
annotation_str
:
elif
"boolean"
in
type_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
bool
]]
return
Mapped
[
bool
]
elif
'float'
in
type_str
or
'numeric'
in
type_str
:
if
'Optional'
in
annotation_str
:
elif
"float"
in
type_str
or
"numeric"
in
type_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
float
]]
return
Mapped
[
float
]
elif
'datetime'
in
type_str
:
elif
"datetime"
in
type_str
:
from
datetime
import
datetime
if
'Optional'
in
annotation_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
datetime
]]
return
Mapped
[
datetime
]
# Default to string
if
'Optional'
in
annotation_str
or
'Union'
in
annotation_str
:
if
"Optional"
in
annotation_str
or
"Union"
in
annotation_str
:
return
Mapped
[
Optional
[
str
]]
return
Mapped
[
str
]
...
...
@@ -442,8 +475,7 @@ class SchemaProcessor:
@
staticmethod
def
process_table_args
(
original_table_args
:
Any
,
new_schema
:
Optional
[
str
]
=
None
original_table_args
:
Any
,
new_schema
:
Optional
[
str
]
=
None
)
->
Union
[
Tuple
,
Dict
,
None
]:
"""
Process table arguments, handling constraints and schema conversion.
...
...
@@ -457,7 +489,7 @@ class SchemaProcessor:
"""
if
not
original_table_args
:
if
new_schema
:
return
{
'schema'
:
new_schema
}
return
{
"schema"
:
new_schema
}
return
None
new_table_args
=
[]
...
...
@@ -468,7 +500,9 @@ class SchemaProcessor:
for
arg
in
original_table_args
:
if
isinstance
(
arg
,
dict
):
# Process dictionary part
processed_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
arg
,
new_schema
)
processed_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
arg
,
new_schema
)
table_kwargs
.
update
(
processed_kwargs
)
elif
isinstance
(
arg
,
(
Index
,
ForeignKeyConstraint
)):
continue
...
...
@@ -481,11 +515,13 @@ class SchemaProcessor:
# Handle dictionary format: {'schema': 'public', 'extend_existing': True}
elif
isinstance
(
original_table_args
,
dict
):
table_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
original_table_args
,
new_schema
)
table_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
original_table_args
,
new_schema
)
# Add new schema if specified and not already set
if
new_schema
is
not
None
and
'schema'
not
in
table_kwargs
:
table_kwargs
[
'schema'
]
=
new_schema
if
new_schema
is
not
None
and
"schema"
not
in
table_kwargs
:
table_kwargs
[
"schema"
]
=
new_schema
# Construct result
if
new_table_args
and
table_kwargs
:
...
...
@@ -499,16 +535,18 @@ class SchemaProcessor:
return
None
@
staticmethod
def
_process_table_kwargs
(
kwargs
:
Dict
[
str
,
Any
],
new_schema
:
Optional
[
str
])
->
Dict
[
str
,
Any
]:
def
_process_table_kwargs
(
kwargs
:
Dict
[
str
,
Any
],
new_schema
:
Optional
[
str
]
)
->
Dict
[
str
,
Any
]:
"""Process table keyword arguments."""
processed
=
{}
for
key
,
value
in
kwargs
.
items
():
if
key
==
'schema'
:
if
key
==
"schema"
:
# Use new_schema if provided, otherwise keep original unless it's 'public'
if
new_schema
is
not
None
:
processed
[
key
]
=
new_schema
elif
value
!=
'public'
:
elif
value
!=
"public"
:
processed
[
key
]
=
value
# Skip 'public' schema (default)
else
:
...
...
@@ -532,7 +570,8 @@ class ModelConverter:
self
.
type_mapper
=
TypeMapper
()
self
.
column_converter
=
ColumnConverter
(
self
.
type_mapper
)
def
convert_model
(
self
,
def
convert_model
(
self
,
postgres_model_class
:
Type
,
base_class
:
Type
,
new_table_name
:
Optional
[
str
]
=
None
,
...
...
@@ -552,27 +591,28 @@ class ModelConverter:
"""
# Create base class if not provided
# Get table information
original_table_name
=
getattr
(
postgres_model_class
,
'__tablename__'
,
'unknown_table'
)
original_table_name
=
getattr
(
postgres_model_class
,
"__tablename__"
,
"unknown_table"
)
table_name
=
new_table_name
or
original_table_name
table_name
=
f
'{table_name}'
table_name
=
f
"{table_name}"
schema_processor
=
SchemaProcessor
()
# Create new model attributes
new_attrs
=
{
'__tablename__'
:
table_name
,
'__module__'
:
postgres_model_class
.
__module__
,
"__tablename__"
:
table_name
,
"__module__"
:
postgres_model_class
.
__module__
,
}
# Process table arguments
if
hasattr
(
postgres_model_class
,
'__table_args__'
):
if
hasattr
(
postgres_model_class
,
"__table_args__"
):
processed_table_args
=
schema_processor
.
process_table_args
(
postgres_model_class
.
__table_args__
,
new_schema
postgres_model_class
.
__table_args__
,
new_schema
)
if
processed_table_args
:
new_attrs
[
'__table_args__'
]
=
processed_table_args
new_attrs
[
"__table_args__"
]
=
processed_table_args
elif
new_schema
:
new_attrs
[
'__table_args__'
]
=
{
'schema'
:
new_schema
}
new_attrs
[
"__table_args__"
]
=
{
"schema"
:
new_schema
}
# Extract and convert columns
columns_info
=
self
.
column_converter
.
extract_column_info
(
postgres_model_class
)
...
...
@@ -587,7 +627,7 @@ class ModelConverter:
# Add annotations if any
if
annotations
:
new_attrs
[
'__annotations__'
]
=
annotations
new_attrs
[
"__annotations__"
]
=
annotations
# Create the new model class
new_class_name
=
f
"{postgres_model_class.__name__}Databricks"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment