Commit 7845d0b6 authored by harshavardhan.c's avatar harshavardhan.c

feat: Functionality added for delete and add jobs in databricks based on the catalog.

parent c741f979
# Created by .ignore support plugin (hsz.mobi)
### Python template
# Byte-compiled / optimized / DLL files
__pycache__/
*.py[cod]
*$py.class
# C extensions
*.so
/.idea
# Distribution / packaging
.Python
build/
develop-eggs/
dist/
downloads/
eggs/
.eggs/
lib/
lib64/
parts/
sdist/
var/
wheels/
pip-wheel-metadata/
share/python-wheels/
*.egg-info/
.installed.cfg
*.egg
MANIFEST
# PyInstaller
# Usually these files are written by a python script from a template
# before PyInstaller builds the exe, so as to inject date/other infos into it.
*.manifest
*.spec
# Installer logs
pip-log.txt
pip-delete-this-directory.txt
# Unit test / coverage reports
htmlcov/
.tox/
.nox/
.coverage
.coverage.*
.cache
nosetests.xml
coverage.xml
*.cover
.hypothesis/
.pytest_cache/
# Translations
*.mo
*.pot
# Django stuff:
*.log
local_settings.py
db.sqlite3
db.sqlite3-journal
# Flask stuff:
instance/
.webassets-cache
# Scrapy stuff:
.scrapy
# Sphinx documentation
docs/_build/
# PyBuilder
target/
# Jupyter Notebook
.ipynb_checkpoints
# IPython
profile_default/
ipython_config.py
# pyenv
.python-version
# pipenv
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
# However, in case of collaboration, if having platform-specific dependencies or dependencies
# having no cross-platform support, pipenv may install dependencies that don't work, or not
# install all needed dependencies.
#Pipfile.lock
# celery beat schedule file
celerybeat-schedule
# SageMath parsed files
*.sage.py
# Environments
.env
.venv
env/
venv/
ENV/
env.bak/
venv.bak/
# Spyder project settings
.spyderproject
.spyproject
# Rope project settings
.ropeproject
# mkdocs documentation
/site
# mypy
.mypy_cache/
.dmypy.json
dmypy.json
# Pyre type checker
.pyre/
.vscode/settings.json
.vscode
data
.env
assets
__pycache__
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.6.0
hooks:
- id: end-of-file-fixer
- id: trailing-whitespace
- id: requirements-txt-fixer
- repo: https://github.com/asottile/pyupgrade
rev: v3.16.0
hooks:
- id: pyupgrade
args:
- --py3-plus
- --keep-runtime-typing
- repo: https://github.com/charliermarsh/ruff-pre-commit
rev: v0.4.8
hooks:
- id: ruff
args:
- --fix
- repo: https://github.com/pycqa/isort
rev: 5.13.2
hooks:
- id: isort
name: isort (python)
- id: isort
name: isort (cython)
types: [cython]
- id: isort
name: isort (pyi)
types: [pyi]
- repo: https://github.com/psf/black
rev: 24.4.2
hooks:
- id: black
# It is recommended to specify the latest version of Python
# supported by your project here, or alternatively use
# pre-commit's default_language_version, see
# https://pre-commit.com/#top_level-default_language_version
language_version: python3.11
# model-managament-databricks
......@@ -6,19 +6,33 @@ from scripts.config import KafkaConfig
from scripts.engines.agents.model_creator_agent import ModelCreatorAgent
from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
@broker.subscriber(KafkaConfig.KAFKA_MODEL_CREATION_TOPIC, group_id="databricks_model_creator_agent", max_workers=2)
@broker.subscriber(
KafkaConfig.KAFKA_MODEL_CREATION_TOPIC,
group_id="databricks_model_creator_agent",
max_workers=2,
)
async def consume_stream_for_processing_dependencies(message: dict):
try:
await ModelCreatorAgent.model_creator_agent(message=ModelCreatorSchema(meta=message))
await ModelCreatorAgent.model_creator_agent(
message=ModelCreatorSchema(meta=message)
)
return True
except Exception as e:
logging.error(f"Exception occurred while creating model in Databricks: {e}")
return False
@broker.subscriber(KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC, group_id="databricks_instance_agent", max_workers=2)
@broker.subscriber(
KafkaConfig.KAFKA_MODEL_INSTANCE_TOPIC,
group_id="databricks_instance_agent",
max_workers=2,
)
async def consume_stream_for_processing_instances(message: dict):
try:
await ModelCreatorAgent.model_instance_agent(ModelInstanceSchema(**message))
......
......@@ -6,11 +6,11 @@ import sys
from dotenv import load_dotenv
load_dotenv()
from agent_subscribers import broker
from faststream import FastStream
from ut_dev_utils import configure_logger
from agent_subscribers import broker
configure_logger()
# Create FastStream app
......
......@@ -70,13 +70,16 @@ class _DatabricksConfig(BaseSettings):
DATABRICKS_PUBLIC_SCHEMA_NAME: str = Field(default="public")
DATABRICKS_ANALYTICAL_SCHEMA_NAME: str = Field(default="analytical")
DATABRICKS_STORAGE_FORMAT: str = Field(default="PARQUET")
DATABRICKS_STORAGE_PATH: str = Field(default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087")
DATABRICKS_STORAGE_PATH: str = Field(
default="abfss://unity-catalog-storage@dbstoragenzxfhpgsipt5a.dfs.core.windows.net/416418955412087"
)
@model_validator(mode="before")
def prepare_databricks_uri(cls, values):
values[
'DATABRICKS_URI'] = (f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}")
values["DATABRICKS_URI"] = (
f"databricks://token:{values['DATABRICKS_ACCESS_TOKEN']}@{values['DATABRICKS_HOST']}:{values['DATABRICKS_PORT']}"
f"?http_path={values['DATABRICKS_HTTP_PATH']}"
)
return values
......@@ -88,4 +91,11 @@ PathToStorage = _PathToStorage()
KafkaConfig = _KafkaConfig()
DatabricksConfig = _DatabricksConfig()
__all__ = ["Services", "RedisConfig", "ExternalServices", "PathToStorage", "KafkaConfig", "DatabricksConfig"]
__all__ = [
"Services",
"RedisConfig",
"ExternalServices",
"PathToStorage",
"KafkaConfig",
"DatabricksConfig",
]
class DatabricksConstants:
METADATA_INGESTION_JOB_NAME = "metadata_ingestion_job"
METADATA_DELETION_JOB_NAME = "metadata_deletion_job"
METADATA_INGESTION_NOTEBOOK_NAME = "metadata_ingestion_notebook"
METADATA_DELETION_NOTEBOOK_NAME = "metadata_deletion_notebook"
TIMESERIES_INGESTION_NOTEBOOK_NAME = "timeseries_ingestion_notebook"
class NotebookConstants:
METADATA_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_ingestion.txt"
)
METADATA_DELETION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/metadata_deletion.txt"
)
TIMESERIES_INGESTION_NOTEBOOK_PATH = (
"scripts/constants/notebooks/timeseries_ingestion.txt"
)
# Databricks notebook source
from delta.tables import DeltaTable
from pyspark.sql.functions import expr
import json
# COMMAND ----------
dbutils.widgets.text("input_message", "", "Input Message JSON")
dbutils.widgets.text("id_column", "id", "ID Column Name")
input_message = dbutils.widgets.get("input_message")
delete_column = dbutils.widgets.get("id_column")
# COMMAND ----------
def extract_table_info(input_message_str: str, delete_column:str = "id"):
"""
Extract table name and data from input message
Args:
input_message_str (str): JSON string containing the message
Returns:
dict: Extracted information
"""
try:
message_data = json.loads(input_message_str)
# Extract table name from data.type
table_name = message_data['table_properties']['table_name'] # 'enterprise'
project_id = message_data['project_id'] # 'project_787'
delete_values = [msg[delete_column] for msg in message_data['data'] if delete_column in msg]
table_properties = message_data['table_properties'] # Fetch table properties
print(f"Extracted Info:")
print(f"Table Name: {table_name}")
print(f"Project ID: {project_id}")
print(f'Deleting rows: {delete_values}')
print(f"Table Prop Keys: {list(table_properties.keys())}")
return {
'table_name': table_name,
'project_id': project_id,
delete_column: delete_values,
'raw_message': message_data,
'table_properties': table_properties
}
except json.JSONDecodeError as e:
raise ValueError(f"Invalid JSON in input_message: {str(e)}")
except KeyError as e:
raise ValueError(f"Missing required field in input_message: {str(e)}")
# COMMAND ----------
def delete_records_by_ids(table_name, ids, id_column="id"):
"""
Delete records from external table (Delta or Parquet) using list of IDs
Args:
table_name (str): Full table name (catalog.schema.table)
ids (list): List of IDs to delete
id_column (str): Column name containing IDs (default: "id")
Returns:
bool: True if successful, False otherwise
"""
try:
if not ids:
print("No IDs provided")
return False
# Format IDs for SQL IN clause
if isinstance(ids[0], str):
id_values = "(" + ",".join([f"'{id}'" for id in ids]) + ")"
else:
id_values = "(" + ",".join([str(id) for id in ids]) + ")"
# Use Delta table DELETE operation
delta_table = DeltaTable.forName(spark, table_name)
condition = f"{id_column} IN {id_values}"
delta_table.delete(condition=expr(condition))
print(f"Successfully deleted {len(ids)} records from {table_name}")
return True
except Exception as e:
print(f"Error: {str(e)}")
return False
# COMMAND ----------
table_info = extract_table_info(input_message, delete_column=delete_column)
# COMMAND ----------
result = delete_records_by_ids(table_name=table_info['table_name'], ids=table_info[delete_column], id_column=delete_column)
print(f"Deletion completed: {result}")
......@@ -4,7 +4,7 @@ from pyspark.sql.functions import *
from pyspark.sql.types import *
import json
spark = SparkSession.builder.appName("StreamingIoTPipeline").getOrCreate()
spark = SparkSession.builder.appName("StreamingTimeseriesPipeline").getOrCreate()
spark.sparkContext.setLogLevel("WARN")
# COMMAND ----------
......@@ -146,4 +146,3 @@ transformed_df.writeStream \
# COMMAND ----------
......@@ -3,6 +3,7 @@ import json
from ut_dev_utils import get_db_name
from scripts.config import DatabricksConfig
from scripts.constants import DatabricksConstants
from scripts.db.databricks.job_manager import DatabricksJobManager
from scripts.db.redis.databricks_details import databricks_details_db
from scripts.schemas import ModelInstanceSchema
......@@ -12,29 +13,44 @@ class ModelInstanceHandler:
def __init__(self, project_id: str, payload: ModelInstanceSchema):
self.project_id = project_id
self.payload = payload
self.catalog_name = get_db_name(project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME)
self.catalog_name = get_db_name(
project_id=project_id, database=DatabricksConfig.DATABRICKS_CATALOG_NAME
)
self.job_manager = DatabricksJobManager(
databricks_host=payload.databricks_host,
access_token=payload.databricks_access_token
access_token=payload.databricks_access_token,
)
def upload_instances_to_unity_catalog(self):
job_id = databricks_details_db.hget(self.project_id, "metadata_ingestion_job")
async def upload_instances_to_unity_catalog(self):
if self.payload.action_type == "delete":
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_DELETION_JOB_NAME
)
else:
job_id = databricks_details_db.hget(
self.project_id, DatabricksConstants.METADATA_INGESTION_JOB_NAME
)
if not job_id:
raise ValueError("No job id found for metadata ingestion job, skipping upload to unity catalog")
run_id = self.job_manager.run_job(job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())})
raise ValueError(
f"No job id found for {self.payload.action_type}, skipping upload to unity catalog"
)
run_id = self.job_manager.run_job(
job_id=job_id,
parameters={"input_message": json.dumps(self.get_job_trigger_payload())},
)
if not run_id:
raise ValueError("Failed to run metadata ingestion job, skipping upload to unity catalog")
raise ValueError(
"Failed to run metadata ingestion job, skipping upload to unity catalog"
)
def get_job_trigger_payload(self):
table_name = self.payload.data[0]['type']
table_name = self.payload.node_type
schema_table = f"{DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME}.{table_name}"
return {
"table_properties": {
"table_name": f'{self.catalog_name}.{schema_table}',
"table_path": f'{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}',
"table_name": f"{self.catalog_name}.{schema_table}",
"table_path": f"{self.payload.databricks_storage_path}/{self.catalog_name}/DELTA/{schema_table}",
},
"project_id": self.project_id,
"data": self.payload.data
"data": self.payload.data,
}
from typing import Dict, List
from sqlalchemy import Table, Column, String, BigInteger, DateTime, MetaData, Integer, Date
from sqlalchemy import (
BigInteger,
Column,
Date,
DateTime,
Integer,
MetaData,
String,
Table,
)
from scripts.utils.databricks_utils import DatabricksSQLUtility
from scripts.utils.model_convertor_utils import TypeMapper
......@@ -11,11 +20,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
super().__init__(catalog_name, project_id)
self.schema = schema
def create_external_table_from_structure(self, table: Table,
def create_external_table_from_structure(
self,
table: Table,
external_location: str,
file_format: str = "PARQUET",
table_properties: Dict[str, str] = None,
partition_columns: list = None) -> str:
partition_columns: list = None,
) -> str:
"""
Create an external table from a model class.
......@@ -31,12 +43,14 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
schema_table = f"{table.schema}.{table.name}" if table.schema else table.name
columns_sql = TypeMapper().extract_columns_without_constraints(table)
external_location = f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
external_location = (
f"{external_location}/{self.catalog_name}/{file_format}/{schema_table}"
)
sql_parts = [
f"CREATE TABLE IF NOT EXISTS {schema_table}",
f"({columns_sql})",
f"USING {file_format}",
f"LOCATION '{external_location}'"
f"LOCATION '{external_location}'",
]
if partition_columns:
partition_clause = ", ".join(partition_columns)
......@@ -64,33 +78,32 @@ class DataBricksSQLLayer(DatabricksSQLUtility):
"""
table_columns = [
Column('timestamp', BigInteger, nullable=False),
Column('dt_timestamp', DateTime, nullable=False),
Column('dt_date', Date, nullable=False),
Column('dt_hour', Integer, nullable=False),
Column('value', String, nullable=False),
Column('value_type', String, nullable=False, default='float'),
Column("c3", String, nullable=False)
Column("timestamp", BigInteger, nullable=False),
Column("dt_timestamp", DateTime, nullable=False),
Column("dt_date", Date, nullable=False),
Column("dt_hour", Integer, nullable=False),
Column("value", String, nullable=False),
Column("value_type", String, nullable=False, default="float"),
Column("c3", String, nullable=False),
]
default_columns = ["c1", "c5", "Q", "T", "D", "P", "A", "B", *columns]
table_columns.extend([Column(col_name, String, nullable=True) for col_name in default_columns])
partition_columns = ['dt_date', 'dt_hour', 'c3']
table_columns.extend(
[Column(col_name, String, nullable=True) for col_name in default_columns]
)
partition_columns = ["dt_date", "dt_hour", "c3"]
table_properties = {
"parquet.compression": "snappy", # Fast decompression for frequent queries
"parquet.page.size": "524288", # 512KB - better time-range filtering
"parquet.block.size": "268435456", # 256MB - efficient sequential reads
"serialization.format": "1" # Support for arrays/complex types
"serialization.format": "1", # Support for arrays/complex types
}
table_obj = Table(
"timeseries_data",
MetaData(),
*table_columns,
schema=self.schema
"timeseries_data", MetaData(), *table_columns, schema=self.schema
)
self.create_external_table_from_structure(
table=table_obj,
external_location=external_location,
partition_columns=partition_columns,
table_properties=table_properties
table_properties=table_properties,
)
return external_location
......@@ -14,10 +14,14 @@ class DatabricksJobManager:
databricks_host: Your Databricks workspace URL
access_token: Personal access token or service principal token
"""
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}"
self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def create_job(self, job_config: dict):
......@@ -32,11 +36,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=job_config)
if response.status_code == 200:
job_id = response.json()['job_id']
job_id = response.json()["job_id"]
logging.info(f"Job created successfully with ID: {job_id}")
return job_id
else:
logging.error(f"Error creating job: {response.status_code} - {response.text}")
logging.error(
f"Error creating job: {response.status_code} - {response.text}"
)
return None
def run_job(self, job_id: str, parameters=None):
......@@ -57,11 +63,13 @@ class DatabricksJobManager:
response = HTTPXRequestUtil(url).post(headers=self.headers, json=payload)
if response.status_code == 200:
run_id = response.json()['run_id']
run_id = response.json()["run_id"]
logging.info(f"Job run started with ID: {run_id}")
return run_id
else:
logging.error(f"Error running job: {response.status_code} - {response.text}")
logging.error(
f"Error running job: {response.status_code} - {response.text}"
)
return None
def get_run_status(self, run_id):
......@@ -73,12 +81,16 @@ class DatabricksJobManager:
url = f"{self.host}/api/2.1/jobs/runs/get"
params = {"run_id": run_id}
response = HTTPXRequestHandler(url).get(url, headers=self.headers, params=params)
response = HTTPXRequestHandler(url).get(
url, headers=self.headers, params=params
)
if response.status_code == 200:
return response.json()
else:
logging.error(f"Error getting run status: {response.status_code} - {response.text}")
logging.error(
f"Error getting run status: {response.status_code} - {response.text}"
)
return None
@staticmethod
......@@ -98,16 +110,18 @@ class DatabricksJobManager:
"task_key": "table_update_task",
"notebook_task": {
"notebook_path": notebook_path,
"base_parameters": {
"input_message": "default_value"
}
"base_parameters": {"input_message": "default_value"},
},
"timeout_seconds": 3600
"timeout_seconds": 3600,
}
],
"max_concurrent_runs": 10,
"tags": {
"purpose": "metadata_ingestion",
"compute_type": "serverless"
}
"purpose": (
"metadata_ingestion"
if "ingestion" in job_name
else "metadata_deletion"
),
"compute_type": "serverless",
},
}
......@@ -13,13 +13,19 @@ class NotebookManager:
databricks_host: Your Databricks workspace URL (e.g., 'https://your-workspace.cloud.databricks.com')
access_token: Personal access token or service principal token
"""
self.host = databricks_host if "https://" in databricks_host else f"https://{databricks_host}"
self.host = (
databricks_host
if "https://" in databricks_host
else f"https://{databricks_host}"
)
self.headers = {
'Authorization': f'Bearer {access_token}',
'Content-Type': 'application/json'
"Authorization": f"Bearer {access_token}",
"Content-Type": "application/json",
}
def create_notebook(self, notebook_path, notebook_code: str, language='PYTHON', overwrite=True):
def create_notebook(
self, notebook_path, notebook_code: str, language="PYTHON", overwrite=True
):
"""
Create a notebook in Databricks workspace
......@@ -31,18 +37,22 @@ class NotebookManager:
"""
url = f"{self.host}/api/2.0/workspace/import"
# Encode the notebook content in base64
encoded_content = base64.b64encode(notebook_code.encode('utf-8')).decode('utf-8')
encoded_content = base64.b64encode(notebook_code.encode("utf-8")).decode(
"utf-8"
)
payload = {
"path": notebook_path,
"format": "SOURCE",
"language": language,
"content": encoded_content,
"overwrite": overwrite
"overwrite": overwrite,
}
response = HTTPXRequestUtil(url=url).post(json=payload, headers=self.headers)
if response.status_code == 200:
logging.info(f"Notebook created successfully at: {notebook_path}")
return True
else:
logging.error(f"Error creating notebook: {response.status_code} - {response.text}")
logging.error(
f"Error creating notebook: {response.status_code} - {response.text}"
)
return False
import orjson
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
databricks_details_db = redis_connector.connect(db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True)
\ No newline at end of file
databricks_details_db = redis_connector.connect(
db=RedisConfig.REDIS_DATABRICKS_DB, decode_responses=True
)
......@@ -9,7 +9,9 @@ from ut_sql_utils.config import PostgresConfig
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
graphql_details_db = redis_connector.connect(db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True)
graphql_details_db = redis_connector.connect(
db=RedisConfig.REDIS_GRAPHQL_DB, decode_responses=True
)
def get_models(
......@@ -38,7 +40,9 @@ def get_models(
"""
tables_data = graphql_details_db.hget(info.data["project_id"], "schema_mapper")
if tables_data is None:
raise ILensErrors(f"No GraphQL schema data found for project {info.data['project_id']}")
raise ILensErrors(
f"No GraphQL schema data found for project {info.data['project_id']}"
)
tables: Dict[str, Any] = orjson.loads(tables_data) or {}
if (
......
......@@ -3,7 +3,9 @@ import orjson
from scripts.config import RedisConfig
from scripts.db.redis import redis_connector
project_details_db = redis_connector.connect(db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True)
project_details_db = redis_connector.connect(
db=RedisConfig.REDIS_PROJECT_TAGS_DB, decode_responses=True
)
def get_project_time_zone(project_id: str):
......@@ -19,7 +21,9 @@ def get_project_time_zone(project_id: str):
return "UTC"
def fetch_level_details(project_id: str, keys: bool = False, raw: bool = False) -> dict | list:
def fetch_level_details(
project_id: str, keys: bool = False, raw: bool = False
) -> dict | list:
"""
Function to fetch level details from project details
Uses redis project details cache db (db18) and fetches the level details
......@@ -60,7 +64,9 @@ def fetch_asset_level(project_id: str) -> str:
project_details = orjson.loads(project_details)
counter_levels = project_details.get("counter_levels", {})
asset_level = (
counter_levels.get("asset", counter_levels.get("equipment")) if isinstance(counter_levels, dict) else None
counter_levels.get("asset", counter_levels.get("equipment"))
if isinstance(counter_levels, dict)
else None
)
if asset_level:
return asset_level
......@@ -81,6 +87,7 @@ def fetch_asset_level_with_mapping(project_id: str) -> tuple[str, str]:
swapped_dict = {v: k for k, v in counter_levels.items()}
return swapped_dict.get("ast", ""), "ast"
def project_template_keys(project_id: str, levels=False):
val = project_details_db.get(project_id)
if val is None:
......
......@@ -2,4 +2,7 @@ from faststream.confluent import KafkaBroker
from scripts.config import KafkaConfig
broker = KafkaBroker(f'{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}', client_id="model_creator_agent")
broker = KafkaBroker(
f"{KafkaConfig.KAFKA_HOST}:{KafkaConfig.KAFKA_PORT}",
client_id="model_creator_agent",
)
......@@ -7,8 +7,7 @@ from scripts.schemas import ModelCreatorSchema, ModelInstanceSchema
class ModelCreatorAgent:
def __init__(self):
...
def __init__(self): ...
@staticmethod
async def model_creator_agent(message: ModelCreatorSchema):
......@@ -18,10 +17,14 @@ class ModelCreatorAgent:
session_manager=session_manager,
schema=message.schema,
)
model_cal_obj = ModelCreatorHandler(message=message, declarative_utils=declarative_utils)
model_cal_obj = ModelCreatorHandler(
message=message, declarative_utils=declarative_utils
)
await model_cal_obj.create_models_in_unity_catalog()
@staticmethod
async def model_instance_agent(message: ModelInstanceSchema):
model_instance_obj = ModelInstanceHandler(project_id=message.project_id, payload=message)
model_instance_obj = ModelInstanceHandler(
project_id=message.project_id, payload=message
)
await model_instance_obj.upload_instances_to_unity_catalog()
from typing import Optional, Union, Dict, Any, List
from typing import Any, Dict, List, Optional, Union
from pydantic import BaseModel, model_validator
from pydantic import BaseModel, Field, model_validator
from ut_security_util import MetaInfoSchema
from scripts.config import DatabricksConfig
......@@ -20,7 +20,11 @@ class ModelCreatorSchema(BaseModel):
class ModelInstanceSchema(BaseModel):
data: Union[Dict[str, Any], List[Dict[str, Any]]]
project_id: str
schema: Optional[str] = DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME
action_type: str = "save"
node_type: str
sql_schema: Optional[str] = Field(
default=DatabricksConfig.DATABRICKS_PUBLIC_SCHEMA_NAME, alias="schema"
)
databricks_host: str = DatabricksConfig.DATABRICKS_HOST
databricks_port: int = DatabricksConfig.DATABRICKS_PORT
databricks_access_token: str = DatabricksConfig.DATABRICKS_ACCESS_TOKEN
......@@ -30,6 +34,6 @@ class ModelInstanceSchema(BaseModel):
@model_validator(mode="before")
def validate_data(cls, values: Dict[str, Any]) -> Dict[str, Any]:
if 'data' in values and isinstance(values['data'], dict):
values['data'] = [values['data']]
if "data" in values and isinstance(values["data"], dict):
values["data"] = [values["data"]]
return values
......@@ -29,13 +29,17 @@ class DatabricksSQLUtility:
DatabricksConfig.DATABRICKS_URI,
pool_pre_ping=True,
pool_recycle=3600,
echo=False
echo=False,
)
# Test connection
with self.engine.connect() as conn:
result = conn.execute(text("SELECT current_user() as user, current_catalog() as catalog"))
result = conn.execute(
text("SELECT current_user() as user, current_catalog() as catalog")
)
user_info = result.fetchone()
logger.info(f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}")
logger.info(
f"Connected as user: {user_info[0]}, current catalog: {user_info[1]}"
)
logger.info("Successfully connected to Databricks")
return True
except Exception as e:
......@@ -46,8 +50,12 @@ class DatabricksSQLUtility:
if self.engine:
self.engine.dispose()
def create_catalog(self, managed_location: Optional[str] = None, comment: Optional[str] = None,
properties: Optional[dict] = None):
def create_catalog(
self,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
"""
Create a new catalog in Unity Catalog
Args:
......@@ -77,8 +85,13 @@ class DatabricksSQLUtility:
logger.error(f"Failed to create catalog '{self.catalog_name}': {str(e)}")
raise
def create_schema(self, schema_name: str, managed_location: Optional[str] = None, comment: Optional[str] = None,
properties: Optional[dict] = None):
def create_schema(
self,
schema_name: str,
managed_location: Optional[str] = None,
comment: Optional[str] = None,
properties: Optional[dict] = None,
):
"""
Create a new schema within a catalog
Args:
......@@ -103,10 +116,14 @@ class DatabricksSQLUtility:
ddl += f"\nWITH DBPROPERTIES ({props})"
self.execute_sql_statement(ddl)
logger.info(f"Schema '{self.catalog_name}.{schema_name}' created successfully")
logger.info(
f"Schema '{self.catalog_name}.{schema_name}' created successfully"
)
return full_schema_name
except Exception as e:
logger.error(f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}")
logger.error(
f"Failed to create schema '{self.catalog_name}.{schema_name}': {str(e)}"
)
raise
def create_external_location(
......@@ -114,7 +131,7 @@ class DatabricksSQLUtility:
location_name: str,
storage_path: str,
credential_name: str,
comment: Optional[str] = None
comment: Optional[str] = None,
) -> str:
"""
Create an external location in Unity Catalog
......@@ -138,7 +155,9 @@ class DatabricksSQLUtility:
logger.info(f"External location '{location_name}' created successfully")
return location_name
except Exception as e:
logger.error(f"Failed to create external location '{location_name}': {str(e)}")
logger.error(
f"Failed to create external location '{location_name}': {str(e)}"
)
raise
def execute_sql_statement(self, query: str):
......
......@@ -18,7 +18,11 @@ class HTTPXRequestUtil:
def delete(self, path="", params=None, **kwargs) -> httpx.Response:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.delete(url=url, params=params)
return response
......@@ -27,7 +31,11 @@ class HTTPXRequestUtil:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.put(url=url, data=data, json=json)
return response
......@@ -42,7 +50,11 @@ class HTTPXRequestUtil:
"""
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.post(url=url, data=data, json=json)
return response
......@@ -52,7 +64,11 @@ class HTTPXRequestUtil:
url = self.get_url(path)
logging.info(url)
with httpx.Client(verify=self.verify, headers=kwargs.get("headers"), cookies=kwargs.get("cookies")) as client:
with httpx.Client(
verify=self.verify,
headers=kwargs.get("headers"),
cookies=kwargs.get("cookies"),
) as client:
response: httpx.Response = client.get(url=url, params=params)
return response
......
This diff is collapsed.
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment