Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
M
model-managament-databricks
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
CI / CD Analytics
Repository Analytics
Value Stream Analytics
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
harshavardhan.c
model-managament-databricks
Commits
7845d0b6
Commit
7845d0b6
authored
Aug 04, 2025
by
harshavardhan.c
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: Functionality added for delete and add jobs in databricks based on the catalog.
parent
c741f979
Changes
26
Show whitespace changes
Inline
Side-by-side
Showing
26 changed files
with
826 additions
and
306 deletions
+826
-306
.env
.env
+1
-1
.gitignore
.gitignore
+134
-0
.pre-commit-config.yaml
.pre-commit-config.yaml
+41
-0
README.md
README.md
+0
-1
agent_subscribers.py
agent_subscribers.py
+19
-5
app.py
app.py
+2
-2
requirements.txt
requirements.txt
+1
-1
scripts/config/__init__.py
scripts/config/__init__.py
+15
-5
scripts/constants/__init__.py
scripts/constants/__init__.py
+15
-1
scripts/constants/notebooks/metadata_deletion.txt
scripts/constants/notebooks/metadata_deletion.txt
+95
-0
scripts/constants/notebooks/metadata_ingestion.txt
scripts/constants/notebooks/metadata_ingestion.txt
+7
-7
scripts/constants/notebooks/timeseries_ingestion.txt
scripts/constants/notebooks/timeseries_ingestion.txt
+3
-4
scripts/core/handlers/instance_handler.py
scripts/core/handlers/instance_handler.py
+29
-13
scripts/core/handlers/model_creator_handler.py
scripts/core/handlers/model_creator_handler.py
+135
-70
scripts/db/databricks/__init__.py
scripts/db/databricks/__init__.py
+36
-23
scripts/db/databricks/job_manager.py
scripts/db/databricks/job_manager.py
+30
-16
scripts/db/databricks/notebook_manager.py
scripts/db/databricks/notebook_manager.py
+17
-7
scripts/db/redis/databricks_details.py
scripts/db/redis/databricks_details.py
+3
-3
scripts/db/redis/graphql.py
scripts/db/redis/graphql.py
+6
-2
scripts/db/redis/project_details.py
scripts/db/redis/project_details.py
+11
-4
scripts/engines/agents/__init__.py
scripts/engines/agents/__init__.py
+4
-1
scripts/engines/agents/model_creator_agent.py
scripts/engines/agents/model_creator_agent.py
+8
-5
scripts/schemas/__init__.py
scripts/schemas/__init__.py
+9
-5
scripts/utils/databricks_utils.py
scripts/utils/databricks_utils.py
+34
-15
scripts/utils/httpx_util.py
scripts/utils/httpx_util.py
+20
-4
scripts/utils/model_convertor_utils.py
scripts/utils/model_convertor_utils.py
+151
-111
No files found.
.env
View file @
7845d0b6
.gitignore
View file @
7845d0b6
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
/.idea
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
.vscode
data
.env
assets
__pycache__
.pre-commit-config.yaml
0 → 100644
View file @
7845d0b6
repos
:
-
repo
:
https://github.com/pre-commit/pre-commit-hooks
rev
:
v4.6.0
hooks
:
-
id
:
end-of-file-fixer
-
id
:
trailing-whitespace
-
id
:
requirements-txt-fixer
-
repo
:
https://github.com/asottile/pyupgrade
rev
:
v3.16.0
hooks
:
-
id
:
pyupgrade
args
:
-
--py3-plus
-
--keep-runtime-typing
-
repo
:
https://github.com/charliermarsh/ruff-pre-commit
rev
:
v0.4.8
hooks
:
-
id
:
ruff
args
:
-
--fix
-
repo
:
https://github.com/pycqa/isort
rev
:
5.13.2
hooks
:
-
id
:
isort
name
:
isort (python)
-
id
:
isort
name
:
isort (cython)
types
:
[
cython
]
-
id
:
isort
name
:
isort (pyi)
types
:
[
pyi
]
-
repo
:
https://github.com/psf/black
rev
:
24.4.2
hooks
:
-
id
:
black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version
:
python3.11
README.md
View file @
7845d0b6
# model-managament-databricks
# model-managament-databricks
agent_subscribers.py
View file @
7845d0b6
...
@@ -6,19 +6,33 @@ from scripts.config import KafkaConfig
...
@@ -6,19 +6,33 @@ from scripts.config import KafkaConfig
from
scripts.engines.agents.model_creator_agent
import
ModelCreatorAgent
from
scripts.engines.agents.model_creator_agent
import
ModelCreatorAgent
from
scripts.schemas
import
ModelCreatorSchema
,
ModelInstanceSchema
from
scripts.schemas
import
ModelCreatorSchema
,
ModelInstanceSchema
broker
=
KafkaBroker
(
f
'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}'
,
client_id
=
"model_creator_agent"
)
broker
=
KafkaBroker
(
f
"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}"
,
client_id
=
"model_creator_agent"
,
)
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_CREATION_TOPIC
,
group_id
=
"databricks_model_creator_agent"
,
max_workers
=
2
)
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_CREATION_TOPIC
,
group_id
=
"databricks_model_creator_agent"
,
max_workers
=
2
,
)
async
def
consume_stream_for_processing_dependencies
(
message
:
dict
):
async
def
consume_stream_for_processing_dependencies
(
message
:
dict
):
try
:
try
:
await
ModelCreatorAgent
.
model_creator_agent
(
message
=
ModelCreatorSchema
(
meta
=
message
))
await
ModelCreatorAgent
.
model_creator_agent
(
message
=
ModelCreatorSchema
(
meta
=
message
)
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
f
"Exception occurred while creating model in Databricks: {e}"
)
logging
.
error
(
f
"Exception occurred while creating model in Databricks: {e}"
)
return
False
return
False
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_INSTANCE_TOPIC
,
group_id
=
"databricks_instance_agent"
,
max_workers
=
2
)
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_INSTANCE_TOPIC
,
group_id
=
"databricks_instance_agent"
,
max_workers
=
2
,
)
async
def
consume_stream_for_processing_instances
(
message
:
dict
):
async
def
consume_stream_for_processing_instances
(
message
:
dict
):
try
:
try
:
await
ModelCreatorAgent
.
model_instance_agent
(
ModelInstanceSchema
(
**
message
))
await
ModelCreatorAgent
.
model_instance_agent
(
ModelInstanceSchema
(
**
message
))
...
...
app.py
View file @
7845d0b6
...
@@ -6,11 +6,11 @@ import sys
...
@@ -6,11 +6,11 @@ import sys
from
dotenv
import
load_dotenv
from
dotenv
import
load_dotenv
load_dotenv
()
load_dotenv
()
from
agent_subscribers
import
broker
from
faststream
import
FastStream
from
faststream
import
FastStream
from
ut_dev_utils
import
configure_logger
from
ut_dev_utils
import
configure_logger
from
agent_subscribers
import
broker
configure_logger
()
configure_logger
()
# Create FastStream app
# Create FastStream app
...
...
requirements.txt
View file @
7845d0b6
scripts/config/__init__.py
View file @
7845d0b6
...
@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
...
@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME
:
str
=
Field
(
default
=
"public"
)
DATABRICKS_PUBLIC_SCHEMA_NAME
:
str
=
Field
(
default
=
"public"
)
DATABRICKS_ANALYTICAL_SCHEMA_NAME
:
str
=
Field
(
default
=
"analytical"
)
DATABRICKS_ANALYTICAL_SCHEMA_NAME
:
str
=
Field
(
default
=
"analytical"
)
DATABRICKS_STORAGE_FORMAT
:
str
=
Field
(
default
=
"PARQUET"
)
DATABRICKS_STORAGE_FORMAT
:
str
=
Field
(
default
=
"PARQUET"
)
DATABRICKS_STORAGE_PATH
:
str
=
Field
(
default
=
"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
DATABRICKS_STORAGE_PATH
:
str
=
Field
(
default
=
"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
def
prepare_databricks_uri
(
cls
,
values
):
def
prepare_databricks_uri
(
cls
,
values
):
values
[
values
[
"DATABRICKS_URI"
]
=
(
'DATABRICKS_URI'
]
=
(
f
"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f
"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f
"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
f
"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
return
values
return
values
...
@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
...
@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
KafkaConfig
=
_KafkaConfig
()
KafkaConfig
=
_KafkaConfig
()
DatabricksConfig
=
_DatabricksConfig
()
DatabricksConfig
=
_DatabricksConfig
()
__all__
=
[
"Services"
,
"RedisConfig"
,
"ExternalServices"
,
"PathToStorage"
,
"KafkaConfig"
,
"DatabricksConfig"
]
__all__
=
[
"Services"
,
"RedisConfig"
,
"ExternalServices"
,
"PathToStorage"
,
"KafkaConfig"
,
"DatabricksConfig"
,
]
scripts/constants/__init__.py
View file @
7845d0b6
class
DatabricksConstants
:
class
DatabricksConstants
:
METADATA_INGESTION_JOB_NAME
=
"metadata_ingestion_job"
METADATA_INGESTION_JOB_NAME
=
"metadata_ingestion_job"
METADATA_DELETION_JOB_NAME
=
"metadata_deletion_job"
METADATA_INGESTION_NOTEBOOK_NAME
=
"metadata_ingestion_notebook"
METADATA_INGESTION_NOTEBOOK_NAME
=
"metadata_ingestion_notebook"
METADATA_DELETION_NOTEBOOK_NAME
=
"metadata_deletion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME
=
"timeseries_ingestion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME
=
"timeseries_ingestion_notebook"
class
NotebookConstants
:
METADATA_INGESTION_NOTEBOOK_PATH
=
(
"scripts/constants/notebooks/metadata_ingestion.txt"
)
METADATA_DELETION_NOTEBOOK_PATH
=
(
"scripts/constants/notebooks/metadata_deletion.txt"
)
TIMESERIES_INGESTION_NOTEBOOK_PATH
=
(
"scripts/constants/notebooks/timeseries_ingestion.txt"
)
scripts/constants/notebooks/metadata_deletion.txt
0 → 100644
View file @
7845d0b6
# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")
# COMMAND ----------
def extract_table_info(input_message_str: str, delete_column:str = "id"):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f'Deleting rows: {delete_values}')
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
delete_column: delete_values,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def delete_records_by_ids(table_name, ids, id_column="id"):
"""
Delete records from external table (Delta or Parquet) using list of IDs
Args:
table_name (str): Full table name (catalog.schema.table)
ids (list): List of IDs to delete
id_column (str): Column name containing IDs (default: "id")
Returns:
bool: True if successful, False otherwise
"""
try:
if not ids:
print("No IDs provided")
return False
# Format IDs for SQL IN clause
if isinstance(ids[0], str):
id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
else:
id_values = "(" + ",".join([str(id) for id in ids]) + ")"
# Use Delta table DELETE operation
delta_table = DeltaTable.forName(spark, table_name)
condition = f"{id_column} IN {id_values}"
delta_table.delete(condition=expr(condition))
print(f"Successfully deleted {len(ids)} records from {table_name}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
# COMMAND ----------
table_info = extract_table_info(input_message, delete_column=delete_column)
# COMMAND ----------
result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
scripts/constants/notebooks/metadata_ingestion.txt
View file @
7845d0b6
scripts/constants/notebooks/timeseries_ingestion.txt
View file @
7845d0b6
...
@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
...
@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
from pyspark.sql.types import *
from pyspark.sql.types import *
import json
import json
spark = SparkSession.builder.appName("Streaming
IoT
Pipeline").getOrCreate()
spark = SparkSession.builder.appName("Streaming
Timeseries
Pipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
spark.sparkContext.setLogLevel("WARN")
# COMMAND ----------
# COMMAND ----------
...
@@ -146,4 +146,3 @@ transformed_df.writeStream \
...
@@ -146,4 +146,3 @@ transformed_df.writeStream \
# COMMAND ----------
# COMMAND ----------
scripts/core/handlers/instance_handler.py
View file @
7845d0b6
...
@@ -3,6 +3,7 @@ import json
...
@@ -3,6 +3,7 @@ import json
from
ut_dev_utils
import
get_db_name
from
ut_dev_utils
import
get_db_name
from
scripts.config
import
DatabricksConfig
from
scripts.config
import
DatabricksConfig
from
scripts.constants
import
DatabricksConstants
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.redis.databricks_details
import
databricks_details_db
from
scripts.db.redis.databricks_details
import
databricks_details_db
from
scripts.schemas
import
ModelInstanceSchema
from
scripts.schemas
import
ModelInstanceSchema
...
@@ -12,29 +13,44 @@ class ModelInstanceHandler:
...
@@ -12,29 +13,44 @@ class ModelInstanceHandler:
def
__init__
(
self
,
project_id
:
str
,
payload
:
ModelInstanceSchema
):
def
__init__
(
self
,
project_id
:
str
,
payload
:
ModelInstanceSchema
):
self
.
project_id
=
project_id
self
.
project_id
=
project_id
self
.
payload
=
payload
self
.
payload
=
payload
self
.
catalog_name
=
get_db_name
(
project_id
=
project_id
,
database
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
)
self
.
catalog_name
=
get_db_name
(
project_id
=
project_id
,
database
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
)
self
.
job_manager
=
DatabricksJobManager
(
self
.
job_manager
=
DatabricksJobManager
(
databricks_host
=
payload
.
databricks_host
,
databricks_host
=
payload
.
databricks_host
,
access_token
=
payload
.
databricks_access_token
access_token
=
payload
.
databricks_access_token
,
)
)
def
upload_instances_to_unity_catalog
(
self
):
async
def
upload_instances_to_unity_catalog
(
self
):
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
"metadata_ingestion_job"
)
if
self
.
payload
.
action_type
==
"delete"
:
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
DatabricksConstants
.
METADATA_DELETION_JOB_NAME
)
else
:
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
DatabricksConstants
.
METADATA_INGESTION_JOB_NAME
)
if
not
job_id
:
if
not
job_id
:
raise
ValueError
(
"No job id found for metadata ingestion job, skipping upload to unity catalog"
)
raise
ValueError
(
run_id
=
self
.
job_manager
.
run_job
(
job_id
=
job_id
,
f
"No job id found for {self.payload.action_type}, skipping upload to unity catalog"
parameters
=
{
"input_message"
:
json
.
dumps
(
self
.
get_job_trigger_payload
())})
)
run_id
=
self
.
job_manager
.
run_job
(
job_id
=
job_id
,
parameters
=
{
"input_message"
:
json
.
dumps
(
self
.
get_job_trigger_payload
())},
)
if
not
run_id
:
if
not
run_id
:
raise
ValueError
(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
raise
ValueError
(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def
get_job_trigger_payload
(
self
):
def
get_job_trigger_payload
(
self
):
table_name
=
self
.
payload
.
data
[
0
][
'type'
]
table_name
=
self
.
payload
.
node_type
schema_table
=
f
"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
schema_table
=
f
"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return
{
return
{
"table_properties"
:
{
"table_properties"
:
{
"table_name"
:
f
'{self.catalog_name}.{schema_table}'
,
"table_name"
:
f
"{self.catalog_name}.{schema_table}"
,
"table_path"
:
f
'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}'
,
"table_path"
:
f
"{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}"
,
},
},
"project_id"
:
self
.
project_id
,
"project_id"
:
self
.
project_id
,
"data"
:
self
.
payload
.
data
"data"
:
self
.
payload
.
data
,
}
}
scripts/core/handlers/model_creator_handler.py
View file @
7845d0b6
...
@@ -5,7 +5,7 @@ from sqlalchemy.orm import declarative_base
...
@@ -5,7 +5,7 @@ from sqlalchemy.orm import declarative_base
from
ut_sql_utils.asyncio.declarative_utils
import
DeclarativeUtils
from
ut_sql_utils.asyncio.declarative_utils
import
DeclarativeUtils
from
scripts.config
import
DatabricksConfig
from
scripts.config
import
DatabricksConfig
from
scripts.constants
import
DatabricksConstants
from
scripts.constants
import
DatabricksConstants
,
NotebookConstants
from
scripts.db.databricks
import
DataBricksSQLLayer
from
scripts.db.databricks
import
DataBricksSQLLayer
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.databricks.notebook_manager
import
NotebookManager
from
scripts.db.databricks.notebook_manager
import
NotebookManager
...
@@ -16,23 +16,25 @@ from scripts.utils.model_convertor_utils import ModelConverter
...
@@ -16,23 +16,25 @@ from scripts.utils.model_convertor_utils import ModelConverter
class
ModelCreatorHandler
:
class
ModelCreatorHandler
:
def
__init__
(
self
,
message
:
ModelCreatorSchema
,
declarative_utils
:
DeclarativeUtils
):
def
__init__
(
self
,
message
:
ModelCreatorSchema
,
declarative_utils
:
DeclarativeUtils
):
self
.
declarative_utils
=
declarative_utils
self
.
declarative_utils
=
declarative_utils
self
.
meta
=
message
.
meta
self
.
meta
=
message
.
meta
self
.
message
=
message
self
.
message
=
message
self
.
model_convertor
=
ModelConverter
()
self
.
model_convertor
=
ModelConverter
()
self
.
job_manager
=
DatabricksJobManager
(
self
.
job_manager
=
DatabricksJobManager
(
databricks_host
=
message
.
databricks_host
,
databricks_host
=
message
.
databricks_host
,
access_token
=
message
.
databricks_access_token
access_token
=
message
.
databricks_access_token
,
)
)
self
.
notebook_manager
=
NotebookManager
(
self
.
notebook_manager
=
NotebookManager
(
databricks_host
=
message
.
databricks_host
,
databricks_host
=
message
.
databricks_host
,
access_token
=
message
.
databricks_access_token
access_token
=
message
.
databricks_access_token
,
)
)
self
.
databricks_sql_obj
=
DataBricksSQLLayer
(
self
.
databricks_sql_obj
=
DataBricksSQLLayer
(
catalog_name
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
,
catalog_name
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
,
project_id
=
self
.
meta
.
project_id
,
project_id
=
self
.
meta
.
project_id
,
schema
=
message
.
schema
schema
=
message
.
schema
,
)
)
self
.
external_location
=
self
.
message
.
databricks_storage_path
self
.
external_location
=
self
.
message
.
databricks_storage_path
...
@@ -47,32 +49,37 @@ class ModelCreatorHandler:
...
@@ -47,32 +49,37 @@ class ModelCreatorHandler:
overall_tables
=
self
.
get_overall_tables
()
overall_tables
=
self
.
get_overall_tables
()
project_levels
=
project_template_keys
(
self
.
meta
.
project_id
,
levels
=
True
)
project_levels
=
project_template_keys
(
self
.
meta
.
project_id
,
levels
=
True
)
base
=
self
.
create_schema_base
(
schema_name
=
f
'{self.databricks_sql_obj.catalog_name}.{self.message.schema}'
)
base
=
self
.
create_schema_base
(
schema_name
=
f
"{self.databricks_sql_obj.catalog_name}.{self.message.schema}"
)
try
:
try
:
# self.databricks_sql_obj.connect_to_databricks()
# self.databricks_sql_obj.connect_to_databricks()
_
=
self
.
setup_dependencies_for_unity_catalog
()
_
=
self
.
setup_dependencies_for_unity_catalog
()
table_properties
=
self
.
fetch_table_properties
()
table_properties
=
self
.
fetch_table_properties
()
# for table in overall_tables:
for
table
in
overall_tables
:
# table_class = self.declarative_utils.get_declarative_class(table)
table_class
=
self
.
declarative_utils
.
get_declarative_class
(
table
)
# if not table_class:
if
not
table_class
:
# logging.error(f"Table class not found for table: {table}")
logging
.
error
(
f
"Table class not found for table: {table}"
)
# return False
return
False
# new_model = self.model_convertor.convert_model(
new_model
=
self
.
model_convertor
.
convert_model
(
# table_class,
table_class
,
# base_class=base,
base_class
=
base
,
# new_schema=self.message.schema,
new_schema
=
self
.
message
.
schema
,
# )
)
#
# self.databricks_sql_obj.create_external_table_from_structure(
self
.
databricks_sql_obj
.
create_external_table_from_structure
(
# table=new_model.__table__,
table
=
new_model
.
__table__
,
# file_format="DELTA",
file_format
=
"DELTA"
,
# external_location=self.external_location,
external_location
=
self
.
external_location
,
# table_properties=table_properties
table_properties
=
table_properties
,
# )
)
ts_external_table
=
self
.
databricks_sql_obj
.
create_timeseries_table
(
columns
=
project_levels
,
ts_external_table
=
self
.
databricks_sql_obj
.
create_timeseries_table
(
external_location
=
self
.
external_location
)
columns
=
project_levels
,
external_location
=
self
.
external_location
self
.
setup_notepads_and_jobs
(
timeseries_table_path
=
ts_external_table
,
project_levels
=
project_levels
)
)
self
.
setup_notepads_and_jobs
(
timeseries_table_path
=
ts_external_table
,
project_levels
=
project_levels
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
f
"Error occurred while creating models in Unity Catalog: {e}"
)
logging
.
error
(
f
"Error occurred while creating models in Unity Catalog: {e}"
)
...
@@ -95,20 +102,25 @@ class ModelCreatorHandler:
...
@@ -95,20 +102,25 @@ class ModelCreatorHandler:
analytical (bool): Flag to indicate if the setup is for analytical or not
analytical (bool): Flag to indicate if the setup is for analytical or not
"""
"""
logging
.
info
(
logging
.
info
(
f
"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
f
"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
self
.
databricks_sql_obj
.
connect_to_databricks
()
self
.
databricks_sql_obj
.
connect_to_databricks
()
# Create catalog
# Create catalog
catalog_success
=
self
.
databricks_sql_obj
.
create_catalog
(
catalog_success
=
self
.
databricks_sql_obj
.
create_catalog
(
managed_location
=
f
'{self.external_location}/{self.databricks_sql_obj.catalog_name}'
,
managed_location
=
f
"{self.external_location}/{self.databricks_sql_obj.catalog_name}"
,
)
)
if
not
catalog_success
:
if
not
catalog_success
:
return
False
return
False
# Create schema
# Create schema
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
)
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
)
if
not
schema_success
:
if
not
schema_success
:
return
False
return
False
if
analytical
:
if
analytical
:
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_ANALYTICAL_SCHEMA_NAME
)
schema_success
=
self
.
databricks_sql_obj
.
create_schema
(
DatabricksConfig
.
DATABRICKS_ANALYTICAL_SCHEMA_NAME
)
if
not
schema_success
:
if
not
schema_success
:
return
False
return
False
return
True
return
True
...
@@ -120,59 +132,112 @@ class ModelCreatorHandler:
...
@@ -120,59 +132,112 @@ class ModelCreatorHandler:
project_levels: List of project levels
project_levels: List of project levels
"""
"""
logging
.
info
(
"Setting up notepads and jobs"
)
logging
.
info
(
"Setting up notepads and jobs"
)
with
open
(
r"scripts/constants/notebooks/metadata_ingestion.txt"
,
"r"
)
as
f
:
meta_ingestion_notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}"
notebook_code
=
f
.
read
()
meta_deletion_notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_NOTEBOOK_NAME}"
timeseries_notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}"
# # Notebook for metadata ingestion
# self.notebook_manager.create_notebook(
# notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}",
# notebook_code=notebook_code,
# overwrite=True
# )
# # Job for metadata ingestion used by model management
# job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
# job_name=f'{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}',
# notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
# ))
#
# databricks_details_db.hset(self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME, job_id)
# Timeseries DataPush Notebook
with
open
(
r"scripts/constants/notebooks/timeseries_ingestion.txt"
,
"r"
)
as
f
:
notebook_code_for_timeseries
=
f
.
read
()
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{timeseries_table_path}}"
,
f
'"{timeseries_table_path}"'
)
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{project_levels}}"
,
str
(
len
(
project_levels
)
-
1
))
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{event_hub_connection_string}}"
,
f
'"{self.meta.project_id}"'
)
self
.
notebook_manager
.
create_notebook
(
# Setting up of Metadata Ingestion Notebook
notebook_path
=
f
"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}"
,
existing_job_id
=
databricks_details_db
.
hget
(
notebook_code
=
notebook_code_for_timeseries
,
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_INGESTION_JOB_NAME
overwrite
=
True
)
if
not
existing_job_id
:
self
.
create_notebook
(
notebook_path
=
meta_ingestion_notebook_path
,
source_notebook_path
=
NotebookConstants
.
METADATA_INGESTION_NOTEBOOK_PATH
,
)
ingestion_job_id
=
self
.
create_job
(
job_name
=
f
"{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}"
,
notebook_path
=
meta_ingestion_notebook_path
,
)
databricks_details_db
.
hset
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_INGESTION_JOB_NAME
,
ingestion_job_id
,
)
existing_job_id
=
databricks_details_db
.
hget
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_DELETION_JOB_NAME
)
if
not
existing_job_id
:
# Setting up of Metadata Deletion Notebook
self
.
create_notebook
(
notebook_path
=
meta_deletion_notebook_path
,
source_notebook_path
=
NotebookConstants
.
METADATA_DELETION_NOTEBOOK_PATH
,
)
deletion_job_id
=
self
.
create_job
(
job_name
=
f
"{self.meta.project_id}_{DatabricksConstants.METADATA_DELETION_JOB_NAME}"
,
notebook_path
=
meta_deletion_notebook_path
,
)
databricks_details_db
.
hset
(
self
.
meta
.
project_id
,
DatabricksConstants
.
METADATA_DELETION_JOB_NAME
,
deletion_job_id
,
)
# Setting up of Timeseries Ingestion Notebook
replace_mapping
=
{
"{{timeseries_table_path}}"
:
f
'"{timeseries_table_path}"'
,
"{{project_levels}}"
:
str
(
len
(
project_levels
)
-
1
),
"{{event_hub_connection_string}}"
:
f
'"{self.meta.project_id}"'
,
}
self
.
create_notebook
(
notebook_path
=
timeseries_notebook_path
,
source_notebook_path
=
NotebookConstants
.
TIMESERIES_INGESTION_NOTEBOOK_PATH
,
replace_mapping
=
replace_mapping
,
)
)
@
staticmethod
@
staticmethod
def
fetch_table_properties
(
file_format
:
str
=
'DELTA'
):
def
fetch_table_properties
(
file_format
:
str
=
"DELTA"
):
if
file_format
.
lower
()
==
'delta'
:
if
file_format
.
lower
()
==
"delta"
:
return
{
return
{
# Performance optimization (Essential)
# Performance optimization (Essential)
"delta.autoOptimize.optimizeWrite"
:
"true"
,
"delta.autoOptimize.optimizeWrite"
:
"true"
,
"delta.autoOptimize.autoCompact"
:
"true"
,
"delta.autoOptimize.autoCompact"
:
"true"
,
"delta.targetFileSize"
:
"134217728"
,
# 128MB
"delta.targetFileSize"
:
"134217728"
,
# 128MB
'delta.enableChangeDataFeed'
:
'true'
,
# If you need CDC
"delta.enableChangeDataFeed"
:
"true"
,
# If you need CDC
# Checkpoint optimization (Performance boost)
# Checkpoint optimization (Performance boost)
"delta.checkpoint.writeStatsAsStruct"
:
"true"
,
"delta.checkpoint.writeStatsAsStruct"
:
"true"
,
"delta.checkpoint.writeStatsAsJson"
:
"false"
"delta.checkpoint.writeStatsAsJson"
:
"false"
,
# Note: Retention properties removed - using defaults:
# Note: Retention properties removed - using defaults:
# delta.deletedFileRetentionDuration = 7 days (default)
# delta.deletedFileRetentionDuration = 7 days (default)
# delta.logRetentionDuration = 30 days (default)
# delta.logRetentionDuration = 30 days (default)
}
}
elif
file_format
.
lower
()
==
'parquet'
:
elif
file_format
.
lower
()
==
"parquet"
:
return
{
"parquet.compression"
:
"snappy"
,
return
{
"parquet.compression"
:
"snappy"
,
"parquet.page.size"
:
"1048576"
,
# 1MB - standard for mixed queries
"parquet.page.size"
:
"1048576"
,
# 1MB - standard for mixed queries
"parquet.block.size"
:
"134217728"
,
# 128MB - balanced performance
"parquet.block.size"
:
"134217728"
,
# 128MB - balanced performance
"serialization.format"
:
"1"
}
"serialization.format"
:
"1"
,
}
else
:
else
:
return
{}
return
{}
@
staticmethod
def
read_data_from_file
(
note_path
:
str
):
with
open
(
note_path
)
as
f
:
notebook_code
=
f
.
read
()
return
notebook_code
def
create_notebook
(
self
,
notebook_path
:
str
,
source_notebook_path
:
str
,
replace_mapping
:
dict
=
None
,
):
logging
.
info
(
f
"Creating notebook {notebook_path}"
)
notebook_code
=
self
.
read_data_from_file
(
source_notebook_path
)
if
replace_mapping
is
not
None
:
for
key
,
value
in
replace_mapping
.
items
():
notebook_code
=
notebook_code
.
replace
(
key
,
value
)
self
.
notebook_manager
.
create_notebook
(
notebook_path
=
notebook_path
,
notebook_code
=
notebook_code
,
overwrite
=
True
)
return
True
def
create_job
(
self
,
job_name
:
str
,
notebook_path
:
str
):
logging
.
info
(
f
"Creating job {job_name}"
)
job_id
=
self
.
job_manager
.
create_job
(
job_config
=
self
.
job_manager
.
create_job_config_for_serverless
(
job_name
=
job_name
,
notebook_path
=
notebook_path
,
)
)
return
job_id
scripts/db/databricks/__init__.py
View file @
7845d0b6
from
typing
import
Dict
,
List
from
typing
import
Dict
,
List
from
sqlalchemy
import
Table
,
Column
,
String
,
BigInteger
,
DateTime
,
MetaData
,
Integer
,
Date
from
sqlalchemy
import
(
BigInteger
,
Column
,
Date
,
DateTime
,
Integer
,
MetaData
,
String
,
Table
,
)
from
scripts.utils.databricks_utils
import
DatabricksSQLUtility
from
scripts.utils.databricks_utils
import
DatabricksSQLUtility
from
scripts.utils.model_convertor_utils
import
TypeMapper
from
scripts.utils.model_convertor_utils
import
TypeMapper
...
@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
...
@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
super
()
.
__init__
(
catalog_name
,
project_id
)
super
()
.
__init__
(
catalog_name
,
project_id
)
self
.
schema
=
schema
self
.
schema
=
schema
def
create_external_table_from_structure
(
self
,
table
:
Table
,
def
create_external_table_from_structure
(
self
,
table
:
Table
,
external_location
:
str
,
external_location
:
str
,
file_format
:
str
=
"PARQUET"
,
file_format
:
str
=
"PARQUET"
,
table_properties
:
Dict
[
str
,
str
]
=
None
,
table_properties
:
Dict
[
str
,
str
]
=
None
,
partition_columns
:
list
=
None
)
->
str
:
partition_columns
:
list
=
None
,
)
->
str
:
"""
"""
Create an external table from a model class.
Create an external table from a model class.
...
@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
...
@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
"""
schema_table
=
f
"{table.schema}.{table.name}"
if
table
.
schema
else
table
.
name
schema_table
=
f
"{table.schema}.{table.name}"
if
table
.
schema
else
table
.
name
columns_sql
=
TypeMapper
()
.
extract_columns_without_constraints
(
table
)
columns_sql
=
TypeMapper
()
.
extract_columns_without_constraints
(
table
)
external_location
=
f
"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
external_location
=
(
f
"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
)
sql_parts
=
[
sql_parts
=
[
f
"CREATE TABLE IF NOT EXISTS {schema_table}"
,
f
"CREATE TABLE IF NOT EXISTS {schema_table}"
,
f
"({columns_sql})"
,
f
"({columns_sql})"
,
f
"USING {file_format}"
,
f
"USING {file_format}"
,
f
"LOCATION '{external_location}'"
f
"LOCATION '{external_location}'"
,
]
]
if
partition_columns
:
if
partition_columns
:
partition_clause
=
", "
.
join
(
partition_columns
)
partition_clause
=
", "
.
join
(
partition_columns
)
...
@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
...
@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
"""
table_columns
=
[
table_columns
=
[
Column
(
'timestamp'
,
BigInteger
,
nullable
=
False
),
Column
(
"timestamp"
,
BigInteger
,
nullable
=
False
),
Column
(
'dt_timestamp'
,
DateTime
,
nullable
=
False
),
Column
(
"dt_timestamp"
,
DateTime
,
nullable
=
False
),
Column
(
'dt_date'
,
Date
,
nullable
=
False
),
Column
(
"dt_date"
,
Date
,
nullable
=
False
),
Column
(
'dt_hour'
,
Integer
,
nullable
=
False
),
Column
(
"dt_hour"
,
Integer
,
nullable
=
False
),
Column
(
'value'
,
String
,
nullable
=
False
),
Column
(
"value"
,
String
,
nullable
=
False
),
Column
(
'value_type'
,
String
,
nullable
=
False
,
default
=
'float'
),
Column
(
"value_type"
,
String
,
nullable
=
False
,
default
=
"float"
),
Column
(
"c3"
,
String
,
nullable
=
False
)
Column
(
"c3"
,
String
,
nullable
=
False
)
,
]
]
default_columns
=
[
"c1"
,
"c5"
,
"Q"
,
"T"
,
"D"
,
"P"
,
"A"
,
"B"
,
*
columns
]
default_columns
=
[
"c1"
,
"c5"
,
"Q"
,
"T"
,
"D"
,
"P"
,
"A"
,
"B"
,
*
columns
]
table_columns
.
extend
([
Column
(
col_name
,
String
,
nullable
=
True
)
for
col_name
in
default_columns
])
table_columns
.
extend
(
partition_columns
=
[
'dt_date'
,
'dt_hour'
,
'c3'
]
[
Column
(
col_name
,
String
,
nullable
=
True
)
for
col_name
in
default_columns
]
)
partition_columns
=
[
"dt_date"
,
"dt_hour"
,
"c3"
]
table_properties
=
{
table_properties
=
{
"parquet.compression"
:
"snappy"
,
# Fast decompression for frequent queries
"parquet.compression"
:
"snappy"
,
# Fast decompression for frequent queries
"parquet.page.size"
:
"524288"
,
# 512KB - better time-range filtering
"parquet.page.size"
:
"524288"
,
# 512KB - better time-range filtering
"parquet.block.size"
:
"268435456"
,
# 256MB - efficient sequential reads
"parquet.block.size"
:
"268435456"
,
# 256MB - efficient sequential reads
"serialization.format"
:
"1"
# Support for arrays/complex types
"serialization.format"
:
"1"
,
# Support for arrays/complex types
}
}
table_obj
=
Table
(
table_obj
=
Table
(
"timeseries_data"
,
"timeseries_data"
,
MetaData
(),
*
table_columns
,
schema
=
self
.
schema
MetaData
(),
*
table_columns
,
schema
=
self
.
schema
)
)
self
.
create_external_table_from_structure
(
self
.
create_external_table_from_structure
(
table
=
table_obj
,
table
=
table_obj
,
external_location
=
external_location
,
external_location
=
external_location
,
partition_columns
=
partition_columns
,
partition_columns
=
partition_columns
,
table_properties
=
table_properties
table_properties
=
table_properties
,
)
)
return
external_location
return
external_location
scripts/db/databricks/job_manager.py
View file @
7845d0b6
...
@@ -14,10 +14,14 @@ class DatabricksJobManager:
...
@@ -14,10 +14,14 @@ class DatabricksJobManager:
databricks_host: Your Databricks workspace URL
databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token
access_token: Personal access token or service principal token
"""
"""
self
.
host
=
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
self
.
host
=
(
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
)
self
.
headers
=
{
self
.
headers
=
{
'Authorization'
:
f
'Bearer {access_token}'
,
"Authorization"
:
f
"Bearer {access_token}"
,
'Content-Type'
:
'application/json'
"Content-Type"
:
"application/json"
,
}
}
def
create_job
(
self
,
job_config
:
dict
):
def
create_job
(
self
,
job_config
:
dict
):
...
@@ -32,11 +36,13 @@ class DatabricksJobManager:
...
@@ -32,11 +36,13 @@ class DatabricksJobManager:
response
=
HTTPXRequestUtil
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
job_config
)
response
=
HTTPXRequestUtil
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
job_config
)
if
response
.
status_code
==
200
:
if
response
.
status_code
==
200
:
job_id
=
response
.
json
()[
'job_id'
]
job_id
=
response
.
json
()[
"job_id"
]
logging
.
info
(
f
"Job created successfully with ID: {job_id}"
)
logging
.
info
(
f
"Job created successfully with ID: {job_id}"
)
return
job_id
return
job_id
else
:
else
:
logging
.
error
(
f
"Error creating job: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error creating job: {response.status_code} - {response.text}"
)
return
None
return
None
def
run_job
(
self
,
job_id
:
str
,
parameters
=
None
):
def
run_job
(
self
,
job_id
:
str
,
parameters
=
None
):
...
@@ -57,11 +63,13 @@ class DatabricksJobManager:
...
@@ -57,11 +63,13 @@ class DatabricksJobManager:
response
=
HTTPXRequestUtil
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
payload
)
response
=
HTTPXRequestUtil
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
payload
)
if
response
.
status_code
==
200
:
if
response
.
status_code
==
200
:
run_id
=
response
.
json
()[
'run_id'
]
run_id
=
response
.
json
()[
"run_id"
]
logging
.
info
(
f
"Job run started with ID: {run_id}"
)
logging
.
info
(
f
"Job run started with ID: {run_id}"
)
return
run_id
return
run_id
else
:
else
:
logging
.
error
(
f
"Error running job: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error running job: {response.status_code} - {response.text}"
)
return
None
return
None
def
get_run_status
(
self
,
run_id
):
def
get_run_status
(
self
,
run_id
):
...
@@ -73,12 +81,16 @@ class DatabricksJobManager:
...
@@ -73,12 +81,16 @@ class DatabricksJobManager:
url
=
f
"{self.host}/api/2.1/jobs/runs/get"
url
=
f
"{self.host}/api/2.1/jobs/runs/get"
params
=
{
"run_id"
:
run_id
}
params
=
{
"run_id"
:
run_id
}
response
=
HTTPXRequestHandler
(
url
)
.
get
(
url
,
headers
=
self
.
headers
,
params
=
params
)
response
=
HTTPXRequestHandler
(
url
)
.
get
(
url
,
headers
=
self
.
headers
,
params
=
params
)
if
response
.
status_code
==
200
:
if
response
.
status_code
==
200
:
return
response
.
json
()
return
response
.
json
()
else
:
else
:
logging
.
error
(
f
"Error getting run status: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error getting run status: {response.status_code} - {response.text}"
)
return
None
return
None
@
staticmethod
@
staticmethod
...
@@ -98,16 +110,18 @@ class DatabricksJobManager:
...
@@ -98,16 +110,18 @@ class DatabricksJobManager:
"task_key"
:
"table_update_task"
,
"task_key"
:
"table_update_task"
,
"notebook_task"
:
{
"notebook_task"
:
{
"notebook_path"
:
notebook_path
,
"notebook_path"
:
notebook_path
,
"base_parameters"
:
{
"base_parameters"
:
{
"input_message"
:
"default_value"
},
"input_message"
:
"default_value"
}
},
},
"timeout_seconds"
:
3600
"timeout_seconds"
:
3600
,
}
}
],
],
"max_concurrent_runs"
:
10
,
"max_concurrent_runs"
:
10
,
"tags"
:
{
"tags"
:
{
"purpose"
:
"metadata_ingestion"
,
"purpose"
:
(
"compute_type"
:
"serverless"
"metadata_ingestion"
}
if
"ingestion"
in
job_name
else
"metadata_deletion"
),
"compute_type"
:
"serverless"
,
},
}
}
scripts/db/databricks/notebook_manager.py
View file @
7845d0b6
...
@@ -13,13 +13,19 @@ class NotebookManager:
...
@@ -13,13 +13,19 @@ class NotebookManager:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token
access_token: Personal access token or service principal token
"""
"""
self
.
host
=
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
self
.
host
=
(
databricks_host
if
"https://"
in
databricks_host
else
f
"https://{databricks_host}"
)
self
.
headers
=
{
self
.
headers
=
{
'Authorization'
:
f
'Bearer {access_token}'
,
"Authorization"
:
f
"Bearer {access_token}"
,
'Content-Type'
:
'application/json'
"Content-Type"
:
"application/json"
,
}
}
def
create_notebook
(
self
,
notebook_path
,
notebook_code
:
str
,
language
=
'PYTHON'
,
overwrite
=
True
):
def
create_notebook
(
self
,
notebook_path
,
notebook_code
:
str
,
language
=
"PYTHON"
,
overwrite
=
True
):
"""
"""
Create a notebook in Databricks workspace
Create a notebook in Databricks workspace
...
@@ -31,18 +37,22 @@ class NotebookManager:
...
@@ -31,18 +37,22 @@ class NotebookManager:
"""
"""
url
=
f
"{self.host}/api/2.0/workspace/import"
url
=
f
"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64
# Encode the notebook content in base64
encoded_content
=
base64
.
b64encode
(
notebook_code
.
encode
(
'utf-8'
))
.
decode
(
'utf-8'
)
encoded_content
=
base64
.
b64encode
(
notebook_code
.
encode
(
"utf-8"
))
.
decode
(
"utf-8"
)
payload
=
{
payload
=
{
"path"
:
notebook_path
,
"path"
:
notebook_path
,
"format"
:
"SOURCE"
,
"format"
:
"SOURCE"
,
"language"
:
language
,
"language"
:
language
,
"content"
:
encoded_content
,
"content"
:
encoded_content
,
"overwrite"
:
overwrite
"overwrite"
:
overwrite
,
}
}
response
=
HTTPXRequestUtil
(
url
=
url
)
.
post
(
json
=
payload
,
headers
=
self
.
headers
)
response
=
HTTPXRequestUtil
(
url
=
url
)
.
post
(
json
=
payload
,
headers
=
self
.
headers
)
if
response
.
status_code
==
200
:
if
response
.
status_code
==
200
:
logging
.
info
(
f
"Notebook created successfully at: {notebook_path}"
)
logging
.
info
(
f
"Notebook created successfully at: {notebook_path}"
)
return
True
return
True
else
:
else
:
logging
.
error
(
f
"Error creating notebook: {response.status_code} - {response.text}"
)
logging
.
error
(
f
"Error creating notebook: {response.status_code} - {response.text}"
)
return
False
return
False
scripts/db/redis/databricks_details.py
View file @
7845d0b6
import
orjson
from
scripts.config
import
RedisConfig
from
scripts.config
import
RedisConfig
from
scripts.db.redis
import
redis_connector
from
scripts.db.redis
import
redis_connector
databricks_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_DATABRICKS_DB
,
decode_responses
=
True
)
databricks_details_db
=
redis_connector
.
connect
(
\ No newline at end of file
db
=
RedisConfig
.
REDIS_DATABRICKS_DB
,
decode_responses
=
True
)
scripts/db/redis/graphql.py
View file @
7845d0b6
...
@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
...
@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
from
scripts.config
import
RedisConfig
from
scripts.config
import
RedisConfig
from
scripts.db.redis
import
redis_connector
from
scripts.db.redis
import
redis_connector
graphql_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_GRAPHQL_DB
,
decode_responses
=
True
)
graphql_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_GRAPHQL_DB
,
decode_responses
=
True
)
def
get_models
(
def
get_models
(
...
@@ -38,7 +40,9 @@ def get_models(
...
@@ -38,7 +40,9 @@ def get_models(
"""
"""
tables_data
=
graphql_details_db
.
hget
(
info
.
data
[
"project_id"
],
"schema_mapper"
)
tables_data
=
graphql_details_db
.
hget
(
info
.
data
[
"project_id"
],
"schema_mapper"
)
if
tables_data
is
None
:
if
tables_data
is
None
:
raise
ILensErrors
(
f
"No GraphQL schema data found for project {info.data['project_id']}"
)
raise
ILensErrors
(
f
"No GraphQL schema data found for project {info.data['project_id']}"
)
tables
:
Dict
[
str
,
Any
]
=
orjson
.
loads
(
tables_data
)
or
{}
tables
:
Dict
[
str
,
Any
]
=
orjson
.
loads
(
tables_data
)
or
{}
if
(
if
(
...
...
scripts/db/redis/project_details.py
View file @
7845d0b6
...
@@ -3,7 +3,9 @@ import orjson
...
@@ -3,7 +3,9 @@ import orjson
from
scripts.config
import
RedisConfig
from
scripts.config
import
RedisConfig
from
scripts.db.redis
import
redis_connector
from
scripts.db.redis
import
redis_connector
project_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_PROJECT_TAGS_DB
,
decode_responses
=
True
)
project_details_db
=
redis_connector
.
connect
(
db
=
RedisConfig
.
REDIS_PROJECT_TAGS_DB
,
decode_responses
=
True
)
def
get_project_time_zone
(
project_id
:
str
):
def
get_project_time_zone
(
project_id
:
str
):
...
@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
...
@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
return
"UTC"
return
"UTC"
def
fetch_level_details
(
project_id
:
str
,
keys
:
bool
=
False
,
raw
:
bool
=
False
)
->
dict
|
list
:
def
fetch_level_details
(
project_id
:
str
,
keys
:
bool
=
False
,
raw
:
bool
=
False
)
->
dict
|
list
:
"""
"""
Function to fetch level details from project details
Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details
Uses redis project details cache db (db18) and fetches the level details
...
@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
...
@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
project_details
=
orjson
.
loads
(
project_details
)
project_details
=
orjson
.
loads
(
project_details
)
counter_levels
=
project_details
.
get
(
"counter_levels"
,
{})
counter_levels
=
project_details
.
get
(
"counter_levels"
,
{})
asset_level
=
(
asset_level
=
(
counter_levels
.
get
(
"asset"
,
counter_levels
.
get
(
"equipment"
))
if
isinstance
(
counter_levels
,
dict
)
else
None
counter_levels
.
get
(
"asset"
,
counter_levels
.
get
(
"equipment"
))
if
isinstance
(
counter_levels
,
dict
)
else
None
)
)
if
asset_level
:
if
asset_level
:
return
asset_level
return
asset_level
...
@@ -81,6 +87,7 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
...
@@ -81,6 +87,7 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
swapped_dict
=
{
v
:
k
for
k
,
v
in
counter_levels
.
items
()}
swapped_dict
=
{
v
:
k
for
k
,
v
in
counter_levels
.
items
()}
return
swapped_dict
.
get
(
"ast"
,
""
),
"ast"
return
swapped_dict
.
get
(
"ast"
,
""
),
"ast"
def
project_template_keys
(
project_id
:
str
,
levels
=
False
):
def
project_template_keys
(
project_id
:
str
,
levels
=
False
):
val
=
project_details_db
.
get
(
project_id
)
val
=
project_details_db
.
get
(
project_id
)
if
val
is
None
:
if
val
is
None
:
...
...
scripts/engines/agents/__init__.py
View file @
7845d0b6
...
@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
...
@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
from
scripts.config
import
KafkaConfig
from
scripts.config
import
KafkaConfig
broker
=
KafkaBroker
(
f
'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}'
,
client_id
=
"model_creator_agent"
)
broker
=
KafkaBroker
(
f
"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}"
,
client_id
=
"model_creator_agent"
,
)
scripts/engines/agents/model_creator_agent.py
View file @
7845d0b6
...
@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
...
@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class
ModelCreatorAgent
:
class
ModelCreatorAgent
:
def
__init__
(
self
):
def
__init__
(
self
):
...
...
@
staticmethod
@
staticmethod
async
def
model_creator_agent
(
message
:
ModelCreatorSchema
):
async
def
model_creator_agent
(
message
:
ModelCreatorSchema
):
...
@@ -18,10 +17,14 @@ class ModelCreatorAgent:
...
@@ -18,10 +17,14 @@ class ModelCreatorAgent:
session_manager
=
session_manager
,
session_manager
=
session_manager
,
schema
=
message
.
schema
,
schema
=
message
.
schema
,
)
)
model_cal_obj
=
ModelCreatorHandler
(
message
=
message
,
declarative_utils
=
declarative_utils
)
model_cal_obj
=
ModelCreatorHandler
(
message
=
message
,
declarative_utils
=
declarative_utils
)
await
model_cal_obj
.
create_models_in_unity_catalog
()
await
model_cal_obj
.
create_models_in_unity_catalog
()
@
staticmethod
@
staticmethod
async
def
model_instance_agent
(
message
:
ModelInstanceSchema
):
async
def
model_instance_agent
(
message
:
ModelInstanceSchema
):
model_instance_obj
=
ModelInstanceHandler
(
project_id
=
message
.
project_id
,
payload
=
message
)
model_instance_obj
=
ModelInstanceHandler
(
project_id
=
message
.
project_id
,
payload
=
message
)
await
model_instance_obj
.
upload_instances_to_unity_catalog
()
await
model_instance_obj
.
upload_instances_to_unity_catalog
()
scripts/schemas/__init__.py
View file @
7845d0b6
from
typing
import
Optional
,
Union
,
Dict
,
Any
,
List
from
typing
import
Any
,
Dict
,
List
,
Optional
,
Union
from
pydantic
import
BaseModel
,
model_validator
from
pydantic
import
BaseModel
,
Field
,
model_validator
from
ut_security_util
import
MetaInfoSchema
from
ut_security_util
import
MetaInfoSchema
from
scripts.config
import
DatabricksConfig
from
scripts.config
import
DatabricksConfig
...
@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
...
@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
class
ModelInstanceSchema
(
BaseModel
):
class
ModelInstanceSchema
(
BaseModel
):
data
:
Union
[
Dict
[
str
,
Any
],
List
[
Dict
[
str
,
Any
]]]
data
:
Union
[
Dict
[
str
,
Any
],
List
[
Dict
[
str
,
Any
]]]
project_id
:
str
project_id
:
str
schema
:
Optional
[
str
]
=
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
action_type
:
str
=
"save"
node_type
:
str
sql_schema
:
Optional
[
str
]
=
Field
(
default
=
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
,
alias
=
"schema"
)
databricks_host
:
str
=
DatabricksConfig
.
DATABRICKS_HOST
databricks_host
:
str
=
DatabricksConfig
.
DATABRICKS_HOST
databricks_port
:
int
=
DatabricksConfig
.
DATABRICKS_PORT
databricks_port
:
int
=
DatabricksConfig
.
DATABRICKS_PORT
databricks_access_token
:
str
=
DatabricksConfig
.
DATABRICKS_ACCESS_TOKEN
databricks_access_token
:
str
=
DatabricksConfig
.
DATABRICKS_ACCESS_TOKEN
...
@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
...
@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
@
model_validator
(
mode
=
"before"
)
@
model_validator
(
mode
=
"before"
)
def
validate_data
(
cls
,
values
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
def
validate_data
(
cls
,
values
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
if
'data'
in
values
and
isinstance
(
values
[
'data'
],
dict
):
if
"data"
in
values
and
isinstance
(
values
[
"data"
],
dict
):
values
[
'data'
]
=
[
values
[
'data'
]]
values
[
"data"
]
=
[
values
[
"data"
]]
return
values
return
values
scripts/utils/databricks_utils.py
View file @
7845d0b6
...
@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
...
@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
DatabricksConfig
.
DATABRICKS_URI
,
DatabricksConfig
.
DATABRICKS_URI
,
pool_pre_ping
=
True
,
pool_pre_ping
=
True
,
pool_recycle
=
3600
,
pool_recycle
=
3600
,
echo
=
False
echo
=
False
,
)
)
# Test connection
# Test connection
with
self
.
engine
.
connect
()
as
conn
:
with
self
.
engine
.
connect
()
as
conn
:
result
=
conn
.
execute
(
text
(
"SELECT current_user() as user, current_catalog() as catalog"
))
result
=
conn
.
execute
(
text
(
"SELECT current_user() as user, current_catalog() as catalog"
)
)
user_info
=
result
.
fetchone
()
user_info
=
result
.
fetchone
()
logger
.
info
(
f
"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger
.
info
(
f
"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger
.
info
(
"Successfully connected to Databricks"
)
logger
.
info
(
"Successfully connected to Databricks"
)
return
True
return
True
except
Exception
as
e
:
except
Exception
as
e
:
...
@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
...
@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
if
self
.
engine
:
if
self
.
engine
:
self
.
engine
.
dispose
()
self
.
engine
.
dispose
()
def
create_catalog
(
self
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
def
create_catalog
(
properties
:
Optional
[
dict
]
=
None
):
self
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
properties
:
Optional
[
dict
]
=
None
,
):
"""
"""
Create a new catalog in Unity Catalog
Create a new catalog in Unity Catalog
Args:
Args:
...
@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
...
@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
logger
.
error
(
f
"Failed to create catalog '{self.catalog_name}': {str(e)}"
)
logger
.
error
(
f
"Failed to create catalog '{self.catalog_name}': {str(e)}"
)
raise
raise
def
create_schema
(
self
,
schema_name
:
str
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
def
create_schema
(
properties
:
Optional
[
dict
]
=
None
):
self
,
schema_name
:
str
,
managed_location
:
Optional
[
str
]
=
None
,
comment
:
Optional
[
str
]
=
None
,
properties
:
Optional
[
dict
]
=
None
,
):
"""
"""
Create a new schema within a catalog
Create a new schema within a catalog
Args:
Args:
...
@@ -103,10 +116,14 @@ class DatabricksSQLUtility:
...
@@ -103,10 +116,14 @@ class DatabricksSQLUtility:
ddl
+=
f
"
\n
WITH DBPROPERTIES ({props})"
ddl
+=
f
"
\n
WITH DBPROPERTIES ({props})"
self
.
execute_sql_statement
(
ddl
)
self
.
execute_sql_statement
(
ddl
)
logger
.
info
(
f
"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
logger
.
info
(
f
"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
return
full_schema_name
return
full_schema_name
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
logger
.
error
(
f
"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
raise
raise
def
create_external_location
(
def
create_external_location
(
...
@@ -114,7 +131,7 @@ class DatabricksSQLUtility:
...
@@ -114,7 +131,7 @@ class DatabricksSQLUtility:
location_name
:
str
,
location_name
:
str
,
storage_path
:
str
,
storage_path
:
str
,
credential_name
:
str
,
credential_name
:
str
,
comment
:
Optional
[
str
]
=
None
comment
:
Optional
[
str
]
=
None
,
)
->
str
:
)
->
str
:
"""
"""
Create an external location in Unity Catalog
Create an external location in Unity Catalog
...
@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
...
@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
logger
.
info
(
f
"External location '{location_name}' created successfully"
)
logger
.
info
(
f
"External location '{location_name}' created successfully"
)
return
location_name
return
location_name
except
Exception
as
e
:
except
Exception
as
e
:
logger
.
error
(
f
"Failed to create external location '{location_name}': {str(e)}"
)
logger
.
error
(
f
"Failed to create external location '{location_name}': {str(e)}"
)
raise
raise
def
execute_sql_statement
(
self
,
query
:
str
):
def
execute_sql_statement
(
self
,
query
:
str
):
...
...
scripts/utils/httpx_util.py
View file @
7845d0b6
...
@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
...
@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
def
delete
(
self
,
path
=
""
,
params
=
None
,
**
kwargs
)
->
httpx
.
Response
:
def
delete
(
self
,
path
=
""
,
params
=
None
,
**
kwargs
)
->
httpx
.
Response
:
url
=
self
.
get_url
(
path
)
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
delete
(
url
=
url
,
params
=
params
)
response
:
httpx
.
Response
=
client
.
delete
(
url
=
url
,
params
=
params
)
return
response
return
response
...
@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
...
@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
url
=
self
.
get_url
(
path
)
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
put
(
url
=
url
,
data
=
data
,
json
=
json
)
response
:
httpx
.
Response
=
client
.
put
(
url
=
url
,
data
=
data
,
json
=
json
)
return
response
return
response
...
@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
...
@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
"""
"""
url
=
self
.
get_url
(
path
)
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
post
(
url
=
url
,
data
=
data
,
json
=
json
)
response
:
httpx
.
Response
=
client
.
post
(
url
=
url
,
data
=
data
,
json
=
json
)
return
response
return
response
...
@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
...
@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
url
=
self
.
get_url
(
path
)
url
=
self
.
get_url
(
path
)
logging
.
info
(
url
)
logging
.
info
(
url
)
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
))
as
client
:
with
httpx
.
Client
(
verify
=
self
.
verify
,
headers
=
kwargs
.
get
(
"headers"
),
cookies
=
kwargs
.
get
(
"cookies"
),
)
as
client
:
response
:
httpx
.
Response
=
client
.
get
(
url
=
url
,
params
=
params
)
response
:
httpx
.
Response
=
client
.
get
(
url
=
url
,
params
=
params
)
return
response
return
response
...
...
scripts/utils/model_convertor_utils.py
View file @
7845d0b6
import
logging
import
logging
from
typing
import
Any
,
Type
,
Optional
,
Dict
,
Union
,
Tuple
from
typing
import
Any
,
Dict
,
Optional
,
Tuple
,
Type
,
Union
from
databricks.sqlalchemy
import
TIMESTAMP
from
databricks.sqlalchemy
import
TIMESTAMP
from
sqlalchemy
import
(
from
sqlalchemy
import
(
Integer
,
BigInteger
,
SmallInteger
,
String
,
Text
,
Boolean
,
CHAR
,
Date
,
DateTime
,
Time
,
Numeric
,
Float
,
DECIMAL
,
CHAR
,
VARCHAR
,
DECIMAL
,
LargeBinary
,
JSON
,
Column
,
PrimaryKeyConstraint
,
UniqueConstraint
,
ForeignKeyConstraint
,
Double
,
Index
,
Table
JSON
,
VARCHAR
,
BigInteger
,
Boolean
,
Column
,
Date
,
DateTime
,
Double
,
Float
,
ForeignKeyConstraint
,
Index
,
Integer
,
LargeBinary
,
Numeric
,
PrimaryKeyConstraint
,
SmallInteger
,
String
,
Table
,
Text
,
Time
,
UniqueConstraint
,
)
)
from
sqlalchemy.dialects
import
postgresql
from
sqlalchemy.dialects
import
postgresql
from
sqlalchemy.orm
import
class_mapper
,
mapped_column
,
Mapped
from
sqlalchemy.orm
import
Mapped
,
class_mapper
,
mapped_column
from
sqlalchemy.types
import
UserDefinedType
from
sqlalchemy.types
import
UserDefinedType
...
@@ -83,7 +103,6 @@ class TypeMapper:
...
@@ -83,7 +103,6 @@ class TypeMapper:
VARCHAR
:
VARCHAR
,
VARCHAR
:
VARCHAR
,
Text
:
String
,
Text
:
String
,
String
:
String
,
String
:
String
,
# DateTime types
# DateTime types
postgresql
.
DATE
:
Date
,
postgresql
.
DATE
:
Date
,
postgresql
.
TIME
:
String
,
postgresql
.
TIME
:
String
,
...
@@ -92,23 +111,18 @@ class TypeMapper:
...
@@ -92,23 +111,18 @@ class TypeMapper:
Date
:
Date
,
Date
:
Date
,
Time
:
String
,
Time
:
String
,
DateTime
:
DateTime
,
DateTime
:
DateTime
,
# Boolean
# Boolean
postgresql
.
BOOLEAN
:
Boolean
,
postgresql
.
BOOLEAN
:
Boolean
,
Boolean
:
Boolean
,
Boolean
:
Boolean
,
# Binary
# Binary
postgresql
.
BYTEA
:
LargeBinary
,
postgresql
.
BYTEA
:
LargeBinary
,
LargeBinary
:
LargeBinary
,
LargeBinary
:
LargeBinary
,
# JSON
# JSON
postgresql
.
JSON
:
String
,
postgresql
.
JSON
:
String
,
postgresql
.
JSONB
:
String
,
postgresql
.
JSONB
:
String
,
JSON
:
String
,
JSON
:
String
,
# Array
# Array
postgresql
.
ARRAY
:
String
,
postgresql
.
ARRAY
:
String
,
# PostgreSQL specific
# PostgreSQL specific
postgresql
.
UUID
:
String
,
postgresql
.
UUID
:
String
,
postgresql
.
INET
:
String
,
postgresql
.
INET
:
String
,
...
@@ -121,14 +135,14 @@ class TypeMapper:
...
@@ -121,14 +135,14 @@ class TypeMapper:
}
}
SQL_TO_DATABRICKS_MAPPING
=
{
SQL_TO_DATABRICKS_MAPPING
=
{
'VARCHAR'
:
'STRING'
,
"VARCHAR"
:
"STRING"
,
'INTEGER'
:
'INT'
,
"INTEGER"
:
"INT"
,
'BIGINT'
:
'BIGINT'
,
# Keep as is
"BIGINT"
:
"BIGINT"
,
# Keep as is
'FLOAT'
:
'DOUBLE'
,
"FLOAT"
:
"DOUBLE"
,
'BOOLEAN'
:
'BOOLEAN'
,
# Keep as is
"BOOLEAN"
:
"BOOLEAN"
,
# Keep as is
'TIMESTAMP'
:
'TIMESTAMP'
,
# Keep as is
"TIMESTAMP"
:
"TIMESTAMP"
,
# Keep as is
'DATETIME'
:
'TIMESTAMP'
,
# Change this mapping
"DATETIME"
:
"TIMESTAMP"
,
# Change this mapping
'TEXT'
:
'STRING'
,
"TEXT"
:
"STRING"
,
# Arrays and complex types are already correct, no replacement needed
# Arrays and complex types are already correct, no replacement needed
}
}
...
@@ -147,14 +161,14 @@ class TypeMapper:
...
@@ -147,14 +161,14 @@ class TypeMapper:
base_type
=
type
(
sql_type
)
base_type
=
type
(
sql_type
)
# Handle special cases first
# Handle special cases first
if
base_type
==
postgresql
.
ARRAY
or
'ARRAY'
in
str
(
sql_type
):
if
base_type
==
postgresql
.
ARRAY
or
"ARRAY"
in
str
(
sql_type
):
return
cls
.
_convert_array_type_fallback
(
sql_type
)
return
cls
.
_convert_array_type_fallback
(
sql_type
)
# Get the mapped type
# Get the mapped type
if
base_type
in
cls
.
POSTGRES_TO_DATABRICKS_MAPPING
:
if
base_type
in
cls
.
POSTGRES_TO_DATABRICKS_MAPPING
:
return
cls
.
POSTGRES_TO_DATABRICKS_MAPPING
[
base_type
]()
return
cls
.
POSTGRES_TO_DATABRICKS_MAPPING
[
base_type
]()
logging
.
info
(
f
'Defaulting to String() for type: {type(sql_type)}'
)
logging
.
info
(
f
"Defaulting to String() for type: {type(sql_type)}"
)
return
String
()
return
String
()
@
classmethod
@
classmethod
...
@@ -171,21 +185,21 @@ class TypeMapper:
...
@@ -171,21 +185,21 @@ class TypeMapper:
element_type_name
=
type
(
postgres_element_type
)
.
__name__
.
upper
()
element_type_name
=
type
(
postgres_element_type
)
.
__name__
.
upper
()
# Map PostgreSQL types to Databricks array element types
# Map PostgreSQL types to Databricks array element types
if
any
(
t
in
element_type_name
for
t
in
[
'VARCHAR'
,
'TEXT'
,
'STRING'
,
'CHAR'
]):
if
any
(
t
in
element_type_name
for
t
in
[
"VARCHAR"
,
"TEXT"
,
"STRING"
,
"CHAR"
]):
return
"STRING"
return
"STRING"
elif
any
(
t
in
element_type_name
for
t
in
[
'INTEGER'
,
'BIGINT'
,
'SMALLINT'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"INTEGER"
,
"BIGINT"
,
"SMALLINT"
]):
return
"INT"
if
'SMALLINT'
not
in
element_type_name
else
"SMALLINT"
return
"INT"
if
"SMALLINT"
not
in
element_type_name
else
"SMALLINT"
elif
'BIGINT'
in
element_type_name
:
elif
"BIGINT"
in
element_type_name
:
return
"BIGINT"
return
"BIGINT"
elif
any
(
t
in
element_type_name
for
t
in
[
'BOOLEAN'
,
'BOOL'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"BOOLEAN"
,
"BOOL"
]):
return
"BOOLEAN"
return
"BOOLEAN"
elif
any
(
t
in
element_type_name
for
t
in
[
'FLOAT'
,
'REAL'
,
'DOUBLE'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"FLOAT"
,
"REAL"
,
"DOUBLE"
]):
return
"DOUBLE"
return
"DOUBLE"
elif
any
(
t
in
element_type_name
for
t
in
[
'NUMERIC'
,
'DECIMAL'
]):
elif
any
(
t
in
element_type_name
for
t
in
[
"NUMERIC"
,
"DECIMAL"
]):
return
"DECIMAL"
return
"DECIMAL"
elif
'DATE'
in
element_type_name
:
elif
"DATE"
in
element_type_name
:
return
"DATE"
return
"DATE"
elif
'TIMESTAMP'
in
element_type_name
:
elif
"TIMESTAMP"
in
element_type_name
:
return
"TIMESTAMP"
return
"TIMESTAMP"
else
:
else
:
return
"STRING"
# Default fallback
return
"STRING"
# Default fallback
...
@@ -225,7 +239,11 @@ class TypeMapper:
...
@@ -225,7 +239,11 @@ class TypeMapper:
default_clause
=
""
default_clause
=
""
if
column
.
default
is
not
None
:
if
column
.
default
is
not
None
:
default_value
=
column
.
default
.
arg
if
hasattr
(
column
.
default
,
'arg'
)
else
column
.
default
default_value
=
(
column
.
default
.
arg
if
hasattr
(
column
.
default
,
"arg"
)
else
column
.
default
)
if
isinstance
(
default_value
,
str
):
if
isinstance
(
default_value
,
str
):
default_clause
=
f
" DEFAULT '{default_value}'"
default_clause
=
f
" DEFAULT '{default_value}'"
elif
isinstance
(
default_value
,
bool
):
elif
isinstance
(
default_value
,
bool
):
...
@@ -258,7 +276,7 @@ class ColumnConverter:
...
@@ -258,7 +276,7 @@ class ColumnConverter:
columns_info
=
{}
columns_info
=
{}
# # Check for modern SQLAlchemy with annotations and mapped_column
# # Check for modern SQLAlchemy with annotations and mapped_column
if
hasattr
(
model_class
,
'__annotations__'
):
if
hasattr
(
model_class
,
"__annotations__"
):
columns_info
.
update
(
self
.
extract_from_annotations
(
model_class
))
columns_info
.
update
(
self
.
extract_from_annotations
(
model_class
))
# Fallback: Try mapper approach for traditional models
# Fallback: Try mapper approach for traditional models
...
@@ -267,7 +285,9 @@ class ColumnConverter:
...
@@ -267,7 +285,9 @@ class ColumnConverter:
mapper
=
class_mapper
(
model_class
)
mapper
=
class_mapper
(
model_class
)
for
column_name
,
column
in
mapper
.
columns
.
items
():
for
column_name
,
column
in
mapper
.
columns
.
items
():
if
column_name
not
in
columns_info
:
if
column_name
not
in
columns_info
:
columns_info
[
column_name
]
=
self
.
extract_column_properties
(
column
)
columns_info
[
column_name
]
=
self
.
extract_column_properties
(
column
)
except
Exception
as
e
:
except
Exception
as
e
:
logging
.
error
(
f
"Failed to extract column info using mapper: {e}"
)
logging
.
error
(
f
"Failed to extract column info using mapper: {e}"
)
# Final fallback: inspect class attributes directly
# Final fallback: inspect class attributes directly
...
@@ -279,20 +299,22 @@ class ColumnConverter:
...
@@ -279,20 +299,22 @@ class ColumnConverter:
def
extract_column_properties
(
column
:
Any
)
->
Dict
[
str
,
Any
]:
def
extract_column_properties
(
column
:
Any
)
->
Dict
[
str
,
Any
]:
"""Extract properties from a column object."""
"""Extract properties from a column object."""
return
{
return
{
'type'
:
getattr
(
column
,
'type'
,
None
),
"type"
:
getattr
(
column
,
"type"
,
None
),
'primary_key'
:
getattr
(
column
,
'primary_key'
,
False
),
"primary_key"
:
getattr
(
column
,
"primary_key"
,
False
),
'nullable'
:
getattr
(
column
,
'nullable'
,
True
),
"nullable"
:
getattr
(
column
,
"nullable"
,
True
),
'default'
:
getattr
(
column
,
'default'
,
None
),
"default"
:
getattr
(
column
,
"default"
,
None
),
'server_default'
:
getattr
(
column
,
'server_default'
,
None
),
"server_default"
:
getattr
(
column
,
"server_default"
,
None
),
'uses_mapped_column'
:
False
,
"uses_mapped_column"
:
False
,
}
}
def
_extract_from_class_attributes
(
self
,
model_class
:
type
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
def
_extract_from_class_attributes
(
self
,
model_class
:
type
)
->
Dict
[
str
,
Dict
[
str
,
Any
]]:
"""Extract column info from class attributes."""
"""Extract column info from class attributes."""
columns_info
=
{}
columns_info
=
{}
for
attr_name
in
dir
(
model_class
):
for
attr_name
in
dir
(
model_class
):
if
attr_name
.
startswith
(
'_'
):
if
attr_name
.
startswith
(
"_"
):
continue
continue
attr
=
getattr
(
model_class
,
attr_name
,
None
)
attr
=
getattr
(
model_class
,
attr_name
,
None
)
...
@@ -303,14 +325,14 @@ class ColumnConverter:
...
@@ -303,14 +325,14 @@ class ColumnConverter:
if
isinstance
(
attr
,
Column
):
if
isinstance
(
attr
,
Column
):
columns_info
[
attr_name
]
=
self
.
extract_column_properties
(
attr
)
columns_info
[
attr_name
]
=
self
.
extract_column_properties
(
attr
)
# Check for mapped_column objects
# Check for mapped_column objects
elif
hasattr
(
attr
,
'type'
)
and
hasattr
(
attr
,
'nullable'
):
elif
hasattr
(
attr
,
"type"
)
and
hasattr
(
attr
,
"nullable"
):
columns_info
[
attr_name
]
=
{
columns_info
[
attr_name
]
=
{
'type'
:
getattr
(
attr
,
'type'
,
None
),
"type"
:
getattr
(
attr
,
"type"
,
None
),
'primary_key'
:
getattr
(
attr
,
'primary_key'
,
False
),
"primary_key"
:
getattr
(
attr
,
"primary_key"
,
False
),
'nullable'
:
getattr
(
attr
,
'nullable'
,
True
),
"nullable"
:
getattr
(
attr
,
"nullable"
,
True
),
'default'
:
getattr
(
attr
,
'default'
,
None
),
"default"
:
getattr
(
attr
,
"default"
,
None
),
'server_default'
:
getattr
(
attr
,
'server_default'
,
None
),
"server_default"
:
getattr
(
attr
,
"server_default"
,
None
),
'uses_mapped_column'
:
True
,
"uses_mapped_column"
:
True
,
}
}
return
columns_info
return
columns_info
...
@@ -320,22 +342,22 @@ class ColumnConverter:
...
@@ -320,22 +342,22 @@ class ColumnConverter:
"""Extract column info from type annotations (modern SQLAlchemy)."""
"""Extract column info from type annotations (modern SQLAlchemy)."""
columns_info
=
{}
columns_info
=
{}
annotations
=
getattr
(
model_class
,
'__annotations__'
,
{})
annotations
=
getattr
(
model_class
,
"__annotations__"
,
{})
for
attr_name
,
annotation
in
annotations
.
items
():
for
attr_name
,
annotation
in
annotations
.
items
():
if
hasattr
(
model_class
,
attr_name
):
if
hasattr
(
model_class
,
attr_name
):
attr
=
getattr
(
model_class
,
attr_name
)
attr
=
getattr
(
model_class
,
attr_name
)
# Check if it's a mapped_column
# Check if it's a mapped_column
if
hasattr
(
attr
,
'type'
):
if
hasattr
(
attr
,
"type"
):
columns_info
[
attr_name
]
=
{
columns_info
[
attr_name
]
=
{
'type'
:
attr
.
type
,
"type"
:
attr
.
type
,
'primary_key'
:
getattr
(
attr
,
'primary_key'
,
False
),
"primary_key"
:
getattr
(
attr
,
"primary_key"
,
False
),
'nullable'
:
getattr
(
attr
,
'nullable'
,
True
),
"nullable"
:
getattr
(
attr
,
"nullable"
,
True
),
'default'
:
getattr
(
attr
,
'default'
,
None
),
"default"
:
getattr
(
attr
,
"default"
,
None
),
'server_default'
:
getattr
(
attr
,
'server_default'
,
None
),
"server_default"
:
getattr
(
attr
,
"server_default"
,
None
),
'annotation'
:
annotation
,
"annotation"
:
annotation
,
'uses_mapped_column'
:
True
,
"uses_mapped_column"
:
True
,
}
}
return
columns_info
return
columns_info
...
@@ -351,36 +373,38 @@ class ColumnConverter:
...
@@ -351,36 +373,38 @@ class ColumnConverter:
Tuple of (column_object, annotation_if_any)
Tuple of (column_object, annotation_if_any)
"""
"""
# Convert the column type
# Convert the column type
new_type
=
self
.
type_mapper
.
get_databricks_type
(
column_info
[
'type'
])
new_type
=
self
.
type_mapper
.
get_databricks_type
(
column_info
[
"type"
])
# Check if this uses type annotations (modern approach)
# Check if this uses type annotations (modern approach)
if
column_info
.
get
(
'uses_mapped_column'
,
False
):
if
column_info
.
get
(
"uses_mapped_column"
,
False
):
new_column
=
mapped_column
(
new_column
=
mapped_column
(
new_type
,
new_type
,
primary_key
=
column_info
.
get
(
'primary_key'
,
False
),
primary_key
=
column_info
.
get
(
"primary_key"
,
False
),
nullable
=
column_info
.
get
(
'nullable'
,
True
),
nullable
=
column_info
.
get
(
"nullable"
,
True
),
default
=
column_info
.
get
(
'default'
),
default
=
column_info
.
get
(
"default"
),
server_default
=
column_info
.
get
(
'server_default'
),
server_default
=
column_info
.
get
(
"server_default"
),
)
)
# Convert annotation
# Convert annotation
annotation
=
self
.
convert_annotation
(
column_info
.
get
(
'annotation'
),
new_type
)
annotation
=
self
.
convert_annotation
(
column_info
.
get
(
"annotation"
),
new_type
)
return
new_column
,
annotation
return
new_column
,
annotation
else
:
else
:
# Traditional Column approach
# Traditional Column approach
new_column
=
Column
(
new_column
=
Column
(
new_type
,
new_type
,
primary_key
=
column_info
.
get
(
'primary_key'
,
False
),
primary_key
=
column_info
.
get
(
"primary_key"
,
False
),
nullable
=
column_info
.
get
(
'nullable'
,
True
),
nullable
=
column_info
.
get
(
"nullable"
,
True
),
default
=
column_info
.
get
(
'default'
),
default
=
column_info
.
get
(
"default"
),
server_default
=
column_info
.
get
(
'server_default'
),
server_default
=
column_info
.
get
(
"server_default"
),
)
)
return
new_column
,
None
return
new_column
,
None
@
staticmethod
@
staticmethod
def
convert_annotation
(
annotation
:
Any
,
databricks_type
:
Any
=
None
)
->
Any
:
def
convert_annotation
(
annotation
:
Any
,
databricks_type
:
Any
=
None
)
->
Any
:
"""Convert type annotations for Databricks compatibility."""
"""Convert type annotations for Databricks compatibility."""
from
typing
import
Optional
,
List
from
typing
import
List
,
Optional
if
annotation
is
None
:
if
annotation
is
None
:
return
Mapped
[
Optional
[
str
]]
return
Mapped
[
Optional
[
str
]]
...
@@ -388,15 +412,17 @@ class ColumnConverter:
...
@@ -388,15 +412,17 @@ class ColumnConverter:
annotation_str
=
str
(
annotation
)
annotation_str
=
str
(
annotation
)
# Handle array/list types -> convert to List annotation
# Handle array/list types -> convert to List annotation
if
any
(
keyword
in
annotation_str
for
keyword
in
[
'list'
,
'List'
]):
if
any
(
keyword
in
annotation_str
for
keyword
in
[
"list"
,
"List"
]):
if
'Optional'
in
annotation_str
or
'Union'
in
annotation_str
:
if
"Optional"
in
annotation_str
or
"Union"
in
annotation_str
:
return
Mapped
[
Optional
[
List
[
str
]]]
# Default to List[str]
return
Mapped
[
Optional
[
List
[
str
]]]
# Default to List[str]
else
:
else
:
return
Mapped
[
List
[
str
]]
return
Mapped
[
List
[
str
]]
# Handle JSON/JSONB/dict types -> convert to string
# Handle JSON/JSONB/dict types -> convert to string
if
any
(
keyword
in
annotation_str
for
keyword
in
[
'dict'
,
'Dict'
,
'json'
,
'Json'
]):
if
any
(
if
'Optional'
in
annotation_str
or
'Union'
in
annotation_str
:
keyword
in
annotation_str
for
keyword
in
[
"dict"
,
"Dict"
,
"json"
,
"Json"
]
):
if
"Optional"
in
annotation_str
or
"Union"
in
annotation_str
:
return
Mapped
[
Optional
[
str
]]
return
Mapped
[
Optional
[
str
]]
else
:
else
:
return
Mapped
[
str
]
return
Mapped
[
str
]
...
@@ -405,31 +431,38 @@ class ColumnConverter:
...
@@ -405,31 +431,38 @@ class ColumnConverter:
if
databricks_type
:
if
databricks_type
:
type_str
=
str
(
type
(
databricks_type
)
.
__name__
)
.
lower
()
type_str
=
str
(
type
(
databricks_type
)
.
__name__
)
.
lower
()
if
'array'
in
type_str
:
if
"array"
in
type_str
:
# For array types, use List annotation
# For array types, use List annotation
if
'Optional'
in
annotation_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
List
[
str
]]]
# Could be more specific based on element type
return
Mapped
[
Optional
[
List
[
str
]]
]
# Could be more specific based on element type
return
Mapped
[
List
[
str
]]
return
Mapped
[
List
[
str
]]
elif
'integer'
in
type_str
or
'biginteger'
in
type_str
or
'smallinteger'
in
type_str
:
elif
(
if
'Optional'
in
annotation_str
:
"integer"
in
type_str
or
"biginteger"
in
type_str
or
"smallinteger"
in
type_str
):
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
int
]]
return
Mapped
[
Optional
[
int
]]
return
Mapped
[
int
]
return
Mapped
[
int
]
elif
'boolean'
in
type_str
:
elif
"boolean"
in
type_str
:
if
'Optional'
in
annotation_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
bool
]]
return
Mapped
[
Optional
[
bool
]]
return
Mapped
[
bool
]
return
Mapped
[
bool
]
elif
'float'
in
type_str
or
'numeric'
in
type_str
:
elif
"float"
in
type_str
or
"numeric"
in
type_str
:
if
'Optional'
in
annotation_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
float
]]
return
Mapped
[
Optional
[
float
]]
return
Mapped
[
float
]
return
Mapped
[
float
]
elif
'datetime'
in
type_str
:
elif
"datetime"
in
type_str
:
from
datetime
import
datetime
from
datetime
import
datetime
if
'Optional'
in
annotation_str
:
if
"Optional"
in
annotation_str
:
return
Mapped
[
Optional
[
datetime
]]
return
Mapped
[
Optional
[
datetime
]]
return
Mapped
[
datetime
]
return
Mapped
[
datetime
]
# Default to string
# Default to string
if
'Optional'
in
annotation_str
or
'Union'
in
annotation_str
:
if
"Optional"
in
annotation_str
or
"Union"
in
annotation_str
:
return
Mapped
[
Optional
[
str
]]
return
Mapped
[
Optional
[
str
]]
return
Mapped
[
str
]
return
Mapped
[
str
]
...
@@ -442,8 +475,7 @@ class SchemaProcessor:
...
@@ -442,8 +475,7 @@ class SchemaProcessor:
@
staticmethod
@
staticmethod
def
process_table_args
(
def
process_table_args
(
original_table_args
:
Any
,
original_table_args
:
Any
,
new_schema
:
Optional
[
str
]
=
None
new_schema
:
Optional
[
str
]
=
None
)
->
Union
[
Tuple
,
Dict
,
None
]:
)
->
Union
[
Tuple
,
Dict
,
None
]:
"""
"""
Process table arguments, handling constraints and schema conversion.
Process table arguments, handling constraints and schema conversion.
...
@@ -457,7 +489,7 @@ class SchemaProcessor:
...
@@ -457,7 +489,7 @@ class SchemaProcessor:
"""
"""
if
not
original_table_args
:
if
not
original_table_args
:
if
new_schema
:
if
new_schema
:
return
{
'schema'
:
new_schema
}
return
{
"schema"
:
new_schema
}
return
None
return
None
new_table_args
=
[]
new_table_args
=
[]
...
@@ -468,7 +500,9 @@ class SchemaProcessor:
...
@@ -468,7 +500,9 @@ class SchemaProcessor:
for
arg
in
original_table_args
:
for
arg
in
original_table_args
:
if
isinstance
(
arg
,
dict
):
if
isinstance
(
arg
,
dict
):
# Process dictionary part
# Process dictionary part
processed_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
arg
,
new_schema
)
processed_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
arg
,
new_schema
)
table_kwargs
.
update
(
processed_kwargs
)
table_kwargs
.
update
(
processed_kwargs
)
elif
isinstance
(
arg
,
(
Index
,
ForeignKeyConstraint
)):
elif
isinstance
(
arg
,
(
Index
,
ForeignKeyConstraint
)):
continue
continue
...
@@ -481,11 +515,13 @@ class SchemaProcessor:
...
@@ -481,11 +515,13 @@ class SchemaProcessor:
# Handle dictionary format: {'schema': 'public', 'extend_existing': True}
# Handle dictionary format: {'schema': 'public', 'extend_existing': True}
elif
isinstance
(
original_table_args
,
dict
):
elif
isinstance
(
original_table_args
,
dict
):
table_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
original_table_args
,
new_schema
)
table_kwargs
=
SchemaProcessor
.
_process_table_kwargs
(
original_table_args
,
new_schema
)
# Add new schema if specified and not already set
# Add new schema if specified and not already set
if
new_schema
is
not
None
and
'schema'
not
in
table_kwargs
:
if
new_schema
is
not
None
and
"schema"
not
in
table_kwargs
:
table_kwargs
[
'schema'
]
=
new_schema
table_kwargs
[
"schema"
]
=
new_schema
# Construct result
# Construct result
if
new_table_args
and
table_kwargs
:
if
new_table_args
and
table_kwargs
:
...
@@ -499,16 +535,18 @@ class SchemaProcessor:
...
@@ -499,16 +535,18 @@ class SchemaProcessor:
return
None
return
None
@
staticmethod
@
staticmethod
def
_process_table_kwargs
(
kwargs
:
Dict
[
str
,
Any
],
new_schema
:
Optional
[
str
])
->
Dict
[
str
,
Any
]:
def
_process_table_kwargs
(
kwargs
:
Dict
[
str
,
Any
],
new_schema
:
Optional
[
str
]
)
->
Dict
[
str
,
Any
]:
"""Process table keyword arguments."""
"""Process table keyword arguments."""
processed
=
{}
processed
=
{}
for
key
,
value
in
kwargs
.
items
():
for
key
,
value
in
kwargs
.
items
():
if
key
==
'schema'
:
if
key
==
"schema"
:
# Use new_schema if provided, otherwise keep original unless it's 'public'
# Use new_schema if provided, otherwise keep original unless it's 'public'
if
new_schema
is
not
None
:
if
new_schema
is
not
None
:
processed
[
key
]
=
new_schema
processed
[
key
]
=
new_schema
elif
value
!=
'public'
:
elif
value
!=
"public"
:
processed
[
key
]
=
value
processed
[
key
]
=
value
# Skip 'public' schema (default)
# Skip 'public' schema (default)
else
:
else
:
...
@@ -532,7 +570,8 @@ class ModelConverter:
...
@@ -532,7 +570,8 @@ class ModelConverter:
self
.
type_mapper
=
TypeMapper
()
self
.
type_mapper
=
TypeMapper
()
self
.
column_converter
=
ColumnConverter
(
self
.
type_mapper
)
self
.
column_converter
=
ColumnConverter
(
self
.
type_mapper
)
def
convert_model
(
self
,
def
convert_model
(
self
,
postgres_model_class
:
Type
,
postgres_model_class
:
Type
,
base_class
:
Type
,
base_class
:
Type
,
new_table_name
:
Optional
[
str
]
=
None
,
new_table_name
:
Optional
[
str
]
=
None
,
...
@@ -552,27 +591,28 @@ class ModelConverter:
...
@@ -552,27 +591,28 @@ class ModelConverter:
"""
"""
# Create base class if not provided
# Create base class if not provided
# Get table information
# Get table information
original_table_name
=
getattr
(
postgres_model_class
,
'__tablename__'
,
'unknown_table'
)
original_table_name
=
getattr
(
postgres_model_class
,
"__tablename__"
,
"unknown_table"
)
table_name
=
new_table_name
or
original_table_name
table_name
=
new_table_name
or
original_table_name
table_name
=
f
'{table_name}'
table_name
=
f
"{table_name}"
schema_processor
=
SchemaProcessor
()
schema_processor
=
SchemaProcessor
()
# Create new model attributes
# Create new model attributes
new_attrs
=
{
new_attrs
=
{
'__tablename__'
:
table_name
,
"__tablename__"
:
table_name
,
'__module__'
:
postgres_model_class
.
__module__
,
"__module__"
:
postgres_model_class
.
__module__
,
}
}
# Process table arguments
# Process table arguments
if
hasattr
(
postgres_model_class
,
'__table_args__'
):
if
hasattr
(
postgres_model_class
,
"__table_args__"
):
processed_table_args
=
schema_processor
.
process_table_args
(
processed_table_args
=
schema_processor
.
process_table_args
(
postgres_model_class
.
__table_args__
,
postgres_model_class
.
__table_args__
,
new_schema
new_schema
)
)
if
processed_table_args
:
if
processed_table_args
:
new_attrs
[
'__table_args__'
]
=
processed_table_args
new_attrs
[
"__table_args__"
]
=
processed_table_args
elif
new_schema
:
elif
new_schema
:
new_attrs
[
'__table_args__'
]
=
{
'schema'
:
new_schema
}
new_attrs
[
"__table_args__"
]
=
{
"schema"
:
new_schema
}
# Extract and convert columns
# Extract and convert columns
columns_info
=
self
.
column_converter
.
extract_column_info
(
postgres_model_class
)
columns_info
=
self
.
column_converter
.
extract_column_info
(
postgres_model_class
)
...
@@ -587,7 +627,7 @@ class ModelConverter:
...
@@ -587,7 +627,7 @@ class ModelConverter:
# Add annotations if any
# Add annotations if any
if
annotations
:
if
annotations
:
new_attrs
[
'__annotations__'
]
=
annotations
new_attrs
[
"__annotations__"
]
=
annotations
# Create the new model class
# Create the new model class
new_class_name
=
f
"{postgres_model_class.__name__}Databricks"
new_class_name
=
f
"{postgres_model_class.__name__}Databricks"
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment