""" Utility functions and base classes for asset management operations. This module provides common functionality shared across asset management modules, including path operations, content hashing, validation, and base classes. """ import hashlib import logging import time from abc import ABC, abstractmethod from pathlib import Path from typing import Optional, Union, List, Dict, Any, Protocol, runtime_checkable from dataclasses import dataclass, field from concurrent.futures import ThreadPoolExecutor logger = logging.getLogger('markitect.assets.utils') class PathUtils: """Utilities for path operations and normalization.""" @staticmethod def normalize_path(path_input: Union[str, Path]) -> Path: """Normalize path strings to Path objects with consistent separators.""" if isinstance(path_input, str): # Replace Windows-style backslashes with forward slashes normalized_str = path_input.replace("\\", "/") return Path(normalized_str) return path_input @staticmethod def ensure_path_exists(path: Path, create_parents: bool = True) -> None: """Ensure a directory path exists, creating it if necessary.""" if create_parents: path.mkdir(parents=True, exist_ok=True) else: path.mkdir(exist_ok=True) @staticmethod def get_relative_path(target: Path, base: Path) -> Path: """Get relative path from base to target, handling cross-platform issues.""" try: return target.relative_to(base) except ValueError: # Paths are not related, return absolute path return target.resolve() @staticmethod def is_safe_path(path: Path, base_path: Path) -> bool: """Check if path is safe (doesn't escape base directory).""" try: resolved_path = (base_path / path).resolve() resolved_base = base_path.resolve() return resolved_path.is_relative_to(resolved_base) except (ValueError, OSError): return False class ContentHasher: """Utilities for content hashing and verification.""" @staticmethod def hash_content(content: bytes, algorithm: str = 'sha256') -> str: """Generate content hash using specified algorithm.""" hasher = hashlib.new(algorithm) hasher.update(content) return hasher.hexdigest() @staticmethod def hash_file(file_path: Path, algorithm: str = 'sha256', chunk_size: int = 8192) -> str: """Generate content hash for a file.""" hasher = hashlib.new(algorithm) with open(file_path, 'rb') as f: while chunk := f.read(chunk_size): hasher.update(chunk) return hasher.hexdigest() @staticmethod def verify_file_integrity(file_path: Path, expected_hash: str, algorithm: str = 'sha256') -> bool: """Verify file integrity against expected hash.""" try: actual_hash = ContentHasher.hash_file(file_path, algorithm) return actual_hash == expected_hash except Exception as e: logger.warning(f"Failed to verify file integrity for {file_path}: {e}") return False @runtime_checkable class ProgressReporter(Protocol): """Protocol for progress reporting interfaces.""" def start(self, total_items: int) -> None: """Start progress tracking.""" ... def update(self, current: int, item_name: str = "") -> None: """Update progress.""" ... def finish(self) -> None: """Finish progress tracking.""" ... @dataclass class BaseResult: """Base class for operation results with common fields.""" # Using field() to handle inheritance with required fields success: bool = field(default=True) error: Optional[Exception] = field(default=None) processing_time: float = field(default=0.0) def __post_init__(self): """Post-initialization validation.""" if self.error is not None and self.success: self.success = False class TimedOperation: """Context manager for timing operations.""" def __init__(self, operation_name: str = "operation"): self.operation_name = operation_name self.start_time = 0.0 self.end_time = 0.0 def __enter__(self): self.start_time = time.time() logger.debug(f"Starting {self.operation_name}") return self def __exit__(self, exc_type, exc_val, exc_tb): self.end_time = time.time() duration = self.elapsed_time if exc_type is None: logger.debug(f"Completed {self.operation_name} in {duration:.3f}s") else: logger.error(f"Failed {self.operation_name} after {duration:.3f}s: {exc_val}") @property def elapsed_time(self) -> float: """Get elapsed time in seconds.""" if self.end_time > 0: return self.end_time - self.start_time return time.time() - self.start_time if self.start_time > 0 else 0.0 class BatchProcessor: """Base class for batch processing operations.""" def __init__(self, max_concurrent: int = 4, chunk_size: int = 50): self.max_concurrent = max_concurrent self.chunk_size = chunk_size self.logger = logging.getLogger(f'{__name__}.{self.__class__.__name__}') def process_batch(self, items: List[Any], processor_func, progress_reporter: Optional[ProgressReporter] = None) -> List[Any]: """Process items in batches with optional progress reporting.""" results = [] if progress_reporter: progress_reporter.start(len(items)) with ThreadPoolExecutor(max_workers=self.max_concurrent) as executor: # Process in chunks to avoid overwhelming the system for i in range(0, len(items), self.chunk_size): chunk = items[i:i + self.chunk_size] # Submit chunk for processing futures = [executor.submit(processor_func, item) for item in chunk] # Collect results for j, future in enumerate(futures): try: result = future.result() results.append(result) if progress_reporter: progress_reporter.update(len(results), str(chunk[j])) except Exception as e: self.logger.error(f"Failed to process item {chunk[j]}: {e}") results.append(self._create_error_result(chunk[j], e)) if progress_reporter: progress_reporter.finish() return results def _create_error_result(self, item: Any, error: Exception) -> BaseResult: """Create error result for failed processing.""" return BaseResult(success=False, error=error) class ConfigurationValidator: """Utilities for configuration validation.""" @staticmethod def validate_path_config(config: Dict[str, Any], key: str, default: Optional[Path] = None) -> Path: """Validate and normalize path configuration.""" if key not in config: if default is None: raise ValueError(f"Required configuration key '{key}' not found") return default path_value = config[key] if isinstance(path_value, str): return PathUtils.normalize_path(path_value) elif isinstance(path_value, Path): return path_value else: raise ValueError(f"Configuration key '{key}' must be a string or Path, got {type(path_value)}") @staticmethod def validate_int_range(config: Dict[str, Any], key: str, min_val: int, max_val: int, default: int) -> int: """Validate integer configuration within range.""" value = config.get(key, default) if not isinstance(value, int): raise ValueError(f"Configuration key '{key}' must be an integer, got {type(value)}") if not (min_val <= value <= max_val): raise ValueError(f"Configuration key '{key}' must be between {min_val} and {max_val}, got {value}") return value @staticmethod def validate_boolean(config: Dict[str, Any], key: str, default: bool) -> bool: """Validate boolean configuration.""" value = config.get(key, default) if not isinstance(value, bool): raise ValueError(f"Configuration key '{key}' must be a boolean, got {type(value)}") return value class MemoryCache: """Simple in-memory cache with TTL support.""" def __init__(self, default_ttl: float = 300.0): # 5 minutes default self.default_ttl = default_ttl self._cache: Dict[str, tuple] = {} # key -> (value, expiry_time) def get(self, key: str) -> Optional[Any]: """Get value from cache if not expired.""" if key not in self._cache: return None value, expiry = self._cache[key] if time.time() > expiry: del self._cache[key] return None return value def set(self, key: str, value: Any, ttl: Optional[float] = None) -> None: """Set value in cache with TTL.""" ttl = ttl or self.default_ttl expiry = time.time() + ttl self._cache[key] = (value, expiry) def clear(self) -> None: """Clear all cached values.""" self._cache.clear() def size(self) -> int: """Get current cache size.""" # Clean expired entries first current_time = time.time() expired_keys = [k for k, (_, expiry) in self._cache.items() if current_time > expiry] for key in expired_keys: del self._cache[key] return len(self._cache) class FileValidator: """Utilities for file validation and safety checks.""" SAFE_EXTENSIONS = { '.md', '.mdx', '.txt', '.json', '.yaml', '.yml', '.png', '.jpg', '.jpeg', '.gif', '.svg', '.webp', '.pdf', '.zip', '.tar', '.gz' } @staticmethod def is_safe_file_type(file_path: Path) -> bool: """Check if file type is considered safe.""" return file_path.suffix.lower() in FileValidator.SAFE_EXTENSIONS @staticmethod def validate_file_size(file_path: Path, max_size_bytes: int = 100 * 1024 * 1024) -> bool: """Validate file size is within acceptable limits.""" try: return file_path.stat().st_size <= max_size_bytes except OSError: return False @staticmethod def is_readable_file(file_path: Path) -> bool: """Check if file exists and is readable.""" return file_path.exists() and file_path.is_file() and file_path.stat().st_mode & 0o444