Commit c741f979 authored by harshavardhan.c's avatar harshavardhan.c

feat: changes for including the timeseries ingestion script.

parent 36df546c
...@@ -4,7 +4,7 @@ from faststream.confluent import KafkaBroker ...@@ -4,7 +4,7 @@ from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig from scripts.config import KafkaConfig
from scripts.engines.agents.model_creator_agent import ModelCreatorAgent from scripts.engines.agents.model_creator_agent import ModelCreatorAgent
from scripts.schemas import ModelCreatorSchema from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent") broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
...@@ -12,8 +12,17 @@ broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', clien ...@@ -12,8 +12,17 @@ broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', clien
@broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2) @broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2)
async def consume_stream_for_processing_dependencies(message: dict): async def consume_stream_for_processing_dependencies(message: dict):
try: try:
await ModelCreatorAgent.model_trigger_agent(message=ModelCreatorSchema(meta=message)) await ModelCreatorAgent.model_creator_agent(message=ModelCreatorSchema(meta=message))
return True return True
except Exception as e: except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}") logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False return False
@broker.subscriber(KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC, group_id="databricks_instance_agent", max_workers=2)
async def consume_stream_for_processing_instances(message: dict):
try:
await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message))
return True
except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False
\ No newline at end of file
...@@ -28,7 +28,7 @@ class _Services(BaseSettings): ...@@ -28,7 +28,7 @@ class _Services(BaseSettings):
class _RedisConfig(BaseSettings): class _RedisConfig(BaseSettings):
REDIS_URI: str REDIS_URI: str
REDIS_PROJECT_TAGS_DB: int = 18 REDIS_PROJECT_TAGS_DB: int = 18
REDIS_DATABRICKS_DB: int = 73 REDIS_DATABRICKS_DB: int = 57
REDIS_GRAPHQL_DB: int = 37 REDIS_GRAPHQL_DB: int = 37
...@@ -70,6 +70,7 @@ class _DatabricksConfig(BaseSettings): ...@@ -70,6 +70,7 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public") DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public")
DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical") DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical")
DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET") DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET")
DATABRICKS_STORAGE_PATH: str = Field(default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087")
@model_validator(mode="before") @model_validator(mode="before")
def prepare_databricks_uri(cls, values): def prepare_databricks_uri(cls, values):
......
class DatabricksConstants:
METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job"
METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook"
\ No newline at end of file
...@@ -73,10 +73,10 @@ def detect_external_table_schema(table_name): ...@@ -73,10 +73,10 @@ def detect_external_table_schema(table_name):
# Try to get schema from catalog # Try to get schema from catalog
table_df = spark.table(table_name) table_df = spark.table(table_name)
schema = table_df.schema schema = table_df.schema
print(f"Schema found in metastore for table: {table_name}") print(f"Schema found in metastore for table: {table_name}")
return schema return schema
except Exception as e: except Exception as e:
print(f"X Failed to get schema from metastore for table: {table_name}") print(f"Failed to get schema from metastore for table: {table_name}")
return None return None
# COMMAND ---------- # COMMAND ----------
...@@ -96,8 +96,4 @@ data_df = spark.createDataFrame(table_info['data_payload'], schema=schema) ...@@ -96,8 +96,4 @@ data_df = spark.createDataFrame(table_info['data_payload'], schema=schema)
# COMMAND ---------- # COMMAND ----------
data_df.write.mode("append").saveAsTable(table_info['table_name'])
data_df.write \ \ No newline at end of file
.mode("append") \
.format("parquet") \
.save(table_info['table_properties']['table_path'])
\ No newline at end of file
# Databricks notebook source
from pyspark.sql import SparkSession
from pyspark.sql.functions import *
from pyspark.sql.types import *
import json
spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# COMMAND ----------
# Input Parameters
event_hub_connection_string = {{event_hub_connection_string}}
timeseries_table_path = {{timeseries_table_path}}
project_levels = {{project_levels}}
# COMMAND ----------
event_hub_conf = {
'eventhubs.connectionString': spark._jvm.org.apache.spark.eventhubs.EventHubsUtils.encrypt(event_hub_connection_string),
'eventhubs.consumerGroup': '$Default'
}
# COMMAND ----------
message_schema = StructType([
StructField("data", StructType([
StructField("tag", StringType(), False),
StructField("dq", IntegerType(), True),
StructField("ta", StringType(), True),
StructField("val", DoubleType(), False)
]), True),
StructField("a_id", StringType(), True),
StructField("d_id", StringType(), True),
StructField("gw_id", StringType(), True),
StructField("msg_id", IntegerType(), True),
StructField("p_id", StringType(), True),
StructField("pd_id", StringType(), True),
StructField("retain_flag", BooleanType(), True),
StructField("site_id", StringType(), True),
StructField("timestamp", LongType(), False),
StructField("ver", DoubleType(), True)
])
# COMMAND ----------
def safe_get_item(array_col, index):
return when(size(array_col) > index, array_col.getItem(index)).otherwise(lit(None))
def transform_timeseries_data_fully_dynamic(df, max_tag_parts=4):
print(f"Transforming to target schema with up to {max_tag_parts} tag parts...")
df_with_split = df.withColumn("tag_parts", split(col("data.tag"), "\\$"))
df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts")))
df_with_split = df_with_split.withColumn("hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1))
df_with_split = df_with_split.withColumn("levels_without_ast", expr("filter(hierarchy_levels, x -> NOT x LIKE '%ast%')"))
df_with_split = df_with_split.withColumn("ast", expr("filter(hierarchy_levels, x -> x LIKE '%ast%')[0]"))
value_type_logic = when(
col("data.val").cast("float").isNotNull() & ~isnan(col("data.val").cast("float")),
lit("float")
).otherwise(lit("string"))
select_columns = [
col("timestamp").alias("timestamp"),
from_unixtime(col("timestamp") / 1000).cast("timestamp").alias("dt_timestamp"),
to_date(from_unixtime(col("timestamp") / 1000)).alias("dt_date"),
hour(from_unixtime(col("timestamp") / 1000)).alias("dt_hour"),
col("data.val").cast("string").alias("value"),
value_type_logic.alias("value_type"),
col("data.tag").alias("c3"),
safe_get_item(col("tag_parts"), 0).alias("c1"),
when(col("tag_parts_count") > 0, col("tag_parts").getItem(col("tag_parts_count") - 1)).otherwise(lit(None)).alias("c5"),
col("data.dq").cast("string").alias("Q"),
col("data.ta").alias("T"),
col("d_id").alias("D"),
col("p_id").alias("P"),
col("a_id").alias("A"),
lit(None).cast("string").alias("B")
]
select_columns += [
safe_get_item(col("levels_without_ast"), i).alias(f"l{i+1}")
for i in range(max_tag_parts)
] + [col("ast").alias("ast")]
return df_with_split.select(*select_columns)
# COMMAND ----------
raw_stream_df = spark.readStream \
.format("eventhubs") \
.options(**event_hub_conf) \
.load()
#Binary -> String
json_df = raw_stream_df.withColumn("json_string", col("body").cast("string"))
#JSON -> Struct
parsed_stream_df = json_df.select(
from_json(col("json_string"), message_schema).alias("parsed_data")
).select("parsed_data.*")
#Explode
df_exploded = parsed_stream_df.select(
explode(col("data")).alias("tag", "value"),
col("a_id"),
col("d_id"),
col("gw_id"),
col("msg_id"),
col("p_id"),
col("pd_id"),
col("retain_flag"),
col("site_id"),
col("timestamp"),
col("ver")
)
# COMMAND ----------
transformed_df = transform_timeseries_data_fully_dynamic(df_exploded, max_tag_parts=projects_levels)
# COMMAND ----------
# Option A: Write to Delta
transformed_df.writeStream \
.format("delta") \
.outputMode("append") \
.partitionBy("dt_date", "dt_hour", "c3") \
.option("checkpointLocation", "/mnt/checkpoints/timeseries_data") \
.start(timeseries_data_path)
# COMMAND ----------
# # Option B: Write to Parquet (same as your batch)
# transformed_df.writeStream \
# .format("parquet") \
# .outputMode("append") \
# .partitionBy("dt_date", "dt_hour", "c3") \
# .option("checkpointLocation", "/mnt/checkpoints/timeseries_data") \
# .start(timeseries_table_path)
# COMMAND ----------
import json
from ut_dev_utils import get_db_name
from scripts.config import DatabricksConfig
from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.redis.databricks_details import databricks_details_db
from scripts.schemas import ModelInstanceSchema
class ModelInstanceHandler:
def __init__(self, project_id: str, payload: ModelInstanceSchema):
self.project_id = project_id
self.payload = payload
self.catalog_name = get_db_name(project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME)
self.job_manager = DatabricksJobManager(
databricks_host=payload.databricks_host,
access_token=payload.databricks_access_token
)
def upload_instances_to_unity_catalog(self):
job_id = databricks_details_db.hget(self.project_id, "metadata_ingestion_job")
if not job_id:
raise ValueError("No job id found for metadata ingestion job, skipping upload to unity catalog")
run_id = self.job_manager.run_job(job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())})
if not run_id:
raise ValueError("Failed to run metadata ingestion job, skipping upload to unity catalog")
def get_job_trigger_payload(self):
table_name = self.payload.data[0]['type']
schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return {
"table_properties": {
"table_name": f'{self.catalog_name}.{schema_table}',
"table_path": f'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}',
},
"project_id": self.project_id,
"data": self.payload.data
}
import json
import logging import logging
from sqlalchemy import MetaData from sqlalchemy import MetaData
...@@ -6,6 +5,7 @@ from sqlalchemy.orm import declarative_base ...@@ -6,6 +5,7 @@ from sqlalchemy.orm import declarative_base
from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtils from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtils
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants
from scripts.db.databricks import DataBricksSQLLayer from scripts.db.databricks import DataBricksSQLLayer
from scripts.db.databricks.job_manager import DatabricksJobManager from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.databricks.notebook_manager import NotebookManager from scripts.db.databricks.notebook_manager import NotebookManager
...@@ -35,7 +35,7 @@ class ModelCreatorHandler: ...@@ -35,7 +35,7 @@ class ModelCreatorHandler:
schema=message.schema schema=message.schema
) )
self.external_location = f"abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087" self.external_location = self.message.databricks_storage_path
@staticmethod @staticmethod
def create_schema_base(schema_name: str): def create_schema_base(schema_name: str):
...@@ -46,33 +46,34 @@ class ModelCreatorHandler: ...@@ -46,33 +46,34 @@ class ModelCreatorHandler:
async def create_models_in_unity_catalog(self): async def create_models_in_unity_catalog(self):
overall_tables = self.get_overall_tables() overall_tables = self.get_overall_tables()
project_levels = project_template_keys(self.meta.project_id, levels=True) project_levels = project_template_keys(self.meta.project_id, levels=True)
schema = f'{self.databricks_sql_obj.catalog_name}.{self.message.schema}'
base = self.create_schema_base(schema) base = self.create_schema_base(schema_name=f'{self.databricks_sql_obj.catalog_name}.{self.message.schema}')
try: try:
self.databricks_sql_obj.connect_to_databricks() # self.databricks_sql_obj.connect_to_databricks()
# _ = self.setup_dependencies_for_unity_catalog() _ = self.setup_dependencies_for_unity_catalog()
table_properties = self.fetch_table_properties() table_properties = self.fetch_table_properties()
for table in overall_tables: # for table in overall_tables:
table_class = self.declarative_utils.get_declarative_class(table) # table_class = self.declarative_utils.get_declarative_class(table)
if not table_class: # if not table_class:
logging.error(f"Table class not found for table: {table}") # logging.error(f"Table class not found for table: {table}")
return False # return False
new_model = self.model_convertor.convert_model( # new_model = self.model_convertor.convert_model(
table_class, # table_class,
base_class=base, # base_class=base,
new_schema=self.message.schema, # new_schema=self.message.schema,
) # )
#
self.databricks_sql_obj.create_external_table_from_structure( # self.databricks_sql_obj.create_external_table_from_structure(
table=new_model.__table__, # table=new_model.__table__,
file_format="DELTA", # file_format="DELTA",
external_location=self.external_location, # external_location=self.external_location,
table_properties=table_properties # table_properties=table_properties
) # )
self.databricks_sql_obj.create_timeseries_table(columns=project_levels, ts_external_table = self.databricks_sql_obj.create_timeseries_table(columns=project_levels,
external_location=self.external_location) external_location=self.external_location)
self.setup_notepads_and_jobs() self.setup_notepads_and_jobs(timeseries_table_path=ts_external_table, project_levels=project_levels)
return True
except Exception as e: except Exception as e:
logging.error(f"Error occurred while creating models in Unity Catalog: {e}") logging.error(f"Error occurred while creating models in Unity Catalog: {e}")
return False return False
...@@ -93,7 +94,8 @@ class ModelCreatorHandler: ...@@ -93,7 +94,8 @@ class ModelCreatorHandler:
Args: Args:
analytical (bool): Flag to indicate if the setup is for analytical or not analytical (bool): Flag to indicate if the setup is for analytical or not
""" """
logging.info(f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}'") logging.info(
f"Setting up catalog '{DatabricksConfig.DATABRICKS_CATALOG_NAME}' for project '{self.meta.project_id}'")
self.databricks_sql_obj.connect_to_databricks() self.databricks_sql_obj.connect_to_databricks()
# Create catalog # Create catalog
catalog_success = self.databricks_sql_obj.create_catalog( catalog_success = self.databricks_sql_obj.create_catalog(
...@@ -111,23 +113,45 @@ class ModelCreatorHandler: ...@@ -111,23 +113,45 @@ class ModelCreatorHandler:
return False return False
return True return True
def setup_notepads_and_jobs(self): def setup_notepads_and_jobs(self, timeseries_table_path: str, project_levels: dict):
"""
Args:
timeseries_table_path: Path for the timeseries table
project_levels: List of project levels
"""
logging.info("Setting up notepads and jobs") logging.info("Setting up notepads and jobs")
with open(r"scripts/constants/notebooks/metadata_ingestion.txt", "r") as f: with open(r"scripts/constants/notebooks/metadata_ingestion.txt", "r") as f:
notebook_code = f.read() notebook_code = f.read()
# # Notebook for metadata ingestion
# self.notebook_manager.create_notebook(
# notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_NOTEBOOK_NAME}",
# notebook_code=notebook_code,
# overwrite=True
# )
# # Job for metadata ingestion used by model management
# job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
# job_name=f'{self.meta.project_id}_{DatabricksConstants.METADATA_INGESTION_JOB_NAME}',
# notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
# ))
#
# databricks_details_db.hset(self.meta.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME, job_id)
# Timeseries DataPush Notebook
with open(r"scripts/constants/notebooks/timeseries_ingestion.txt", "r") as f:
notebook_code_for_timeseries = f.read()
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{timeseries_table_path}}",
f'"{timeseries_table_path}"')
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{project_levels}}", str(len(project_levels) - 1))
notebook_code_for_timeseries = notebook_code_for_timeseries.replace("{{event_hub_connection_string}}", f'"{self.meta.project_id}"')
self.notebook_manager.create_notebook( self.notebook_manager.create_notebook(
notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook", notebook_path=f"/Users/{self.message.databricks_user_email}/{self.meta.project_id}_{DatabricksConstants.TIMESERIES_INGESTION_NOTEBOOK_NAME}",
notebook_code=notebook_code, notebook_code=notebook_code_for_timeseries,
overwrite=True overwrite=True
) )
#Job for metadata ingestion used by model management
job_id = self.job_manager.create_job(job_config=self.job_manager.create_job_config_for_serverless(
job_name="metadata_ingestion_job",
notebook_path=f"/Users/{self.message.databricks_user_email}/metadata_ingestion_notebook",
))
redis_dict = {"metadata_ingestion_job": job_id}
databricks_details_db.hset(self.meta.project_id, json.dumps(redis_dict))
@staticmethod @staticmethod
def fetch_table_properties(file_format: str = 'DELTA'): def fetch_table_properties(file_format: str = 'DELTA'):
......
...@@ -9,14 +9,13 @@ from scripts.utils.model_convertor_utils import TypeMapper ...@@ -9,14 +9,13 @@ from scripts.utils.model_convertor_utils import TypeMapper
class DataBricksSQLLayer(DatabricksSQLUtility): class DataBricksSQLLayer(DatabricksSQLUtility):
def __init__(self, catalog_name: str, project_id: str, schema: str): def __init__(self, catalog_name: str, project_id: str, schema: str):
super().__init__(catalog_name, project_id) super().__init__(catalog_name, project_id)
self.catalog_name = catalog_name
self.schema = schema self.schema = schema
def create_external_table_from_structure(self, table: Table, def create_external_table_from_structure(self, table: Table,
external_location: str, external_location: str,
file_format: str = "PARQUET", file_format: str = "PARQUET",
table_properties: Dict[str, str] = None, table_properties: Dict[str, str] = None,
partition_columns: list = None): partition_columns: list = None) -> str:
""" """
Create an external table from a model class. Create an external table from a model class.
...@@ -26,6 +25,9 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -26,6 +25,9 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
file_format: The file format of the data files. file_format: The file format of the data files.
table_properties: Additional table properties. table_properties: Additional table properties.
partition_columns: List of columns to partition the table by. partition_columns: List of columns to partition the table by.
Returns:
External Location - Returns the external location
""" """
schema_table = f"{table.schema}.{table.name}" if table.schema else table.name schema_table = f"{table.schema}.{table.name}" if table.schema else table.name
columns_sql = TypeMapper().extract_columns_without_constraints(table) columns_sql = TypeMapper().extract_columns_without_constraints(table)
...@@ -47,7 +49,7 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -47,7 +49,7 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
create_sql = "\n".join(sql_parts) create_sql = "\n".join(sql_parts)
self.execute_sql_statement(create_sql) self.execute_sql_statement(create_sql)
return True return external_location
def create_timeseries_table(self, columns: List[str], external_location: str): def create_timeseries_table(self, columns: List[str], external_location: str):
""" """
...@@ -91,4 +93,4 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -91,4 +93,4 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
partition_columns=partition_columns, partition_columns=partition_columns,
table_properties=table_properties table_properties=table_properties
) )
return table_obj return external_location
...@@ -54,7 +54,7 @@ class DatabricksJobManager: ...@@ -54,7 +54,7 @@ class DatabricksJobManager:
if parameters: if parameters:
payload["notebook_params"] = parameters payload["notebook_params"] = parameters
response = HTTPXRequestHandler(url).post(headers=self.headers, json=payload) response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload)
if response.status_code == 200: if response.status_code == 200:
run_id = response.json()['run_id'] run_id = response.json()['run_id']
......
from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtilsFactory from ut_sql_utils.asyncio.declarative_utils import DeclarativeUtilsFactory
from scripts.core.handlers.instance_handler import ModelInstanceHandler
from scripts.core.handlers.model_creator_handler import ModelCreatorHandler from scripts.core.handlers.model_creator_handler import ModelCreatorHandler
from scripts.db.psql import session_manager from scripts.db.psql import session_manager
from scripts.schemas import ModelCreatorSchema from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class ModelCreatorAgent: class ModelCreatorAgent:
...@@ -10,7 +11,7 @@ class ModelCreatorAgent: ...@@ -10,7 +11,7 @@ class ModelCreatorAgent:
... ...
@staticmethod @staticmethod
async def model_trigger_agent(message: ModelCreatorSchema): async def model_creator_agent(message: ModelCreatorSchema):
declarative_utils = await DeclarativeUtilsFactory.get_declarative_utils( declarative_utils = await DeclarativeUtilsFactory.get_declarative_utils(
raw_database="unified_model", raw_database="unified_model",
project_id=message.meta.project_id, project_id=message.meta.project_id,
...@@ -19,3 +20,8 @@ class ModelCreatorAgent: ...@@ -19,3 +20,8 @@ class ModelCreatorAgent:
) )
model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils) model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils)
await model_cal_obj.create_models_in_unity_catalog() await model_cal_obj.create_models_in_unity_catalog()
@staticmethod
async def model_instance_agent(message: ModelInstanceSchema):
model_instance_obj = ModelInstanceHandler(project_id=message.project_id, payload=message)
await model_instance_obj.upload_instances_to_unity_catalog()
\ No newline at end of file
from typing import Optional from typing import Optional, Union, Dict, Any, List
from pydantic import BaseModel from pydantic import BaseModel, model_validator
from ut_security_util import MetaInfoSchema from ut_security_util import MetaInfoSchema
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
...@@ -14,4 +14,22 @@ class ModelCreatorSchema(BaseModel): ...@@ -14,4 +14,22 @@ class ModelCreatorSchema(BaseModel):
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
databricks_http_path: str = DatabricksConfig.DATABRICKS_HTTP_PATH databricks_http_path: str = DatabricksConfig.DATABRICKS_HTTP_PATH
databricks_user_email: str = "aniket.dhale@ilenscloud.onmicrosoft.com" databricks_user_email: str = "aniket.dhale@ilenscloud.onmicrosoft.com"
databricks_storage_path: str = DatabricksConfig.DATABRICKS_STORAGE_PATH
class ModelInstanceSchema(BaseModel):
data: Union[Dict[str, Any], List[Dict[str, Any]]]
project_id: str
schema: Optional[str] = DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME
databricks_host: str = DatabricksConfig.DATABRICKS_HOST
databricks_port: int = DatabricksConfig.DATABRICKS_PORT
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
databricks_http_path: str = DatabricksConfig.DATABRICKS_HTTP_PATH
databricks_user_email: str = "aniket.dhale@ilenscloud.onmicrosoft.com"
databricks_storage_path: str = DatabricksConfig.DATABRICKS_STORAGE_PATH
@model_validator(mode="before")
def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if 'data' in values and isinstance(values['data'], dict):
values['data'] = [values['data']]
return values
...@@ -15,9 +15,7 @@ class DatabricksSQLUtility: ...@@ -15,9 +15,7 @@ class DatabricksSQLUtility:
catalog_name: Name of the catalog to create catalog_name: Name of the catalog to create
project_id: Project ID project_id: Project ID
""" """
# self.catalog_name = get_db_name(project_id=project_id, database=catalog_name) self.catalog_name = get_db_name(project_id=project_id, database=catalog_name)
self.catalog_name = catalog_name
# self.catalog_name = catalog_name
self.engine = None self.engine = None
def connect_to_databricks(self): def connect_to_databricks(self):
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment