Skip to content
Projects
Groups
Snippets
Help
Loading...
Help
Support
Keyboard shortcuts
?
Submit feedback
Contribute to GitLab
Sign in / Register
Toggle navigation
M
model-managament-databricks
Project overview
Project overview
Details
Activity
Releases
Repository
Repository
Files
Commits
Branches
Tags
Contributors
Graph
Compare
Issues
0
Issues
0
List
Boards
Labels
Milestones
Merge Requests
0
Merge Requests
0
CI / CD
CI / CD
Pipelines
Jobs
Schedules
Analytics
CI / CD Analytics
Repository Analytics
Value Stream Analytics
Wiki
Wiki
Snippets
Snippets
Members
Members
Collapse sidebar
Close sidebar
Activity
Graph
Create a new issue
Jobs
Commits
Issue Boards
Open sidebar
harshavardhan.c
model-managament-databricks
Commits
c741f979
Commit
c741f979
authored
Aug 01, 2025
by
harshavardhan.c
Browse files
Options
Browse Files
Download
Email Patches
Plain Diff
feat: changes for including the timeseries ingestion script.
parent
36df546c
Changes
12
Hide whitespace changes
Inline
Side-by-side
Showing
12 changed files
with
307 additions
and
60 deletions
+307
-60
agent_subscribers.py
agent_subscribers.py
+11
-2
scripts/config/__init__.py
scripts/config/__init__.py
+2
-1
scripts/constants/__init__.py
scripts/constants/__init__.py
+4
-0
scripts/constants/notebooks/metadata_ingestion.txt
scripts/constants/notebooks/metadata_ingestion.txt
+3
-7
scripts/constants/notebooks/timeseries_ingestion.txt
scripts/constants/notebooks/timeseries_ingestion.txt
+149
-0
scripts/core/handlers/instance_handler.py
scripts/core/handlers/instance_handler.py
+40
-0
scripts/core/handlers/model_creator_handler.py
scripts/core/handlers/model_creator_handler.py
+62
-38
scripts/db/databricks/__init__.py
scripts/db/databricks/__init__.py
+6
-4
scripts/db/databricks/job_manager.py
scripts/db/databricks/job_manager.py
+1
-1
scripts/engines/agents/model_creator_agent.py
scripts/engines/agents/model_creator_agent.py
+8
-2
scripts/schemas/__init__.py
scripts/schemas/__init__.py
+20
-2
scripts/utils/databricks_utils.py
scripts/utils/databricks_utils.py
+1
-3
No files found.
agent_subscribers.py
View file @
c741f979
...
...
@@ -4,7 +4,7 @@ from faststream.confluent import KafkaBroker
from
scripts.config
import
KafkaConfig
from
scripts.engines.agents.model_creator_agent
import
ModelCreatorAgent
from
scripts.schemas
import
ModelCreatorSchema
from
scripts.schemas
import
ModelCreatorSchema
,
ModelInstanceSchema
broker
=
KafkaBroker
(
f
'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}'
,
client_id
=
"model_creator_agent"
)
...
...
@@ -12,8 +12,17 @@ broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', clien
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_CREATION_TOPIC
,
group_id
=
"databricks_model_creator_agent"
,
max_workers
=
2
)
async
def
consume_stream_for_processing_dependencies
(
message
:
dict
):
try
:
await
ModelCreatorAgent
.
model_
trigge
r_agent
(
message
=
ModelCreatorSchema
(
meta
=
message
))
await
ModelCreatorAgent
.
model_
creato
r_agent
(
message
=
ModelCreatorSchema
(
meta
=
message
))
return
True
except
Exception
as
e
:
logging
.
error
(
f
"Exception occurred while creating model in Databricks: {e}"
)
return
False
@
broker
.
subscriber
(
KafkaConfig
.
KAFKA_MODEL_INSTANCE_TOPIC
,
group_id
=
"databricks_instance_agent"
,
max_workers
=
2
)
async
def
consume_stream_for_processing_instances
(
message
:
dict
):
try
:
await
ModelCreatorAgent
.
model_instance_agent
(
ModelInstanceSchema
(
**
message
))
return
True
except
Exception
as
e
:
logging
.
error
(
f
"Exception occurred while creating model in Databricks: {e}"
)
return
False
\ No newline at end of file
scripts/config/__init__.py
View file @
c741f979
...
...
@@ -28,7 +28,7 @@ class _Services(BaseSettings):
class
_RedisConfig
(
BaseSettings
):
REDIS_URI
:
str
REDIS_PROJECT_TAGS_DB
:
int
=
18
REDIS_DATABRICKS_DB
:
int
=
73
REDIS_DATABRICKS_DB
:
int
=
57
REDIS_GRAPHQL_DB
:
int
=
37
...
...
@@ -70,6 +70,7 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME
:
str
=
Field
(
default
=
"public"
)
DATABRICKS_ANALYTICAL_SCHEMA_NAME
:
str
=
Field
(
default
=
"analytical"
)
DATABRICKS_STORAGE_FORMAT
:
str
=
Field
(
default
=
"PARQUET"
)
DATABRICKS_STORAGE_PATH
:
str
=
Field
(
default
=
"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@
model_validator
(
mode
=
"before"
)
def
prepare_databricks_uri
(
cls
,
values
):
...
...
scripts/constants/__init__.py
View file @
c741f979
class
DatabricksConstants
:
METADATA_INGESTION_JOB_NAME
=
"metadata_ingestion_job"
METADATA_INGESTION_NOTEBOOK_NAME
=
"metadata_ingestion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME
=
"timeseries_ingestion_notebook"
\ No newline at end of file
scripts/constants/notebooks/metadata_ingestion.txt
View file @
c741f979
...
...
@@ -73,10 +73,10 @@ def detect_external_table_schema(table_name):
# Try to get schema from catalog
table_df = spark.table(table_name)
schema = table_df.schema
print(f"
✓
Schema found in metastore for table: {table_name}")
print(f"Schema found in metastore for table: {table_name}")
return schema
except Exception as e:
print(f"
X
Failed to get schema from metastore for table: {table_name}")
print(f"Failed to get schema from metastore for table: {table_name}")
return None
# COMMAND ----------
...
...
@@ -96,8 +96,4 @@ data_df = spark.createDataFrame(table_info['data_payload'], schema=schema)
# COMMAND ----------
data_df.write \
.mode("append") \
.format("parquet") \
.save(table_info['table_properties']['table_path'])
\ No newline at end of file
data_df.write.mode("append").saveAsTable(table_info['table_name'])
\ No newline at end of file
scripts/constants/notebooks/timeseries_ingestion.txt
0 → 100644
View file @
c741f979
# Databricks notebook source
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import json
spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# COMMAND ----------
# Input Parameters
event_hub_connection_string = {{event_hub_connection_string}}
timeseries_table_path = {{timeseries_table_path}}
project_levels = {{project_levels}}
# COMMAND ----------
event_hub_conf = {
'eventhubs.connectionString': spark._jvm.org.apache.spark.eventhubs.EventHubsUtils.encrypt(event_hub_connection_string),
'eventhubs.consumerGroup': '$Default'
}
# COMMAND ----------
message_schema = StructType([
StructField("data", StructType([
StructField("tag", StringType(), False),
StructField("dq", IntegerType(), True),
StructField("ta", StringType(), True),
StructField("val", DoubleType(), False)
]), True),
StructField("a_id", StringType(), True),
StructField("d_id", StringType(), True),
StructField("gw_id", StringType(), True),
StructField("msg_id", IntegerType(), True),
StructField("p_id", StringType(), True),
StructField("pd_id", StringType(), True),
StructField("retain_flag", BooleanType(), True),
StructField("site_id", StringType(), True),
StructField("timestamp", LongType(), False),
StructField("ver", DoubleType(), True)
])
# COMMAND ----------
def safe_get_item(array_col, index):
return when(size(array_col) > index, array_col.getItem(index)).otherwise(lit(None))
def transform_timeseries_data_fully_dynamic(df, max_tag_parts=4):
print(f"Transforming to target schema with up to {max_tag_parts} tag parts...")
df_with_split = df.withColumn("tag_parts", split(col("data.tag"), "\\$"))
df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts")))
df_with_split = df_with_split.withColumn("hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1))
df_with_split = df_with_split.withColumn("levels_without_ast", expr("filter(hierarchy_levels, x -> NOT x LIKE '%ast%')"))
df_with_split = df_with_split.withColumn("ast", expr("filter(hierarchy_levels, x -> x LIKE '%ast%')[0]"))
value_type_logic = when(
col("data.val").cast("float").isNotNull() & ~isnan(col("data.val").cast("float")),
lit("float")
).otherwise(lit("string"))
select_columns = [
col("timestamp").alias("timestamp"),
from_unixtime(col("timestamp") / 1000).cast("timestamp").alias("dt_timestamp"),
to_date(from_unixtime(col("timestamp") / 1000)).alias("dt_date"),
hour(from_unixtime(col("timestamp") / 1000)).alias("dt_hour"),
col("data.val").cast("string").alias("value"),
value_type_logic.alias("value_type"),
col("data.tag").alias("c3"),
safe_get_item(col("tag_parts"), 0).alias("c1"),
when(col("tag_parts_count") > 0, col("tag_parts").getItem(col("tag_parts_count") - 1)).otherwise(lit(None)).alias("c5"),
col("data.dq").cast("string").alias("Q"),
col("data.ta").alias("T"),
col("d_id").alias("D"),
col("p_id").alias("P"),
col("a_id").alias("A"),
lit(None).cast("string").alias("B")
]
select_columns += [
safe_get_item(col("levels_without_ast"), i).alias(f"l{i+1}")
for i in range(max_tag_parts)
] + [col("ast").alias("ast")]
return df_with_split.select(*select_columns)
# COMMAND ----------
raw_stream_df = spark.readStream \
.format("eventhubs") \
.options(**event_hub_conf) \
.load()
#Binary -> String
json_df = raw_stream_df.withColumn("json_string", col("body").cast("string"))
#JSON -> Struct
parsed_stream_df = json_df.select(
from_json(col("json_string"), message_schema).alias("parsed_data")
).select("parsed_data.*")
#Explode
df_exploded = parsed_stream_df.select(
explode(col("data")).alias("tag", "value"),
col("a_id"),
col("d_id"),
col("gw_id"),
col("msg_id"),
col("p_id"),
col("pd_id"),
col("retain_flag"),
col("site_id"),
col("timestamp"),
col("ver")
)
# COMMAND ----------
transformed_df = transform_timeseries_data_fully_dynamic(df_exploded, max_tag_parts=projects_levels)
# COMMAND ----------
# Option A: Write to Delta
transformed_df.writeStream \
.format("delta") \
.outputMode("append") \
.partitionBy("dt_date", "dt_hour", "c3") \
.option("checkpointLocation", "/mnt/checkpoints/timeseries_data") \
.start(timeseries_data_path)
# COMMAND ----------
# # Option B: Write to Parquet (same as your batch)
# transformed_df.writeStream \
# .format("parquet") \
# .outputMode("append") \
# .partitionBy("dt_date", "dt_hour", "c3") \
# .option("checkpointLocation", "/mnt/checkpoints/timeseries_data") \
# .start(timeseries_table_path)
# COMMAND ----------
scripts/core/handlers/instance_handler.py
0 → 100644
View file @
c741f979
import
json
from
ut_dev_utils
import
get_db_name
from
scripts.config
import
DatabricksConfig
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.redis.databricks_details
import
databricks_details_db
from
scripts.schemas
import
ModelInstanceSchema
class
ModelInstanceHandler
:
def
__init__
(
self
,
project_id
:
str
,
payload
:
ModelInstanceSchema
):
self
.
project_id
=
project_id
self
.
payload
=
payload
self
.
catalog_name
=
get_db_name
(
project_id
=
project_id
,
database
=
DatabricksConfig
.
DATABRICKS_CATALOG_NAME
)
self
.
job_manager
=
DatabricksJobManager
(
databricks_host
=
payload
.
databricks_host
,
access_token
=
payload
.
databricks_access_token
)
def
upload_instances_to_unity_catalog
(
self
):
job_id
=
databricks_details_db
.
hget
(
self
.
project_id
,
"metadata_ingestion_job"
)
if
not
job_id
:
raise
ValueError
(
"No job id found for metadata ingestion job, skipping upload to unity catalog"
)
run_id
=
self
.
job_manager
.
run_job
(
job_id
=
job_id
,
parameters
=
{
"input_message"
:
json
.
dumps
(
self
.
get_job_trigger_payload
())})
if
not
run_id
:
raise
ValueError
(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def
get_job_trigger_payload
(
self
):
table_name
=
self
.
payload
.
data
[
0
][
'type'
]
schema_table
=
f
"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return
{
"table_properties"
:
{
"table_name"
:
f
'{self.catalog_name}.{schema_table}'
,
"table_path"
:
f
'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}'
,
},
"project_id"
:
self
.
project_id
,
"data"
:
self
.
payload
.
data
}
scripts/core/handlers/model_creator_handler.py
View file @
c741f979
import
json
import
logging
from
sqlalchemy
import
MetaData
...
...
@@ -6,6 +5,7 @@ from sqlalchemy.orm import declarative_base
from
ut_sql_utils.asyncio.declarative_utils
import
DeclarativeUtils
from
scripts.config
import
DatabricksConfig
from
scripts.constants
import
DatabricksConstants
from
scripts.db.databricks
import
DataBricksSQLLayer
from
scripts.db.databricks.job_manager
import
DatabricksJobManager
from
scripts.db.databricks.notebook_manager
import
NotebookManager
...
...
@@ -35,7 +35,7 @@ class ModelCreatorHandler:
schema
=
message
.
schema
)
self
.
external_location
=
f
"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
self
.
external_location
=
self
.
message
.
databricks_storage_path
@
staticmethod
def
create_schema_base
(
schema_name
:
str
):
...
...
@@ -46,33 +46,34 @@ class ModelCreatorHandler:
async
def
create_models_in_unity_catalog
(
self
):
overall_tables
=
self
.
get_overall_tables
()
project_levels
=
project_template_keys
(
self
.
meta
.
project_id
,
levels
=
True
)
schema
=
f
'{self.databricks_sql_obj.catalog_name}.{self.message.schema}'
base
=
self
.
create_schema_base
(
schema
)
base
=
self
.
create_schema_base
(
schema
_name
=
f
'{self.databricks_sql_obj.catalog_name}.{self.message.schema}'
)
try
:
self
.
databricks_sql_obj
.
connect_to_databricks
()
#
_ = self.setup_dependencies_for_unity_catalog()
#
self.databricks_sql_obj.connect_to_databricks()
_
=
self
.
setup_dependencies_for_unity_catalog
()
table_properties
=
self
.
fetch_table_properties
()
for
table
in
overall_tables
:
table_class
=
self
.
declarative_utils
.
get_declarative_class
(
table
)
if
not
table_class
:
logging
.
error
(
f
"Table class not found for table: {table}"
)
return
False
new_model
=
self
.
model_convertor
.
convert_model
(
table_class
,
base_class
=
base
,
new_schema
=
self
.
message
.
schema
,
)
self
.
databricks_sql_obj
.
create_external_table_from_structure
(
table
=
new_model
.
__table__
,
file_format
=
"DELTA"
,
external_location
=
self
.
external_location
,
table_properties
=
table_properties
)
self
.
databricks_sql_obj
.
create_timeseries_table
(
columns
=
project_levels
,
external_location
=
self
.
external_location
)
self
.
setup_notepads_and_jobs
()
# for table in overall_tables:
# table_class = self.declarative_utils.get_declarative_class(table)
# if not table_class:
# logging.error(f"Table class not found for table: {table}")
# return False
# new_model = self.model_convertor.convert_model(
# table_class,
# base_class=base,
# new_schema=self.message.schema,
# )
#
# self.databricks_sql_obj.create_external_table_from_structure(
# table=new_model.__table__,
# file_format="DELTA",
# external_location=self.external_location,
# table_properties=table_properties
# )
ts_external_table
=
self
.
databricks_sql_obj
.
create_timeseries_table
(
columns
=
project_levels
,
external_location
=
self
.
external_location
)
self
.
setup_notepads_and_jobs
(
timeseries_table_path
=
ts_external_table
,
project_levels
=
project_levels
)
return
True
except
Exception
as
e
:
logging
.
error
(
f
"Error occurred while creating models in Unity Catalog: {e}"
)
return
False
...
...
@@ -93,7 +94,8 @@ class ModelCreatorHandler:
Args:
analytical (bool): Flag to indicate if the setup is for analytical or not
"""
logging
.
info
(
f
"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}'"
)
logging
.
info
(
f
"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'"
)
self
.
databricks_sql_obj
.
connect_to_databricks
()
# Create catalog
catalog_success
=
self
.
databricks_sql_obj
.
create_catalog
(
...
...
@@ -111,23 +113,45 @@ class ModelCreatorHandler:
return
False
return
True
def
setup_notepads_and_jobs
(
self
):
def
setup_notepads_and_jobs
(
self
,
timeseries_table_path
:
str
,
project_levels
:
dict
):
"""
Args:
timeseries_table_path: Path for the timeseries table
project_levels: List of project levels
"""
logging
.
info
(
"Setting up notepads and jobs"
)
with
open
(
r"scripts/constants/notebooks/metadata_ingestion.txt"
,
"r"
)
as
f
:
notebook_code
=
f
.
read
()
# # Notebook for metadata ingestion
# self.notebook_manager.create_notebook(
# notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}",
# notebook_code=notebook_code,
# overwrite=True
# )
# # Job for metadata ingestion used by model management
# job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
# job_name=f'{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}',
# notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
# ))
#
# databricks_details_db.hset(self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME, job_id)
# Timeseries DataPush Notebook
with
open
(
r"scripts/constants/notebooks/timeseries_ingestion.txt"
,
"r"
)
as
f
:
notebook_code_for_timeseries
=
f
.
read
()
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{timeseries_table_path}}"
,
f
'"{timeseries_table_path}"'
)
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{project_levels}}"
,
str
(
len
(
project_levels
)
-
1
))
notebook_code_for_timeseries
=
notebook_code_for_timeseries
.
replace
(
"{{event_hub_connection_string}}"
,
f
'"{self.meta.project_id}"'
)
self
.
notebook_manager
.
create_notebook
(
notebook_path
=
f
"/Users/{self.message.databricks_user_email}/
metadata_ingestion_notebook
"
,
notebook_code
=
notebook_code
,
notebook_path
=
f
"/Users/{self.message.databricks_user_email}/
{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}
"
,
notebook_code
=
notebook_code
_for_timeseries
,
overwrite
=
True
)
#Job for metadata ingestion used by model management
job_id
=
self
.
job_manager
.
create_job
(
job_config
=
self
.
job_manager
.
create_job_config_for_serverless
(
job_name
=
"metadata_ingestion_job"
,
notebook_path
=
f
"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook"
,
))
redis_dict
=
{
"metadata_ingestion_job"
:
job_id
}
databricks_details_db
.
hset
(
self
.
meta
.
project_id
,
json
.
dumps
(
redis_dict
))
@
staticmethod
def
fetch_table_properties
(
file_format
:
str
=
'DELTA'
):
...
...
scripts/db/databricks/__init__.py
View file @
c741f979
...
...
@@ -9,14 +9,13 @@ from scripts.utils.model_convertor_utils import TypeMapper
class
DataBricksSQLLayer
(
DatabricksSQLUtility
):
def
__init__
(
self
,
catalog_name
:
str
,
project_id
:
str
,
schema
:
str
):
super
()
.
__init__
(
catalog_name
,
project_id
)
self
.
catalog_name
=
catalog_name
self
.
schema
=
schema
def
create_external_table_from_structure
(
self
,
table
:
Table
,
external_location
:
str
,
file_format
:
str
=
"PARQUET"
,
table_properties
:
Dict
[
str
,
str
]
=
None
,
partition_columns
:
list
=
None
):
partition_columns
:
list
=
None
)
->
str
:
"""
Create an external table from a model class.
...
...
@@ -26,6 +25,9 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
file_format: The file format of the data files.
table_properties: Additional table properties.
partition_columns: List of columns to partition the table by.
Returns:
External Location - Returns the external location
"""
schema_table
=
f
"{table.schema}.{table.name}"
if
table
.
schema
else
table
.
name
columns_sql
=
TypeMapper
()
.
extract_columns_without_constraints
(
table
)
...
...
@@ -47,7 +49,7 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
create_sql
=
"
\n
"
.
join
(
sql_parts
)
self
.
execute_sql_statement
(
create_sql
)
return
True
return
external_location
def
create_timeseries_table
(
self
,
columns
:
List
[
str
],
external_location
:
str
):
"""
...
...
@@ -91,4 +93,4 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
partition_columns
=
partition_columns
,
table_properties
=
table_properties
)
return
table_obj
return
external_location
scripts/db/databricks/job_manager.py
View file @
c741f979
...
...
@@ -54,7 +54,7 @@ class DatabricksJobManager:
if
parameters
:
payload
[
"notebook_params"
]
=
parameters
response
=
HTTPXRequest
Handler
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
payload
)
response
=
HTTPXRequest
Util
(
url
)
.
post
(
headers
=
self
.
headers
,
json
=
payload
)
if
response
.
status_code
==
200
:
run_id
=
response
.
json
()[
'run_id'
]
...
...
scripts/engines/agents/model_creator_agent.py
View file @
c741f979
from
ut_sql_utils.asyncio.declarative_utils
import
DeclarativeUtilsFactory
from
scripts.core.handlers.instance_handler
import
ModelInstanceHandler
from
scripts.core.handlers.model_creator_handler
import
ModelCreatorHandler
from
scripts.db.psql
import
session_manager
from
scripts.schemas
import
ModelCreatorSchema
from
scripts.schemas
import
ModelCreatorSchema
,
ModelInstanceSchema
class
ModelCreatorAgent
:
...
...
@@ -10,7 +11,7 @@ class ModelCreatorAgent:
...
@
staticmethod
async
def
model_
trigge
r_agent
(
message
:
ModelCreatorSchema
):
async
def
model_
creato
r_agent
(
message
:
ModelCreatorSchema
):
declarative_utils
=
await
DeclarativeUtilsFactory
.
get_declarative_utils
(
raw_database
=
"unified_model"
,
project_id
=
message
.
meta
.
project_id
,
...
...
@@ -19,3 +20,8 @@ class ModelCreatorAgent:
)
model_cal_obj
=
ModelCreatorHandler
(
message
=
message
,
declarative_utils
=
declarative_utils
)
await
model_cal_obj
.
create_models_in_unity_catalog
()
@
staticmethod
async
def
model_instance_agent
(
message
:
ModelInstanceSchema
):
model_instance_obj
=
ModelInstanceHandler
(
project_id
=
message
.
project_id
,
payload
=
message
)
await
model_instance_obj
.
upload_instances_to_unity_catalog
()
\ No newline at end of file
scripts/schemas/__init__.py
View file @
c741f979
from
typing
import
Optional
from
typing
import
Optional
,
Union
,
Dict
,
Any
,
List
from
pydantic
import
BaseModel
from
pydantic
import
BaseModel
,
model_validator
from
ut_security_util
import
MetaInfoSchema
from
scripts.config
import
DatabricksConfig
...
...
@@ -14,4 +14,22 @@ class ModelCreatorSchema(BaseModel):
databricks_access_token
:
str
=
DatabricksConfig
.
DATABRICKS_ACCESS_TOKEN
databricks_http_path
:
str
=
DatabricksConfig
.
DATABRICKS_HTTP_PATH
databricks_user_email
:
str
=
"aniket.dhale@ilenscloud.onmicrosoft.com"
databricks_storage_path
:
str
=
DatabricksConfig
.
DATABRICKS_STORAGE_PATH
class
ModelInstanceSchema
(
BaseModel
):
data
:
Union
[
Dict
[
str
,
Any
],
List
[
Dict
[
str
,
Any
]]]
project_id
:
str
schema
:
Optional
[
str
]
=
DatabricksConfig
.
DATABRICKS_PUBLIC_SCHEMA_NAME
databricks_host
:
str
=
DatabricksConfig
.
DATABRICKS_HOST
databricks_port
:
int
=
DatabricksConfig
.
DATABRICKS_PORT
databricks_access_token
:
str
=
DatabricksConfig
.
DATABRICKS_ACCESS_TOKEN
databricks_http_path
:
str
=
DatabricksConfig
.
DATABRICKS_HTTP_PATH
databricks_user_email
:
str
=
"aniket.dhale@ilenscloud.onmicrosoft.com"
databricks_storage_path
:
str
=
DatabricksConfig
.
DATABRICKS_STORAGE_PATH
@
model_validator
(
mode
=
"before"
)
def
validate_data
(
cls
,
values
:
Dict
[
str
,
Any
])
->
Dict
[
str
,
Any
]:
if
'data'
in
values
and
isinstance
(
values
[
'data'
],
dict
):
values
[
'data'
]
=
[
values
[
'data'
]]
return
values
scripts/utils/databricks_utils.py
View file @
c741f979
...
...
@@ -15,9 +15,7 @@ class DatabricksSQLUtility:
catalog_name: Name of the catalog to create
project_id: Project ID
"""
# self.catalog_name = get_db_name(project_id=project_id, database=catalog_name)
self
.
catalog_name
=
catalog_name
# self.catalog_name = catalog_name
self
.
catalog_name
=
get_db_name
(
project_id
=
project_id
,
database
=
catalog_name
)
self
.
engine
=
None
def
connect_to_databricks
(
self
):
...
...
Write
Preview
Markdown
is supported
0%
Try again
or
attach a new file
Attach a file
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Cancel
Please
register
or
sign in
to comment