Commit 7845d0b6 authored by harshavardhan.c's avatar harshavardhan.c

feat: Functionality added for delete and add jobs in databricks based on the catalog.

parent c741f979
...@@ -72,4 +72,4 @@ DATABRICKS_ACCESS_TOKEN=dapi72a54657606877a3f7a6d92dd573df28 ...@@ -72,4 +72,4 @@ DATABRICKS_ACCESS_TOKEN=dapi72a54657606877a3f7a6d92dd573df28
#METADATA_SERVICES_URL=http://192.168.0.221:7111 #METADATA_SERVICES_URL=http://192.168.0.221:7111
#HIERARCHY_SERVICES_URL=http://192.168.0.221:7112 #HIERARCHY_SERVICES_URL=http://192.168.0.221:7112
#MODEL_MANAGEMENT_URL=http://192.168.0.221:7113 #MODEL_MANAGEMENT_URL=http://192.168.0.221:7113
#BATCH_PROCESS_APP_URL=http://localhost:7879 #BATCH_PROCESS_APP_URL=http://localhost:7879
\ No newline at end of file
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
/.idea
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
.vscode
data
.env
assets
__pycache__
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: requirements-txt-fixer
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
args:
- --py3-plus
- --keep-runtime-typing
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.8
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.11
# model-managament-databricks # model-managament-databricks
...@@ -6,23 +6,37 @@ from scripts.config import KafkaConfig ...@@ -6,23 +6,37 @@ from scripts.config import KafkaConfig
from scripts.engines.agents.model_creator_agent import ModelCreatorAgent from scripts.engines.agents.model_creator_agent import ModelCreatorAgent
from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent") broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
@broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2) @broker.subscriber(
KafkaConfig.KAFKA_MODEL_CREATION_TOPIC,
group_id="databricks_model_creator_agent",
max_workers=2,
)
async def consume_stream_for_processing_dependencies(message: dict): async def consume_stream_for_processing_dependencies(message: dict):
try: try:
await ModelCreatorAgent.model_creator_agent(message=ModelCreatorSchema(meta=message)) await ModelCreatorAgent.model_creator_agent(
message=ModelCreatorSchema(meta=message)
)
return True return True
except Exception as e: except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}") logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False return False
@broker.subscriber(KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC, group_id="databricks_instance_agent", max_workers=2)
@broker.subscriber(
KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC,
group_id="databricks_instance_agent",
max_workers=2,
)
async def consume_stream_for_processing_instances(message: dict): async def consume_stream_for_processing_instances(message: dict):
try: try:
await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message)) await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message))
return True return True
except Exception as e: except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}") logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False return False
\ No newline at end of file
...@@ -6,11 +6,11 @@ import sys ...@@ -6,11 +6,11 @@ import sys
from dotenv import load_dotenv from dotenv import load_dotenv
load_dotenv() load_dotenv()
from agent_subscribers import broker
from faststream import FastStream from faststream import FastStream
from ut_dev_utils import configure_logger from ut_dev_utils import configure_logger
from agent_subscribers import broker
configure_logger() configure_logger()
# Create FastStream app # Create FastStream app
......
faststream[confluent]==0.5.48 faststream[confluent]==0.5.48
ut-dev-utils[sql,essentials]==1.2 ut-dev-utils[sql,essentials]==1.2
uvloop==0.21.0 uvloop==0.21.0
\ No newline at end of file
...@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings): ...@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public") DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public")
DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical") DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical")
DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET") DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET")
DATABRICKS_STORAGE_PATH: str = Field(default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087") DATABRICKS_STORAGE_PATH: str = Field(
default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@model_validator(mode="before") @model_validator(mode="before")
def prepare_databricks_uri(cls, values): def prepare_databricks_uri(cls, values):
values[ values["DATABRICKS_URI"] = (
'DATABRICKS_URI'] = (f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}" f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}") f"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
return values return values
...@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage() ...@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
KafkaConfig = _KafkaConfig() KafkaConfig = _KafkaConfig()
DatabricksConfig = _DatabricksConfig() DatabricksConfig = _DatabricksConfig()
__all__ = ["Services", "RedisConfig", "ExternalServices", "PathToStorage", "KafkaConfig", "DatabricksConfig"] __all__ = [
"Services",
"RedisConfig",
"ExternalServices",
"PathToStorage",
"KafkaConfig",
"DatabricksConfig",
]
class DatabricksConstants: class DatabricksConstants:
METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job" METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job"
METADATA_DELETION_JOB_NAME = "metadata_deletion_job"
METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook" METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook" METADATA_DELETION_NOTEBOOK_NAME = "metadata_deletion_notebook"
\ No newline at end of file TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook"
class NotebookConstants:
METADATA_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_ingestion.txt"
)
METADATA_DELETION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_deletion.txt"
)
TIMESERIES_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/timeseries_ingestion.txt"
)
# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")
# COMMAND ----------
def extract_table_info(input_message_str: str, delete_column:str = "id"):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f'Deleting rows: {delete_values}')
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
delete_column: delete_values,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def delete_records_by_ids(table_name, ids, id_column="id"):
"""
Delete records from external table (Delta or Parquet) using list of IDs
Args:
table_name (str): Full table name (catalog.schema.table)
ids (list): List of IDs to delete
id_column (str): Column name containing IDs (default: "id")
Returns:
bool: True if successful, False otherwise
"""
try:
if not ids:
print("No IDs provided")
return False
# Format IDs for SQL IN clause
if isinstance(ids[0], str):
id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
else:
id_values = "(" + ",".join([str(id) for id in ids]) + ")"
# Use Delta table DELETE operation
delta_table = DeltaTable.forName(spark, table_name)
condition = f"{id_column} IN {id_values}"
delta_table.delete(condition=expr(condition))
print(f"Successfully deleted {len(ids)} records from {table_name}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
# COMMAND ----------
table_info = extract_table_info(input_message, delete_column=delete_column)
# COMMAND ----------
result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
...@@ -31,18 +31,18 @@ def extract_table_info(input_message_str: str): ...@@ -31,18 +31,18 @@ def extract_table_info(input_message_str: str):
""" """
try: try:
message_data = json.loads(input_message_str) message_data = json.loads(input_message_str)
# Extract table name from data.type # Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise' table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787' project_id = message_data['project_id'] # 'project_787'
data_payload = message_data['data'] # Full data object data_payload = message_data['data'] # Full data object
table_properties = message_data['table_properties'] # Fetch table properties table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:") print(f"Extracted Info:")
print(f"Table Name: {table_name}") print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}") print(f"Project ID: {project_id}")
print(f"Table Prop Keys: {list(table_properties.keys())}") print(f"Table Prop Keys: {list(table_properties.keys())}")
return { return {
'table_name': table_name, 'table_name': table_name,
'project_id': project_id, 'project_id': project_id,
...@@ -50,7 +50,7 @@ def extract_table_info(input_message_str: str): ...@@ -50,7 +50,7 @@ def extract_table_info(input_message_str: str):
'raw_message': message_data, 'raw_message': message_data,
'table_properties': table_properties 'table_properties': table_properties
} }
except json.JSONDecodeError as e: except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}") raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e: except KeyError as e:
...@@ -62,13 +62,13 @@ def extract_table_info(input_message_str: str): ...@@ -62,13 +62,13 @@ def extract_table_info(input_message_str: str):
def detect_external_table_schema(table_name): def detect_external_table_schema(table_name):
""" """
Detect schema of external Delta or Parquet table Detect schema of external Delta or Parquet table
Args: Args:
table_name (str): Name of the table (e.g., 'enterprise') table_name (str): Name of the table (e.g., 'enterprise')
Returns: Returns:
pyspark.sql.types.StructType: Schema of the table pyspark.sql.types.StructType: Schema of the table
""" """
try: try:
# Try to get schema from catalog # Try to get schema from catalog
table_df = spark.table(table_name) table_df = spark.table(table_name)
...@@ -96,4 +96,4 @@ data_df = spark.createDataFrame(table_info['data_payload'], schema=schema) ...@@ -96,4 +96,4 @@ data_df = spark.createDataFrame(table_info['data_payload'], schema=schema)
# COMMAND ---------- # COMMAND ----------
data_df.write.mode("append").saveAsTable(table_info['table_name']) data_df.write.mode("append").saveAsTable(table_info['table_name'])
\ No newline at end of file
...@@ -4,7 +4,7 @@ from pyspark.sql.functions import * ...@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
from pyspark.sql.types import * from pyspark.sql.types import *
import json import json
spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate() spark = SparkSession.builder.appName("StreamingTimeseriesPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN") spark.sparkContext.setLogLevel("WARN")
# COMMAND ---------- # COMMAND ----------
...@@ -31,7 +31,7 @@ message_schema = StructType([ ...@@ -31,7 +31,7 @@ message_schema = StructType([
]), True), ]), True),
StructField("a_id", StringType(), True), StructField("a_id", StringType(), True),
StructField("d_id", StringType(), True), StructField("d_id", StringType(), True),
StructField("gw_id", StringType(), True), StructField("gw_id", StringType(), True),
StructField("msg_id", IntegerType(), True), StructField("msg_id", IntegerType(), True),
StructField("p_id", StringType(), True), StructField("p_id", StringType(), True),
StructField("pd_id", StringType(), True), StructField("pd_id", StringType(), True),
...@@ -50,7 +50,7 @@ def safe_get_item(array_col, index): ...@@ -50,7 +50,7 @@ def safe_get_item(array_col, index):
def transform_timeseries_data_fully_dynamic(df, max_tag_parts=4): def transform_timeseries_data_fully_dynamic(df, max_tag_parts=4):
print(f"Transforming to target schema with up to {max_tag_parts} tag parts...") print(f"Transforming to target schema with up to {max_tag_parts} tag parts...")
df_with_split = df.withColumn("tag_parts", split(col("data.tag"), "\\$")) df_with_split = df.withColumn("tag_parts", split(col("data.tag"), "\\$"))
df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts"))) df_with_split = df_with_split.withColumn("tag_parts_count", size(col("tag_parts")))
df_with_split = df_with_split.withColumn("hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1)) df_with_split = df_with_split.withColumn("hierarchy_levels", slice(col("tag_parts"), 1, size(col("tag_parts")) - 1))
...@@ -146,4 +146,3 @@ transformed_df.writeStream \ ...@@ -146,4 +146,3 @@ transformed_df.writeStream \
# COMMAND ---------- # COMMAND ----------
...@@ -3,6 +3,7 @@ import json ...@@ -3,6 +3,7 @@ import json
from ut_dev_utils import get_db_name from ut_dev_utils import get_db_name
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants
from scripts.db.databricks.job_manager import DatabricksJobManager from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.redis.databricks_details import databricks_details_db from scripts.db.redis.databricks_details import databricks_details_db
from scripts.schemas import ModelInstanceSchema from scripts.schemas import ModelInstanceSchema
...@@ -12,29 +13,44 @@ class ModelInstanceHandler: ...@@ -12,29 +13,44 @@ class ModelInstanceHandler:
def __init__(self, project_id: str, payload: ModelInstanceSchema): def __init__(self, project_id: str, payload: ModelInstanceSchema):
self.project_id = project_id self.project_id = project_id
self.payload = payload self.payload = payload
self.catalog_name = get_db_name(project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME) self.catalog_name = get_db_name(
project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME
)
self.job_manager = DatabricksJobManager( self.job_manager = DatabricksJobManager(
databricks_host=payload.databricks_host, databricks_host=payload.databricks_host,
access_token=payload.databricks_access_token access_token=payload.databricks_access_token,
) )
def upload_instances_to_unity_catalog(self): async def upload_instances_to_unity_catalog(self):
job_id = databricks_details_db.hget(self.project_id, "metadata_ingestion_job") if self.payload.action_type == "delete":
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_DELETION_JOB_NAME
)
else:
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME
)
if not job_id: if not job_id:
raise ValueError("No job id found for metadata ingestion job, skipping upload to unity catalog") raise ValueError(
run_id = self.job_manager.run_job(job_id=job_id, f"No job id found for {self.payload.action_type}, skipping upload to unity catalog"
parameters={"input_message": json.dumps(self.get_job_trigger_payload())}) )
run_id = self.job_manager.run_job(
job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())},
)
if not run_id: if not run_id:
raise ValueError("Failed to run metadata ingestion job, skipping upload to unity catalog") raise ValueError(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def get_job_trigger_payload(self): def get_job_trigger_payload(self):
table_name = self.payload.data[0]['type'] table_name = self.payload.node_type
schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}" schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return { return {
"table_properties": { "table_properties": {
"table_name": f'{self.catalog_name}.{schema_table}', "table_name": f"{self.catalog_name}.{schema_table}",
"table_path": f'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}', "table_path": f"{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}",
}, },
"project_id": self.project_id, "project_id": self.project_id,
"data": self.payload.data "data": self.payload.data,
} }
from typing import Dict, List from typing import Dict, List
from sqlalchemy import Table, Column, String, BigInteger, DateTime, MetaData, Integer, Date from sqlalchemy import (
BigInteger,
Column,
Date,
DateTime,
Integer,
MetaData,
String,
Table,
)
from scripts.utils.databricks_utils import DatabricksSQLUtility from scripts.utils.databricks_utils import DatabricksSQLUtility
from scripts.utils.model_convertor_utils import TypeMapper from scripts.utils.model_convertor_utils import TypeMapper
...@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
super().__init__(catalog_name, project_id) super().__init__(catalog_name, project_id)
self.schema = schema self.schema = schema
def create_external_table_from_structure(self, table: Table, def create_external_table_from_structure(
external_location: str, self,
file_format: str = "PARQUET", table: Table,
table_properties: Dict[str, str] = None, external_location: str,
partition_columns: list = None) -> str: file_format: str = "PARQUET",
table_properties: Dict[str, str] = None,
partition_columns: list = None,
) -> str:
""" """
Create an external table from a model class. Create an external table from a model class.
...@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
""" """
schema_table = f"{table.schema}.{table.name}" if table.schema else table.name schema_table = f"{table.schema}.{table.name}" if table.schema else table.name
columns_sql = TypeMapper().extract_columns_without_constraints(table) columns_sql = TypeMapper().extract_columns_without_constraints(table)
external_location = f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}" external_location = (
f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
)
sql_parts = [ sql_parts = [
f"CREATE TABLE IF NOT EXISTS {schema_table}", f"CREATE TABLE IF NOT EXISTS {schema_table}",
f"({columns_sql})", f"({columns_sql})",
f"USING {file_format}", f"USING {file_format}",
f"LOCATION '{external_location}'" f"LOCATION '{external_location}'",
] ]
if partition_columns: if partition_columns:
partition_clause = ", ".join(partition_columns) partition_clause = ", ".join(partition_columns)
...@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility): ...@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
""" """
table_columns = [ table_columns = [
Column('timestamp', BigInteger, nullable=False), Column("timestamp", BigInteger, nullable=False),
Column('dt_timestamp', DateTime, nullable=False), Column("dt_timestamp", DateTime, nullable=False),
Column('dt_date', Date, nullable=False), Column("dt_date", Date, nullable=False),
Column('dt_hour', Integer, nullable=False), Column("dt_hour", Integer, nullable=False),
Column('value', String, nullable=False), Column("value", String, nullable=False),
Column('value_type', String, nullable=False, default='float'), Column("value_type", String, nullable=False, default="float"),
Column("c3", String, nullable=False) Column("c3", String, nullable=False),
] ]
default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns] default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns]
table_columns.extend([Column(col_name, String, nullable=True) for col_name in default_columns]) table_columns.extend(
partition_columns = ['dt_date', 'dt_hour', 'c3'] [Column(col_name, String, nullable=True) for col_name in default_columns]
)
partition_columns = ["dt_date", "dt_hour", "c3"]
table_properties = { table_properties = {
"parquet.compression": "snappy", # Fast decompression for frequent queries "parquet.compression": "snappy", # Fast decompression for frequent queries
"parquet.page.size": "524288", # 512KB - better time-range filtering "parquet.page.size": "524288", # 512KB - better time-range filtering
"parquet.block.size": "268435456", # 256MB - efficient sequential reads "parquet.block.size": "268435456", # 256MB - efficient sequential reads
"serialization.format": "1" # Support for arrays/complex types "serialization.format": "1", # Support for arrays/complex types
} }
table_obj = Table( table_obj = Table(
"timeseries_data", "timeseries_data", MetaData(), *table_columns, schema=self.schema
MetaData(),
*table_columns,
schema=self.schema
) )
self.create_external_table_from_structure( self.create_external_table_from_structure(
table=table_obj, table=table_obj,
external_location=external_location, external_location=external_location,
partition_columns=partition_columns, partition_columns=partition_columns,
table_properties=table_properties table_properties=table_properties,
) )
return external_location return external_location
...@@ -14,10 +14,14 @@ class DatabricksJobManager: ...@@ -14,10 +14,14 @@ class DatabricksJobManager:
databricks_host: Your Databricks workspace URL databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token access_token: Personal access token or service principal token
""" """
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}" self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = { self.headers = {
'Authorization': f'Bearer {access_token}', "Authorization": f"Bearer {access_token}",
'Content-Type': 'application/json' "Content-Type": "application/json",
} }
def create_job(self, job_config: dict): def create_job(self, job_config: dict):
...@@ -32,11 +36,13 @@ class DatabricksJobManager: ...@@ -32,11 +36,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config) response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config)
if response.status_code == 200: if response.status_code == 200:
job_id = response.json()['job_id'] job_id = response.json()["job_id"]
logging.info(f"Job created successfully with ID: {job_id}") logging.info(f"Job created successfully with ID: {job_id}")
return job_id return job_id
else: else:
logging.error(f"Error creating job: {response.status_code} - {response.text}") logging.error(
f"Error creating job: {response.status_code} - {response.text}"
)
return None return None
def run_job(self, job_id: str, parameters=None): def run_job(self, job_id: str, parameters=None):
...@@ -57,11 +63,13 @@ class DatabricksJobManager: ...@@ -57,11 +63,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload) response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload)
if response.status_code == 200: if response.status_code == 200:
run_id = response.json()['run_id'] run_id = response.json()["run_id"]
logging.info(f"Job run started with ID: {run_id}") logging.info(f"Job run started with ID: {run_id}")
return run_id return run_id
else: else:
logging.error(f"Error running job: {response.status_code} - {response.text}") logging.error(
f"Error running job: {response.status_code} - {response.text}"
)
return None return None
def get_run_status(self, run_id): def get_run_status(self, run_id):
...@@ -73,12 +81,16 @@ class DatabricksJobManager: ...@@ -73,12 +81,16 @@ class DatabricksJobManager:
url = f"{self.host}/api/2.1/jobs/runs/get" url = f"{self.host}/api/2.1/jobs/runs/get"
params = {"run_id": run_id} params = {"run_id": run_id}
response = HTTPXRequestHandler(url).get(url, headers=self.headers, params=params) response = HTTPXRequestHandler(url).get(
url, headers=self.headers, params=params
)
if response.status_code == 200: if response.status_code == 200:
return response.json() return response.json()
else: else:
logging.error(f"Error getting run status: {response.status_code} - {response.text}") logging.error(
f"Error getting run status: {response.status_code} - {response.text}"
)
return None return None
@staticmethod @staticmethod
...@@ -98,16 +110,18 @@ class DatabricksJobManager: ...@@ -98,16 +110,18 @@ class DatabricksJobManager:
"task_key": "table_update_task", "task_key": "table_update_task",
"notebook_task": { "notebook_task": {
"notebook_path": notebook_path, "notebook_path": notebook_path,
"base_parameters": { "base_parameters": {"input_message": "default_value"},
"input_message": "default_value"
}
}, },
"timeout_seconds": 3600 "timeout_seconds": 3600,
} }
], ],
"max_concurrent_runs": 10, "max_concurrent_runs": 10,
"tags": { "tags": {
"purpose": "metadata_ingestion", "purpose": (
"compute_type": "serverless" "metadata_ingestion"
} if "ingestion" in job_name
else "metadata_deletion"
),
"compute_type": "serverless",
},
} }
...@@ -13,13 +13,19 @@ class NotebookManager: ...@@ -13,13 +13,19 @@ class NotebookManager:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com') databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token access_token: Personal access token or service principal token
""" """
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}" self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = { self.headers = {
'Authorization': f'Bearer {access_token}', "Authorization": f"Bearer {access_token}",
'Content-Type': 'application/json' "Content-Type": "application/json",
} }
def create_notebook(self, notebook_path, notebook_code: str, language='PYTHON', overwrite=True): def create_notebook(
self, notebook_path, notebook_code: str, language="PYTHON", overwrite=True
):
""" """
Create a notebook in Databricks workspace Create a notebook in Databricks workspace
...@@ -31,18 +37,22 @@ class NotebookManager: ...@@ -31,18 +37,22 @@ class NotebookManager:
""" """
url = f"{self.host}/api/2.0/workspace/import" url = f"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64 # Encode the notebook content in base64
encoded_content = base64.b64encode(notebook_code.encode('utf-8')).decode('utf-8') encoded_content = base64.b64encode(notebook_code.encode("utf-8")).decode(
"utf-8"
)
payload = { payload = {
"path": notebook_path, "path": notebook_path,
"format": "SOURCE", "format": "SOURCE",
"language": language, "language": language,
"content": encoded_content, "content": encoded_content,
"overwrite": overwrite "overwrite": overwrite,
} }
response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers) response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers)
if response.status_code == 200: if response.status_code == 200:
logging.info(f"Notebook created successfully at: {notebook_path}") logging.info(f"Notebook created successfully at: {notebook_path}")
return True return True
else: else:
logging.error(f"Error creating notebook: {response.status_code} - {response.text}") logging.error(
f"Error creating notebook: {response.status_code} - {response.text}"
)
return False return False
import orjson
from scripts.config import RedisConfig from scripts.config import RedisConfig
from scripts.db.redis import redis_connector from scripts.db.redis import redis_connector
databricks_details_db = redis_connector.connect(db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True) databricks_details_db = redis_connector.connect(
\ No newline at end of file db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True
)
...@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig ...@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
from scripts.config import RedisConfig from scripts.config import RedisConfig
from scripts.db.redis import redis_connector from scripts.db.redis import redis_connector
graphql_details_db = redis_connector.connect(db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True) graphql_details_db = redis_connector.connect(
db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True
)
def get_models( def get_models(
...@@ -38,7 +40,9 @@ def get_models( ...@@ -38,7 +40,9 @@ def get_models(
""" """
tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper") tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper")
if tables_data is None: if tables_data is None:
raise ILensErrors(f"No GraphQL schema data found for project {info.data['project_id']}") raise ILensErrors(
f"No GraphQL schema data found for project {info.data['project_id']}"
)
tables: Dict[str, Any] = orjson.loads(tables_data) or {} tables: Dict[str, Any] = orjson.loads(tables_data) or {}
if ( if (
......
...@@ -3,7 +3,9 @@ import orjson ...@@ -3,7 +3,9 @@ import orjson
from scripts.config import RedisConfig from scripts.config import RedisConfig
from scripts.db.redis import redis_connector from scripts.db.redis import redis_connector
project_details_db = redis_connector.connect(db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True) project_details_db = redis_connector.connect(
db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True
)
def get_project_time_zone(project_id: str): def get_project_time_zone(project_id: str):
...@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str): ...@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
return "UTC" return "UTC"
def fetch_level_details(project_id: str, keys: bool = False, raw: bool = False) -> dict | list: def fetch_level_details(
project_id: str, keys: bool = False, raw: bool = False
) -> dict | list:
""" """
Function to fetch level details from project details Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details Uses redis project details cache db (db18) and fetches the level details
...@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str: ...@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
project_details = orjson.loads(project_details) project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {}) counter_levels = project_details.get("counter_levels", {})
asset_level = ( asset_level = (
counter_levels.get("asset", counter_levels.get("equipment")) if isinstance(counter_levels, dict) else None counter_levels.get("asset", counter_levels.get("equipment"))
if isinstance(counter_levels, dict)
else None
) )
if asset_level: if asset_level:
return asset_level return asset_level
...@@ -81,9 +87,10 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]: ...@@ -81,9 +87,10 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
swapped_dict = {v: k for k, v in counter_levels.items()} swapped_dict = {v: k for k, v in counter_levels.items()}
return swapped_dict.get("ast", ""), "ast" return swapped_dict.get("ast", ""), "ast"
def project_template_keys(project_id: str, levels=False): def project_template_keys(project_id: str, levels=False):
val = project_details_db.get(project_id) val = project_details_db.get(project_id)
if val is None: if val is None:
raise ValueError(f"Unknown Project, Project ID:{project_id}Not Found!!!") raise ValueError(f"Unknown Project, Project ID:{project_id}Not Found!!!")
val = orjson.loads(val) val = orjson.loads(val)
return val.get("levels", {}) if levels else list(val.get("levels", {}).keys()) return val.get("levels", {}) if levels else list(val.get("levels", {}).keys())
\ No newline at end of file
...@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker ...@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig from scripts.config import KafkaConfig
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent") broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
...@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema ...@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class ModelCreatorAgent: class ModelCreatorAgent:
def __init__(self): def __init__(self): ...
...
@staticmethod @staticmethod
async def model_creator_agent(message: ModelCreatorSchema): async def model_creator_agent(message: ModelCreatorSchema):
...@@ -18,10 +17,14 @@ class ModelCreatorAgent: ...@@ -18,10 +17,14 @@ class ModelCreatorAgent:
session_manager=session_manager, session_manager=session_manager,
schema=message.schema, schema=message.schema,
) )
model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils) model_cal_obj = ModelCreatorHandler(
message=message, declarative_utils=declarative_utils
)
await model_cal_obj.create_models_in_unity_catalog() await model_cal_obj.create_models_in_unity_catalog()
@staticmethod @staticmethod
async def model_instance_agent(message: ModelInstanceSchema): async def model_instance_agent(message: ModelInstanceSchema):
model_instance_obj = ModelInstanceHandler(project_id=message.project_id, payload=message) model_instance_obj = ModelInstanceHandler(
await model_instance_obj.upload_instances_to_unity_catalog() project_id=message.project_id, payload=message
\ No newline at end of file )
await model_instance_obj.upload_instances_to_unity_catalog()
from typing import Optional, Union, Dict, Any, List from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, model_validator from pydantic import BaseModel, Field, model_validator
from ut_security_util import MetaInfoSchema from ut_security_util import MetaInfoSchema
from scripts.config import DatabricksConfig from scripts.config import DatabricksConfig
...@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel): ...@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
class ModelInstanceSchema(BaseModel): class ModelInstanceSchema(BaseModel):
data: Union[Dict[str, Any], List[Dict[str, Any]]] data: Union[Dict[str, Any], List[Dict[str, Any]]]
project_id: str project_id: str
schema: Optional[str] = DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME action_type: str = "save"
node_type: str
sql_schema: Optional[str] = Field(
default=DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME, alias="schema"
)
databricks_host: str = DatabricksConfig.DATABRICKS_HOST databricks_host: str = DatabricksConfig.DATABRICKS_HOST
databricks_port: int = DatabricksConfig.DATABRICKS_PORT databricks_port: int = DatabricksConfig.DATABRICKS_PORT
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
...@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel): ...@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
@model_validator(mode="before") @model_validator(mode="before")
def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]: def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if 'data' in values and isinstance(values['data'], dict): if "data" in values and isinstance(values["data"], dict):
values['data'] = [values['data']] values["data"] = [values["data"]]
return values return values
...@@ -29,13 +29,17 @@ class DatabricksSQLUtility: ...@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
DatabricksConfig.DATABRICKS_URI, DatabricksConfig.DATABRICKS_URI,
pool_pre_ping=True, pool_pre_ping=True,
pool_recycle=3600, pool_recycle=3600,
echo=False echo=False,
) )
# Test connection # Test connection
with self.engine.connect() as conn: with self.engine.connect() as conn:
result = conn.execute(text("SELECT current_user() as user, current_catalog() as catalog")) result = conn.execute(
text("SELECT current_user() as user, current_catalog() as catalog")
)
user_info = result.fetchone() user_info = result.fetchone()
logger.info(f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}") logger.info(
f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger.info("Successfully connected to Databricks") logger.info("Successfully connected to Databricks")
return True return True
except Exception as e: except Exception as e:
...@@ -46,8 +50,12 @@ class DatabricksSQLUtility: ...@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
if self.engine: if self.engine:
self.engine.dispose() self.engine.dispose()
def create_catalog(self, managed_location: Optional[str] = None, comment: Optional[str] = None, def create_catalog(
properties: Optional[dict] = None): self,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
""" """
Create a new catalog in Unity Catalog Create a new catalog in Unity Catalog
Args: Args:
...@@ -77,8 +85,13 @@ class DatabricksSQLUtility: ...@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}") logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}")
raise raise
def create_schema(self, schema_name: str, managed_location: Optional[str] = None, comment: Optional[str] = None, def create_schema(
properties: Optional[dict] = None): self,
schema_name: str,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
""" """
Create a new schema within a catalog Create a new schema within a catalog
Args: Args:
...@@ -103,18 +116,22 @@ class DatabricksSQLUtility: ...@@ -103,18 +116,22 @@ class DatabricksSQLUtility:
ddl += f"\nWITH DBPROPERTIES ({props})" ddl += f"\nWITH DBPROPERTIES ({props})"
self.execute_sql_statement(ddl) self.execute_sql_statement(ddl)
logger.info(f"Schema '{self.catalog_name}.{schema_name}' created successfully") logger.info(
f"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
return full_schema_name return full_schema_name
except Exception as e: except Exception as e:
logger.error(f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}") logger.error(
f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
raise raise
def create_external_location( def create_external_location(
self, self,
location_name: str, location_name: str,
storage_path: str, storage_path: str,
credential_name: str, credential_name: str,
comment: Optional[str] = None comment: Optional[str] = None,
) -> str: ) -> str:
""" """
Create an external location in Unity Catalog Create an external location in Unity Catalog
...@@ -138,7 +155,9 @@ class DatabricksSQLUtility: ...@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
logger.info(f"External location '{location_name}' created successfully") logger.info(f"External location '{location_name}' created successfully")
return location_name return location_name
except Exception as e: except Exception as e:
logger.error(f"Failed to create external location '{location_name}': {str(e)}") logger.error(
f"Failed to create external location '{location_name}': {str(e)}"
)
raise raise
def execute_sql_statement(self, query: str): def execute_sql_statement(self, query: str):
......
...@@ -18,7 +18,11 @@ class HTTPXRequestUtil: ...@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
def delete(self, path="", params=None, **kwargs) -> httpx.Response: def delete(self, path="", params=None, **kwargs) -> httpx.Response:
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.delete(url=url, params=params) response: httpx.Response = client.delete(url=url, params=params)
return response return response
...@@ -27,7 +31,11 @@ class HTTPXRequestUtil: ...@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.put(url=url, data=data, json=json) response: httpx.Response = client.put(url=url, data=data, json=json)
return response return response
...@@ -42,7 +50,11 @@ class HTTPXRequestUtil: ...@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
""" """
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.post(url=url, data=data, json=json) response: httpx.Response = client.post(url=url, data=data, json=json)
return response return response
...@@ -52,7 +64,11 @@ class HTTPXRequestUtil: ...@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
url = self.get_url(path) url = self.get_url(path)
logging.info(url) logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client: with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.get(url=url, params=params) response: httpx.Response = client.get(url=url, params=params)
return response return response
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment