feat: reorganize tests by capability with separate test targets
Some checks failed
Test Suite / unit-tests (3.11) (push) Has been cancelled
Test Suite / unit-tests (3.12) (push) Has been cancelled
Test Suite / integration-tests (push) Has been cancelled
Test Suite / e2e-tests (push) Has been cancelled
Test Suite / performance-tests (push) Has been cancelled
Test Suite / code-quality (push) Has been cancelled
Test Suite / security-scan (push) Has been cancelled
Test Suite / test-summary (push) Has been cancelled
Some checks failed
Test Suite / unit-tests (3.11) (push) Has been cancelled
Test Suite / unit-tests (3.12) (push) Has been cancelled
Test Suite / integration-tests (push) Has been cancelled
Test Suite / e2e-tests (push) Has been cancelled
Test Suite / performance-tests (push) Has been cancelled
Test Suite / code-quality (push) Has been cancelled
Test Suite / security-scan (push) Has been cancelled
Test Suite / test-summary (push) Has been cancelled
Separate capability-specific tests from core system tests to establish clear test organization and separation of concerns. ## Test Reorganization: - **markitect-content tests**: Moved 6 tests to capabilities/markitect-content/tests/ - **markitect-finance tests**: Moved 7 tests to markitect/finance/tests/ - **markitect-query tests**: Moved 1 test to markitect/query_paradigms/tests/ - **markitect-graphql tests**: Moved 2 tests to markitect/graphql/tests/ - **markitect-plugins tests**: Moved 2 tests to markitect/plugins/tests/ ## Makefile Updates: - **make test**: Excludes capability tests, runs only core system tests - **make test-capabilities**: Runs all capability tests - **make test-capability-***: Individual capability test targets - Updated all test targets (test-red, test-green, test-ultra-fast, test-perf) - Added capability test targets to help documentation ## Benefits: - Clear separation between core system tests and capability-specific tests - Faster core test execution (capability tests not run by default) - Individual capability testing for focused development - Supports future capability extraction workflow - Maintains capability test independence Test verification: - Core tests: 1291 tests (capability tests excluded) - Finance capability: 143 tests working independently - Content capability: 79 tests working independently 🤖 Generated with [Claude Code](https://claude.ai/code) Co-Authored-By: Claude <noreply@anthropic.com>
This commit is contained in:
0
markitect/finance/tests/__init__.py
Normal file
0
markitect/finance/tests/__init__.py
Normal file
393
markitect/finance/tests/test_cost_cli_commands.py
Normal file
393
markitect/finance/tests/test_cost_cli_commands.py
Normal file
@@ -0,0 +1,393 @@
|
||||
"""
|
||||
Tests for MarkiTect cost tracking CLI commands.
|
||||
|
||||
This module tests the command-line interface for cost management including:
|
||||
- Cost report generation commands
|
||||
- Cost item management commands
|
||||
- Category management commands
|
||||
- Period cost calculations
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from click.testing import CliRunner
|
||||
|
||||
from markitect.finance.cli import cost_commands
|
||||
from markitect.finance.cost_manager import CostItemManager, CostItem
|
||||
from markitect.finance.models import FinanceModels
|
||||
|
||||
|
||||
class TestCostCLICommands:
|
||||
"""Test suite for cost tracking CLI commands."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create temporary database for testing."""
|
||||
fd, path = tempfile.mkstemp(suffix='.db')
|
||||
os.close(fd)
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_data(self, temp_db):
|
||||
"""Setup test database with sample cost data."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
|
||||
# Get categories
|
||||
infra_cat = cost_manager.get_category_by_name('Infrastructure')
|
||||
software_cat = cost_manager.get_category_by_name('Software')
|
||||
|
||||
# Create sample cost items
|
||||
cost_items = [
|
||||
CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Test Server',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('25.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=software_cat['id'],
|
||||
name='Test Software',
|
||||
cost_type='one_time',
|
||||
amount_eur=Decimal('50.00'),
|
||||
starting_from_date=date(2025, 1, 15)
|
||||
)
|
||||
]
|
||||
|
||||
for item in cost_items:
|
||||
cost_manager.create_cost_item(item)
|
||||
|
||||
return temp_db
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create Click test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_cost_report_generate_summary(self, runner, setup_test_data):
|
||||
"""Test cost report generate command with summary format."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'generate',
|
||||
'--period', '2025-01',
|
||||
'--format', 'summary',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Cost Summary Report - January 2025" in result.output
|
||||
assert "€75.00" in result.output # 25 + 50
|
||||
assert "frontmatter" not in result.output.lower() # Should be properly formatted
|
||||
|
||||
def test_cost_report_generate_detailed(self, runner, setup_test_data):
|
||||
"""Test cost report generate command with detailed format."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'generate',
|
||||
'--period', '2025-01',
|
||||
'--format', 'detailed',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Detailed Cost Report - January 2025" in result.output
|
||||
assert "Infrastructure" in result.output
|
||||
assert "Software" in result.output
|
||||
assert "Test Server" in result.output
|
||||
assert "Test Software" in result.output
|
||||
|
||||
def test_cost_report_generate_audit(self, runner, setup_test_data):
|
||||
"""Test cost report generate command with audit format."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'generate',
|
||||
'--period', '2025-01',
|
||||
'--format', 'audit',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Cost Audit Report - January 2025" in result.output
|
||||
assert "Audit Summary" in result.output
|
||||
assert "Transaction History" in result.output
|
||||
|
||||
def test_cost_report_generate_with_output_file(self, runner, setup_test_data):
|
||||
"""Test saving report to output file."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.md') as f:
|
||||
output_path = f.name
|
||||
|
||||
try:
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'generate',
|
||||
'--period', '2025-01',
|
||||
'--output', output_path,
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"Report saved to: {output_path}" in result.output
|
||||
|
||||
# Verify file was created
|
||||
assert os.path.exists(output_path)
|
||||
with open(output_path, 'r') as f:
|
||||
content = f.read()
|
||||
assert "Cost Summary Report" in content
|
||||
|
||||
finally:
|
||||
if os.path.exists(output_path):
|
||||
os.unlink(output_path)
|
||||
|
||||
def test_cost_report_generate_invalid_period(self, runner, setup_test_data):
|
||||
"""Test report generation with invalid period format."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'generate',
|
||||
'--period', 'invalid-period',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Period must be in YYYY-MM format" in result.output
|
||||
|
||||
def test_cost_report_generate_default_database(self, runner):
|
||||
"""Test report generation with default database path from config."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'generate',
|
||||
'--period', '2025-01'
|
||||
])
|
||||
|
||||
# Should succeed with default config and empty database
|
||||
assert result.exit_code == 0
|
||||
assert "Cost Summary Report - January 2025" in result.output
|
||||
assert "€0.00" in result.output # Empty database shows zero costs
|
||||
|
||||
def test_cost_report_template_show(self, runner):
|
||||
"""Test cost report template show command."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'template', '--show'
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Summary Report Template" in result.output
|
||||
assert "Description" in result.output
|
||||
assert "Frontmatter Fields" in result.output
|
||||
|
||||
def test_cost_report_template_different_formats(self, runner):
|
||||
"""Test template show for different formats."""
|
||||
formats = ['summary', 'detailed', 'audit']
|
||||
|
||||
for format_type in formats:
|
||||
result = runner.invoke(cost_commands, [
|
||||
'report', 'template', '--show', '--format', format_type
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"{format_type.title()} Report Template" in result.output
|
||||
|
||||
def test_cost_item_add(self, runner, temp_db):
|
||||
"""Test adding new cost item via CLI."""
|
||||
# Initialize database
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'add', 'Test Item',
|
||||
'--category', 'Infrastructure',
|
||||
'--amount', '15.50',
|
||||
'--type', 'monthly',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Created cost item 'Test Item'" in result.output
|
||||
|
||||
# Verify item was created
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
items = cost_manager.list_cost_items()
|
||||
assert len(items) == 1
|
||||
assert items[0]['name'] == 'Test Item'
|
||||
assert float(items[0]['amount_eur']) == 15.50
|
||||
|
||||
def test_cost_item_add_with_description_and_date(self, runner, temp_db):
|
||||
"""Test adding cost item with description and start date."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'add', 'Test Item',
|
||||
'--category', 'Software',
|
||||
'--amount', '99.99',
|
||||
'--type', 'one_time',
|
||||
'--description', 'Test description',
|
||||
'--start-date', '2025-01-15',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Created cost item 'Test Item'" in result.output
|
||||
|
||||
def test_cost_item_add_invalid_category(self, runner, temp_db):
|
||||
"""Test adding item with non-existent category."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'add', 'Test Item',
|
||||
'--category', 'NonExistent',
|
||||
'--amount', '10.00',
|
||||
'--type', 'monthly',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Category 'NonExistent' not found" in result.output
|
||||
assert "Available categories:" in result.output
|
||||
|
||||
def test_cost_item_list(self, runner, setup_test_data):
|
||||
"""Test listing cost items."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'list',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Test Server" in result.output
|
||||
assert "Test Software" in result.output
|
||||
assert "€25.00" in result.output
|
||||
assert "€50.00" in result.output
|
||||
|
||||
def test_cost_item_list_with_filters(self, runner, setup_test_data):
|
||||
"""Test listing cost items with filters."""
|
||||
# Filter by category
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'list',
|
||||
'--category', 'Infrastructure',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Test Server" in result.output
|
||||
assert "Test Software" not in result.output
|
||||
|
||||
# Filter by type
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'list',
|
||||
'--type', 'monthly',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Test Server" in result.output
|
||||
assert "Test Software" not in result.output
|
||||
|
||||
def test_cost_category_list(self, runner, setup_test_data):
|
||||
"""Test listing cost categories."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'category', 'list',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Infrastructure" in result.output
|
||||
assert "Software" in result.output
|
||||
assert "Total: 8 categories" in result.output # Default categories
|
||||
|
||||
def test_cost_category_add(self, runner, temp_db):
|
||||
"""Test adding new cost category."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'category', 'add', 'Custom Category',
|
||||
'--description', 'Custom test category',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Created category 'Custom Category'" in result.output
|
||||
|
||||
# Verify category was created
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
categories = cost_manager.list_categories()
|
||||
category_names = [cat['name'] for cat in categories]
|
||||
assert 'Custom Category' in category_names
|
||||
|
||||
def test_cost_calculate(self, runner, setup_test_data):
|
||||
"""Test cost calculation command."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'calculate',
|
||||
'--period', '2025-01',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Cost Calculation - January 2025" in result.output
|
||||
assert "Monthly Recurring: €25.00" in result.output
|
||||
assert "One-time Expenses: €50.00" in result.output
|
||||
assert "Total Period Cost: €75.00" in result.output
|
||||
assert "Active Cost Items: 2" in result.output
|
||||
|
||||
def test_cost_calculate_current_month(self, runner, setup_test_data):
|
||||
"""Test cost calculation for current month (default)."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'calculate',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Cost Calculation" in result.output
|
||||
# Should default to current month
|
||||
|
||||
def test_cost_calculate_invalid_period(self, runner, setup_test_data):
|
||||
"""Test cost calculation with invalid period."""
|
||||
result = runner.invoke(cost_commands, [
|
||||
'calculate',
|
||||
'--period', 'invalid',
|
||||
'--database', setup_test_data
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Period must be in YYYY-MM format" in result.output
|
||||
|
||||
def test_cost_item_add_invalid_date_format(self, runner, temp_db):
|
||||
"""Test adding item with invalid date format."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'item', 'add', 'Test Item',
|
||||
'--category', 'Infrastructure',
|
||||
'--amount', '10.00',
|
||||
'--type', 'monthly',
|
||||
'--start-date', 'invalid-date',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Start date must be in YYYY-MM-DD format" in result.output
|
||||
|
||||
def test_help_commands(self, runner):
|
||||
"""Test help output for cost commands."""
|
||||
# Test main cost help
|
||||
result = runner.invoke(cost_commands, ['--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Cost tracking and financial reporting commands" in result.output
|
||||
|
||||
# Test report help
|
||||
result = runner.invoke(cost_commands, ['report', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Generate cost reports" in result.output
|
||||
|
||||
# Test item help
|
||||
result = runner.invoke(cost_commands, ['item', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Manage cost items" in result.output
|
||||
|
||||
# Test category help
|
||||
result = runner.invoke(cost_commands, ['category', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Manage cost categories" in result.output
|
||||
398
markitect/finance/tests/test_cost_manager.py
Normal file
398
markitect/finance/tests/test_cost_manager.py
Normal file
@@ -0,0 +1,398 @@
|
||||
"""
|
||||
Tests for MarkiTect cost item management system.
|
||||
|
||||
This module tests the complete cost item management functionality including:
|
||||
- Cost item lifecycle (create, update, deactivate)
|
||||
- Category management
|
||||
- Business rule validation
|
||||
- Period-based cost calculations
|
||||
- Integration with database models
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from markitect.finance.cost_manager import CostItemManager, CostItem, CostCategory
|
||||
from markitect.finance.models import FinanceModels
|
||||
|
||||
|
||||
class TestCostItemManager:
|
||||
"""Test suite for cost item management system."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create temporary database for testing."""
|
||||
fd, path = tempfile.mkstemp(suffix='.db')
|
||||
os.close(fd)
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.fixture
|
||||
def cost_manager(self, temp_db):
|
||||
"""Create CostItemManager instance with initialized database."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
return CostItemManager(temp_db)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_category_id(self, cost_manager):
|
||||
"""Create a sample category for testing."""
|
||||
return cost_manager.create_category("Test Category", "For testing purposes")
|
||||
|
||||
def test_create_cost_item_valid(self, cost_manager, sample_category_id):
|
||||
"""Test creating a valid cost item."""
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Test Server",
|
||||
description="Monthly hosting",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('25.50'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
|
||||
cost_item_id = cost_manager.create_cost_item(cost_item)
|
||||
assert cost_item_id is not None
|
||||
|
||||
# Verify item was created
|
||||
retrieved = cost_manager.get_cost_item(cost_item_id)
|
||||
assert retrieved['name'] == "Test Server"
|
||||
assert float(retrieved['amount_eur']) == 25.50
|
||||
assert retrieved['cost_type'] == "monthly"
|
||||
assert retrieved['is_active'] == 1 # SQLite stores booleans as integers
|
||||
|
||||
def test_create_cost_item_validation_errors(self, cost_manager, sample_category_id):
|
||||
"""Test cost item validation errors."""
|
||||
# Missing name
|
||||
with pytest.raises(ValueError, match="name is required"):
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Invalid cost type
|
||||
with pytest.raises(ValueError, match="must be 'monthly' or 'one_time'"):
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Test Item",
|
||||
cost_type="invalid",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Negative amount
|
||||
with pytest.raises(ValueError, match="must be non-negative"):
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Test Item",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('-10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Invalid date range
|
||||
with pytest.raises(ValueError, match="must be after starting date"):
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Test Item",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 15),
|
||||
ending_date=date(2025, 1, 10)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Inactive without ending date
|
||||
with pytest.raises(ValueError, match="must have an ending date"):
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Test Item",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1),
|
||||
is_active=False
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
def test_update_cost_item(self, cost_manager, sample_category_id):
|
||||
"""Test updating cost item."""
|
||||
# Create initial cost item
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Original Name",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_item_id = cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Update the cost item
|
||||
updates = {
|
||||
'name': 'Updated Name',
|
||||
'amount_eur': Decimal('15.50'),
|
||||
'description': 'Updated description'
|
||||
}
|
||||
success = cost_manager.update_cost_item(cost_item_id, updates)
|
||||
assert success is True
|
||||
|
||||
# Verify updates
|
||||
updated = cost_manager.get_cost_item(cost_item_id)
|
||||
assert updated['name'] == 'Updated Name'
|
||||
assert float(updated['amount_eur']) == 15.50
|
||||
assert updated['description'] == 'Updated description'
|
||||
|
||||
def test_update_nonexistent_cost_item(self, cost_manager):
|
||||
"""Test updating non-existent cost item."""
|
||||
with pytest.raises(ValueError, match="not found"):
|
||||
cost_manager.update_cost_item(99999, {'name': 'New Name'})
|
||||
|
||||
def test_deactivate_cost_item(self, cost_manager, sample_category_id):
|
||||
"""Test deactivating cost item."""
|
||||
# Create cost item
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Test Item",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_item_id = cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Deactivate with specific ending date
|
||||
ending_date = date(2025, 6, 30)
|
||||
success = cost_manager.deactivate_cost_item(cost_item_id, ending_date)
|
||||
assert success is True
|
||||
|
||||
# Verify deactivation
|
||||
updated = cost_manager.get_cost_item(cost_item_id)
|
||||
assert updated['is_active'] == 0 # SQLite stores False as 0
|
||||
assert updated['ending_date'] == ending_date.isoformat()
|
||||
|
||||
def test_list_cost_items_filtering(self, cost_manager, sample_category_id):
|
||||
"""Test listing cost items with filtering."""
|
||||
# Create multiple cost items
|
||||
items = [
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Monthly Item 1",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="One-time Item",
|
||||
cost_type="one_time",
|
||||
amount_eur=Decimal('50.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Inactive Item",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('5.00'),
|
||||
starting_from_date=date(2025, 1, 1),
|
||||
ending_date=date(2025, 1, 31),
|
||||
is_active=False
|
||||
)
|
||||
]
|
||||
|
||||
for item in items:
|
||||
cost_manager.create_cost_item(item)
|
||||
|
||||
# Test filtering by active only
|
||||
active_items = cost_manager.list_cost_items(active_only=True)
|
||||
assert len(active_items) == 2
|
||||
assert all(item['is_active'] == 1 for item in active_items)
|
||||
|
||||
# Test filtering by cost type
|
||||
monthly_items = cost_manager.list_cost_items(cost_type="monthly")
|
||||
assert len(monthly_items) == 1 # Only active monthly items
|
||||
assert monthly_items[0]['cost_type'] == "monthly"
|
||||
|
||||
# Test including inactive items
|
||||
all_items = cost_manager.list_cost_items(active_only=False)
|
||||
assert len(all_items) == 3
|
||||
|
||||
def test_get_active_costs_for_period(self, cost_manager, sample_category_id):
|
||||
"""Test retrieving active costs for specific period."""
|
||||
# Create cost items with different date ranges
|
||||
items = [
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Active Throughout",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2024, 12, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Starts Mid-Period",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('15.00'),
|
||||
starting_from_date=date(2025, 1, 15)
|
||||
),
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Ends Mid-Period",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('20.00'),
|
||||
starting_from_date=date(2024, 12, 1),
|
||||
ending_date=date(2025, 1, 15)
|
||||
),
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Outside Period",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('25.00'),
|
||||
starting_from_date=date(2025, 2, 1)
|
||||
)
|
||||
]
|
||||
|
||||
for item in items:
|
||||
cost_manager.create_cost_item(item)
|
||||
|
||||
# Get active costs for January 2025
|
||||
period_start = date(2025, 1, 1)
|
||||
period_end = date(2025, 1, 31)
|
||||
active_costs = cost_manager.get_active_costs_for_period(period_start, period_end)
|
||||
|
||||
# Should include first 3 items but not the fourth
|
||||
assert len(active_costs) == 3
|
||||
names = [item['name'] for item in active_costs]
|
||||
assert "Active Throughout" in names
|
||||
assert "Starts Mid-Period" in names
|
||||
assert "Ends Mid-Period" in names
|
||||
assert "Outside Period" not in names
|
||||
|
||||
def test_calculate_period_costs(self, cost_manager, sample_category_id):
|
||||
"""Test period cost calculations."""
|
||||
# Create another category
|
||||
other_category_id = cost_manager.create_category("Other Category")
|
||||
|
||||
# Create cost items in different categories
|
||||
items = [
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Monthly Cost 1",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Monthly Cost 2",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('15.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=other_category_id,
|
||||
name="One-time Cost",
|
||||
cost_type="one_time",
|
||||
amount_eur=Decimal('100.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
]
|
||||
|
||||
for item in items:
|
||||
cost_manager.create_cost_item(item)
|
||||
|
||||
# Calculate costs for January 2025
|
||||
period_start = date(2025, 1, 1)
|
||||
period_end = date(2025, 1, 31)
|
||||
calculations = cost_manager.calculate_period_costs(period_start, period_end)
|
||||
|
||||
assert calculations['total_monthly'] == 25.00
|
||||
assert calculations['total_one_time'] == 100.00
|
||||
assert calculations['total_period'] == 125.00
|
||||
assert calculations['active_cost_items'] == 3
|
||||
|
||||
# Check category breakdown
|
||||
assert 'Test Category' in calculations['category_breakdown']
|
||||
assert 'Other Category' in calculations['category_breakdown']
|
||||
assert calculations['category_breakdown']['Test Category']['monthly'] == 25.00
|
||||
assert calculations['category_breakdown']['Other Category']['one_time'] == 100.00
|
||||
|
||||
def test_category_management(self, cost_manager):
|
||||
"""Test category creation and management."""
|
||||
# Create category with unique name
|
||||
category_id = cost_manager.create_category("Custom Infrastructure", "Custom server costs")
|
||||
assert category_id is not None
|
||||
|
||||
# Retrieve category
|
||||
category = cost_manager.get_category(category_id)
|
||||
assert category['name'] == "Custom Infrastructure"
|
||||
assert category['description'] == "Custom server costs"
|
||||
|
||||
# Test duplicate category
|
||||
with pytest.raises(ValueError, match="already exists"):
|
||||
cost_manager.create_category("Custom Infrastructure")
|
||||
|
||||
# List categories
|
||||
categories = cost_manager.list_categories()
|
||||
category_names = [cat['name'] for cat in categories]
|
||||
assert "Custom Infrastructure" in category_names
|
||||
# Should also include default categories from schema initialization
|
||||
assert len(categories) >= 9 # 8 default + 1 created
|
||||
|
||||
# Get category by name
|
||||
found_category = cost_manager.get_category_by_name("Custom Infrastructure")
|
||||
assert found_category['id'] == category_id
|
||||
|
||||
def test_cost_item_with_category_validation(self, cost_manager):
|
||||
"""Test cost item creation with category validation."""
|
||||
# Try to create cost item with non-existent category
|
||||
with pytest.raises(ValueError, match="does not exist"):
|
||||
cost_item = CostItem(
|
||||
category_id=99999,
|
||||
name="Test Item",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
def test_precision_handling(self, cost_manager, sample_category_id):
|
||||
"""Test decimal precision in cost calculations."""
|
||||
# Create cost item with precise decimal
|
||||
cost_item = CostItem(
|
||||
category_id=sample_category_id,
|
||||
name="Precise Cost",
|
||||
cost_type="monthly",
|
||||
amount_eur=Decimal('10.99'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_item_id = cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Verify precision is maintained
|
||||
retrieved = cost_manager.get_cost_item(cost_item_id)
|
||||
assert float(retrieved['amount_eur']) == 10.99
|
||||
|
||||
# Test in period calculations
|
||||
calculations = cost_manager.calculate_period_costs(date(2025, 1, 1), date(2025, 1, 31))
|
||||
assert calculations['total_monthly'] == 10.99
|
||||
|
||||
def test_empty_database_operations(self, cost_manager):
|
||||
"""Test operations on empty database."""
|
||||
# List items in empty database
|
||||
items = cost_manager.list_cost_items()
|
||||
assert len(items) == 0
|
||||
|
||||
# Get non-existent item
|
||||
item = cost_manager.get_cost_item(99999)
|
||||
assert item is None
|
||||
|
||||
# Calculate costs for empty period
|
||||
calculations = cost_manager.calculate_period_costs(date(2025, 1, 1), date(2025, 1, 31))
|
||||
assert calculations['total_monthly'] == 0.00
|
||||
assert calculations['total_one_time'] == 0.00
|
||||
assert calculations['active_cost_items'] == 0
|
||||
357
markitect/finance/tests/test_cost_report_generator.py
Normal file
357
markitect/finance/tests/test_cost_report_generator.py
Normal file
@@ -0,0 +1,357 @@
|
||||
"""
|
||||
Tests for MarkiTect cost report template generator.
|
||||
|
||||
This module tests the complete cost report generation functionality including:
|
||||
- Report generation in different formats (summary, detailed, audit)
|
||||
- Markdown output with frontmatter and contentmatter
|
||||
- CLI integration and command functionality
|
||||
- Template structure validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
import json
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from markitect.finance.cost_manager import CostItemManager, CostItem
|
||||
from markitect.finance.report_generator import CostReportGenerator, ReportConfig
|
||||
from markitect.finance.models import FinanceModels
|
||||
|
||||
|
||||
class TestCostReportGenerator:
|
||||
"""Test suite for cost report generation system."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create temporary database for testing."""
|
||||
fd, path = tempfile.mkstemp(suffix='.db')
|
||||
os.close(fd)
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_data(self, temp_db):
|
||||
"""Setup test database with sample cost data."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
|
||||
# Get categories
|
||||
infra_cat = cost_manager.get_category_by_name('Infrastructure')
|
||||
software_cat = cost_manager.get_category_by_name('Software')
|
||||
|
||||
# Create sample cost items
|
||||
cost_items = [
|
||||
CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Hosteurope Server',
|
||||
description='Monthly server hosting',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=software_cat['id'],
|
||||
name='Bubble.io Plan',
|
||||
description='No-code platform subscription',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('32.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
),
|
||||
CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='SSL Certificate',
|
||||
description='Annual SSL certificate',
|
||||
cost_type='one_time',
|
||||
amount_eur=Decimal('45.00'),
|
||||
starting_from_date=date(2025, 1, 15)
|
||||
)
|
||||
]
|
||||
|
||||
for item in cost_items:
|
||||
cost_manager.create_cost_item(item)
|
||||
|
||||
return temp_db
|
||||
|
||||
@pytest.fixture
|
||||
def report_generator(self, setup_test_data):
|
||||
"""Create report generator with test data."""
|
||||
return CostReportGenerator(setup_test_data)
|
||||
|
||||
def test_report_config_creation(self):
|
||||
"""Test ReportConfig dataclass creation."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31),
|
||||
currency="EUR"
|
||||
)
|
||||
|
||||
assert config.format == "summary"
|
||||
assert config.period_start == date(2025, 1, 1)
|
||||
assert config.period_end == date(2025, 1, 31)
|
||||
assert config.currency == "EUR"
|
||||
assert config.include_inactive is False
|
||||
assert config.output_path is None
|
||||
|
||||
def test_generate_summary_report(self, report_generator):
|
||||
"""Test generation of summary cost report."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Check that it's valid markdown with frontmatter
|
||||
assert report.startswith("---")
|
||||
assert "Cost Summary Report - January 2025" in report
|
||||
assert "total_costs: 87.0" in report
|
||||
assert "report_type: \"cost_summary\"" in report
|
||||
|
||||
# Check contentmatter is present
|
||||
assert "contentmatter:" in report
|
||||
assert "cost_data" in report
|
||||
|
||||
# Verify total costs are correct (10 + 32 + 45 = 87)
|
||||
assert "€87.00" in report
|
||||
|
||||
def test_generate_detailed_report(self, report_generator):
|
||||
"""Test generation of detailed cost report."""
|
||||
config = ReportConfig(
|
||||
format="detailed",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Check report structure
|
||||
assert "Detailed Cost Report - January 2025" in report
|
||||
assert "Executive Summary" in report
|
||||
assert "report_type: \"cost_detailed\"" in report
|
||||
|
||||
# Check category sections are present
|
||||
assert "Infrastructure" in report
|
||||
assert "Software" in report
|
||||
|
||||
# Check individual items are listed
|
||||
assert "Hosteurope Server" in report
|
||||
assert "Bubble.io Plan" in report
|
||||
assert "SSL Certificate" in report
|
||||
|
||||
# Check table format
|
||||
assert "| Name | Type | Amount | Status | Start Date |" in report
|
||||
|
||||
def test_generate_audit_report(self, report_generator):
|
||||
"""Test generation of audit trail report."""
|
||||
config = ReportConfig(
|
||||
format="audit",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Check report structure
|
||||
assert "Cost Audit Report - January 2025" in report
|
||||
assert "Audit Summary" in report
|
||||
assert "report_type: \"cost_audit\"" in report
|
||||
assert "audit_trail: True" in report
|
||||
|
||||
# Check audit sections
|
||||
assert "Cost Verification" in report
|
||||
assert "Active Cost Items" in report
|
||||
assert "Transaction History" in report
|
||||
assert "Audit Trail" in report
|
||||
|
||||
# Check contentmatter includes audit data
|
||||
assert "audit_data" in report
|
||||
|
||||
def test_generate_period_report_convenience_method(self, report_generator):
|
||||
"""Test convenience method for generating monthly reports."""
|
||||
report = report_generator.generate_period_report(2025, 1, "summary")
|
||||
|
||||
assert "Cost Summary Report - January 2025" in report
|
||||
assert "2025-01-01" in report
|
||||
assert "2025-01-31" in report
|
||||
|
||||
def test_invalid_report_format_raises_error(self, report_generator):
|
||||
"""Test that invalid report format raises ValueError."""
|
||||
config = ReportConfig(
|
||||
format="invalid",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Unknown report format"):
|
||||
report_generator.generate_report(config)
|
||||
|
||||
def test_frontmatter_structure(self, report_generator):
|
||||
"""Test frontmatter structure in generated reports."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Extract frontmatter (between first two ---)
|
||||
lines = report.split('\n')
|
||||
frontmatter_lines = []
|
||||
in_frontmatter = False
|
||||
|
||||
for line in lines:
|
||||
if line.strip() == "---":
|
||||
if not in_frontmatter:
|
||||
in_frontmatter = True
|
||||
continue
|
||||
else:
|
||||
break
|
||||
if in_frontmatter:
|
||||
frontmatter_lines.append(line)
|
||||
|
||||
frontmatter_text = '\n'.join(frontmatter_lines)
|
||||
|
||||
# Check required frontmatter fields
|
||||
assert 'report_type:' in frontmatter_text
|
||||
assert 'period_start:' in frontmatter_text
|
||||
assert 'period_end:' in frontmatter_text
|
||||
assert 'total_costs:' in frontmatter_text
|
||||
assert 'currency:' in frontmatter_text
|
||||
assert 'generated_at:' in frontmatter_text
|
||||
|
||||
def test_contentmatter_structure(self, report_generator):
|
||||
"""Test contentmatter structure in generated reports."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Extract contentmatter (JSON in HTML comment)
|
||||
assert "<!--" in report
|
||||
assert "contentmatter:" in report
|
||||
assert "-->" in report
|
||||
|
||||
# Find and extract JSON
|
||||
start = report.find("contentmatter:\n") + len("contentmatter:\n")
|
||||
end = report.find("\n-->")
|
||||
json_text = report[start:end].strip()
|
||||
|
||||
# Parse JSON to verify structure
|
||||
contentmatter = json.loads(json_text)
|
||||
|
||||
assert "cost_data" in contentmatter
|
||||
assert "total_monthly" in contentmatter["cost_data"]
|
||||
assert "total_one_time" in contentmatter["cost_data"]
|
||||
assert "categories" in contentmatter["cost_data"]
|
||||
assert "active_items" in contentmatter["cost_data"]
|
||||
|
||||
# Verify totals
|
||||
assert contentmatter["cost_data"]["total_monthly"] == 42.0
|
||||
assert contentmatter["cost_data"]["total_one_time"] == 45.0
|
||||
|
||||
def test_save_report_to_file(self, report_generator, temp_db):
|
||||
"""Test saving report to file."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Save to temporary file
|
||||
with tempfile.NamedTemporaryFile(mode='w', delete=False, suffix='.md') as f:
|
||||
output_path = f.name
|
||||
|
||||
try:
|
||||
report_generator.save_report(report, output_path)
|
||||
|
||||
# Verify file was created and contains expected content
|
||||
with open(output_path, 'r', encoding='utf-8') as f:
|
||||
saved_content = f.read()
|
||||
|
||||
assert saved_content == report
|
||||
assert "Cost Summary Report" in saved_content
|
||||
|
||||
finally:
|
||||
os.unlink(output_path)
|
||||
|
||||
def test_empty_database_report(self, temp_db):
|
||||
"""Test report generation with empty database."""
|
||||
# Initialize empty database
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
report_generator = CostReportGenerator(temp_db)
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Should still generate valid report with zero costs
|
||||
assert "total_costs: 0.0" in report
|
||||
assert "€0.00" in report
|
||||
|
||||
def test_different_currency(self, report_generator):
|
||||
"""Test report generation with different currency."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31),
|
||||
currency="USD"
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
assert 'currency: "USD"' in report
|
||||
# Note: amounts are still in EUR from database, currency is just metadata
|
||||
|
||||
def test_report_with_inactive_items(self, setup_test_data):
|
||||
"""Test report behavior with inactive cost items."""
|
||||
cost_manager = CostItemManager(setup_test_data)
|
||||
|
||||
# Deactivate one item
|
||||
items = cost_manager.list_cost_items()
|
||||
if items:
|
||||
cost_manager.deactivate_cost_item(items[0]['id'], date(2025, 1, 15))
|
||||
|
||||
report_generator = CostReportGenerator(setup_test_data)
|
||||
config = ReportConfig(
|
||||
format="detailed",
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31),
|
||||
include_inactive=False
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
# Should still generate valid report, potentially with fewer active items
|
||||
assert "Detailed Cost Report" in report
|
||||
assert "contentmatter:" in report
|
||||
|
||||
def test_cross_month_period(self, report_generator):
|
||||
"""Test report generation across multiple months."""
|
||||
config = ReportConfig(
|
||||
format="summary",
|
||||
period_start=date(2025, 1, 15),
|
||||
period_end=date(2025, 2, 15)
|
||||
)
|
||||
|
||||
report = report_generator.generate_report(config)
|
||||
|
||||
assert "2025-01-15" in report
|
||||
assert "2025-02-15" in report
|
||||
# Should include items active during this period
|
||||
430
markitect/finance/tests/test_finance_models.py
Normal file
430
markitect/finance/tests/test_finance_models.py
Normal file
@@ -0,0 +1,430 @@
|
||||
"""
|
||||
Tests for MarkiTect finance models and database schema.
|
||||
|
||||
This module tests the complete finance schema including:
|
||||
- Database table creation and relationships
|
||||
- Data integrity constraints
|
||||
- Index performance
|
||||
- Schema validation
|
||||
- Migration functionality
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from datetime import date, datetime
|
||||
from decimal import Decimal
|
||||
|
||||
from markitect.finance.models import FinanceModels
|
||||
|
||||
|
||||
class TestFinanceModels:
|
||||
"""Test suite for finance database models."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create temporary database for testing."""
|
||||
fd, path = tempfile.mkstemp(suffix='.db')
|
||||
os.close(fd)
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.fixture
|
||||
def finance_models(self, temp_db):
|
||||
"""Create FinanceModels instance with temporary database."""
|
||||
return FinanceModels(temp_db)
|
||||
|
||||
def test_initialize_finance_schema(self, finance_models):
|
||||
"""Test complete finance schema initialization."""
|
||||
# Initialize schema
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
# Validate schema was created
|
||||
assert finance_models.validate_schema()
|
||||
|
||||
# Check all required tables exist
|
||||
schema_info = finance_models.get_schema_info()
|
||||
expected_tables = [
|
||||
'cost_categories',
|
||||
'cost_items',
|
||||
'cost_periods',
|
||||
'cost_transactions',
|
||||
'issue_cost_allocations',
|
||||
'issue_activity_log'
|
||||
]
|
||||
|
||||
for table in expected_tables:
|
||||
assert table in schema_info['tables']
|
||||
assert len(schema_info['tables'][table]['columns']) > 0
|
||||
|
||||
def test_cost_categories_table(self, finance_models):
|
||||
"""Test cost categories table structure and data."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Test default categories were inserted
|
||||
cursor.execute('SELECT COUNT(*) FROM cost_categories')
|
||||
count = cursor.fetchone()[0]
|
||||
assert count >= 8 # At least 8 default categories
|
||||
|
||||
# Test unique constraint
|
||||
with pytest.raises(Exception): # Should violate unique constraint
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_categories (name, description)
|
||||
VALUES ('Infrastructure', 'Duplicate category')
|
||||
''')
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_cost_items_table(self, finance_models):
|
||||
"""Test cost items table constraints and relationships."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Insert test category
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_categories (name, description)
|
||||
VALUES ('Test Category', 'For testing')
|
||||
''')
|
||||
category_id = cursor.lastrowid
|
||||
|
||||
# Test valid cost item insertion
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (?, 'Test Server', 'monthly', 10.50, '2025-01-01')
|
||||
''', (category_id,))
|
||||
|
||||
# Test cost_type constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (?, 'Invalid Type', 'invalid', 10.00, '2025-01-01')
|
||||
''', (category_id,))
|
||||
|
||||
# Test negative amount constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (?, 'Negative Cost', 'monthly', -10.00, '2025-01-01')
|
||||
''', (category_id,))
|
||||
|
||||
# Test date range constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date, ending_date)
|
||||
VALUES (?, 'Invalid Dates', 'monthly', 10.00, '2025-01-01', '2024-12-31')
|
||||
''', (category_id,))
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_cost_periods_table(self, finance_models):
|
||||
"""Test cost periods table constraints."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Test valid period insertion
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_periods (period_start, period_end)
|
||||
VALUES ('2025-01-01', '2025-01-31')
|
||||
''')
|
||||
|
||||
# Test period date constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_periods (period_start, period_end)
|
||||
VALUES ('2025-01-31', '2025-01-01')
|
||||
''')
|
||||
|
||||
# Test status constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_periods (period_start, period_end, status)
|
||||
VALUES ('2025-02-01', '2025-02-28', 'invalid_status')
|
||||
''')
|
||||
|
||||
# Test unique period constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_periods (period_start, period_end)
|
||||
VALUES ('2025-01-01', '2025-01-31')
|
||||
''')
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_cost_transactions_table(self, finance_models):
|
||||
"""Test cost transactions table and audit trail."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create test data
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_categories (name) VALUES ('Test Category')
|
||||
''')
|
||||
category_id = cursor.lastrowid
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (?, 'Test Item', 'monthly', 10.00, '2025-01-01')
|
||||
''', (category_id,))
|
||||
cost_item_id = cursor.lastrowid
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_periods (period_start, period_end)
|
||||
VALUES ('2025-01-01', '2025-01-31')
|
||||
''')
|
||||
period_id = cursor.lastrowid
|
||||
|
||||
# Test valid transaction
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_transactions
|
||||
(period_id, cost_item_id, transaction_type, amount_eur, transaction_date)
|
||||
VALUES (?, ?, 'cost_incurred', 10.00, '2025-01-15')
|
||||
''', (period_id, cost_item_id))
|
||||
|
||||
# Test transaction type constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_transactions
|
||||
(period_id, cost_item_id, transaction_type, amount_eur, transaction_date)
|
||||
VALUES (?, ?, 'invalid_type', 10.00, '2025-01-15')
|
||||
''', (period_id, cost_item_id))
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_issue_cost_allocations_table(self, finance_models):
|
||||
"""Test issue cost allocations table."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Create test period
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_periods (period_start, period_end)
|
||||
VALUES ('2025-01-01', '2025-01-31')
|
||||
''')
|
||||
period_id = cursor.lastrowid
|
||||
|
||||
# Test valid allocation
|
||||
cursor.execute('''
|
||||
INSERT INTO issue_cost_allocations
|
||||
(issue_id, period_id, allocated_amount, allocation_date)
|
||||
VALUES (123, ?, 5.50, '2025-01-31')
|
||||
''', (period_id,))
|
||||
|
||||
# Test positive amount constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO issue_cost_allocations
|
||||
(issue_id, period_id, allocated_amount, allocation_date)
|
||||
VALUES (124, ?, -1.00, '2025-01-31')
|
||||
''', (period_id,))
|
||||
|
||||
# Test unique issue-period constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO issue_cost_allocations
|
||||
(issue_id, period_id, allocated_amount, allocation_date)
|
||||
VALUES (123, ?, 3.00, '2025-01-31')
|
||||
''', (period_id,))
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_issue_activity_log_table(self, finance_models):
|
||||
"""Test issue activity log table."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Test valid activity log entry
|
||||
cursor.execute('''
|
||||
INSERT INTO issue_activity_log
|
||||
(issue_id, activity_type, activity_date)
|
||||
VALUES (123, 'created', '2025-01-15')
|
||||
''')
|
||||
|
||||
# Test activity type constraint
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO issue_activity_log
|
||||
(issue_id, activity_type, activity_date)
|
||||
VALUES (124, 'invalid_activity', '2025-01-15')
|
||||
''')
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_foreign_key_constraints(self, finance_models):
|
||||
"""Test foreign key relationships are enforced."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Test cost_items references cost_categories
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (999, 'Invalid Category', 'monthly', 10.00, '2025-01-01')
|
||||
''')
|
||||
|
||||
# Test cost_transactions references cost_periods
|
||||
with pytest.raises(Exception):
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_transactions
|
||||
(period_id, transaction_type, amount_eur, transaction_date)
|
||||
VALUES (999, 'cost_incurred', 10.00, '2025-01-15')
|
||||
''')
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_indexes_created(self, finance_models):
|
||||
"""Test that performance indexes are created."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
schema_info = finance_models.get_schema_info()
|
||||
index_names = [idx['name'] for idx in schema_info['indexes']]
|
||||
|
||||
# Check critical indexes exist
|
||||
expected_indexes = [
|
||||
'idx_cost_items_active',
|
||||
'idx_cost_items_type',
|
||||
'idx_cost_periods_status',
|
||||
'idx_cost_transactions_period',
|
||||
'idx_issue_allocations_issue'
|
||||
]
|
||||
|
||||
for index in expected_indexes:
|
||||
assert index in index_names
|
||||
|
||||
def test_schema_validation(self, finance_models):
|
||||
"""Test schema validation functionality."""
|
||||
# Before initialization
|
||||
assert not finance_models.validate_schema()
|
||||
|
||||
# After initialization
|
||||
finance_models.initialize_finance_schema()
|
||||
assert finance_models.validate_schema()
|
||||
|
||||
def test_drop_finance_schema(self, finance_models):
|
||||
"""Test schema cleanup functionality."""
|
||||
# Initialize schema
|
||||
finance_models.initialize_finance_schema()
|
||||
assert finance_models.validate_schema()
|
||||
|
||||
# Drop schema
|
||||
finance_models.drop_finance_schema()
|
||||
assert not finance_models.validate_schema()
|
||||
|
||||
def test_database_integration(self, temp_db):
|
||||
"""Test integration with existing DatabaseManager."""
|
||||
from markitect.database import DatabaseManager
|
||||
|
||||
# Initialize standard database
|
||||
db_manager = DatabaseManager(temp_db)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Verify finance tables were also created
|
||||
finance_models = FinanceModels(temp_db)
|
||||
assert finance_models.validate_schema()
|
||||
|
||||
# Verify existing tables still exist
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute('''
|
||||
SELECT name FROM sqlite_master
|
||||
WHERE type='table' AND name IN ('markdown_files', 'schemas')
|
||||
''')
|
||||
existing_tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
assert 'markdown_files' in existing_tables
|
||||
assert 'schemas' in existing_tables
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_decimal_precision(self, finance_models):
|
||||
"""Test decimal precision for financial calculations."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Insert test category
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_categories (name) VALUES ('Test Category')
|
||||
''')
|
||||
category_id = cursor.lastrowid
|
||||
|
||||
# Test precise decimal amounts
|
||||
test_amounts = [10.50, 99.99, 0.01, 1234.56]
|
||||
|
||||
for amount in test_amounts:
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (?, ?, 'monthly', ?, '2025-01-01')
|
||||
''', (category_id, f'Test Item {amount}', amount))
|
||||
|
||||
# Verify precision is maintained
|
||||
cursor.execute('SELECT amount_eur FROM cost_items ORDER BY id')
|
||||
stored_amounts = [float(row[0]) for row in cursor.fetchall()]
|
||||
|
||||
assert stored_amounts == test_amounts
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_example_cost_data(self, finance_models):
|
||||
"""Test insertion of example cost data from issue description."""
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
conn = finance_models.get_connection()
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Get category IDs
|
||||
cursor.execute('SELECT id, name FROM cost_categories')
|
||||
categories = {name: id for id, name in cursor.fetchall()}
|
||||
|
||||
# Insert example costs from issue #88
|
||||
example_costs = [
|
||||
('Infrastructure', 'Hosteurope Server', 'Monthly server hosting', 10.00),
|
||||
('Software', 'Bubble.io Plan', 'No-code platform subscription', 32.00),
|
||||
('Domain & DNS', 'Coulomb.social Domain', 'Domain registration', 5.00),
|
||||
('Development Tools', 'Claude Code Plan', 'AI coding assistant', 20.00),
|
||||
('AI & ML Services', 'Gemini Plan', 'LLM API for specifications', 20.00)
|
||||
]
|
||||
|
||||
for category_name, name, description, amount in example_costs:
|
||||
category_id = categories.get(category_name)
|
||||
assert category_id is not None
|
||||
|
||||
cursor.execute('''
|
||||
INSERT INTO cost_items
|
||||
(category_id, name, description, cost_type, amount_eur, starting_from_date)
|
||||
VALUES (?, ?, ?, 'monthly', ?, '2025-01-01')
|
||||
''', (category_id, name, description, amount))
|
||||
|
||||
# Verify total monthly costs
|
||||
cursor.execute('''
|
||||
SELECT SUM(amount_eur) FROM cost_items
|
||||
WHERE cost_type = 'monthly' AND is_active = TRUE
|
||||
''')
|
||||
total_monthly = float(cursor.fetchone()[0])
|
||||
|
||||
assert total_monthly == 87.00 # €87/month as described in issue
|
||||
|
||||
conn.close()
|
||||
794
markitect/finance/tests/test_issue_122_worktime_tracking.py
Normal file
794
markitect/finance/tests/test_issue_122_worktime_tracking.py
Normal file
@@ -0,0 +1,794 @@
|
||||
"""
|
||||
Tests for Issue #122 - Daily worktime estimation and distribution of associated cost
|
||||
|
||||
This module contains comprehensive tests for the worktime tracking system
|
||||
that estimates daily work time and distributes costs proportionally based
|
||||
on time allocation across issues.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import sqlite3
|
||||
from datetime import datetime, date, timedelta
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
from pathlib import Path
|
||||
import tempfile
|
||||
import json
|
||||
from decimal import Decimal
|
||||
|
||||
from markitect.finance.worktime_tracker import WorktimeTracker, WorktimeEntry, DailySummary
|
||||
from markitect.finance.worktime_commands import worktime, _parse_duration, _format_duration
|
||||
|
||||
|
||||
class TestWorktimeEntry:
|
||||
"""Test suite for WorktimeEntry dataclass."""
|
||||
|
||||
def test_worktime_entry_creation(self):
|
||||
"""Test that WorktimeEntry objects can be created properly."""
|
||||
entry = WorktimeEntry(
|
||||
id=1,
|
||||
issue_id=122,
|
||||
work_date=date.today(),
|
||||
duration_minutes=90,
|
||||
description="Working on worktime tracking"
|
||||
)
|
||||
|
||||
assert entry.id == 1
|
||||
assert entry.issue_id == 122
|
||||
assert entry.work_date == date.today()
|
||||
assert entry.duration_minutes == 90
|
||||
assert entry.description == "Working on worktime tracking"
|
||||
|
||||
def test_worktime_entry_defaults(self):
|
||||
"""Test that WorktimeEntry has proper default values."""
|
||||
entry = WorktimeEntry()
|
||||
|
||||
assert entry.id is None
|
||||
assert entry.issue_id is None
|
||||
assert entry.work_date is None
|
||||
assert entry.start_time is None
|
||||
assert entry.end_time is None
|
||||
assert entry.duration_minutes is None
|
||||
assert entry.description is None
|
||||
assert entry.entry_type == "manual"
|
||||
assert entry.created_at is None
|
||||
assert entry.updated_at is None
|
||||
|
||||
|
||||
class TestDailySummary:
|
||||
"""Test suite for DailySummary dataclass."""
|
||||
|
||||
def test_daily_summary_creation(self):
|
||||
"""Test that DailySummary objects can be created properly."""
|
||||
entries = [
|
||||
WorktimeEntry(id=1, issue_id=122, duration_minutes=90),
|
||||
WorktimeEntry(id=2, issue_id=123, duration_minutes=60)
|
||||
]
|
||||
|
||||
summary = DailySummary(
|
||||
work_date=date.today(),
|
||||
total_minutes=150,
|
||||
issue_count=2,
|
||||
entries=entries,
|
||||
cost_per_minute=Decimal('0.1'),
|
||||
total_cost_allocated=Decimal('15.0')
|
||||
)
|
||||
|
||||
assert summary.work_date == date.today()
|
||||
assert summary.total_minutes == 150
|
||||
assert summary.issue_count == 2
|
||||
assert len(summary.entries) == 2
|
||||
assert summary.cost_per_minute == Decimal('0.1')
|
||||
assert summary.total_cost_allocated == Decimal('15.0')
|
||||
|
||||
|
||||
class TestWorktimeTracker:
|
||||
"""Test suite for WorktimeTracker service."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test fixtures with temporary database."""
|
||||
self.temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False)
|
||||
self.temp_db.close()
|
||||
self.db_path = self.temp_db.name
|
||||
self.tracker = WorktimeTracker(self.db_path)
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up test fixtures."""
|
||||
Path(self.db_path).unlink(missing_ok=True)
|
||||
|
||||
def test_tracker_initialization(self):
|
||||
"""Test that tracker initializes properly with database."""
|
||||
assert self.tracker.db_path == self.db_path
|
||||
assert self.tracker.finance_models is not None
|
||||
|
||||
# Verify worktime tables were created
|
||||
with self.tracker.finance_models.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' ORDER BY name")
|
||||
tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
expected_tables = ['worktime_entries', 'daily_worktime_summaries', 'worktime_cost_distributions']
|
||||
for table in expected_tables:
|
||||
assert table in tables
|
||||
|
||||
def test_log_worktime_basic(self):
|
||||
"""Test logging basic worktime entry."""
|
||||
entry_id = self.tracker.log_worktime(
|
||||
issue_id=122,
|
||||
duration_minutes=90,
|
||||
description="Implementing worktime tracking"
|
||||
)
|
||||
|
||||
assert entry_id is not None
|
||||
|
||||
# Verify entry was stored
|
||||
entries = self.tracker.get_worktime_entries(issue_id=122)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].issue_id == 122
|
||||
assert entries[0].duration_minutes == 90
|
||||
assert entries[0].description == "Implementing worktime tracking"
|
||||
|
||||
def test_log_worktime_with_timestamps(self):
|
||||
"""Test logging worktime with start and end times."""
|
||||
now = datetime.now()
|
||||
start_time = now.replace(hour=9, minute=0, second=0, microsecond=0)
|
||||
end_time = now.replace(hour=10, minute=30, second=0, microsecond=0)
|
||||
|
||||
entry_id = self.tracker.log_worktime(
|
||||
issue_id=122,
|
||||
duration_minutes=90,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
description="Morning work session"
|
||||
)
|
||||
|
||||
entries = self.tracker.get_worktime_entries(issue_id=122)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].start_time.hour == 9
|
||||
assert entries[0].end_time.hour == 10
|
||||
assert entries[0].end_time.minute == 30
|
||||
|
||||
def test_log_worktime_validation(self):
|
||||
"""Test worktime logging validation."""
|
||||
# Test negative duration
|
||||
with pytest.raises(ValueError, match="Duration must be positive"):
|
||||
self.tracker.log_worktime(issue_id=122, duration_minutes=-30)
|
||||
|
||||
# Test zero duration
|
||||
with pytest.raises(ValueError, match="Duration must be positive"):
|
||||
self.tracker.log_worktime(issue_id=122, duration_minutes=0)
|
||||
|
||||
def test_get_worktime_entries_filtering(self):
|
||||
"""Test worktime entry retrieval with various filters."""
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
# Create test entries
|
||||
self.tracker.log_worktime(122, 60, work_date=today, description="Today's work")
|
||||
self.tracker.log_worktime(123, 90, work_date=today, description="Today's other work")
|
||||
self.tracker.log_worktime(122, 45, work_date=yesterday, description="Yesterday's work")
|
||||
|
||||
# Test filtering by issue
|
||||
issue_122_entries = self.tracker.get_worktime_entries(issue_id=122)
|
||||
assert len(issue_122_entries) == 2
|
||||
assert all(e.issue_id == 122 for e in issue_122_entries)
|
||||
|
||||
# Test filtering by date
|
||||
today_entries = self.tracker.get_worktime_entries(work_date=today)
|
||||
assert len(today_entries) == 2
|
||||
assert all(e.work_date == today for e in today_entries)
|
||||
|
||||
# Test date range filtering
|
||||
range_entries = self.tracker.get_worktime_entries(start_date=yesterday, end_date=today)
|
||||
assert len(range_entries) == 3
|
||||
|
||||
def test_get_daily_summary(self):
|
||||
"""Test daily worktime summary generation."""
|
||||
today = date.today()
|
||||
|
||||
# Log multiple entries for today
|
||||
self.tracker.log_worktime(122, 90, work_date=today)
|
||||
self.tracker.log_worktime(123, 60, work_date=today)
|
||||
self.tracker.log_worktime(122, 30, work_date=today) # Second entry for same issue
|
||||
|
||||
summary = self.tracker.get_daily_summary(today)
|
||||
|
||||
assert summary is not None
|
||||
assert summary.work_date == today
|
||||
assert summary.total_minutes == 180 # 90 + 60 + 30
|
||||
assert summary.issue_count == 2 # Issues 122 and 123
|
||||
assert len(summary.entries) == 3
|
||||
|
||||
def test_estimate_daily_worktime_equal_distribution(self):
|
||||
"""Test daily worktime estimation with equal distribution."""
|
||||
today = date.today()
|
||||
issues = [122, 123, 124]
|
||||
|
||||
result = self.tracker.estimate_daily_worktime(
|
||||
work_date=today,
|
||||
total_hours=6.0,
|
||||
issues=issues,
|
||||
distribution_method="equal"
|
||||
)
|
||||
|
||||
assert result['work_date'] == today
|
||||
assert result['total_minutes'] == 360 # 6 hours
|
||||
assert result['distribution_method'] == "equal"
|
||||
assert result['issues_count'] == 3
|
||||
|
||||
# Each issue should get 120 minutes (2 hours)
|
||||
for issue_id in issues:
|
||||
assert result['issue_estimates'][issue_id] == 120
|
||||
|
||||
# Verify entries were created
|
||||
entries = self.tracker.get_worktime_entries(work_date=today)
|
||||
assert len(entries) == 3
|
||||
assert all(e.entry_type == "estimated" for e in entries)
|
||||
|
||||
def test_estimate_daily_worktime_activity_based(self):
|
||||
"""Test daily worktime estimation with activity-based distribution."""
|
||||
today = date.today()
|
||||
|
||||
# Mock activity data - issue 122 has more activities
|
||||
with patch.object(self.tracker, '_get_activity_weights_for_date') as mock_weights:
|
||||
mock_weights.return_value = {122: 5, 123: 2, 124: 1} # Different activity levels
|
||||
|
||||
result = self.tracker.estimate_daily_worktime(
|
||||
work_date=today,
|
||||
total_hours=8.0,
|
||||
issues=[122, 123, 124],
|
||||
distribution_method="activity_based"
|
||||
)
|
||||
|
||||
# Verify distribution is proportional to activities
|
||||
total_weight = 5 + 2 + 1 # 8
|
||||
expected_122 = int((5/8) * 480) # 300 minutes
|
||||
expected_123 = int((2/8) * 480) # 120 minutes
|
||||
expected_124 = int((1/8) * 480) # 60 minutes
|
||||
|
||||
assert result['issue_estimates'][122] == expected_122
|
||||
assert result['issue_estimates'][123] == expected_123
|
||||
assert result['issue_estimates'][124] == expected_124
|
||||
|
||||
def test_distribute_daily_costs(self):
|
||||
"""Test daily cost distribution based on time allocation."""
|
||||
today = date.today()
|
||||
|
||||
# Log different amounts of time for different issues
|
||||
self.tracker.log_worktime(122, 120, work_date=today) # 2 hours
|
||||
self.tracker.log_worktime(123, 60, work_date=today) # 1 hour
|
||||
self.tracker.log_worktime(124, 120, work_date=today) # 2 hours
|
||||
# Total: 5 hours (300 minutes)
|
||||
|
||||
total_cost = Decimal('150.00') # €150 for the day
|
||||
result = self.tracker.distribute_daily_costs(
|
||||
work_date=today,
|
||||
total_daily_cost=total_cost
|
||||
)
|
||||
|
||||
assert result['work_date'] == today
|
||||
assert result['total_cost'] == 150.0
|
||||
assert result['total_minutes'] == 300
|
||||
assert result['cost_per_minute'] == 0.5 # €150 / 300 minutes
|
||||
|
||||
# Check cost distribution
|
||||
assert result['distributions'][122]['cost_allocated'] == 60.0 # 120 min * €0.5
|
||||
assert result['distributions'][123]['cost_allocated'] == 30.0 # 60 min * €0.5
|
||||
assert result['distributions'][124]['cost_allocated'] == 60.0 # 120 min * €0.5
|
||||
|
||||
# Check percentages
|
||||
assert result['distributions'][122]['percentage'] == 40.0 # 120/300 * 100
|
||||
assert result['distributions'][123]['percentage'] == 20.0 # 60/300 * 100
|
||||
assert result['distributions'][124]['percentage'] == 40.0 # 120/300 * 100
|
||||
|
||||
def test_distribute_daily_costs_no_worktime(self):
|
||||
"""Test cost distribution when no worktime is logged."""
|
||||
today = date.today()
|
||||
total_cost = Decimal('100.00')
|
||||
|
||||
result = self.tracker.distribute_daily_costs(
|
||||
work_date=today,
|
||||
total_daily_cost=total_cost
|
||||
)
|
||||
|
||||
assert 'message' in result
|
||||
assert "No worktime entries found" in result['message']
|
||||
|
||||
def test_get_worktime_report(self):
|
||||
"""Test comprehensive worktime reporting."""
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
# Create test data across multiple days and issues
|
||||
self.tracker.log_worktime(122, 90, work_date=yesterday)
|
||||
self.tracker.log_worktime(123, 60, work_date=yesterday)
|
||||
self.tracker.log_worktime(122, 120, work_date=today)
|
||||
self.tracker.log_worktime(124, 45, work_date=today)
|
||||
|
||||
report = self.tracker.get_worktime_report(
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
assert report['total_entries'] == 4
|
||||
assert report['total_time']['total_minutes'] == 315 # 90+60+120+45
|
||||
assert report['total_time']['hours'] == 5
|
||||
assert report['total_time']['minutes'] == 15
|
||||
assert report['unique_issues'] == 3 # Issues 122, 123, 124
|
||||
assert report['unique_dates'] == 2
|
||||
|
||||
# Check issue breakdown
|
||||
assert 122 in report['issue_breakdown']
|
||||
assert report['issue_breakdown'][122]['total_minutes'] == 210 # 90+120
|
||||
assert report['issue_breakdown'][122]['entry_count'] == 2
|
||||
assert report['issue_breakdown'][122]['unique_dates'] == 2
|
||||
|
||||
def test_delete_worktime_entry(self):
|
||||
"""Test deleting worktime entries."""
|
||||
entry_id = self.tracker.log_worktime(122, 90, description="Test entry")
|
||||
|
||||
# Verify entry exists
|
||||
entries = self.tracker.get_worktime_entries(issue_id=122)
|
||||
assert len(entries) == 1
|
||||
|
||||
# Delete entry
|
||||
success = self.tracker.delete_worktime_entry(entry_id)
|
||||
assert success is True
|
||||
|
||||
# Verify entry is gone
|
||||
entries = self.tracker.get_worktime_entries(issue_id=122)
|
||||
assert len(entries) == 0
|
||||
|
||||
# Try to delete non-existent entry
|
||||
success = self.tracker.delete_worktime_entry(99999)
|
||||
assert success is False
|
||||
|
||||
def test_update_worktime_entry(self):
|
||||
"""Test updating worktime entries."""
|
||||
entry_id = self.tracker.log_worktime(122, 90, description="Original description")
|
||||
|
||||
# Update duration and description
|
||||
success = self.tracker.update_worktime_entry(
|
||||
entry_id=entry_id,
|
||||
duration_minutes=120,
|
||||
description="Updated description"
|
||||
)
|
||||
assert success is True
|
||||
|
||||
# Verify updates
|
||||
entries = self.tracker.get_worktime_entries(issue_id=122)
|
||||
assert len(entries) == 1
|
||||
assert entries[0].duration_minutes == 120
|
||||
assert entries[0].description == "Updated description"
|
||||
|
||||
# Try to update non-existent entry
|
||||
success = self.tracker.update_worktime_entry(
|
||||
entry_id=99999,
|
||||
duration_minutes=60
|
||||
)
|
||||
assert success is False
|
||||
|
||||
|
||||
class TestWorktimeCommands:
|
||||
"""Test suite for worktime CLI commands."""
|
||||
|
||||
def test_parse_duration_minutes(self):
|
||||
"""Test parsing duration strings - minutes format."""
|
||||
assert _parse_duration("90") == 90
|
||||
assert _parse_duration("120") == 120
|
||||
assert _parse_duration("45m") == 45
|
||||
|
||||
def test_parse_duration_hours(self):
|
||||
"""Test parsing duration strings - hours format."""
|
||||
assert _parse_duration("1h") == 60
|
||||
assert _parse_duration("2h") == 120
|
||||
assert _parse_duration("1.5h") == 90
|
||||
assert _parse_duration("2.25h") == 135
|
||||
|
||||
def test_parse_duration_hours_minutes(self):
|
||||
"""Test parsing duration strings - hours and minutes format."""
|
||||
assert _parse_duration("1h30m") == 90
|
||||
assert _parse_duration("2h15m") == 135
|
||||
assert _parse_duration("0h45m") == 45
|
||||
assert _parse_duration("3h0m") == 180
|
||||
|
||||
def test_parse_duration_invalid(self):
|
||||
"""Test parsing invalid duration strings."""
|
||||
with pytest.raises(ValueError):
|
||||
_parse_duration("invalid")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_parse_duration("1x30m")
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
_parse_duration("")
|
||||
|
||||
def test_format_duration_minutes_only(self):
|
||||
"""Test formatting duration - minutes only."""
|
||||
assert _format_duration(30) == "30m"
|
||||
assert _format_duration(45) == "45m"
|
||||
assert _format_duration(59) == "59m"
|
||||
|
||||
def test_format_duration_hours_only(self):
|
||||
"""Test formatting duration - hours only."""
|
||||
assert _format_duration(60) == "1h"
|
||||
assert _format_duration(120) == "2h"
|
||||
assert _format_duration(180) == "3h"
|
||||
|
||||
def test_format_duration_hours_and_minutes(self):
|
||||
"""Test formatting duration - hours and minutes."""
|
||||
assert _format_duration(90) == "1h30m"
|
||||
assert _format_duration(135) == "2h15m"
|
||||
assert _format_duration(195) == "3h15m"
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_log_command_basic(self, mock_tracker_class):
|
||||
"""Test the log command with basic parameters."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
mock_tracker.log_worktime.return_value = 1
|
||||
mock_tracker.get_daily_summary.return_value = DailySummary(
|
||||
work_date=date.today(),
|
||||
total_minutes=90,
|
||||
issue_count=1,
|
||||
entries=[]
|
||||
)
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
result = runner.invoke(worktime, ['log', '122', '1h30m'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Logged 90min worktime for issue #122" in result.output
|
||||
mock_tracker.log_worktime.assert_called_once()
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_log_command_with_description(self, mock_tracker_class):
|
||||
"""Test the log command with description."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
mock_tracker.log_worktime.return_value = 1
|
||||
mock_tracker.get_daily_summary.return_value = None
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
result = runner.invoke(worktime, ['log', '122', '90', '--description', 'Testing worktime'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Testing worktime" in result.output
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_list_command_table_format(self, mock_tracker_class):
|
||||
"""Test the list command with table output format."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
|
||||
mock_entries = [
|
||||
WorktimeEntry(
|
||||
id=1,
|
||||
issue_id=122,
|
||||
work_date=date.today(),
|
||||
duration_minutes=90,
|
||||
description="Test worktime",
|
||||
entry_type="manual"
|
||||
)
|
||||
]
|
||||
mock_tracker.get_worktime_entries.return_value = mock_entries
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
result = runner.invoke(worktime, ['list'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "⏰ Worktime Entries" in result.output
|
||||
assert "#122" in result.output
|
||||
assert "1h30m" in result.output
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_list_command_json_format(self, mock_tracker_class):
|
||||
"""Test the list command with JSON output format."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
|
||||
mock_entries = [
|
||||
WorktimeEntry(
|
||||
id=1,
|
||||
issue_id=122,
|
||||
work_date=date.today(),
|
||||
duration_minutes=90,
|
||||
description="Test worktime",
|
||||
entry_type="manual"
|
||||
)
|
||||
]
|
||||
mock_tracker.get_worktime_entries.return_value = mock_entries
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
result = runner.invoke(worktime, ['list', '--format', 'json'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Should be valid JSON
|
||||
output_data = json.loads(result.output.strip())
|
||||
assert len(output_data) == 1
|
||||
assert output_data[0]['issue_id'] == 122
|
||||
assert output_data[0]['duration_minutes'] == 90
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_daily_command(self, mock_tracker_class):
|
||||
"""Test the daily summary command."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
|
||||
mock_entries = [
|
||||
WorktimeEntry(id=1, issue_id=122, duration_minutes=90, entry_type="manual"),
|
||||
WorktimeEntry(id=2, issue_id=123, duration_minutes=60, entry_type="manual")
|
||||
]
|
||||
|
||||
mock_summary = DailySummary(
|
||||
work_date=date.today(),
|
||||
total_minutes=150,
|
||||
issue_count=2,
|
||||
entries=mock_entries,
|
||||
cost_per_minute=Decimal('0.5'),
|
||||
total_cost_allocated=Decimal('75.0')
|
||||
)
|
||||
mock_tracker.get_daily_summary.return_value = mock_summary
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
today = date.today().strftime('%Y-%m-%d')
|
||||
result = runner.invoke(worktime, ['daily', today])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"📅 Daily Summary for {date.today()}" in result.output
|
||||
assert "Total Time: 2h30m" in result.output
|
||||
assert "Issues Worked: 2" in result.output
|
||||
assert "Cost per Minute: €0.5000" in result.output
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_estimate_command(self, mock_tracker_class):
|
||||
"""Test the estimate worktime command."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
|
||||
mock_result = {
|
||||
'work_date': date.today(),
|
||||
'total_minutes': 480, # 8 hours
|
||||
'distribution_method': 'equal',
|
||||
'issue_estimates': {122: 240, 123: 240},
|
||||
'issues_count': 2
|
||||
}
|
||||
mock_tracker.estimate_daily_worktime.return_value = mock_result
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
today = date.today().strftime('%Y-%m-%d')
|
||||
result = runner.invoke(worktime, ['estimate', today, '8', '-i', '122', '-i', '123'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "📊 Worktime Estimation" in result.output
|
||||
assert "Total Hours: 8.0h" in result.output
|
||||
assert "✅ Created 2 estimated worktime entries" in result.output
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_distribute_command(self, mock_tracker_class):
|
||||
"""Test the cost distribution command."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
|
||||
mock_result = {
|
||||
'work_date': date.today(),
|
||||
'total_cost': 100.0,
|
||||
'total_minutes': 200,
|
||||
'cost_per_minute': 0.5,
|
||||
'distributions': {
|
||||
122: {'minutes': 120, 'percentage': 60.0, 'cost_allocated': 60.0},
|
||||
123: {'minutes': 80, 'percentage': 40.0, 'cost_allocated': 40.0}
|
||||
},
|
||||
'issues_count': 2
|
||||
}
|
||||
mock_tracker.distribute_daily_costs.return_value = mock_result
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
today = date.today().strftime('%Y-%m-%d')
|
||||
result = runner.invoke(worktime, ['distribute', today, '100'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "💰 Cost Distribution" in result.output
|
||||
assert "Total Cost: €100.00" in result.output
|
||||
assert "Cost per Minute: €0.5000" in result.output
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_delete_command(self, mock_tracker_class):
|
||||
"""Test the delete command."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
mock_tracker.delete_worktime_entry.return_value = True
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
# Auto-confirm the deletion
|
||||
result = runner.invoke(worktime, ['delete', '1'], input='y\n')
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Deleted worktime entry #1" in result.output
|
||||
mock_tracker.delete_worktime_entry.assert_called_once_with(1)
|
||||
|
||||
@patch('markitect.finance.worktime_commands.WorktimeTracker')
|
||||
def test_update_command(self, mock_tracker_class):
|
||||
"""Test the update command."""
|
||||
mock_tracker = Mock()
|
||||
mock_tracker_class.return_value = mock_tracker
|
||||
mock_tracker.update_worktime_entry.return_value = True
|
||||
|
||||
from click.testing import CliRunner
|
||||
runner = CliRunner()
|
||||
|
||||
result = runner.invoke(worktime, ['update', '1', '--duration', '2h', '--description', 'Updated'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Updated worktime entry #1" in result.output
|
||||
|
||||
|
||||
class TestWorktimeIntegration:
|
||||
"""Integration tests for the complete worktime tracking system."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up integration test fixtures."""
|
||||
self.temp_db = tempfile.NamedTemporaryFile(suffix='.db', delete=False)
|
||||
self.temp_db.close()
|
||||
self.db_path = self.temp_db.name
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up integration test fixtures."""
|
||||
Path(self.db_path).unlink(missing_ok=True)
|
||||
|
||||
def test_full_worktime_lifecycle(self):
|
||||
"""Test the complete lifecycle of worktime tracking."""
|
||||
tracker = WorktimeTracker(self.db_path)
|
||||
|
||||
# 1. Log worktime for multiple issues across multiple days
|
||||
today = date.today()
|
||||
yesterday = today - timedelta(days=1)
|
||||
|
||||
tracker.log_worktime(122, 120, work_date=yesterday, description="Initial development")
|
||||
tracker.log_worktime(123, 90, work_date=yesterday, description="Code review")
|
||||
|
||||
tracker.log_worktime(122, 90, work_date=today, description="Bug fixes")
|
||||
tracker.log_worktime(124, 60, work_date=today, description="Documentation")
|
||||
|
||||
# 2. Verify daily summaries
|
||||
yesterday_summary = tracker.get_daily_summary(yesterday)
|
||||
assert yesterday_summary.total_minutes == 210 # 120 + 90
|
||||
assert yesterday_summary.issue_count == 2
|
||||
|
||||
today_summary = tracker.get_daily_summary(today)
|
||||
assert today_summary.total_minutes == 150 # 90 + 60
|
||||
assert today_summary.issue_count == 2
|
||||
|
||||
# 3. Distribute costs for a day
|
||||
distribution = tracker.distribute_daily_costs(
|
||||
work_date=today,
|
||||
total_daily_cost=Decimal('75.00') # €75 for today's work
|
||||
)
|
||||
|
||||
assert distribution['total_cost'] == 75.0
|
||||
assert distribution['total_minutes'] == 150
|
||||
assert distribution['cost_per_minute'] == 0.5
|
||||
|
||||
# Issue 122: 90 minutes = €45
|
||||
# Issue 124: 60 minutes = €30
|
||||
assert distribution['distributions'][122]['cost_allocated'] == 45.0
|
||||
assert distribution['distributions'][124]['cost_allocated'] == 30.0
|
||||
|
||||
# 4. Generate comprehensive report
|
||||
report = tracker.get_worktime_report(
|
||||
start_date=yesterday,
|
||||
end_date=today
|
||||
)
|
||||
|
||||
assert report['total_entries'] == 4
|
||||
assert report['total_time']['total_minutes'] == 360 # 210 + 150
|
||||
assert report['unique_issues'] == 3 # Issues 122, 123, 124
|
||||
assert report['unique_dates'] == 2
|
||||
|
||||
# 5. Test estimation functionality
|
||||
tomorrow = today + timedelta(days=1)
|
||||
estimation = tracker.estimate_daily_worktime(
|
||||
work_date=tomorrow,
|
||||
total_hours=6.0,
|
||||
issues=[122, 125, 126],
|
||||
distribution_method="equal"
|
||||
)
|
||||
|
||||
assert estimation['total_minutes'] == 360
|
||||
assert len(estimation['issue_estimates']) == 3
|
||||
# Each issue should get 120 minutes (equal distribution)
|
||||
for minutes in estimation['issue_estimates'].values():
|
||||
assert minutes == 120
|
||||
|
||||
# 6. Verify estimated entries were created
|
||||
tomorrow_entries = tracker.get_worktime_entries(work_date=tomorrow)
|
||||
assert len(tomorrow_entries) == 3
|
||||
assert all(e.entry_type == "estimated" for e in tomorrow_entries)
|
||||
|
||||
def test_cost_distribution_accuracy(self):
|
||||
"""Test accurate cost distribution calculations."""
|
||||
tracker = WorktimeTracker(self.db_path)
|
||||
work_date = date.today()
|
||||
|
||||
# Log precise worktime amounts
|
||||
tracker.log_worktime(122, 100, work_date=work_date) # 100 minutes
|
||||
tracker.log_worktime(123, 50, work_date=work_date) # 50 minutes
|
||||
tracker.log_worktime(124, 150, work_date=work_date) # 150 minutes
|
||||
# Total: 300 minutes
|
||||
|
||||
# Distribute exactly €300
|
||||
distribution = tracker.distribute_daily_costs(
|
||||
work_date=work_date,
|
||||
total_daily_cost=Decimal('300.00')
|
||||
)
|
||||
|
||||
# Should be exactly €1 per minute
|
||||
assert distribution['cost_per_minute'] == 1.0
|
||||
|
||||
# Verify exact cost allocation
|
||||
assert distribution['distributions'][122]['cost_allocated'] == 100.0
|
||||
assert distribution['distributions'][123]['cost_allocated'] == 50.0
|
||||
assert distribution['distributions'][124]['cost_allocated'] == 150.0
|
||||
|
||||
# Verify percentages sum to 100%
|
||||
total_percentage = sum(
|
||||
dist['percentage'] for dist in distribution['distributions'].values()
|
||||
)
|
||||
assert abs(total_percentage - 100.0) < 0.01 # Allow for rounding
|
||||
|
||||
# Verify cost allocation was logged to database
|
||||
with tracker.finance_models.get_connection() as conn:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute('''
|
||||
SELECT issue_id, cost_allocated
|
||||
FROM worktime_cost_distributions
|
||||
WHERE work_date = ?
|
||||
ORDER BY issue_id
|
||||
''', (work_date.isoformat() if hasattr(work_date, 'isoformat') else work_date,))
|
||||
results = cursor.fetchall()
|
||||
|
||||
assert len(results) == 3
|
||||
assert results[0] == (122, 100.0)
|
||||
assert results[1] == (123, 50.0)
|
||||
assert results[2] == (124, 150.0)
|
||||
|
||||
def test_worktime_modification_and_summary_updates(self):
|
||||
"""Test that modifying worktime entries correctly updates summaries."""
|
||||
tracker = WorktimeTracker(self.db_path)
|
||||
work_date = date.today()
|
||||
|
||||
# Log initial worktime
|
||||
entry_id = tracker.log_worktime(122, 60, work_date=work_date)
|
||||
|
||||
# Check initial summary
|
||||
summary = tracker.get_daily_summary(work_date)
|
||||
assert summary.total_minutes == 60
|
||||
|
||||
# Update the entry
|
||||
tracker.update_worktime_entry(entry_id, duration_minutes=120)
|
||||
|
||||
# Check updated summary
|
||||
summary = tracker.get_daily_summary(work_date)
|
||||
assert summary.total_minutes == 120
|
||||
|
||||
# Delete the entry
|
||||
tracker.delete_worktime_entry(entry_id)
|
||||
|
||||
# Check final summary
|
||||
summary = tracker.get_daily_summary(work_date)
|
||||
assert summary is None or summary.total_minutes == 0
|
||||
454
markitect/finance/tests/test_period_cli_commands.py
Normal file
454
markitect/finance/tests/test_period_cli_commands.py
Normal file
@@ -0,0 +1,454 @@
|
||||
"""
|
||||
Tests for MarkiTect period management CLI commands.
|
||||
|
||||
This module tests the command-line interface for period management including:
|
||||
- Period creation, listing, and status management
|
||||
- Period calculation and lifecycle operations
|
||||
- CLI error handling and validation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from datetime import date
|
||||
from decimal import Decimal
|
||||
from click.testing import CliRunner
|
||||
|
||||
from markitect.finance.cli import cost_commands
|
||||
from markitect.finance.models import FinanceModels
|
||||
from markitect.finance.period_manager import PeriodManager
|
||||
from markitect.finance.cost_manager import CostItemManager, CostItem
|
||||
|
||||
|
||||
class TestPeriodCLICommands:
|
||||
"""Test suite for period management CLI commands."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create temporary database for testing."""
|
||||
fd, path = tempfile.mkstemp(suffix='.db')
|
||||
os.close(fd)
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.fixture
|
||||
def setup_test_data(self, temp_db):
|
||||
"""Setup test database with sample period data."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
period_manager = PeriodManager(temp_db)
|
||||
|
||||
# Create sample period
|
||||
period_id = period_manager.create_period(
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31),
|
||||
period_type='monthly'
|
||||
)
|
||||
|
||||
return temp_db, period_id
|
||||
|
||||
@pytest.fixture
|
||||
def runner(self):
|
||||
"""Create Click test runner."""
|
||||
return CliRunner()
|
||||
|
||||
def test_period_create_success(self, runner, temp_db):
|
||||
"""Test period creation via CLI."""
|
||||
# Initialize database first
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'create',
|
||||
'--start-date', '2025-02-01',
|
||||
'--end-date', '2025-02-28',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Created period #" in result.output
|
||||
assert "📅 Period: 2025-02-01 to 2025-02-28" in result.output
|
||||
assert "📊 Type: monthly" in result.output
|
||||
|
||||
def test_period_create_with_loss_forward(self, runner, temp_db):
|
||||
"""Test period creation with loss carried forward."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'create',
|
||||
'--start-date', '2025-03-01',
|
||||
'--end-date', '2025-03-31',
|
||||
'--loss-forward', '15.75',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Created period #" in result.output
|
||||
assert "💸 Loss carried forward: €15.7500" in result.output
|
||||
|
||||
def test_period_create_invalid_dates(self, runner, temp_db):
|
||||
"""Test period creation with invalid date format."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'create',
|
||||
'--start-date', 'invalid-date',
|
||||
'--end-date', '2025-02-28',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error: Dates must be in YYYY-MM-DD format" in result.output
|
||||
|
||||
def test_period_create_overlapping_fails(self, runner, setup_test_data):
|
||||
"""Test that creating overlapping periods fails."""
|
||||
temp_db, existing_period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'create',
|
||||
'--start-date', '2025-01-15', # Overlaps with existing period
|
||||
'--end-date', '2025-02-15',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error:" in result.output
|
||||
assert "overlaps" in result.output.lower()
|
||||
|
||||
def test_period_list_all(self, runner, setup_test_data):
|
||||
"""Test listing all periods."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'list',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "📅 Calculation Periods" in result.output
|
||||
assert "2025-01-01" in result.output
|
||||
assert "2025-01-31" in result.output
|
||||
assert "Total: 1 periods" in result.output
|
||||
|
||||
def test_period_list_with_status_filter(self, runner, setup_test_data):
|
||||
"""Test listing periods with status filter."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
# Create second period and close it
|
||||
period_manager = PeriodManager(temp_db)
|
||||
period_id2 = period_manager.create_period(
|
||||
period_start=date(2025, 2, 1),
|
||||
period_end=date(2025, 2, 28)
|
||||
)
|
||||
period_manager.close_period(period_id2)
|
||||
|
||||
# Filter by open status
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'list',
|
||||
'--status', 'open',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "2025-01-01" in result.output # First period should be shown
|
||||
assert "2025-02-01" not in result.output # Second period should be filtered out
|
||||
|
||||
# Filter by closed status
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'list',
|
||||
'--status', 'closed',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "2025-02-01" in result.output # Second period should be shown
|
||||
assert "2025-01-01" not in result.output # First period should be filtered out
|
||||
|
||||
def test_period_list_with_date_filters(self, runner, temp_db):
|
||||
"""Test listing periods with date range filters."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
period_manager = PeriodManager(temp_db)
|
||||
|
||||
# Create periods in different months
|
||||
jan_period = period_manager.create_period(date(2025, 1, 1), date(2025, 1, 31))
|
||||
feb_period = period_manager.create_period(date(2025, 2, 1), date(2025, 2, 28))
|
||||
|
||||
# Filter by start date
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'list',
|
||||
'--start-from', '2025-02-01',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "2025-02-01" in result.output
|
||||
assert "2025-01-01" not in result.output
|
||||
|
||||
def test_period_list_empty(self, runner, temp_db):
|
||||
"""Test listing periods when none exist."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'list',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No periods found matching criteria" in result.output
|
||||
|
||||
def test_period_show_details(self, runner, setup_test_data):
|
||||
"""Test showing period details."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'show', str(period_id),
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"📅 Period #{period_id} Details" in result.output
|
||||
assert "Start Date: 2025-01-01" in result.output
|
||||
assert "End Date: 2025-01-31" in result.output
|
||||
assert "Type: monthly" in result.output
|
||||
assert "Status: open" in result.output
|
||||
|
||||
def test_period_show_nonexistent(self, runner, temp_db):
|
||||
"""Test showing non-existent period."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'show', '999',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Period #999 not found" in result.output
|
||||
|
||||
def test_period_calculate(self, runner, setup_test_data):
|
||||
"""Test period cost calculation."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
# Add some cost items for calculation
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
infra_cat = cost_manager.get_category_by_name('Infrastructure')
|
||||
|
||||
cost_item = CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Test Server',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('25.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'calculate', str(period_id),
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"📊 Period #{period_id} Cost Calculation" in result.output
|
||||
assert "Period: 2025-01-01 to 2025-01-31" in result.output
|
||||
assert "Monthly Recurring: €25.00" in result.output
|
||||
assert "Total Period Cost: €25.00" in result.output
|
||||
|
||||
def test_period_calculate_nonexistent(self, runner, temp_db):
|
||||
"""Test calculating costs for non-existent period."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'calculate', '999',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error calculating period:" in result.output
|
||||
|
||||
def test_period_status_update(self, runner, setup_test_data):
|
||||
"""Test period status update."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'status', str(period_id), 'calculating',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"✅ Period #{period_id} status updated to 'calculating'" in result.output
|
||||
|
||||
# Verify the status was actually updated
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'show', str(period_id),
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert "Status: calculating" in result.output
|
||||
|
||||
def test_period_status_update_invalid_status(self, runner, setup_test_data):
|
||||
"""Test period status update with invalid status."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'status', str(period_id), 'invalid',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 2 # Click validation error
|
||||
assert "Invalid value" in result.output
|
||||
|
||||
def test_period_status_update_nonexistent(self, runner, temp_db):
|
||||
"""Test status update for non-existent period."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'status', '999', 'calculating',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error:" in result.output
|
||||
|
||||
def test_period_close(self, runner, setup_test_data):
|
||||
"""Test period closure."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'close', str(period_id),
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert f"✅ Period #{period_id} has been closed" in result.output
|
||||
assert "💰 Final total cost:" in result.output
|
||||
|
||||
# Verify the period is actually closed
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'show', str(period_id),
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert "Status: closed" in result.output
|
||||
|
||||
def test_period_close_nonexistent(self, runner, temp_db):
|
||||
"""Test closing non-existent period."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'close', '999',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error:" in result.output
|
||||
|
||||
def test_period_current_exists(self, runner, setup_test_data):
|
||||
"""Test finding current period when it exists."""
|
||||
temp_db, period_id = setup_test_data
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'current',
|
||||
'--date', '2025-01-15',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "📅 Current Active Period" in result.output
|
||||
assert f"Period #{period_id}" in result.output
|
||||
assert "Dates: 2025-01-01 to 2025-01-31" in result.output
|
||||
|
||||
def test_period_current_not_found(self, runner, temp_db):
|
||||
"""Test finding current period when none exists."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'current',
|
||||
'--date', '2025-03-15',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No active period found for 2025-03-15" in result.output
|
||||
|
||||
def test_period_current_default_to_today(self, runner, temp_db):
|
||||
"""Test current period defaults to today."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'current',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "No active period found for today" in result.output
|
||||
assert "💡 Create one with:" in result.output
|
||||
assert "markitect cost period create" in result.output
|
||||
|
||||
def test_period_current_invalid_date(self, runner, temp_db):
|
||||
"""Test current period with invalid date format."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'current',
|
||||
'--date', 'invalid-date',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 1
|
||||
assert "Error: Date must be in YYYY-MM-DD format" in result.output
|
||||
|
||||
def test_period_help_commands(self, runner):
|
||||
"""Test help output for period commands."""
|
||||
# Test main period help
|
||||
result = runner.invoke(cost_commands, ['period', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Manage calculation periods and lifecycle" in result.output
|
||||
|
||||
# Test create help
|
||||
result = runner.invoke(cost_commands, ['period', 'create', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "Create a new calculation period" in result.output
|
||||
|
||||
# Test list help
|
||||
result = runner.invoke(cost_commands, ['period', 'list', '--help'])
|
||||
assert result.exit_code == 0
|
||||
assert "List calculation periods with optional filtering" in result.output
|
||||
|
||||
def test_period_commands_missing_database(self, runner):
|
||||
"""Test period commands without database specification."""
|
||||
# These should use default config path and still work or show appropriate error
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'list'
|
||||
])
|
||||
|
||||
# Should succeed with default database configuration
|
||||
assert result.exit_code == 0
|
||||
|
||||
def test_period_create_quarterly_type(self, runner, temp_db):
|
||||
"""Test creating quarterly period type."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
|
||||
result = runner.invoke(cost_commands, [
|
||||
'period', 'create',
|
||||
'--start-date', '2025-04-01',
|
||||
'--end-date', '2025-06-30',
|
||||
'--type', 'quarterly',
|
||||
'--database', temp_db
|
||||
])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "✅ Created period #" in result.output
|
||||
assert "📊 Type: quarterly" in result.output
|
||||
489
markitect/finance/tests/test_period_manager.py
Normal file
489
markitect/finance/tests/test_period_manager.py
Normal file
@@ -0,0 +1,489 @@
|
||||
"""
|
||||
Tests for MarkiTect Period Management Framework.
|
||||
|
||||
This module tests the complete period lifecycle management system including:
|
||||
- Period creation, status management, and lifecycle transitions
|
||||
- Period overlap validation and conflict resolution
|
||||
- Period calculations and cost aggregation
|
||||
- Period closure validation and audit trails
|
||||
- Current period detection and auto-creation
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import os
|
||||
from datetime import date, datetime, timedelta
|
||||
from decimal import Decimal
|
||||
|
||||
from markitect.finance.period_manager import PeriodManager, PeriodStatus, Period
|
||||
from markitect.finance.models import FinanceModels
|
||||
from markitect.finance.cost_manager import CostItemManager, CostItem
|
||||
|
||||
|
||||
class TestPeriodManager:
|
||||
"""Test suite for period management system."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db(self):
|
||||
"""Create temporary database for testing."""
|
||||
fd, path = tempfile.mkstemp(suffix='.db')
|
||||
os.close(fd)
|
||||
yield path
|
||||
os.unlink(path)
|
||||
|
||||
@pytest.fixture
|
||||
def period_manager(self, temp_db):
|
||||
"""Create period manager with initialized database."""
|
||||
finance_models = FinanceModels(temp_db)
|
||||
finance_models.initialize_finance_schema()
|
||||
return PeriodManager(temp_db)
|
||||
|
||||
@pytest.fixture
|
||||
def sample_period_data(self):
|
||||
"""Sample period data for testing."""
|
||||
return {
|
||||
'period_start': date(2025, 1, 1),
|
||||
'period_end': date(2025, 1, 31),
|
||||
'period_type': 'monthly'
|
||||
}
|
||||
|
||||
def test_period_status_enum(self):
|
||||
"""Test period status enumeration."""
|
||||
assert PeriodStatus.OPEN.value == 'open'
|
||||
assert PeriodStatus.CALCULATING.value == 'calculating'
|
||||
assert PeriodStatus.CLOSED.value == 'closed'
|
||||
|
||||
def test_period_dataclass(self):
|
||||
"""Test Period dataclass creation."""
|
||||
period = Period(
|
||||
id=1,
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31),
|
||||
period_type='monthly',
|
||||
status='open',
|
||||
total_costs=Decimal('100.50')
|
||||
)
|
||||
|
||||
assert period.id == 1
|
||||
assert period.period_start == date(2025, 1, 1)
|
||||
assert period.period_end == date(2025, 1, 31)
|
||||
assert period.total_costs == Decimal('100.50')
|
||||
|
||||
def test_create_period_success(self, period_manager, sample_period_data):
|
||||
"""Test successful period creation."""
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end'],
|
||||
period_type=sample_period_data['period_type']
|
||||
)
|
||||
|
||||
assert period_id is not None
|
||||
assert isinstance(period_id, int)
|
||||
|
||||
# Verify period was created
|
||||
created_period = period_manager.get_period_by_id(period_id)
|
||||
assert created_period is not None
|
||||
assert created_period['period_start'] == sample_period_data['period_start'].isoformat()
|
||||
assert created_period['period_end'] == sample_period_data['period_end'].isoformat()
|
||||
assert created_period['status'] == PeriodStatus.OPEN.value
|
||||
|
||||
def test_create_period_invalid_dates(self, period_manager):
|
||||
"""Test period creation with invalid date range."""
|
||||
with pytest.raises(ValueError, match="Period end date must be after start date"):
|
||||
period_manager.create_period(
|
||||
period_start=date(2025, 1, 31),
|
||||
period_end=date(2025, 1, 1) # End before start
|
||||
)
|
||||
|
||||
def test_create_period_with_loss_carried_forward(self, period_manager, sample_period_data):
|
||||
"""Test period creation with loss carried forward."""
|
||||
loss_amount = Decimal('25.50')
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end'],
|
||||
loss_carried_forward=loss_amount
|
||||
)
|
||||
|
||||
created_period = period_manager.get_period_by_id(period_id)
|
||||
assert created_period['loss_carried_forward'] == loss_amount
|
||||
|
||||
def test_find_overlapping_periods(self, period_manager, sample_period_data):
|
||||
"""Test overlap detection functionality."""
|
||||
# Create first period
|
||||
period_id1 = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
# Test overlapping period detection
|
||||
overlapping = period_manager.find_overlapping_periods(
|
||||
period_start=date(2025, 1, 15), # Overlaps with existing
|
||||
period_end=date(2025, 2, 15)
|
||||
)
|
||||
|
||||
assert len(overlapping) == 1
|
||||
assert overlapping[0]['id'] == period_id1
|
||||
|
||||
def test_create_overlapping_period_fails(self, period_manager, sample_period_data):
|
||||
"""Test that creating overlapping periods fails."""
|
||||
# Create first period
|
||||
period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
# Try to create overlapping period
|
||||
with pytest.raises(ValueError, match="Period overlaps with existing periods"):
|
||||
period_manager.create_period(
|
||||
period_start=date(2025, 1, 15), # Overlaps
|
||||
period_end=date(2025, 2, 15)
|
||||
)
|
||||
|
||||
def test_update_period_status_valid_transition(self, period_manager, sample_period_data):
|
||||
"""Test valid period status transitions."""
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
# Transition from OPEN to CALCULATING
|
||||
success = period_manager.update_period_status(period_id, PeriodStatus.CALCULATING.value)
|
||||
assert success is True
|
||||
|
||||
updated_period = period_manager.get_period_by_id(period_id)
|
||||
assert updated_period['status'] == PeriodStatus.CALCULATING.value
|
||||
|
||||
# Transition from CALCULATING to CLOSED
|
||||
success = period_manager.update_period_status(period_id, PeriodStatus.CLOSED.value)
|
||||
assert success is True
|
||||
|
||||
updated_period = period_manager.get_period_by_id(period_id)
|
||||
assert updated_period['status'] == PeriodStatus.CLOSED.value
|
||||
|
||||
def test_update_period_status_invalid_status(self, period_manager, sample_period_data):
|
||||
"""Test update with invalid status."""
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
with pytest.raises(ValueError, match="Invalid status 'invalid'"):
|
||||
period_manager.update_period_status(period_id, 'invalid')
|
||||
|
||||
def test_update_period_status_nonexistent_period(self, period_manager):
|
||||
"""Test update status for non-existent period."""
|
||||
with pytest.raises(ValueError, match="Period #999 not found"):
|
||||
period_manager.update_period_status(999, PeriodStatus.CALCULATING.value)
|
||||
|
||||
def test_calculate_period_costs(self, period_manager, sample_period_data, temp_db):
|
||||
"""Test period cost calculation functionality."""
|
||||
# Create period
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
# Set up cost manager and add test data
|
||||
finance_models = FinanceModels(temp_db)
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
|
||||
# Get categories
|
||||
infra_cat = cost_manager.get_category_by_name('Infrastructure')
|
||||
software_cat = cost_manager.get_category_by_name('Software')
|
||||
|
||||
# Create test cost items
|
||||
monthly_item = CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Monthly Server',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('25.00'),
|
||||
starting_from_date=date(2024, 12, 1) # Started before period
|
||||
)
|
||||
|
||||
one_time_item = CostItem(
|
||||
category_id=software_cat['id'],
|
||||
name='One-time License',
|
||||
cost_type='one_time',
|
||||
amount_eur=Decimal('50.00'),
|
||||
starting_from_date=date(2025, 1, 15) # Within period
|
||||
)
|
||||
|
||||
cost_manager.create_cost_item(monthly_item)
|
||||
cost_manager.create_cost_item(one_time_item)
|
||||
|
||||
# Calculate period costs
|
||||
calculation_result = period_manager.calculate_period_costs(period_id)
|
||||
|
||||
# Verify calculation results
|
||||
assert calculation_result['period_id'] == period_id
|
||||
assert calculation_result['monthly_costs'] == 25.0
|
||||
assert calculation_result['one_time_costs'] == 50.0
|
||||
assert calculation_result['total_costs'] == 75.0
|
||||
|
||||
# Verify period was updated
|
||||
updated_period = period_manager.get_period_by_id(period_id)
|
||||
assert updated_period['total_costs'] == Decimal('75.00')
|
||||
|
||||
def test_close_period(self, period_manager, sample_period_data):
|
||||
"""Test period closure functionality."""
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
# Close the period
|
||||
success = period_manager.close_period(period_id)
|
||||
assert success is True
|
||||
|
||||
# Verify period is closed
|
||||
closed_period = period_manager.get_period_by_id(period_id)
|
||||
assert closed_period['status'] == PeriodStatus.CLOSED.value
|
||||
|
||||
def test_close_period_already_closed(self, period_manager, sample_period_data):
|
||||
"""Test closing an already closed period."""
|
||||
period_id = period_manager.create_period(
|
||||
period_start=sample_period_data['period_start'],
|
||||
period_end=sample_period_data['period_end']
|
||||
)
|
||||
|
||||
# Close period first time
|
||||
period_manager.close_period(period_id)
|
||||
|
||||
# Close again (should succeed without error)
|
||||
success = period_manager.close_period(period_id)
|
||||
assert success is True
|
||||
|
||||
def test_close_nonexistent_period(self, period_manager):
|
||||
"""Test closing non-existent period."""
|
||||
with pytest.raises(ValueError, match="Period #999 not found"):
|
||||
period_manager.close_period(999)
|
||||
|
||||
def test_list_periods_no_filter(self, period_manager, sample_period_data):
|
||||
"""Test listing all periods without filters."""
|
||||
# Create multiple periods
|
||||
period_id1 = period_manager.create_period(
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
period_id2 = period_manager.create_period(
|
||||
period_start=date(2025, 2, 1),
|
||||
period_end=date(2025, 2, 28)
|
||||
)
|
||||
|
||||
# List all periods
|
||||
periods = period_manager.list_periods()
|
||||
|
||||
assert len(periods) == 2
|
||||
period_ids = [p['id'] for p in periods]
|
||||
assert period_id1 in period_ids
|
||||
assert period_id2 in period_ids
|
||||
|
||||
def test_list_periods_with_status_filter(self, period_manager):
|
||||
"""Test listing periods with status filter."""
|
||||
# Create periods with different statuses
|
||||
period_id1 = period_manager.create_period(
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
period_id2 = period_manager.create_period(
|
||||
period_start=date(2025, 2, 1),
|
||||
period_end=date(2025, 2, 28)
|
||||
)
|
||||
|
||||
# Close one period
|
||||
period_manager.close_period(period_id2)
|
||||
|
||||
# Filter by open status
|
||||
open_periods = period_manager.list_periods(status_filter=PeriodStatus.OPEN.value)
|
||||
assert len(open_periods) == 1
|
||||
assert open_periods[0]['id'] == period_id1
|
||||
|
||||
# Filter by closed status
|
||||
closed_periods = period_manager.list_periods(status_filter=PeriodStatus.CLOSED.value)
|
||||
assert len(closed_periods) == 1
|
||||
assert closed_periods[0]['id'] == period_id2
|
||||
|
||||
def test_list_periods_with_date_filters(self, period_manager):
|
||||
"""Test listing periods with date range filters."""
|
||||
# Create periods in different months
|
||||
jan_period = period_manager.create_period(
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
feb_period = period_manager.create_period(
|
||||
period_start=date(2025, 2, 1),
|
||||
period_end=date(2025, 2, 28)
|
||||
)
|
||||
|
||||
# Filter by start date
|
||||
periods_from_feb = period_manager.list_periods(start_date=date(2025, 2, 1))
|
||||
assert len(periods_from_feb) == 1
|
||||
assert periods_from_feb[0]['id'] == feb_period
|
||||
|
||||
# Filter by end date
|
||||
periods_until_jan = period_manager.list_periods(end_date=date(2025, 1, 31))
|
||||
assert len(periods_until_jan) == 1
|
||||
assert periods_until_jan[0]['id'] == jan_period
|
||||
|
||||
def test_get_current_period(self, period_manager):
|
||||
"""Test getting current period for a specific date."""
|
||||
# Create period covering January 2025
|
||||
period_id = period_manager.create_period(
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31)
|
||||
)
|
||||
|
||||
# Test date within period
|
||||
current = period_manager.get_current_period(date(2025, 1, 15))
|
||||
assert current is not None
|
||||
assert current['id'] == period_id
|
||||
|
||||
# Test date outside period
|
||||
current = period_manager.get_current_period(date(2025, 2, 15))
|
||||
assert current is None
|
||||
|
||||
def test_get_current_period_defaults_to_today(self, period_manager):
|
||||
"""Test that get_current_period defaults to today's date."""
|
||||
today = date.today()
|
||||
|
||||
# Create period covering today
|
||||
period_id = period_manager.create_period(
|
||||
period_start=date(today.year, today.month, 1),
|
||||
period_end=date(today.year, today.month, 31) if today.month != 12
|
||||
else date(today.year, 12, 31)
|
||||
)
|
||||
|
||||
# Get current period without specifying date
|
||||
current = period_manager.get_current_period()
|
||||
assert current is not None
|
||||
assert current['id'] == period_id
|
||||
|
||||
def test_create_monthly_period(self, period_manager):
|
||||
"""Test convenience method for creating monthly periods."""
|
||||
period_id = period_manager.create_monthly_period(2025, 3)
|
||||
assert period_id is not None
|
||||
|
||||
# Verify correct dates were set
|
||||
period = period_manager.get_period_by_id(period_id)
|
||||
assert period['period_start'] == '2025-03-01'
|
||||
assert period['period_end'] == '2025-03-31'
|
||||
assert period['period_type'] == 'monthly'
|
||||
|
||||
def test_create_monthly_period_december(self, period_manager):
|
||||
"""Test creating monthly period for December (year boundary)."""
|
||||
period_id = period_manager.create_monthly_period(2025, 12)
|
||||
|
||||
period = period_manager.get_period_by_id(period_id)
|
||||
assert period['period_start'] == '2025-12-01'
|
||||
assert period['period_end'] == '2025-12-31'
|
||||
|
||||
def test_auto_create_period_for_date(self, period_manager):
|
||||
"""Test automatic period creation for a given date."""
|
||||
test_date = date(2025, 5, 15)
|
||||
|
||||
# First call should create new period
|
||||
period_id = period_manager.auto_create_period_for_date(test_date)
|
||||
assert period_id is not None
|
||||
|
||||
# Second call should return existing period
|
||||
period_id2 = period_manager.auto_create_period_for_date(test_date)
|
||||
assert period_id2 == period_id
|
||||
|
||||
# Verify period covers the test date
|
||||
period = period_manager.get_period_by_id(period_id)
|
||||
assert period['period_start'] == '2025-05-01'
|
||||
assert period['period_end'] == '2025-05-31'
|
||||
|
||||
def test_period_calculation_with_loss_carried_forward(self, period_manager, temp_db):
|
||||
"""Test period calculation including loss carried forward."""
|
||||
# Create period with loss carried forward
|
||||
period_id = period_manager.create_period(
|
||||
period_start=date(2025, 1, 1),
|
||||
period_end=date(2025, 1, 31),
|
||||
loss_carried_forward=Decimal('15.75')
|
||||
)
|
||||
|
||||
# Add a cost item
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
infra_cat = cost_manager.get_category_by_name('Infrastructure')
|
||||
|
||||
cost_item = CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Test Server',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('10.00'),
|
||||
starting_from_date=date(2025, 1, 1)
|
||||
)
|
||||
cost_manager.create_cost_item(cost_item)
|
||||
|
||||
# Calculate costs
|
||||
calculation = period_manager.calculate_period_costs(period_id)
|
||||
|
||||
# Should include loss carried forward
|
||||
assert calculation['loss_carried_forward'] == 15.75
|
||||
assert calculation['monthly_costs'] == 10.0
|
||||
assert calculation['total_costs'] == 25.75 # 10.0 + 15.75
|
||||
|
||||
def test_period_cost_calculation_edge_cases(self, period_manager, temp_db):
|
||||
"""Test period cost calculation with various edge cases."""
|
||||
# Create period
|
||||
period_id = period_manager.create_period(
|
||||
period_start=date(2025, 3, 1),
|
||||
period_end=date(2025, 3, 31)
|
||||
)
|
||||
|
||||
cost_manager = CostItemManager(temp_db)
|
||||
infra_cat = cost_manager.get_category_by_name('Infrastructure')
|
||||
|
||||
# Item that starts before period and ends during period
|
||||
item1 = CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Ending Item',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('20.00'),
|
||||
starting_from_date=date(2025, 1, 1),
|
||||
ending_date=date(2025, 3, 15)
|
||||
)
|
||||
|
||||
# Item that starts after period
|
||||
item2 = CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Future Item',
|
||||
cost_type='monthly',
|
||||
amount_eur=Decimal('30.00'),
|
||||
starting_from_date=date(2025, 4, 1)
|
||||
)
|
||||
|
||||
# One-time item outside period
|
||||
item3 = CostItem(
|
||||
category_id=infra_cat['id'],
|
||||
name='Past One-time',
|
||||
cost_type='one_time',
|
||||
amount_eur=Decimal('100.00'),
|
||||
starting_from_date=date(2025, 2, 15)
|
||||
)
|
||||
|
||||
cost_manager.create_cost_item(item1)
|
||||
cost_manager.create_cost_item(item2)
|
||||
cost_manager.create_cost_item(item3)
|
||||
|
||||
# Calculate costs
|
||||
calculation = period_manager.calculate_period_costs(period_id)
|
||||
|
||||
# Only item1 should be included (ends during period)
|
||||
assert calculation['monthly_costs'] == 20.0
|
||||
assert calculation['one_time_costs'] == 0.0
|
||||
assert calculation['total_costs'] == 20.0
|
||||
|
||||
def test_error_handling_database_errors(self, period_manager):
|
||||
"""Test error handling for database-related issues."""
|
||||
# Test with invalid period ID
|
||||
with pytest.raises(ValueError, match="Period #-1 not found"):
|
||||
period_manager.calculate_period_costs(-1)
|
||||
|
||||
# Test getting non-existent period
|
||||
result = period_manager.get_period_by_id(99999)
|
||||
assert result is None
|
||||
0
markitect/graphql/tests/__init__.py
Normal file
0
markitect/graphql/tests/__init__.py
Normal file
797
markitect/graphql/tests/test_issue_10_graphql_mutations.py
Normal file
797
markitect/graphql/tests/test_issue_10_graphql_mutations.py
Normal file
@@ -0,0 +1,797 @@
|
||||
"""
|
||||
Comprehensive tests for GraphQL mutations (Issue #10).
|
||||
|
||||
Tests all aspects of the GraphQL write interface including:
|
||||
- Mutation schema validation
|
||||
- Markdown file CRUD operations
|
||||
- Schema CRUD operations
|
||||
- Error handling
|
||||
- CLI integration
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
import os
|
||||
from datetime import datetime
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from markitect.graphql.schema import schema
|
||||
from markitect.database import DatabaseManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Initialize database with test data
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def populated_db_path():
|
||||
"""Create temporary database with some test data."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Initialize database with test data
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Add sample data
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Sample markdown file
|
||||
cursor.execute("""
|
||||
INSERT INTO markdown_files (filename, content, front_matter, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (
|
||||
'existing.md',
|
||||
'# Existing Document\n\nThis document already exists.',
|
||||
'{"title": "Existing Document"}',
|
||||
datetime.now().isoformat()
|
||||
))
|
||||
|
||||
# Sample schema
|
||||
cursor.execute("""
|
||||
INSERT INTO schemas (filename, title, description, schema_content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
'existing-schema.json',
|
||||
'Existing Schema',
|
||||
'A schema that already exists',
|
||||
'{"type": "object", "properties": {"name": {"type": "string"}}}',
|
||||
datetime.now().isoformat()
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
class TestGraphQLMutationSchema:
|
||||
"""Test GraphQL mutation schema definition and validation."""
|
||||
|
||||
def test_schema_has_mutations(self):
|
||||
"""Test that the GraphQL schema has mutations."""
|
||||
result = schema.execute('''
|
||||
{
|
||||
__schema {
|
||||
mutationType {
|
||||
name
|
||||
fields {
|
||||
name
|
||||
description
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
''')
|
||||
|
||||
assert result.errors is None
|
||||
mutation_type = result.data['__schema']['mutationType']
|
||||
assert mutation_type is not None
|
||||
assert mutation_type['name'] == 'Mutation'
|
||||
|
||||
field_names = [field['name'] for field in mutation_type['fields']]
|
||||
assert 'addMarkdownFile' in field_names
|
||||
assert 'updateMarkdownFile' in field_names
|
||||
assert 'addSchema' in field_names
|
||||
assert 'updateSchema' in field_names
|
||||
assert 'deleteSchema' in field_names
|
||||
|
||||
def test_add_markdown_file_mutation_signature(self):
|
||||
"""Test addMarkdownFile mutation has correct signature."""
|
||||
result = schema.execute('''
|
||||
{
|
||||
__schema {
|
||||
mutationType {
|
||||
fields {
|
||||
name
|
||||
args {
|
||||
name
|
||||
type {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
''')
|
||||
|
||||
mutation_fields = result.data['__schema']['mutationType']['fields']
|
||||
add_file_field = next(f for f in mutation_fields if f['name'] == 'addMarkdownFile')
|
||||
|
||||
arg_names = [arg['name'] for arg in add_file_field['args']]
|
||||
assert 'filename' in arg_names
|
||||
assert 'content' in arg_names
|
||||
|
||||
def test_mutation_payload_types(self):
|
||||
"""Test that mutation payload types have correct structure."""
|
||||
result = schema.execute('''
|
||||
{
|
||||
__schema {
|
||||
types {
|
||||
name
|
||||
fields {
|
||||
name
|
||||
type {
|
||||
name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
''')
|
||||
|
||||
types = {t['name']: t for t in result.data['__schema']['types']}
|
||||
|
||||
# Check AddMarkdownFilePayload
|
||||
payload = types.get('AddMarkdownFilePayload')
|
||||
assert payload is not None
|
||||
field_names = [f['name'] for f in payload['fields']]
|
||||
assert 'markdownFile' in field_names
|
||||
assert 'success' in field_names
|
||||
assert 'errors' in field_names
|
||||
|
||||
|
||||
class TestMarkdownFileMutations:
|
||||
"""Test markdown file CRUD mutations."""
|
||||
|
||||
def test_add_markdown_file_success(self, temp_db_path):
|
||||
"""Test successful markdown file creation."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "new-file.md"
|
||||
content: "# New File\\n\\nThis is new content."
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
filename
|
||||
content
|
||||
wordCount
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['addMarkdownFile']
|
||||
assert data['success'] is True
|
||||
assert len(data['errors']) == 0
|
||||
assert data['markdownFile'] is not None
|
||||
assert data['markdownFile']['filename'] == 'new-file.md'
|
||||
assert 'New File' in data['markdownFile']['content']
|
||||
assert data['markdownFile']['wordCount'] > 0
|
||||
|
||||
def test_add_markdown_file_with_front_matter(self, temp_db_path):
|
||||
"""Test markdown file creation with front matter."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
content_with_frontmatter = '''---
|
||||
title: "Test Document"
|
||||
author: "Test Author"
|
||||
tags: ["test", "markdown"]
|
||||
---
|
||||
|
||||
# Test Document
|
||||
|
||||
This is a test document with front matter.
|
||||
'''
|
||||
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "with-frontmatter.md"
|
||||
content: "%s"
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
filename
|
||||
hasFrontMatter
|
||||
frontMatter {
|
||||
key
|
||||
value
|
||||
}
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
''' % content_with_frontmatter.replace('\n', '\\n').replace('"', '\\"')
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['addMarkdownFile']
|
||||
assert data['success'] is True
|
||||
assert data['markdownFile']['hasFrontMatter'] is True
|
||||
front_matter_keys = [fm['key'] for fm in data['markdownFile']['frontMatter']]
|
||||
assert 'title' in front_matter_keys
|
||||
assert 'author' in front_matter_keys
|
||||
|
||||
def test_add_markdown_file_duplicate_filename(self, populated_db_path):
|
||||
"""Test adding a file with duplicate filename (should succeed as update)."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=populated_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "existing.md"
|
||||
content: "# Updated Content\\n\\nThis content replaces the existing."
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
filename
|
||||
content
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['addMarkdownFile']
|
||||
assert data['success'] is True
|
||||
assert 'Updated Content' in data['markdownFile']['content']
|
||||
|
||||
def test_update_markdown_file_success(self, populated_db_path):
|
||||
"""Test successful markdown file update."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=populated_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
updateMarkdownFile(
|
||||
id: 1
|
||||
content: "# Updated Title\\n\\nThis content has been updated."
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
content
|
||||
wordCount
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['updateMarkdownFile']
|
||||
assert data['success'] is True
|
||||
assert len(data['errors']) == 0
|
||||
assert 'Updated Title' in data['markdownFile']['content']
|
||||
|
||||
def test_update_markdown_file_not_found(self, temp_db_path):
|
||||
"""Test updating non-existent markdown file."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
updateMarkdownFile(
|
||||
id: 999
|
||||
content: "# This should fail"
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['updateMarkdownFile']
|
||||
assert data['success'] is False
|
||||
assert data['markdownFile'] is None
|
||||
assert len(data['errors']) > 0
|
||||
assert 'not found' in data['errors'][0].lower()
|
||||
|
||||
def test_update_markdown_file_no_content(self, populated_db_path):
|
||||
"""Test updating markdown file without providing content."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=populated_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
updateMarkdownFile(id: 1) {
|
||||
success
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['updateMarkdownFile']
|
||||
assert data['success'] is False
|
||||
assert 'required' in data['errors'][0].lower()
|
||||
|
||||
|
||||
class TestSchemaMutations:
|
||||
"""Test JSON schema CRUD mutations."""
|
||||
|
||||
def test_add_schema_success(self, temp_db_path):
|
||||
"""Test successful schema creation."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
schema_content = {
|
||||
"type": "object",
|
||||
"title": "User Schema",
|
||||
"description": "Schema for user objects",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"age": {"type": "integer", "minimum": 0}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
|
||||
mutation = '''
|
||||
mutation {
|
||||
addSchema(
|
||||
filename: "user-schema.json"
|
||||
schemaContent: "%s"
|
||||
) {
|
||||
success
|
||||
schema {
|
||||
id
|
||||
filename
|
||||
title
|
||||
description
|
||||
propertyCount
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
''' % json.dumps(schema_content).replace('"', '\\"')
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['addSchema']
|
||||
assert data['success'] is True
|
||||
assert len(data['errors']) == 0
|
||||
assert data['schema']['filename'] == 'user-schema.json'
|
||||
assert data['schema']['title'] == 'User Schema'
|
||||
assert data['schema']['propertyCount'] == 2
|
||||
|
||||
def test_add_schema_invalid_json(self, temp_db_path):
|
||||
"""Test adding schema with invalid JSON."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
addSchema(
|
||||
filename: "invalid-schema.json"
|
||||
schemaContent: "{ invalid json }"
|
||||
) {
|
||||
success
|
||||
schema {
|
||||
id
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
# GraphQL should reject invalid JSON at the schema validation level
|
||||
assert result.errors is not None
|
||||
assert len(result.errors) > 0
|
||||
assert "Badly formed JSONString" in str(result.errors[0])
|
||||
|
||||
def test_update_schema_success(self, populated_db_path):
|
||||
"""Test successful schema update."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=populated_db_path):
|
||||
new_schema = {
|
||||
"type": "object",
|
||||
"title": "Updated Schema",
|
||||
"properties": {
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"}
|
||||
}
|
||||
}
|
||||
|
||||
mutation = '''
|
||||
mutation {
|
||||
updateSchema(
|
||||
id: 1
|
||||
schemaContent: "%s"
|
||||
) {
|
||||
success
|
||||
schema {
|
||||
title
|
||||
propertyCount
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
''' % json.dumps(new_schema).replace('"', '\\"')
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['updateSchema']
|
||||
assert data['success'] is True
|
||||
assert data['schema']['title'] == 'Updated Schema'
|
||||
assert data['schema']['propertyCount'] == 2
|
||||
|
||||
def test_update_schema_not_found(self, temp_db_path):
|
||||
"""Test updating non-existent schema."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
updateSchema(
|
||||
id: 999
|
||||
schemaContent: "{\\"type\\": \\"object\\"}"
|
||||
) {
|
||||
success
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['updateSchema']
|
||||
assert data['success'] is False
|
||||
assert 'not found' in data['errors'][0].lower()
|
||||
|
||||
def test_delete_schema_success(self, populated_db_path):
|
||||
"""Test successful schema deletion."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=populated_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
deleteSchema(filename: "existing-schema.json") {
|
||||
success
|
||||
deletedFilename
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['deleteSchema']
|
||||
assert data['success'] is True
|
||||
assert data['deletedFilename'] == 'existing-schema.json'
|
||||
assert len(data['errors']) == 0
|
||||
|
||||
def test_delete_schema_not_found(self, temp_db_path):
|
||||
"""Test deleting non-existent schema."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
deleteSchema(filename: "nonexistent.json") {
|
||||
success
|
||||
deletedFilename
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['deleteSchema']
|
||||
assert data['success'] is False
|
||||
assert data['deletedFilename'] is None
|
||||
|
||||
|
||||
class TestMutationErrorHandling:
|
||||
"""Test error handling in mutations."""
|
||||
|
||||
def test_database_error_handling(self, temp_db_path):
|
||||
"""Test mutation behavior when database is unavailable."""
|
||||
# Use a non-existent database path
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value='/nonexistent/path.db'):
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "test.md"
|
||||
content: "# Test"
|
||||
) {
|
||||
success
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
data = result.data['addMarkdownFile']
|
||||
assert data['success'] is False
|
||||
assert len(data['errors']) > 0
|
||||
|
||||
def test_invalid_mutation_syntax(self):
|
||||
"""Test handling of invalid mutation syntax."""
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(filename: "test.md") {
|
||||
success
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
# Should have errors due to missing required 'content' argument
|
||||
assert result.errors is not None
|
||||
|
||||
def test_missing_required_arguments(self):
|
||||
"""Test mutations with missing required arguments."""
|
||||
mutation = '''
|
||||
mutation {
|
||||
addSchema(filename: "test.json") {
|
||||
success
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
# Should have errors due to missing required 'schemaContent' argument
|
||||
assert result.errors is not None
|
||||
|
||||
|
||||
class TestMutationIntegration:
|
||||
"""Test full integration of mutations with database."""
|
||||
|
||||
def test_crud_workflow(self, temp_db_path):
|
||||
"""Test complete CRUD workflow for markdown files."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
# 1. Create a file
|
||||
create_mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "workflow-test.md"
|
||||
content: "# Original Content\\n\\nOriginal text."
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
filename
|
||||
content
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(create_mutation)
|
||||
assert result.data['addMarkdownFile']['success'] is True
|
||||
file_id = result.data['addMarkdownFile']['markdownFile']['id']
|
||||
|
||||
# 2. Update the file
|
||||
update_mutation = '''
|
||||
mutation {
|
||||
updateMarkdownFile(
|
||||
id: %d
|
||||
content: "# Updated Content\\n\\nUpdated text."
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
content
|
||||
}
|
||||
}
|
||||
}
|
||||
''' % file_id
|
||||
|
||||
result = schema.execute(update_mutation)
|
||||
assert result.data['updateMarkdownFile']['success'] is True
|
||||
assert 'Updated Content' in result.data['updateMarkdownFile']['markdownFile']['content']
|
||||
|
||||
def test_schema_crud_workflow(self, temp_db_path):
|
||||
"""Test complete CRUD workflow for schemas."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
# 1. Create a schema
|
||||
create_mutation = '''
|
||||
mutation {
|
||||
addSchema(
|
||||
filename: "workflow-schema.json"
|
||||
schemaContent: "{\\"type\\": \\"object\\", \\"title\\": \\"Original\\"}"
|
||||
) {
|
||||
success
|
||||
schema {
|
||||
id
|
||||
title
|
||||
}
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(create_mutation)
|
||||
assert result.data['addSchema']['success'] is True
|
||||
schema_id = result.data['addSchema']['schema']['id']
|
||||
|
||||
# 2. Update the schema
|
||||
update_mutation = '''
|
||||
mutation {
|
||||
updateSchema(
|
||||
id: %d
|
||||
schemaContent: "{\\"type\\": \\"object\\", \\"title\\": \\"Updated\\"}"
|
||||
) {
|
||||
success
|
||||
schema {
|
||||
title
|
||||
}
|
||||
}
|
||||
}
|
||||
''' % schema_id
|
||||
|
||||
result = schema.execute(update_mutation)
|
||||
assert result.data['updateSchema']['success'] is True
|
||||
assert result.data['updateSchema']['schema']['title'] == 'Updated'
|
||||
|
||||
# 3. Delete the schema
|
||||
delete_mutation = '''
|
||||
mutation {
|
||||
deleteSchema(filename: "workflow-schema.json") {
|
||||
success
|
||||
deletedFilename
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(delete_mutation)
|
||||
assert result.data['deleteSchema']['success'] is True
|
||||
assert result.data['deleteSchema']['deletedFilename'] == 'workflow-schema.json'
|
||||
|
||||
|
||||
class TestMutationCLI:
|
||||
"""Test CLI integration for mutations."""
|
||||
|
||||
def test_graphql_mutate_command_available(self):
|
||||
"""Test that graphql-mutate command is available."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "markitect.cli", "graphql-mutate", "--help"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "Execute GraphQL mutations" in result.stdout
|
||||
assert "--local" in result.stdout
|
||||
assert "--variables" in result.stdout
|
||||
|
||||
def test_mutation_examples_in_help(self):
|
||||
"""Test that mutation examples are included in help."""
|
||||
import subprocess
|
||||
import sys
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "markitect.cli", "graphql-examples"],
|
||||
capture_output=True,
|
||||
text=True
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "Mutation Examples" in result.stdout
|
||||
assert "addMarkdownFile" in result.stdout
|
||||
assert "updateMarkdownFile" in result.stdout
|
||||
assert "addSchema" in result.stdout
|
||||
assert "deleteSchema" in result.stdout
|
||||
|
||||
|
||||
class TestMutationPayloads:
|
||||
"""Test mutation payload structures."""
|
||||
|
||||
def test_add_markdown_file_payload_structure(self, temp_db_path):
|
||||
"""Test AddMarkdownFilePayload has correct structure."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "payload-test.md"
|
||||
content: "# Payload Test"
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
filename
|
||||
content
|
||||
wordCount
|
||||
lineCount
|
||||
hasFrontMatter
|
||||
createdAt
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
payload = result.data['addMarkdownFile']
|
||||
|
||||
# Check payload structure
|
||||
assert isinstance(payload['success'], bool)
|
||||
assert isinstance(payload['errors'], list)
|
||||
|
||||
if payload['success']:
|
||||
md_file = payload['markdownFile']
|
||||
assert md_file is not None
|
||||
assert isinstance(md_file['id'], int)
|
||||
assert isinstance(md_file['filename'], str)
|
||||
assert isinstance(md_file['wordCount'], int)
|
||||
assert isinstance(md_file['lineCount'], int)
|
||||
assert isinstance(md_file['hasFrontMatter'], bool)
|
||||
|
||||
def test_error_payload_structure(self, temp_db_path):
|
||||
"""Test error payloads have correct structure."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value='/nonexistent/path.db'):
|
||||
mutation = '''
|
||||
mutation {
|
||||
addMarkdownFile(
|
||||
filename: "error-test.md"
|
||||
content: "# Error Test"
|
||||
) {
|
||||
success
|
||||
markdownFile {
|
||||
id
|
||||
}
|
||||
errors
|
||||
}
|
||||
}
|
||||
'''
|
||||
|
||||
result = schema.execute(mutation)
|
||||
|
||||
assert result.errors is None
|
||||
payload = result.data['addMarkdownFile']
|
||||
|
||||
assert payload['success'] is False
|
||||
assert payload['markdownFile'] is None
|
||||
assert isinstance(payload['errors'], list)
|
||||
assert len(payload['errors']) > 0
|
||||
assert all(isinstance(error, str) for error in payload['errors'])
|
||||
619
markitect/graphql/tests/test_issue_9_graphql_interface.py
Normal file
619
markitect/graphql/tests/test_issue_9_graphql_interface.py
Normal file
@@ -0,0 +1,619 @@
|
||||
"""
|
||||
Comprehensive tests for GraphQL interface (Issue #9).
|
||||
|
||||
Tests all aspects of the GraphQL read interface including:
|
||||
- Schema definition and validation
|
||||
- Resolver functionality
|
||||
- Server endpoints
|
||||
- CLI integration
|
||||
- Error handling
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import sqlite3
|
||||
import tempfile
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch, MagicMock
|
||||
import subprocess
|
||||
import sys
|
||||
import os
|
||||
from datetime import datetime
|
||||
|
||||
from markitect.graphql.schema import schema, MarkdownFile, Schema as SchemaType, AST, DatabaseStats
|
||||
from markitect.graphql.resolvers import Query, MarkiTectResolver, get_default_database_path
|
||||
from markitect.graphql.server import GraphQLServer, GraphQLClient
|
||||
from markitect.database import DatabaseManager
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path():
|
||||
"""Create temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Initialize database with test data
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Add sample data
|
||||
conn = sqlite3.connect(db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
# Sample markdown file
|
||||
cursor.execute("""
|
||||
INSERT INTO markdown_files (filename, content, front_matter, created_at)
|
||||
VALUES (?, ?, ?, ?)
|
||||
""", (
|
||||
'test.md',
|
||||
'# Test Document\n\nThis is a test document with [a link](http://example.com).',
|
||||
'{"title": "Test Document", "author": "Test Author"}',
|
||||
datetime.now().isoformat()
|
||||
))
|
||||
|
||||
# Sample schema
|
||||
cursor.execute("""
|
||||
INSERT INTO schemas (filename, title, description, schema_content, created_at)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
""", (
|
||||
'test-schema.json',
|
||||
'Test Schema',
|
||||
'A test schema for testing',
|
||||
'{"type": "object", "properties": {"name": {"type": "string"}}}',
|
||||
datetime.now().isoformat()
|
||||
))
|
||||
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graphql_resolver(temp_db_path):
|
||||
"""Create GraphQL resolver with test database."""
|
||||
return MarkiTectResolver(temp_db_path)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def graphql_query(temp_db_path):
|
||||
"""Create GraphQL Query instance with test database."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
return Query()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def flask_app(temp_db_path):
|
||||
"""Create Flask app for testing GraphQL server."""
|
||||
server = GraphQLServer(db_path=temp_db_path, enable_cors=True)
|
||||
app = server.create_app()
|
||||
app.config['TESTING'] = True
|
||||
return app
|
||||
|
||||
|
||||
class TestGraphQLSchema:
|
||||
"""Test GraphQL schema definition and validation."""
|
||||
|
||||
def test_schema_is_valid(self):
|
||||
"""Test that the GraphQL schema is valid."""
|
||||
assert schema is not None
|
||||
assert hasattr(schema, 'execute')
|
||||
|
||||
def test_schema_has_required_types(self):
|
||||
"""Test that schema contains all required types."""
|
||||
schema_str = str(schema)
|
||||
|
||||
# Check for main types
|
||||
assert 'MarkdownFile' in schema_str
|
||||
assert 'Schema' in schema_str
|
||||
assert 'AST' in schema_str
|
||||
assert 'DatabaseStats' in schema_str
|
||||
assert 'SearchResult' in schema_str
|
||||
|
||||
def test_query_type_fields(self):
|
||||
"""Test that Query type has all required fields."""
|
||||
schema_str = str(schema)
|
||||
|
||||
# Check for query fields
|
||||
assert 'markdownFile' in schema_str
|
||||
assert 'markdownFiles' in schema_str
|
||||
assert 'schema' in schema_str
|
||||
assert 'schemas' in schema_str
|
||||
assert 'ast' in schema_str
|
||||
assert 'search' in schema_str
|
||||
assert 'databaseStats' in schema_str
|
||||
assert 'astQuery' in schema_str
|
||||
|
||||
|
||||
class TestGraphQLResolvers:
|
||||
"""Test GraphQL resolver functionality."""
|
||||
|
||||
def test_resolver_initialization(self, temp_db_path):
|
||||
"""Test resolver initializes correctly."""
|
||||
resolver = MarkiTectResolver(temp_db_path)
|
||||
|
||||
assert resolver.db_path == temp_db_path
|
||||
assert resolver.db_manager is not None
|
||||
assert resolver.ast_service is not None
|
||||
|
||||
def test_get_connection(self, graphql_resolver):
|
||||
"""Test database connection method."""
|
||||
conn = graphql_resolver.get_connection()
|
||||
|
||||
assert conn is not None
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1")
|
||||
result = cursor.fetchone()
|
||||
assert result[0] == 1
|
||||
conn.close()
|
||||
|
||||
def test_row_to_dict(self, graphql_resolver):
|
||||
"""Test row to dictionary conversion."""
|
||||
conn = graphql_resolver.get_connection()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT 1 as test_col")
|
||||
row = cursor.fetchone()
|
||||
|
||||
result = graphql_resolver.row_to_dict(cursor, row)
|
||||
assert result == {'test_col': 1}
|
||||
conn.close()
|
||||
|
||||
def test_resolve_markdown_file_by_id(self, graphql_query):
|
||||
"""Test resolving markdown file by ID."""
|
||||
result = graphql_query.resolve_markdown_file(None, id=1)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MarkdownFile)
|
||||
assert result.filename == 'test.md'
|
||||
assert 'Test Document' in result.content
|
||||
|
||||
def test_resolve_markdown_file_by_filename(self, graphql_query):
|
||||
"""Test resolving markdown file by filename."""
|
||||
result = graphql_query.resolve_markdown_file(None, filename='test.md')
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, MarkdownFile)
|
||||
assert result.id == 1
|
||||
|
||||
def test_resolve_markdown_file_not_found(self, graphql_query):
|
||||
"""Test resolving non-existent markdown file."""
|
||||
result = graphql_query.resolve_markdown_file(None, id=999)
|
||||
assert result is None
|
||||
|
||||
result = graphql_query.resolve_markdown_file(None, filename='nonexistent.md')
|
||||
assert result is None
|
||||
|
||||
def test_resolve_schema_by_id(self, graphql_query):
|
||||
"""Test resolving schema by ID."""
|
||||
result = graphql_query.resolve_schema(None, id=1)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, SchemaType)
|
||||
assert result.title == 'Test Schema'
|
||||
|
||||
def test_resolve_markdown_files_list(self, graphql_query):
|
||||
"""Test resolving list of markdown files."""
|
||||
results = graphql_query.resolve_markdown_files(None, limit=10, offset=0)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) >= 1
|
||||
assert all(isinstance(f, MarkdownFile) for f in results)
|
||||
|
||||
def test_resolve_schemas_list(self, graphql_query):
|
||||
"""Test resolving list of schemas."""
|
||||
results = graphql_query.resolve_schemas(None, limit=10, offset=0)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) >= 1
|
||||
assert all(isinstance(s, SchemaType) for s in results)
|
||||
|
||||
def test_resolve_search_files(self, graphql_query):
|
||||
"""Test search functionality for files."""
|
||||
results = graphql_query.resolve_search(None, query="Test", type="file", limit=10)
|
||||
|
||||
assert isinstance(results, list)
|
||||
assert len(results) >= 1
|
||||
assert all(hasattr(r, 'type') and hasattr(r, 'score') for r in results)
|
||||
|
||||
def test_resolve_database_stats(self, graphql_query):
|
||||
"""Test database statistics resolver."""
|
||||
result = graphql_query.resolve_database_stats(None)
|
||||
|
||||
assert result is not None
|
||||
assert isinstance(result, DatabaseStats)
|
||||
assert result.total_files >= 1
|
||||
assert result.total_schemas >= 1
|
||||
assert result.total_size_bytes > 0
|
||||
|
||||
@patch('markitect.graphql.resolvers.Path.exists')
|
||||
def test_resolve_ast_file_not_found(self, mock_exists, graphql_query):
|
||||
"""Test AST resolution when file doesn't exist."""
|
||||
mock_exists.return_value = False
|
||||
|
||||
result = graphql_query.resolve_ast(None, filename='nonexistent.md')
|
||||
assert result is None
|
||||
|
||||
|
||||
class TestGraphQLServer:
|
||||
"""Test GraphQL server functionality."""
|
||||
|
||||
def test_server_initialization(self, temp_db_path):
|
||||
"""Test server initializes correctly."""
|
||||
server = GraphQLServer(db_path=temp_db_path, enable_cors=True)
|
||||
|
||||
assert server.db_path == temp_db_path
|
||||
assert server.enable_cors is True
|
||||
assert server.app is None
|
||||
|
||||
def test_server_initialization_without_flask(self):
|
||||
"""Test server initialization when Flask is not available."""
|
||||
with patch('markitect.graphql.server.FLASK_AVAILABLE', False):
|
||||
with pytest.raises(ImportError, match="Flask is required"):
|
||||
GraphQLServer()
|
||||
|
||||
def test_create_app(self, temp_db_path):
|
||||
"""Test Flask app creation."""
|
||||
server = GraphQLServer(db_path=temp_db_path)
|
||||
app = server.create_app()
|
||||
|
||||
assert app is not None
|
||||
assert server.app is app
|
||||
|
||||
def test_graphql_endpoint_post(self, flask_app):
|
||||
"""Test GraphQL POST endpoint."""
|
||||
with flask_app.test_client() as client:
|
||||
query = '{ databaseStats { totalFiles } }'
|
||||
response = client.post('/graphql',
|
||||
json={'query': query},
|
||||
content_type='application/json')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'data' in data
|
||||
assert 'databaseStats' in data['data']
|
||||
|
||||
def test_graphql_endpoint_invalid_json(self, flask_app):
|
||||
"""Test GraphQL endpoint with invalid JSON."""
|
||||
with flask_app.test_client() as client:
|
||||
response = client.post('/graphql',
|
||||
data='invalid json',
|
||||
content_type='application/json')
|
||||
|
||||
# Flask returns 500 for malformed JSON, which is reasonable
|
||||
assert response.status_code in [400, 500]
|
||||
|
||||
def test_graphql_endpoint_no_query(self, flask_app):
|
||||
"""Test GraphQL endpoint without query."""
|
||||
with flask_app.test_client() as client:
|
||||
response = client.post('/graphql',
|
||||
json={},
|
||||
content_type='application/json')
|
||||
|
||||
assert response.status_code == 400
|
||||
data = response.get_json()
|
||||
assert 'error' in data
|
||||
|
||||
def test_graphql_playground(self, flask_app):
|
||||
"""Test GraphQL playground endpoint."""
|
||||
with flask_app.test_client() as client:
|
||||
response = client.get('/graphql')
|
||||
|
||||
assert response.status_code == 200
|
||||
assert 'GraphQL Playground' in response.get_data(as_text=True)
|
||||
|
||||
def test_schema_endpoint(self, flask_app):
|
||||
"""Test schema introspection endpoint."""
|
||||
with flask_app.test_client() as client:
|
||||
response = client.get('/schema')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert 'schema' in data
|
||||
|
||||
def test_health_check_healthy(self, flask_app):
|
||||
"""Test health check endpoint when healthy."""
|
||||
with flask_app.test_client() as client:
|
||||
response = client.get('/health')
|
||||
|
||||
assert response.status_code == 200
|
||||
data = response.get_json()
|
||||
assert data['status'] == 'healthy'
|
||||
assert data['database'] == 'connected'
|
||||
|
||||
def test_health_check_unhealthy(self, temp_db_path):
|
||||
"""Test health check when database is unavailable."""
|
||||
# Use non-existent database path
|
||||
server = GraphQLServer(db_path='/nonexistent/path.db')
|
||||
app = server.create_app()
|
||||
|
||||
with app.test_client() as client:
|
||||
response = client.get('/health')
|
||||
|
||||
assert response.status_code == 500
|
||||
data = response.get_json()
|
||||
assert data['status'] == 'unhealthy'
|
||||
|
||||
|
||||
class TestGraphQLClient:
|
||||
"""Test GraphQL client functionality."""
|
||||
|
||||
def test_client_initialization(self):
|
||||
"""Test client initializes correctly."""
|
||||
client = GraphQLClient("http://localhost:5000/graphql")
|
||||
assert client.endpoint == "http://localhost:5000/graphql"
|
||||
|
||||
def test_client_default_endpoint(self):
|
||||
"""Test client uses default endpoint."""
|
||||
client = GraphQLClient()
|
||||
assert client.endpoint == "http://localhost:5000/graphql"
|
||||
|
||||
@patch('requests.post')
|
||||
def test_client_execute_query(self, mock_post):
|
||||
"""Test client query execution."""
|
||||
# Mock response
|
||||
mock_response = Mock()
|
||||
mock_response.json.return_value = {
|
||||
'data': {'databaseStats': {'totalFiles': 5}}
|
||||
}
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
client = GraphQLClient()
|
||||
result = client.execute('{ databaseStats { totalFiles } }')
|
||||
|
||||
assert result['data']['databaseStats']['totalFiles'] == 5
|
||||
mock_post.assert_called_once()
|
||||
|
||||
def test_client_execute_local(self, temp_db_path):
|
||||
"""Test client local query execution."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
client = GraphQLClient()
|
||||
result = client.execute_local('{ databaseStats { totalFiles } }', context={'db_path': temp_db_path})
|
||||
|
||||
assert result is not None
|
||||
assert 'data' in result
|
||||
# The databaseStats resolver might return None if db is empty, so let's be more flexible
|
||||
if result['data']['databaseStats'] is not None:
|
||||
assert result['data']['databaseStats']['totalFiles'] >= 0
|
||||
|
||||
def test_client_execute_without_requests(self):
|
||||
"""Test client execution when requests is not available."""
|
||||
import builtins
|
||||
original_import = builtins.__import__
|
||||
|
||||
def mock_import(name, *args, **kwargs):
|
||||
if name == 'requests':
|
||||
raise ImportError("No module named 'requests'")
|
||||
return original_import(name, *args, **kwargs)
|
||||
|
||||
with patch('builtins.__import__', side_effect=mock_import):
|
||||
client = GraphQLClient()
|
||||
|
||||
with pytest.raises(ImportError, match="requests is required"):
|
||||
client.execute('{ databaseStats { totalFiles } }')
|
||||
|
||||
|
||||
class TestGraphQLQueries:
|
||||
"""Test actual GraphQL query execution."""
|
||||
|
||||
def test_simple_database_stats_query(self, temp_db_path):
|
||||
"""Test simple database stats query."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = """
|
||||
{
|
||||
databaseStats {
|
||||
totalFiles
|
||||
totalSchemas
|
||||
totalSizeBytes
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query, context={'db_path': temp_db_path})
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data is not None
|
||||
assert 'databaseStats' in result.data
|
||||
if result.data['databaseStats'] is not None:
|
||||
assert result.data['databaseStats']['totalFiles'] >= 1
|
||||
assert result.data['databaseStats']['totalSchemas'] >= 1
|
||||
|
||||
def test_markdown_file_query_with_computed_fields(self, temp_db_path):
|
||||
"""Test markdown file query with computed fields."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = """
|
||||
{
|
||||
markdownFile(id: 1) {
|
||||
id
|
||||
filename
|
||||
content
|
||||
wordCount
|
||||
lineCount
|
||||
hasFrontMatter
|
||||
frontMatter {
|
||||
key
|
||||
value
|
||||
}
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query, context={'db_path': temp_db_path})
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data is not None
|
||||
data = result.data['markdownFile']
|
||||
if data is not None:
|
||||
assert data['id'] == 1
|
||||
assert data['filename'] == 'test.md'
|
||||
assert data['wordCount'] > 0
|
||||
assert data['lineCount'] > 0
|
||||
assert data['hasFrontMatter'] is True
|
||||
assert len(data['frontMatter']) > 0
|
||||
|
||||
def test_search_query(self, temp_db_path):
|
||||
"""Test search functionality."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = """
|
||||
{
|
||||
search(query: "Test", type: "all", limit: 10) {
|
||||
type
|
||||
score
|
||||
file {
|
||||
filename
|
||||
}
|
||||
schema {
|
||||
title
|
||||
}
|
||||
highlight
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query, context={'db_path': temp_db_path})
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data is not None
|
||||
if result.data['search'] is not None:
|
||||
assert len(result.data['search']) >= 0
|
||||
|
||||
def test_pagination_query(self, temp_db_path):
|
||||
"""Test pagination in list queries."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = """
|
||||
{
|
||||
markdownFiles(limit: 1, offset: 0) {
|
||||
id
|
||||
filename
|
||||
}
|
||||
}
|
||||
"""
|
||||
|
||||
result = schema.execute(query, context={'db_path': temp_db_path})
|
||||
|
||||
assert result.errors is None
|
||||
assert result.data is not None
|
||||
if result.data['markdownFiles'] is not None:
|
||||
assert len(result.data['markdownFiles']) <= 1
|
||||
|
||||
|
||||
@pytest.mark.e2e
|
||||
class TestGraphQLCLIIntegration:
|
||||
"""Test GraphQL CLI command integration."""
|
||||
|
||||
def test_graphql_schema_command(self, isolated_environment):
|
||||
"""Test graphql-schema CLI command."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "markitect.cli", "graphql-schema", "--local"],
|
||||
env=isolated_environment,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path.cwd()
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "type Query" in result.stdout
|
||||
|
||||
def test_graphql_query_command(self, isolated_environment):
|
||||
"""Test graphql-query CLI command."""
|
||||
query = "{ databaseStats { totalFiles } }"
|
||||
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "markitect.cli", "graphql-query", query, "--local"],
|
||||
env=isolated_environment,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path.cwd()
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
# The database might be empty in test environment, so check for JSON structure
|
||||
assert "databaseStats" in result.stdout
|
||||
|
||||
def test_graphql_examples_command(self, isolated_environment):
|
||||
"""Test graphql-examples CLI command."""
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "markitect.cli", "graphql-examples"],
|
||||
env=isolated_environment,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path.cwd()
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "GraphQL Query Examples" in result.stdout
|
||||
assert "databaseStats" in result.stdout
|
||||
|
||||
@patch('markitect.graphql.server.GraphQLServer')
|
||||
def test_graphql_serve_command(self, mock_server_class, isolated_environment):
|
||||
"""Test graphql-serve CLI command."""
|
||||
mock_server = Mock()
|
||||
mock_server_class.return_value = mock_server
|
||||
|
||||
# We can't actually start the server in tests, so we just test command parsing
|
||||
result = subprocess.run(
|
||||
[sys.executable, "-m", "markitect.cli", "graphql-serve", "--help"],
|
||||
env=isolated_environment,
|
||||
capture_output=True,
|
||||
text=True,
|
||||
cwd=Path.cwd()
|
||||
)
|
||||
|
||||
assert result.returncode == 0
|
||||
assert "Start GraphQL server" in result.stdout
|
||||
|
||||
|
||||
class TestErrorHandling:
|
||||
"""Test error handling in GraphQL interface."""
|
||||
|
||||
def test_invalid_query_syntax(self, temp_db_path):
|
||||
"""Test handling of invalid GraphQL syntax."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = "{ invalidSyntax }"
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert result.errors is not None
|
||||
assert len(result.errors) > 0
|
||||
|
||||
def test_nonexistent_field_query(self, temp_db_path):
|
||||
"""Test querying nonexistent fields."""
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = "{ nonexistentField }"
|
||||
|
||||
result = schema.execute(query)
|
||||
|
||||
assert result.errors is not None
|
||||
|
||||
def test_resolver_database_error(self, temp_db_path):
|
||||
"""Test resolver behavior when database is corrupted."""
|
||||
# Corrupt the database file
|
||||
with open(temp_db_path, 'w') as f:
|
||||
f.write("corrupted data")
|
||||
|
||||
with patch('markitect.graphql.resolvers.get_default_database_path', return_value=temp_db_path):
|
||||
query = "{ databaseStats { totalFiles } }"
|
||||
|
||||
result = schema.execute(query, context={'db_path': temp_db_path})
|
||||
|
||||
# Should handle database errors gracefully - either with errors or None data
|
||||
assert result.errors is not None or result.data['databaseStats'] is None
|
||||
|
||||
|
||||
class TestUtilityFunctions:
|
||||
"""Test utility functions in GraphQL module."""
|
||||
|
||||
def test_get_default_database_path_with_env(self):
|
||||
"""Test get_default_database_path with environment variable."""
|
||||
with patch.dict(os.environ, {'MARKITECT_DB': '/custom/path.db'}):
|
||||
path = get_default_database_path()
|
||||
assert path == '/custom/path.db'
|
||||
|
||||
def test_get_default_database_path_default(self):
|
||||
"""Test get_default_database_path with default location."""
|
||||
with patch.dict(os.environ, {}, clear=True):
|
||||
path = get_default_database_path()
|
||||
assert path.endswith('markitect.db')
|
||||
assert '.markitect' in path
|
||||
0
markitect/plugins/tests/__init__.py
Normal file
0
markitect/plugins/tests/__init__.py
Normal file
855
markitect/plugins/tests/test_issue_19_plugin_architecture.py
Normal file
855
markitect/plugins/tests/test_issue_19_plugin_architecture.py
Normal file
@@ -0,0 +1,855 @@
|
||||
"""
|
||||
Tests for Issue #19: Plugin Architecture and Extensions System
|
||||
|
||||
This module provides comprehensive tests for the MarkiTect plugin system
|
||||
including plugin discovery, loading, management, and CLI integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
import tempfile
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import Mock, patch
|
||||
|
||||
from markitect.plugins import (
|
||||
PluginManager,
|
||||
BasePlugin,
|
||||
ProcessorPlugin,
|
||||
FormatterPlugin,
|
||||
PluginType,
|
||||
PluginMetadata,
|
||||
plugin_registry,
|
||||
register_plugin
|
||||
)
|
||||
from markitect.plugins.manager import PluginManager
|
||||
from markitect.plugins.registry import PluginRegistry
|
||||
|
||||
|
||||
class TestPluginArchitecture:
|
||||
"""Test suite for plugin architecture components."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
# Clear plugin registry for clean tests
|
||||
plugin_registry.cleanup_all()
|
||||
plugin_registry._plugins.clear()
|
||||
plugin_registry._instances.clear()
|
||||
plugin_registry._plugins_by_type.clear()
|
||||
|
||||
def teardown_method(self):
|
||||
"""Clean up after tests."""
|
||||
plugin_registry.cleanup_all()
|
||||
plugin_registry._plugins.clear()
|
||||
plugin_registry._instances.clear()
|
||||
plugin_registry._plugins_by_type.clear()
|
||||
|
||||
|
||||
class TestPluginBase:
|
||||
"""Test base plugin functionality."""
|
||||
|
||||
def test_plugin_metadata_creation(self):
|
||||
"""Test PluginMetadata creation and properties."""
|
||||
metadata = PluginMetadata(
|
||||
name="test_plugin",
|
||||
version="1.0.0",
|
||||
description="Test plugin",
|
||||
author="Test Author",
|
||||
plugin_type=PluginType.PROCESSOR,
|
||||
dependencies=["dep1", "dep2"],
|
||||
markitect_version=">=0.1.0"
|
||||
)
|
||||
|
||||
assert metadata.name == "test_plugin"
|
||||
assert metadata.version == "1.0.0"
|
||||
assert metadata.description == "Test plugin"
|
||||
assert metadata.author == "Test Author"
|
||||
assert metadata.plugin_type == PluginType.PROCESSOR
|
||||
assert metadata.dependencies == ["dep1", "dep2"]
|
||||
assert metadata.markitect_version == ">=0.1.0"
|
||||
|
||||
def test_base_plugin_initialization(self):
|
||||
"""Test BasePlugin initialization."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
config = {"option1": "value1", "option2": "value2"}
|
||||
plugin = TestPlugin(config)
|
||||
|
||||
assert plugin.config == config
|
||||
assert not plugin.is_initialized
|
||||
|
||||
def test_plugin_initialization_lifecycle(self):
|
||||
"""Test plugin initialization and cleanup lifecycle."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
def __init__(self, config=None):
|
||||
super().__init__(config)
|
||||
self.initialized = False
|
||||
self.cleaned_up = False
|
||||
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
def _initialize(self):
|
||||
self.initialized = True
|
||||
|
||||
def cleanup(self):
|
||||
self.cleaned_up = True
|
||||
|
||||
plugin = TestPlugin()
|
||||
assert not plugin.initialized
|
||||
assert not plugin.is_initialized
|
||||
|
||||
# Test initialization
|
||||
result = plugin.initialize()
|
||||
assert result is True
|
||||
assert plugin.initialized
|
||||
assert plugin.is_initialized
|
||||
|
||||
# Test cleanup
|
||||
plugin.cleanup()
|
||||
assert plugin.cleaned_up
|
||||
|
||||
def test_plugin_initialization_failure(self):
|
||||
"""Test plugin initialization failure handling."""
|
||||
|
||||
class FailingPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="failing",
|
||||
version="1.0.0",
|
||||
description="Failing plugin",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
def _initialize(self):
|
||||
raise Exception("Initialization failed")
|
||||
|
||||
plugin = FailingPlugin()
|
||||
result = plugin.initialize()
|
||||
assert result is False
|
||||
assert not plugin.is_initialized
|
||||
|
||||
|
||||
class TestProcessorPlugin:
|
||||
"""Test processor plugin functionality."""
|
||||
|
||||
def test_processor_plugin_interface(self):
|
||||
"""Test processor plugin interface implementation."""
|
||||
|
||||
class TestProcessor(ProcessorPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test_processor",
|
||||
version="1.0.0",
|
||||
description="Test processor",
|
||||
plugin_type=PluginType.PROCESSOR
|
||||
)
|
||||
|
||||
def process(self, content: str, **kwargs) -> str:
|
||||
return content.upper()
|
||||
|
||||
processor = TestProcessor()
|
||||
result = processor.process("hello world")
|
||||
assert result == "HELLO WORLD"
|
||||
|
||||
# Test default can_process implementation
|
||||
assert processor.can_process("any content")
|
||||
|
||||
def test_processor_plugin_with_options(self):
|
||||
"""Test processor plugin with processing options."""
|
||||
|
||||
class ConfigurableProcessor(ProcessorPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="configurable_processor",
|
||||
version="1.0.0",
|
||||
description="Configurable processor",
|
||||
plugin_type=PluginType.PROCESSOR
|
||||
)
|
||||
|
||||
def process(self, content: str, **kwargs) -> str:
|
||||
if kwargs.get('uppercase', False):
|
||||
content = content.upper()
|
||||
if kwargs.get('reverse', False):
|
||||
content = content[::-1]
|
||||
return content
|
||||
|
||||
processor = ConfigurableProcessor()
|
||||
|
||||
# Test with no options
|
||||
result = processor.process("hello")
|
||||
assert result == "hello"
|
||||
|
||||
# Test with uppercase option
|
||||
result = processor.process("hello", uppercase=True)
|
||||
assert result == "HELLO"
|
||||
|
||||
# Test with both options
|
||||
result = processor.process("hello", uppercase=True, reverse=True)
|
||||
assert result == "OLLEH"
|
||||
|
||||
|
||||
class TestFormatterPlugin:
|
||||
"""Test formatter plugin functionality."""
|
||||
|
||||
def test_formatter_plugin_interface(self):
|
||||
"""Test formatter plugin interface implementation."""
|
||||
|
||||
class TestFormatter(FormatterPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test_formatter",
|
||||
version="1.0.0",
|
||||
description="Test formatter",
|
||||
plugin_type=PluginType.FORMATTER
|
||||
)
|
||||
|
||||
def format(self, data, **kwargs) -> str:
|
||||
return json.dumps(data, indent=kwargs.get('indent', 2))
|
||||
|
||||
def get_file_extension(self) -> str:
|
||||
return '.json'
|
||||
|
||||
formatter = TestFormatter()
|
||||
data = {"key": "value", "number": 42}
|
||||
|
||||
result = formatter.format(data)
|
||||
parsed = json.loads(result)
|
||||
assert parsed == data
|
||||
|
||||
extension = formatter.get_file_extension()
|
||||
assert extension == '.json'
|
||||
|
||||
|
||||
class TestPluginRegistry:
|
||||
"""Test plugin registry functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
self.registry = PluginRegistry()
|
||||
|
||||
def test_plugin_registration(self):
|
||||
"""Test plugin registration."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
# Test registration
|
||||
name = self.registry.register(TestPlugin)
|
||||
assert name == "TestPlugin"
|
||||
assert "TestPlugin" in self.registry._plugins
|
||||
|
||||
# Test registration with custom name
|
||||
custom_name = self.registry.register(TestPlugin, "custom_name")
|
||||
assert custom_name == "custom_name"
|
||||
assert "custom_name" in self.registry._plugins
|
||||
|
||||
def test_plugin_registration_duplicate_name(self):
|
||||
"""Test plugin registration with duplicate name."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
self.registry.register(TestPlugin, "test_name")
|
||||
|
||||
# Should raise error for duplicate name
|
||||
with pytest.raises(ValueError, match="already registered"):
|
||||
self.registry.register(TestPlugin, "test_name")
|
||||
|
||||
def test_plugin_retrieval(self):
|
||||
"""Test plugin retrieval from registry."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
self.registry.register(TestPlugin, "test_plugin")
|
||||
|
||||
# Test successful retrieval
|
||||
plugin = self.registry.get_plugin("test_plugin")
|
||||
assert plugin is not None
|
||||
assert isinstance(plugin, TestPlugin)
|
||||
|
||||
# Test non-existent plugin
|
||||
plugin = self.registry.get_plugin("non_existent")
|
||||
assert plugin is None
|
||||
|
||||
def test_plugin_unregistration(self):
|
||||
"""Test plugin unregistration."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
self.registry.register(TestPlugin, "test_plugin")
|
||||
plugin = self.registry.get_plugin("test_plugin")
|
||||
assert plugin is not None
|
||||
|
||||
# Test unregistration
|
||||
result = self.registry.unregister("test_plugin")
|
||||
assert result is True
|
||||
|
||||
# Plugin should no longer be available
|
||||
plugin = self.registry.get_plugin("test_plugin")
|
||||
assert plugin is None
|
||||
|
||||
# Test unregistering non-existent plugin
|
||||
result = self.registry.unregister("non_existent")
|
||||
assert result is False
|
||||
|
||||
def test_plugins_by_type(self):
|
||||
"""Test retrieving plugins by type."""
|
||||
|
||||
class ProcessorPlugin1(ProcessorPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="processor1",
|
||||
version="1.0.0",
|
||||
description="Processor 1",
|
||||
plugin_type=PluginType.PROCESSOR
|
||||
)
|
||||
|
||||
def process(self, content, **kwargs):
|
||||
return content
|
||||
|
||||
class FormatterPlugin1(FormatterPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="formatter1",
|
||||
version="1.0.0",
|
||||
description="Formatter 1",
|
||||
plugin_type=PluginType.FORMATTER
|
||||
)
|
||||
|
||||
def format(self, data, **kwargs):
|
||||
return str(data)
|
||||
|
||||
def get_file_extension(self):
|
||||
return '.txt'
|
||||
|
||||
self.registry.register(ProcessorPlugin1, "processor1")
|
||||
self.registry.register(FormatterPlugin1, "formatter1")
|
||||
|
||||
# Test getting processors
|
||||
processors = self.registry.get_plugins_by_type(PluginType.PROCESSOR)
|
||||
assert "processor1" in processors
|
||||
assert "formatter1" not in processors
|
||||
|
||||
# Test getting formatters
|
||||
formatters = self.registry.get_plugins_by_type(PluginType.FORMATTER)
|
||||
assert "formatter1" in formatters
|
||||
assert "processor1" not in formatters
|
||||
|
||||
def test_list_plugins(self):
|
||||
"""Test listing all plugins with metadata."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test plugin",
|
||||
author="Test Author",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
self.registry.register(TestPlugin, "test_plugin")
|
||||
|
||||
plugins = self.registry.list_plugins()
|
||||
assert "test_plugin" in plugins
|
||||
|
||||
plugin_info = plugins["test_plugin"]
|
||||
assert plugin_info["name"] == "test"
|
||||
assert plugin_info["version"] == "1.0.0"
|
||||
assert plugin_info["description"] == "Test plugin"
|
||||
assert plugin_info["author"] == "Test Author"
|
||||
assert plugin_info["type"] == "extension"
|
||||
|
||||
|
||||
class TestPluginManager:
|
||||
"""Test plugin manager functionality."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
# Clear plugin registry
|
||||
plugin_registry.cleanup_all()
|
||||
plugin_registry._plugins.clear()
|
||||
plugin_registry._instances.clear()
|
||||
plugin_registry._plugins_by_type.clear()
|
||||
|
||||
def test_plugin_manager_initialization(self):
|
||||
"""Test plugin manager initialization."""
|
||||
manager = PluginManager()
|
||||
assert manager.config is not None
|
||||
assert isinstance(manager.plugin_directories, list)
|
||||
|
||||
def test_plugin_manager_with_config(self):
|
||||
"""Test plugin manager with custom configuration."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
|
||||
f.write("""
|
||||
plugin_directories:
|
||||
- "custom_plugins"
|
||||
auto_discover: false
|
||||
plugins:
|
||||
test_plugin:
|
||||
enabled: true
|
||||
""")
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
manager = PluginManager(config_path)
|
||||
assert "custom_plugins" in manager.config.get('plugin_directories', [])
|
||||
assert manager.config.get('auto_discover') is False
|
||||
assert 'test_plugin' in manager.config.get('plugins', {})
|
||||
finally:
|
||||
os.unlink(config_path)
|
||||
|
||||
def test_plugin_discovery_empty(self):
|
||||
"""Test plugin discovery with no plugins."""
|
||||
manager = PluginManager()
|
||||
discovered = manager.discover_plugins()
|
||||
# Should be a dictionary (empty or with built-ins)
|
||||
assert isinstance(discovered, dict)
|
||||
|
||||
@patch('importlib.import_module')
|
||||
def test_load_plugin_success(self, mock_import):
|
||||
"""Test successful plugin loading."""
|
||||
|
||||
class TestPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test",
|
||||
version="1.0.0",
|
||||
description="Test",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
# Mock module with plugin
|
||||
mock_module = Mock()
|
||||
mock_module.TestPlugin = TestPlugin
|
||||
mock_import.return_value = mock_module
|
||||
|
||||
manager = PluginManager()
|
||||
|
||||
# Manually add to discovered plugins
|
||||
manager._discovered_plugins = {
|
||||
"test_plugin": {
|
||||
"module_name": "test_module",
|
||||
"class_name": "TestPlugin"
|
||||
}
|
||||
}
|
||||
|
||||
plugin = manager.load_plugin("test_plugin")
|
||||
assert plugin is not None
|
||||
assert isinstance(plugin, TestPlugin)
|
||||
|
||||
def test_load_plugin_not_found(self):
|
||||
"""Test loading non-existent plugin."""
|
||||
manager = PluginManager()
|
||||
plugin = manager.load_plugin("non_existent_plugin")
|
||||
assert plugin is None
|
||||
|
||||
def test_get_plugins_by_type(self):
|
||||
"""Test getting plugins by type."""
|
||||
|
||||
class TestProcessor(ProcessorPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="test_processor",
|
||||
version="1.0.0",
|
||||
description="Test processor",
|
||||
plugin_type=PluginType.PROCESSOR
|
||||
)
|
||||
|
||||
def process(self, content, **kwargs):
|
||||
return content
|
||||
|
||||
# Register plugin directly
|
||||
plugin_registry.register(TestProcessor, "test_processor")
|
||||
|
||||
manager = PluginManager()
|
||||
processors = manager.get_plugins_by_type(PluginType.PROCESSOR)
|
||||
|
||||
# Should have at least our test processor
|
||||
assert len(processors) >= 1
|
||||
assert any(isinstance(p, TestProcessor) for p in processors)
|
||||
|
||||
|
||||
class TestPluginDecorator:
|
||||
"""Test plugin registration decorator."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
# Clear plugin registry
|
||||
plugin_registry.cleanup_all()
|
||||
plugin_registry._plugins.clear()
|
||||
plugin_registry._instances.clear()
|
||||
plugin_registry._plugins_by_type.clear()
|
||||
|
||||
def test_register_plugin_decorator(self):
|
||||
"""Test @register_plugin decorator."""
|
||||
|
||||
@register_plugin("decorated_plugin")
|
||||
class DecoratedPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="decorated",
|
||||
version="1.0.0",
|
||||
description="Decorated plugin",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
# Plugin should be automatically registered
|
||||
assert "decorated_plugin" in plugin_registry._plugins
|
||||
|
||||
# Should be able to retrieve it
|
||||
plugin = plugin_registry.get_plugin("decorated_plugin")
|
||||
assert plugin is not None
|
||||
assert isinstance(plugin, DecoratedPlugin)
|
||||
|
||||
def test_register_plugin_decorator_no_name(self):
|
||||
"""Test @register_plugin decorator without name."""
|
||||
|
||||
@register_plugin()
|
||||
class AutoNamedPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="auto_named",
|
||||
version="1.0.0",
|
||||
description="Auto named plugin",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
# Should use class name
|
||||
assert "AutoNamedPlugin" in plugin_registry._plugins
|
||||
|
||||
|
||||
class TestBuiltinPlugins:
|
||||
"""Test built-in plugins."""
|
||||
|
||||
def test_json_formatter_plugin(self):
|
||||
"""Test built-in JSON formatter plugin."""
|
||||
from markitect.plugins.builtin.formatters import JsonFormatter
|
||||
|
||||
formatter = JsonFormatter()
|
||||
assert formatter.metadata.plugin_type == PluginType.FORMATTER
|
||||
|
||||
data = {"key": "value", "number": 42}
|
||||
result = formatter.format(data)
|
||||
|
||||
parsed = json.loads(result)
|
||||
assert parsed == data
|
||||
|
||||
assert formatter.get_file_extension() == '.json'
|
||||
|
||||
def test_table_formatter_plugin(self):
|
||||
"""Test built-in table formatter plugin."""
|
||||
from markitect.plugins.builtin.formatters import TableFormatter
|
||||
|
||||
formatter = TableFormatter()
|
||||
assert formatter.metadata.plugin_type == PluginType.FORMATTER
|
||||
|
||||
# Test with list of dictionaries
|
||||
data = [
|
||||
{"name": "John", "age": 30},
|
||||
{"name": "Jane", "age": 25}
|
||||
]
|
||||
|
||||
result = formatter.format(data)
|
||||
assert "John" in result
|
||||
assert "Jane" in result
|
||||
assert "name" in result
|
||||
assert "age" in result
|
||||
|
||||
assert formatter.get_file_extension() == '.txt'
|
||||
|
||||
def test_markdown_processor_plugin(self):
|
||||
"""Test built-in markdown processor plugin."""
|
||||
from markitect.plugins.builtin.processors import MarkdownProcessor
|
||||
|
||||
processor = MarkdownProcessor()
|
||||
assert processor.metadata.plugin_type == PluginType.PROCESSOR
|
||||
|
||||
# Test basic processing
|
||||
content = "# Header\n\nSome content\n"
|
||||
result = processor.process(content)
|
||||
assert isinstance(result, str)
|
||||
|
||||
# Test can_process
|
||||
assert processor.can_process("# Markdown header")
|
||||
assert processor.can_process("Some **bold** text")
|
||||
|
||||
|
||||
class TestPluginCLIIntegration:
|
||||
"""Test plugin CLI command integration."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
# Clear plugin registry
|
||||
plugin_registry.cleanup_all()
|
||||
plugin_registry._plugins.clear()
|
||||
plugin_registry._instances.clear()
|
||||
plugin_registry._plugins_by_type.clear()
|
||||
|
||||
def test_plugin_list_command_import(self):
|
||||
"""Test that plugin CLI commands can be imported."""
|
||||
# This tests that the CLI commands are properly integrated
|
||||
from markitect.cli import plugin_list, plugin_load, plugin_info
|
||||
|
||||
assert callable(plugin_list)
|
||||
assert callable(plugin_load)
|
||||
assert callable(plugin_info)
|
||||
|
||||
def test_plugin_type_enum_import(self):
|
||||
"""Test that PluginType enum is accessible for CLI."""
|
||||
from markitect.plugins.base import PluginType
|
||||
|
||||
# Test all plugin types are available
|
||||
assert PluginType.PROCESSOR
|
||||
assert PluginType.FORMATTER
|
||||
assert PluginType.VALIDATOR
|
||||
assert PluginType.EXPORTER
|
||||
assert PluginType.GENERATOR
|
||||
assert PluginType.IMPORTER
|
||||
assert PluginType.TRANSFORMER
|
||||
assert PluginType.EXTENSION
|
||||
assert PluginType.BACKEND
|
||||
assert PluginType.COMMAND
|
||||
|
||||
# Test values are strings
|
||||
assert isinstance(PluginType.PROCESSOR.value, str)
|
||||
|
||||
|
||||
class TestPluginErrorHandling:
|
||||
"""Test plugin error handling and edge cases."""
|
||||
|
||||
def test_plugin_with_invalid_metadata(self):
|
||||
"""Test plugin with invalid metadata."""
|
||||
|
||||
class BadMetadataPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
# Missing required fields
|
||||
return None
|
||||
|
||||
plugin = BadMetadataPlugin()
|
||||
|
||||
# Should handle gracefully
|
||||
try:
|
||||
plugin_registry.register(BadMetadataPlugin, "bad_plugin")
|
||||
# Should not crash, might register as extension type
|
||||
except Exception:
|
||||
# Exception is acceptable for invalid metadata
|
||||
pass
|
||||
|
||||
def test_plugin_initialization_with_bad_config(self):
|
||||
"""Test plugin initialization with invalid configuration."""
|
||||
|
||||
class ConfigValidatingPlugin(BasePlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="config_validator",
|
||||
version="1.0.0",
|
||||
description="Config validating plugin",
|
||||
plugin_type=PluginType.EXTENSION
|
||||
)
|
||||
|
||||
def validate_config(self):
|
||||
errors = []
|
||||
if 'required_field' not in self.config:
|
||||
errors.append("Missing required_field")
|
||||
return errors
|
||||
|
||||
# Test with invalid config
|
||||
plugin = ConfigValidatingPlugin({"wrong_field": "value"})
|
||||
errors = plugin.validate_config()
|
||||
assert len(errors) > 0
|
||||
assert "required_field" in errors[0]
|
||||
|
||||
# Test with valid config
|
||||
plugin = ConfigValidatingPlugin({"required_field": "value"})
|
||||
errors = plugin.validate_config()
|
||||
assert len(errors) == 0
|
||||
|
||||
def test_plugin_manager_with_invalid_config_file(self):
|
||||
"""Test plugin manager with invalid configuration file."""
|
||||
with tempfile.NamedTemporaryFile(mode='w', suffix='.yml', delete=False) as f:
|
||||
f.write("invalid: yaml: content: [") # Invalid YAML
|
||||
config_path = f.name
|
||||
|
||||
try:
|
||||
# Should not crash, should use defaults
|
||||
manager = PluginManager(config_path)
|
||||
assert manager.config is not None
|
||||
# Should fall back to defaults
|
||||
assert 'plugin_directories' in manager.config
|
||||
finally:
|
||||
os.unlink(config_path)
|
||||
|
||||
|
||||
class TestPluginIntegration:
|
||||
"""Integration tests for the plugin system."""
|
||||
|
||||
def setup_method(self):
|
||||
"""Set up test environment."""
|
||||
# Clear plugin registry
|
||||
plugin_registry.cleanup_all()
|
||||
plugin_registry._plugins.clear()
|
||||
plugin_registry._instances.clear()
|
||||
plugin_registry._plugins_by_type.clear()
|
||||
|
||||
def test_end_to_end_plugin_workflow(self):
|
||||
"""Test complete plugin workflow from registration to usage."""
|
||||
|
||||
# 1. Create a plugin
|
||||
@register_plugin("workflow_processor")
|
||||
class WorkflowProcessor(ProcessorPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="workflow_processor",
|
||||
version="1.0.0",
|
||||
description="End-to-end workflow processor",
|
||||
plugin_type=PluginType.PROCESSOR
|
||||
)
|
||||
|
||||
def process(self, content, **kwargs):
|
||||
prefix = kwargs.get('prefix', self.config.get('prefix', ''))
|
||||
return f"{prefix}{content}"
|
||||
|
||||
# 2. Verify registration
|
||||
assert "workflow_processor" in plugin_registry._plugins
|
||||
|
||||
# 3. Create manager and load plugin
|
||||
manager = PluginManager()
|
||||
plugin = manager.load_plugin("workflow_processor", {"prefix": ">> "})
|
||||
|
||||
# 4. Use plugin
|
||||
assert plugin is not None
|
||||
result = plugin.process("Hello World")
|
||||
assert result == ">> Hello World"
|
||||
|
||||
# 5. Verify plugin is in registry
|
||||
assert plugin_registry.is_loaded("workflow_processor")
|
||||
|
||||
# 6. Get plugin by type
|
||||
processors = manager.get_plugins_by_type(PluginType.PROCESSOR)
|
||||
assert any(isinstance(p, WorkflowProcessor) for p in processors)
|
||||
|
||||
# 7. Unload plugin
|
||||
success = manager.unload_plugin("workflow_processor")
|
||||
assert success is True
|
||||
assert not plugin_registry.is_loaded("workflow_processor")
|
||||
|
||||
def test_multiple_plugins_interaction(self):
|
||||
"""Test interaction between multiple plugins."""
|
||||
|
||||
# Register multiple plugins
|
||||
@register_plugin("upper_processor")
|
||||
class UpperProcessor(ProcessorPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="upper_processor",
|
||||
version="1.0.0",
|
||||
description="Uppercase processor",
|
||||
plugin_type=PluginType.PROCESSOR
|
||||
)
|
||||
|
||||
def process(self, content, **kwargs):
|
||||
return content.upper()
|
||||
|
||||
@register_plugin("json_test_formatter")
|
||||
class JsonTestFormatter(FormatterPlugin):
|
||||
@property
|
||||
def metadata(self):
|
||||
return PluginMetadata(
|
||||
name="json_test_formatter",
|
||||
version="1.0.0",
|
||||
description="JSON test formatter",
|
||||
plugin_type=PluginType.FORMATTER
|
||||
)
|
||||
|
||||
def format(self, data, **kwargs):
|
||||
return json.dumps(data)
|
||||
|
||||
def get_file_extension(self):
|
||||
return '.json'
|
||||
|
||||
manager = PluginManager()
|
||||
|
||||
# Load both plugins
|
||||
processor = manager.load_plugin("upper_processor")
|
||||
formatter = manager.load_plugin("json_test_formatter")
|
||||
|
||||
assert processor is not None
|
||||
assert formatter is not None
|
||||
|
||||
# Use them together
|
||||
processed = processor.process("hello world")
|
||||
formatted = formatter.format({"result": processed})
|
||||
|
||||
data = json.loads(formatted)
|
||||
assert data["result"] == "HELLO WORLD"
|
||||
|
||||
# Verify both are loaded
|
||||
assert plugin_registry.is_loaded("upper_processor")
|
||||
assert plugin_registry.is_loaded("json_test_formatter")
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
pytest.main([__file__])
|
||||
627
markitect/plugins/tests/test_issue_83_full_text_search.py
Normal file
627
markitect/plugins/tests/test_issue_83_full_text_search.py
Normal file
@@ -0,0 +1,627 @@
|
||||
"""
|
||||
Tests for Issue #83: Full text search functionality.
|
||||
|
||||
Tests the FTS5-based full text search plugin including indexing,
|
||||
querying, and CLI integration.
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import tempfile
|
||||
import sqlite3
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from unittest.mock import patch, MagicMock
|
||||
|
||||
from markitect.plugins.builtin.search import FTSSearchPlugin, SearchIndexer, QueryParser
|
||||
from markitect.database import DatabaseManager
|
||||
|
||||
|
||||
class TestSearchIndexer:
|
||||
"""Test the search indexing functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self):
|
||||
"""Create a temporary database for testing."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Initialize database with test data
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Add test markdown files
|
||||
db_manager.store_markdown_file("test1.md", "# Test Document\n\nThis is a test document about API development.")
|
||||
db_manager.store_markdown_file("test2.md", "# Another Document\n\nGraphQL interface documentation.")
|
||||
db_manager.store_markdown_file("test3.md", "---\ntitle: Blog Post\n---\n# My Blog\n\nContent about technology.")
|
||||
|
||||
# Add test schemas
|
||||
schema1 = {"type": "object", "title": "User Schema", "description": "Schema for user objects"}
|
||||
schema2 = {"type": "object", "title": "Product Schema", "description": "E-commerce product definition"}
|
||||
db_manager.store_schema_file("user.json", json.dumps(schema1))
|
||||
db_manager.store_schema_file("product.json", json.dumps(schema2))
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
def test_check_fts_availability(self, temp_db_path):
|
||||
"""Test checking FTS5 availability."""
|
||||
indexer = SearchIndexer()
|
||||
available = indexer.check_fts_availability(temp_db_path)
|
||||
|
||||
# FTS5 should be available in most modern SQLite installations
|
||||
assert isinstance(available, bool)
|
||||
|
||||
def test_initialize_fts_tables(self, temp_db_path):
|
||||
"""Test FTS5 table initialization."""
|
||||
indexer = SearchIndexer()
|
||||
indexer.initialize_fts_tables(temp_db_path)
|
||||
|
||||
# Check that FTS tables were created
|
||||
conn = sqlite3.connect(temp_db_path)
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name LIKE 'fts_%'")
|
||||
fts_tables = [row[0] for row in cursor.fetchall()]
|
||||
|
||||
if indexer.check_fts_availability(temp_db_path):
|
||||
assert 'fts_files' in fts_tables
|
||||
assert 'fts_schemas' in fts_tables
|
||||
else:
|
||||
# If FTS5 not available, should have status table
|
||||
cursor.execute("SELECT name FROM sqlite_master WHERE type='table' AND name='fts_status'")
|
||||
assert cursor.fetchone() is not None
|
||||
|
||||
conn.close()
|
||||
|
||||
def test_rebuild_index(self, temp_db_path):
|
||||
"""Test rebuilding search indexes."""
|
||||
indexer = SearchIndexer()
|
||||
indexer.initialize_fts_tables(temp_db_path)
|
||||
|
||||
stats = indexer.rebuild_index(temp_db_path)
|
||||
|
||||
assert 'files_indexed' in stats
|
||||
assert 'schemas_indexed' in stats
|
||||
|
||||
if indexer.check_fts_availability(temp_db_path):
|
||||
# If FTS5 is available, should index successfully
|
||||
assert stats['files_indexed'] >= 0
|
||||
assert stats['schemas_indexed'] >= 0
|
||||
else:
|
||||
# If FTS5 not available, might have error
|
||||
pass # Just check stats exist
|
||||
|
||||
def test_get_index_info(self, temp_db_path):
|
||||
"""Test getting index information."""
|
||||
indexer = SearchIndexer()
|
||||
indexer.initialize_fts_tables(temp_db_path)
|
||||
indexer.rebuild_index(temp_db_path)
|
||||
|
||||
info = indexer.get_index_info(temp_db_path)
|
||||
|
||||
assert 'fts_enabled' in info
|
||||
if info['fts_enabled']:
|
||||
assert 'fts_tables' in info
|
||||
assert 'fts_files_count' in info
|
||||
assert 'fts_schemas_count' in info
|
||||
|
||||
|
||||
class TestQueryParser:
|
||||
"""Test query parsing functionality."""
|
||||
|
||||
def test_parse_simple_query(self):
|
||||
"""Test parsing simple queries."""
|
||||
parser = QueryParser()
|
||||
|
||||
# Simple word
|
||||
result = parser.parse_query("test")
|
||||
assert "test*" in result
|
||||
|
||||
# Multiple words
|
||||
result = parser.parse_query("test document")
|
||||
assert "test*" in result
|
||||
assert "document*" in result
|
||||
assert "AND" in result
|
||||
|
||||
def test_parse_phrase_query(self):
|
||||
"""Test parsing phrase queries."""
|
||||
parser = QueryParser()
|
||||
|
||||
result = parser.parse_query('"exact phrase"')
|
||||
assert '"exact phrase"' in result
|
||||
|
||||
def test_parse_boolean_operators(self):
|
||||
"""Test parsing boolean operators."""
|
||||
parser = QueryParser()
|
||||
|
||||
# AND operator - if already FTS5, should be preserved
|
||||
result = parser.parse_query("test AND document")
|
||||
assert "test" in result
|
||||
assert "AND" in result
|
||||
assert "document" in result
|
||||
|
||||
# OR operator - if already FTS5, should be preserved
|
||||
result = parser.parse_query("test OR document")
|
||||
assert "test" in result
|
||||
assert "OR" in result
|
||||
assert "document" in result
|
||||
|
||||
# NOT operator - if already FTS5, should be preserved
|
||||
result = parser.parse_query("test NOT document")
|
||||
assert "test" in result
|
||||
assert "NOT" in result
|
||||
|
||||
def test_validate_query(self):
|
||||
"""Test query validation."""
|
||||
parser = QueryParser()
|
||||
|
||||
# Valid queries
|
||||
valid, error = parser.validate_query("test")
|
||||
assert valid
|
||||
assert error is None
|
||||
|
||||
valid, error = parser.validate_query('"exact phrase"')
|
||||
assert valid
|
||||
assert error is None
|
||||
|
||||
# Invalid queries
|
||||
valid, error = parser.validate_query('unmatched "quote')
|
||||
assert not valid
|
||||
assert "quotes" in error
|
||||
|
||||
valid, error = parser.validate_query("test (unmatched")
|
||||
assert not valid
|
||||
assert "parentheses" in error
|
||||
|
||||
def test_get_query_terms(self):
|
||||
"""Test extracting terms from queries."""
|
||||
parser = QueryParser()
|
||||
|
||||
terms = parser.get_query_terms("test document AND api")
|
||||
assert "test" in terms
|
||||
assert "document" in terms
|
||||
assert "api" in terms
|
||||
assert "AND" not in terms # Operators should be excluded
|
||||
|
||||
def test_build_column_query(self):
|
||||
"""Test building column-specific queries."""
|
||||
parser = QueryParser()
|
||||
|
||||
result = parser.build_column_query("test", ["title", "content"])
|
||||
assert "title:" in result
|
||||
assert "content:" in result
|
||||
assert "OR" in result
|
||||
|
||||
|
||||
class TestFTSSearchPlugin:
|
||||
"""Test the main FTS search plugin."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self):
|
||||
"""Create a temporary database with test data."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Initialize database with test data
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Add test markdown files
|
||||
db_manager.store_markdown_file("api-guide.md", "# API Guide\n\nComprehensive API development guide with examples.")
|
||||
db_manager.store_markdown_file("tutorial.md", "# GraphQL Tutorial\n\nLearn GraphQL basics and advanced concepts.")
|
||||
db_manager.store_markdown_file("readme.md", "---\ntitle: Project README\ntags: [documentation, guide]\n---\n# Project\n\nProject documentation and setup guide.")
|
||||
|
||||
# Add test schemas
|
||||
schema1 = {"type": "object", "title": "API Schema", "description": "REST API response schema", "properties": {"data": {"type": "object"}}}
|
||||
schema2 = {"type": "object", "title": "User Schema", "description": "User profile schema", "properties": {"name": {"type": "string"}}}
|
||||
db_manager.store_schema_file("api-schema.json", json.dumps(schema1))
|
||||
db_manager.store_schema_file("user-schema.json", json.dumps(schema2))
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
def test_plugin_metadata(self):
|
||||
"""Test plugin metadata."""
|
||||
plugin = FTSSearchPlugin()
|
||||
metadata = plugin.metadata
|
||||
|
||||
assert metadata.name == "fts_search"
|
||||
assert metadata.version == "1.0.0"
|
||||
assert "full text search" in metadata.description.lower()
|
||||
|
||||
def test_initialize_plugin(self, temp_db_path):
|
||||
"""Test plugin initialization."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
|
||||
# Check that FTS tables exist (if FTS5 is available)
|
||||
stats = plugin.get_search_stats(temp_db_path)
|
||||
assert 'fts_enabled' in stats
|
||||
|
||||
def test_search_files_only(self, temp_db_path):
|
||||
"""Test searching only in files."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
plugin.rebuild_index(temp_db_path)
|
||||
|
||||
results = plugin.search(temp_db_path, "API", content_type="files", limit=10)
|
||||
|
||||
# Should find files containing "API"
|
||||
assert isinstance(results, list)
|
||||
for result in results:
|
||||
assert result['type'] == 'file'
|
||||
assert 'file' in result
|
||||
assert 'score' in result
|
||||
|
||||
def test_search_schemas_only(self, temp_db_path):
|
||||
"""Test searching only in schemas."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
plugin.rebuild_index(temp_db_path)
|
||||
|
||||
results = plugin.search(temp_db_path, "schema", content_type="schemas", limit=10)
|
||||
|
||||
# Should find schemas
|
||||
assert isinstance(results, list)
|
||||
for result in results:
|
||||
assert result['type'] == 'schema'
|
||||
assert 'schema' in result
|
||||
assert 'score' in result
|
||||
|
||||
def test_search_all_content(self, temp_db_path):
|
||||
"""Test searching all content types."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
plugin.rebuild_index(temp_db_path)
|
||||
|
||||
results = plugin.search(temp_db_path, "guide", content_type="all", limit=10)
|
||||
|
||||
# Should find both files and schemas (or empty list if FTS5 unavailable)
|
||||
assert isinstance(results, list)
|
||||
|
||||
# If results found, should be properly formatted and sorted
|
||||
if results:
|
||||
# Results should be sorted by score
|
||||
scores = [result.get('score', 0) for result in results]
|
||||
assert scores == sorted(scores, reverse=True)
|
||||
|
||||
# Check result structure
|
||||
for result in results:
|
||||
assert 'type' in result
|
||||
assert 'score' in result
|
||||
|
||||
def test_search_with_pagination(self, temp_db_path):
|
||||
"""Test search with pagination."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
plugin.rebuild_index(temp_db_path)
|
||||
|
||||
# Get first page
|
||||
results1 = plugin.search(temp_db_path, "guide", limit=1, offset=0)
|
||||
|
||||
# Get second page
|
||||
results2 = plugin.search(temp_db_path, "guide", limit=1, offset=1)
|
||||
|
||||
# Results should be different (if there are enough results)
|
||||
if len(results1) > 0 and len(results2) > 0:
|
||||
assert results1[0] != results2[0]
|
||||
|
||||
def test_fallback_search(self, temp_db_path):
|
||||
"""Test fallback search when FTS5 fails."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
|
||||
# Force fallback by using invalid FTS5 query syntax with mock
|
||||
with patch.object(plugin, '_search_files', side_effect=Exception("FTS5 error")):
|
||||
with patch.object(plugin, '_search_schemas', side_effect=Exception("FTS5 error")):
|
||||
results = plugin.search(temp_db_path, "API", content_type="all", limit=10)
|
||||
|
||||
# Should still return results via fallback
|
||||
assert isinstance(results, list)
|
||||
|
||||
def test_get_search_stats(self, temp_db_path):
|
||||
"""Test getting search statistics."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(temp_db_path)
|
||||
|
||||
stats = plugin.get_search_stats(temp_db_path)
|
||||
|
||||
assert 'fts_enabled' in stats
|
||||
assert 'fts_tables' in stats
|
||||
|
||||
|
||||
class TestSearchCLI:
|
||||
"""Test search CLI commands."""
|
||||
|
||||
@pytest.fixture
|
||||
def temp_db_path(self):
|
||||
"""Create a temporary database with test data."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
# Initialize database with test data
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Add test data
|
||||
db_manager.store_markdown_file("test.md", "# Test\n\nThis is a test document.")
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
def test_search_init_command(self, temp_db_path):
|
||||
"""Test the search init CLI command."""
|
||||
from click.testing import CliRunner
|
||||
from markitect.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
with patch('markitect.cli.get_database_path', return_value=temp_db_path):
|
||||
result = runner.invoke(cli, ['search', 'init'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Search indexes initialized" in result.output or "Search plugin not available" in result.output
|
||||
|
||||
def test_search_query_command(self, temp_db_path):
|
||||
"""Test the search query CLI command."""
|
||||
from click.testing import CliRunner
|
||||
from markitect.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
with patch('markitect.cli.get_database_path', return_value=temp_db_path):
|
||||
# Initialize search first
|
||||
runner.invoke(cli, ['search', 'init'])
|
||||
|
||||
# Perform search
|
||||
result = runner.invoke(cli, ['search', 'query', 'test'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
# Should either show results or indicate no search plugin
|
||||
assert "results" in result.output or "Search plugin not available" in result.output
|
||||
|
||||
def test_search_status_command(self, temp_db_path):
|
||||
"""Test the search status CLI command."""
|
||||
from click.testing import CliRunner
|
||||
from markitect.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
with patch('markitect.cli.get_database_path', return_value=temp_db_path):
|
||||
result = runner.invoke(cli, ['search', 'status'])
|
||||
|
||||
assert result.exit_code == 0
|
||||
assert "Search Index Status" in result.output or "Search plugin not available" in result.output
|
||||
|
||||
def test_search_rebuild_command(self, temp_db_path):
|
||||
"""Test the search rebuild CLI command."""
|
||||
from click.testing import CliRunner
|
||||
from markitect.cli import cli
|
||||
|
||||
runner = CliRunner()
|
||||
|
||||
with patch('markitect.cli.get_database_path', return_value=temp_db_path):
|
||||
# Initialize search first
|
||||
runner.invoke(cli, ['search', 'init'])
|
||||
|
||||
# Rebuild indexes
|
||||
result = runner.invoke(cli, ['search', 'rebuild'])
|
||||
|
||||
if result.exit_code != 0:
|
||||
print(f"Command output: {result.output}")
|
||||
print(f"Exception: {result.exception}")
|
||||
|
||||
# Should succeed or fail gracefully with plugin unavailable message or database error
|
||||
acceptable_errors = [
|
||||
"Search plugin not available",
|
||||
"database disk image is malformed", # Can happen with concurrent access
|
||||
"database is locked"
|
||||
]
|
||||
|
||||
if result.exit_code == 0:
|
||||
assert "Rebuilding search indexes" in result.output
|
||||
else:
|
||||
# Check if it's an acceptable error
|
||||
assert any(error in result.output for error in acceptable_errors)
|
||||
|
||||
|
||||
class TestSearchIntegration:
|
||||
"""Integration tests for search functionality."""
|
||||
|
||||
@pytest.fixture
|
||||
def populated_db_path(self):
|
||||
"""Create a database with realistic test data."""
|
||||
with tempfile.NamedTemporaryFile(suffix='.db', delete=False) as f:
|
||||
db_path = f.name
|
||||
|
||||
db_manager = DatabaseManager(db_path)
|
||||
db_manager.initialize_database()
|
||||
|
||||
# Add realistic markdown files
|
||||
files = [
|
||||
("api-documentation.md", """# API Documentation
|
||||
|
||||
## Authentication
|
||||
The API uses Bearer token authentication. Include your token in the Authorization header.
|
||||
|
||||
## Endpoints
|
||||
- GET /users - List all users
|
||||
- POST /users - Create a new user
|
||||
- GET /users/{id} - Get specific user
|
||||
|
||||
## Error Handling
|
||||
All errors return JSON with error message and status code.
|
||||
"""),
|
||||
("graphql-guide.md", """---
|
||||
title: GraphQL Complete Guide
|
||||
tags: [graphql, api, tutorial]
|
||||
author: Development Team
|
||||
---
|
||||
|
||||
# GraphQL Complete Guide
|
||||
|
||||
GraphQL is a query language for APIs and a runtime for executing those queries.
|
||||
|
||||
## Benefits
|
||||
- Single endpoint
|
||||
- Type safety
|
||||
- Efficient data fetching
|
||||
- Strong introspection
|
||||
|
||||
## Schema Definition
|
||||
Define your GraphQL schema using SDL (Schema Definition Language).
|
||||
"""),
|
||||
("project-readme.md", """# MarkiTect Project
|
||||
|
||||
MarkiTect is a comprehensive markdown content management and analysis system.
|
||||
|
||||
## Features
|
||||
- Document indexing and storage
|
||||
- Full text search capabilities
|
||||
- GraphQL API interface
|
||||
- Plugin system for extensibility
|
||||
|
||||
## Installation
|
||||
1. Clone the repository
|
||||
2. Install dependencies: pip install -r requirements.txt
|
||||
3. Initialize database: markitect init
|
||||
|
||||
## Usage Examples
|
||||
Search for content: markitect search query "API documentation"
|
||||
""")
|
||||
]
|
||||
|
||||
for filename, content in files:
|
||||
db_manager.store_markdown_file(filename, content)
|
||||
|
||||
# Add realistic schemas
|
||||
schemas = [
|
||||
("user-schema.json", {
|
||||
"type": "object",
|
||||
"title": "User Schema",
|
||||
"description": "Schema for user profile data in the API",
|
||||
"properties": {
|
||||
"id": {"type": "integer"},
|
||||
"name": {"type": "string"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"created_at": {"type": "string", "format": "date-time"}
|
||||
},
|
||||
"required": ["id", "name", "email"]
|
||||
}),
|
||||
("api-response-schema.json", {
|
||||
"type": "object",
|
||||
"title": "API Response Schema",
|
||||
"description": "Standard API response format for all endpoints",
|
||||
"properties": {
|
||||
"data": {"type": "object"},
|
||||
"success": {"type": "boolean"},
|
||||
"message": {"type": "string"},
|
||||
"errors": {"type": "array", "items": {"type": "string"}}
|
||||
},
|
||||
"required": ["success"]
|
||||
})
|
||||
]
|
||||
|
||||
for filename, schema in schemas:
|
||||
db_manager.store_schema_file(filename, json.dumps(schema))
|
||||
|
||||
yield db_path
|
||||
|
||||
# Cleanup
|
||||
os.unlink(db_path)
|
||||
|
||||
def test_end_to_end_search_workflow(self, populated_db_path):
|
||||
"""Test complete search workflow from initialization to querying."""
|
||||
plugin = FTSSearchPlugin()
|
||||
|
||||
# Initialize search
|
||||
plugin.initialize(populated_db_path)
|
||||
|
||||
# Rebuild indexes
|
||||
stats = plugin.rebuild_index(populated_db_path)
|
||||
|
||||
if plugin.indexer.check_fts_availability(populated_db_path):
|
||||
# If FTS5 is available, should index files
|
||||
assert stats['files_indexed'] >= 0
|
||||
assert stats['schemas_indexed'] >= 0
|
||||
else:
|
||||
# If FTS5 not available, might be 0
|
||||
pass
|
||||
|
||||
# Search for API-related content
|
||||
results = plugin.search(populated_db_path, "API", content_type="all", limit=10)
|
||||
|
||||
# Results should be a list (may be empty if FTS5 not available)
|
||||
assert isinstance(results, list)
|
||||
|
||||
# If we have results, verify they're properly formatted
|
||||
if results:
|
||||
# Should find both files and schemas
|
||||
result_types = {result['type'] for result in results}
|
||||
assert len(result_types) > 0 # At least one type found
|
||||
|
||||
# Verify results have required fields
|
||||
for result in results:
|
||||
assert 'type' in result
|
||||
assert 'score' in result
|
||||
assert result['score'] > 0
|
||||
|
||||
if result['type'] == 'file':
|
||||
assert 'file' in result
|
||||
assert 'filename' in result['file']
|
||||
elif result['type'] == 'schema':
|
||||
assert 'schema' in result
|
||||
assert 'filename' in result['schema']
|
||||
|
||||
def test_search_ranking_quality(self, populated_db_path):
|
||||
"""Test that search ranking produces sensible results."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(populated_db_path)
|
||||
plugin.rebuild_index(populated_db_path)
|
||||
|
||||
# Search for "GraphQL"
|
||||
results = plugin.search(populated_db_path, "GraphQL", content_type="files", limit=10)
|
||||
|
||||
if results:
|
||||
# The GraphQL guide should rank highest
|
||||
top_result = results[0]
|
||||
assert 'graphql' in top_result['file']['filename'].lower()
|
||||
|
||||
# Search for exact phrase
|
||||
results = plugin.search(populated_db_path, '"API documentation"', content_type="files", limit=10)
|
||||
|
||||
if results:
|
||||
# Should find exact phrase matches
|
||||
for result in results:
|
||||
content = result['file'].get('content', '').lower()
|
||||
# Either in content or highlighted
|
||||
assert 'api documentation' in content or 'api documentation' in result.get('highlight', '').lower()
|
||||
|
||||
def test_search_error_handling(self, populated_db_path):
|
||||
"""Test search error handling and edge cases."""
|
||||
plugin = FTSSearchPlugin()
|
||||
plugin.initialize(populated_db_path)
|
||||
|
||||
# Empty query
|
||||
results = plugin.search(populated_db_path, "", content_type="all", limit=10)
|
||||
assert isinstance(results, list)
|
||||
|
||||
# Very long query
|
||||
long_query = "word " * 100
|
||||
results = plugin.search(populated_db_path, long_query, content_type="all", limit=10)
|
||||
assert isinstance(results, list)
|
||||
|
||||
# Special characters
|
||||
results = plugin.search(populated_db_path, "query with @#$%", content_type="all", limit=10)
|
||||
assert isinstance(results, list)
|
||||
|
||||
# Zero limit
|
||||
results = plugin.search(populated_db_path, "API", content_type="all", limit=0)
|
||||
assert len(results) == 0
|
||||
0
markitect/query_paradigms/tests/__init__.py
Normal file
0
markitect/query_paradigms/tests/__init__.py
Normal file
333
markitect/query_paradigms/tests/test_query_paradigms.py
Normal file
333
markitect/query_paradigms/tests/test_query_paradigms.py
Normal file
@@ -0,0 +1,333 @@
|
||||
"""
|
||||
Tests for query paradigm system - Issue #62
|
||||
"""
|
||||
|
||||
import pytest
|
||||
import json
|
||||
from markitect.query_paradigms.registry import registry
|
||||
from markitect.query_paradigms.base import BaseQueryParadigm, QueryResult
|
||||
from markitect.query_paradigms.paradigms.sql_paradigm import SQLQueryParadigm
|
||||
from markitect.query_paradigms.paradigms.fts_paradigm import FullTextSearchParadigm
|
||||
from markitect.query_paradigms.paradigms.qbe_paradigm import QueryByExampleParadigm
|
||||
|
||||
|
||||
class TestQueryParadigmRegistry:
|
||||
"""Test the query paradigm registry system."""
|
||||
|
||||
def test_registry_has_paradigms(self):
|
||||
"""Test that paradigms are automatically registered."""
|
||||
paradigms = registry.list_all()
|
||||
assert len(paradigms) >= 14 # We expect at least 14 paradigms
|
||||
|
||||
# Check that key paradigms are present
|
||||
paradigm_names = [p.name for p in paradigms]
|
||||
assert "SQL" in paradigm_names
|
||||
assert "FTS" in paradigm_names
|
||||
assert "GraphQL" in paradigm_names
|
||||
assert "Natural Language" in paradigm_names
|
||||
|
||||
def test_get_paradigm_by_name(self):
|
||||
"""Test retrieving paradigms by name."""
|
||||
sql_paradigm = registry.get("SQL")
|
||||
assert sql_paradigm is not None
|
||||
assert sql_paradigm.name == "SQL"
|
||||
assert sql_paradigm.category == "structural"
|
||||
|
||||
# Test case insensitive lookup
|
||||
fts_paradigm = registry.get("fts")
|
||||
assert fts_paradigm is not None
|
||||
assert fts_paradigm.name == "FTS"
|
||||
|
||||
def test_get_nonexistent_paradigm(self):
|
||||
"""Test getting a paradigm that doesn't exist."""
|
||||
result = registry.get("NonExistentParadigm")
|
||||
assert result is None
|
||||
|
||||
def test_list_by_category(self):
|
||||
"""Test filtering paradigms by category."""
|
||||
structural = registry.list_by_category("structural")
|
||||
assert len(structural) > 0
|
||||
|
||||
for paradigm in structural:
|
||||
assert paradigm.category == "structural"
|
||||
|
||||
textual = registry.list_by_category("textual")
|
||||
assert len(textual) > 0
|
||||
|
||||
for paradigm in textual:
|
||||
assert paradigm.category == "textual"
|
||||
|
||||
def test_list_by_complexity(self):
|
||||
"""Test filtering paradigms by complexity."""
|
||||
beginner = registry.list_by_complexity("beginner")
|
||||
assert len(beginner) > 0
|
||||
|
||||
for paradigm in beginner:
|
||||
assert paradigm.complexity == "beginner"
|
||||
|
||||
def test_search_paradigms(self):
|
||||
"""Test searching paradigms by query."""
|
||||
# Search by name
|
||||
sql_results = registry.search_paradigms("SQL")
|
||||
assert len(sql_results) > 0
|
||||
assert any(p.name == "SQL" for p in sql_results)
|
||||
|
||||
# Search by description
|
||||
visual_results = registry.search_paradigms("visual")
|
||||
assert len(visual_results) > 0
|
||||
assert any("visual" in p.description.lower() for p in visual_results)
|
||||
|
||||
# Search for non-existent term
|
||||
empty_results = registry.search_paradigms("xyznonexistent")
|
||||
assert len(empty_results) == 0
|
||||
|
||||
def test_get_categories(self):
|
||||
"""Test getting all available categories."""
|
||||
categories = registry.get_categories()
|
||||
assert isinstance(categories, list)
|
||||
assert len(categories) > 0
|
||||
assert "structural" in categories
|
||||
assert "textual" in categories
|
||||
assert "semantic" in categories
|
||||
|
||||
def test_get_complexity_levels(self):
|
||||
"""Test getting all complexity levels."""
|
||||
levels = registry.get_complexity_levels()
|
||||
assert isinstance(levels, list)
|
||||
assert len(levels) > 0
|
||||
assert "beginner" in levels
|
||||
assert "intermediate" in levels
|
||||
assert "advanced" in levels
|
||||
|
||||
|
||||
class TestSQLParadigm:
|
||||
"""Test the SQL query paradigm."""
|
||||
|
||||
def test_paradigm_properties(self):
|
||||
"""Test SQL paradigm basic properties."""
|
||||
paradigm = SQLQueryParadigm()
|
||||
assert paradigm.name == "SQL"
|
||||
assert paradigm.category == "structural"
|
||||
assert paradigm.complexity == "intermediate"
|
||||
assert "database" in paradigm.description.lower()
|
||||
|
||||
def test_validate_query(self):
|
||||
"""Test SQL query validation."""
|
||||
paradigm = SQLQueryParadigm()
|
||||
|
||||
# Valid queries
|
||||
valid, error = paradigm.validate_query("SELECT * FROM files")
|
||||
assert valid
|
||||
assert error is None
|
||||
|
||||
valid, error = paradigm.validate_query("SELECT name FROM files WHERE author = 'Alice'")
|
||||
assert valid
|
||||
|
||||
# Invalid queries
|
||||
valid, error = paradigm.validate_query("")
|
||||
assert not valid
|
||||
assert error is not None
|
||||
|
||||
valid, error = paradigm.validate_query(" ")
|
||||
assert not valid
|
||||
|
||||
def test_get_examples(self):
|
||||
"""Test SQL paradigm examples."""
|
||||
paradigm = SQLQueryParadigm()
|
||||
examples = paradigm.get_examples()
|
||||
|
||||
assert isinstance(examples, list)
|
||||
assert len(examples) > 0
|
||||
|
||||
for example in examples:
|
||||
assert "name" in example
|
||||
assert "description" in example
|
||||
assert "query" in example
|
||||
assert isinstance(example["query"], str)
|
||||
|
||||
def test_get_syntax_help(self):
|
||||
"""Test SQL syntax help."""
|
||||
paradigm = SQLQueryParadigm()
|
||||
help_text = paradigm.get_syntax_help()
|
||||
|
||||
assert isinstance(help_text, str)
|
||||
assert len(help_text) > 0
|
||||
assert "SELECT" in help_text
|
||||
|
||||
|
||||
class TestFTSParadigm:
|
||||
"""Test the Full Text Search paradigm."""
|
||||
|
||||
def test_paradigm_properties(self):
|
||||
"""Test FTS paradigm basic properties."""
|
||||
paradigm = FullTextSearchParadigm()
|
||||
assert paradigm.name == "FTS"
|
||||
assert paradigm.category == "textual"
|
||||
assert paradigm.complexity == "beginner"
|
||||
assert "search" in paradigm.description.lower()
|
||||
|
||||
def test_validate_query(self):
|
||||
"""Test FTS query validation."""
|
||||
paradigm = FullTextSearchParadigm()
|
||||
|
||||
# Valid queries
|
||||
valid, error = paradigm.validate_query("documentation")
|
||||
assert valid
|
||||
assert error is None
|
||||
|
||||
valid, error = paradigm.validate_query("API AND documentation")
|
||||
assert valid
|
||||
|
||||
valid, error = paradigm.validate_query('"getting started"')
|
||||
assert valid
|
||||
|
||||
# Invalid queries
|
||||
valid, error = paradigm.validate_query("")
|
||||
assert not valid
|
||||
assert error is not None
|
||||
|
||||
def test_get_examples(self):
|
||||
"""Test FTS paradigm examples."""
|
||||
paradigm = FullTextSearchParadigm()
|
||||
examples = paradigm.get_examples()
|
||||
|
||||
assert isinstance(examples, list)
|
||||
assert len(examples) > 0
|
||||
|
||||
# Check for expected example types
|
||||
example_names = [ex["name"] for ex in examples]
|
||||
assert "Simple search" in example_names
|
||||
assert "Boolean search" in example_names
|
||||
|
||||
|
||||
class TestQueryByExampleParadigm:
|
||||
"""Test the Query By Example paradigm (documentation-only)."""
|
||||
|
||||
def test_paradigm_properties(self):
|
||||
"""Test QBE paradigm basic properties."""
|
||||
paradigm = QueryByExampleParadigm()
|
||||
assert paradigm.name == "Query By Example"
|
||||
assert paradigm.category == "visual"
|
||||
assert paradigm.complexity == "beginner"
|
||||
assert "template" in paradigm.description.lower()
|
||||
|
||||
def test_validate_query(self):
|
||||
"""Test QBE query validation."""
|
||||
paradigm = QueryByExampleParadigm()
|
||||
|
||||
# Valid JSON templates
|
||||
valid, error = paradigm.validate_query('{"author": "Alice"}')
|
||||
assert valid
|
||||
assert error is None
|
||||
|
||||
valid, error = paradigm.validate_query('{"tags": ["tutorial"], "type": "markdown"}')
|
||||
assert valid
|
||||
|
||||
# Invalid queries
|
||||
valid, error = paradigm.validate_query("")
|
||||
assert not valid
|
||||
assert error is not None
|
||||
|
||||
valid, error = paradigm.validate_query("not json")
|
||||
assert not valid
|
||||
assert "JSON" in error
|
||||
|
||||
valid, error = paradigm.validate_query('["not", "an", "object"]')
|
||||
assert not valid
|
||||
assert "object" in error
|
||||
|
||||
def test_execute_returns_not_implemented(self):
|
||||
"""Test that QBE execution returns not implemented error."""
|
||||
paradigm = QueryByExampleParadigm()
|
||||
result = paradigm.execute('{"author": "Alice"}')
|
||||
|
||||
assert isinstance(result, QueryResult)
|
||||
assert not result.success
|
||||
assert result.error_message is not None
|
||||
assert "not yet implemented" in result.error_message.lower()
|
||||
assert result.metadata["status"] == "not_implemented"
|
||||
|
||||
def test_get_syntax_help(self):
|
||||
"""Test QBE syntax help."""
|
||||
paradigm = QueryByExampleParadigm()
|
||||
help_text = paradigm.get_syntax_help()
|
||||
|
||||
assert isinstance(help_text, str)
|
||||
assert len(help_text) > 0
|
||||
assert "JSON" in help_text
|
||||
assert "template" in help_text.lower()
|
||||
|
||||
|
||||
class TestQueryResult:
|
||||
"""Test the QueryResult data structure."""
|
||||
|
||||
def test_query_result_creation(self):
|
||||
"""Test creating a QueryResult."""
|
||||
result = QueryResult(
|
||||
paradigm="Test",
|
||||
query="test query",
|
||||
execution_time_ms=10.5,
|
||||
result_count=3,
|
||||
results=[{"id": 1}, {"id": 2}, {"id": 3}],
|
||||
metadata={"type": "test"},
|
||||
success=True
|
||||
)
|
||||
|
||||
assert result.paradigm == "Test"
|
||||
assert result.query == "test query"
|
||||
assert result.execution_time_ms == 10.5
|
||||
assert result.result_count == 3
|
||||
assert len(result.results) == 3
|
||||
assert result.metadata["type"] == "test"
|
||||
assert result.success is True
|
||||
assert result.error_message is None
|
||||
|
||||
def test_query_result_with_error(self):
|
||||
"""Test creating a QueryResult with error."""
|
||||
result = QueryResult(
|
||||
paradigm="Test",
|
||||
query="bad query",
|
||||
execution_time_ms=1.0,
|
||||
result_count=0,
|
||||
results=[],
|
||||
metadata={},
|
||||
success=False,
|
||||
error_message="Query failed"
|
||||
)
|
||||
|
||||
assert not result.success
|
||||
assert result.error_message == "Query failed"
|
||||
assert result.result_count == 0
|
||||
|
||||
|
||||
class TestBaseQueryParadigm:
|
||||
"""Test the base query paradigm interface."""
|
||||
|
||||
def test_cannot_instantiate_base_class(self):
|
||||
"""Test that BaseQueryParadigm cannot be instantiated directly."""
|
||||
with pytest.raises(TypeError):
|
||||
BaseQueryParadigm()
|
||||
|
||||
def test_paradigm_interface(self):
|
||||
"""Test that paradigms implement the required interface."""
|
||||
paradigm = SQLQueryParadigm()
|
||||
|
||||
# Test all required properties
|
||||
assert hasattr(paradigm, 'name')
|
||||
assert hasattr(paradigm, 'description')
|
||||
assert hasattr(paradigm, 'category')
|
||||
assert hasattr(paradigm, 'complexity')
|
||||
|
||||
# Test all required methods
|
||||
assert hasattr(paradigm, 'execute')
|
||||
assert hasattr(paradigm, 'get_examples')
|
||||
assert hasattr(paradigm, 'validate_query')
|
||||
assert hasattr(paradigm, 'get_syntax_help')
|
||||
|
||||
# Test optional methods
|
||||
assert hasattr(paradigm, 'can_translate_from')
|
||||
assert hasattr(paradigm, 'translate_query')
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__])
|
||||
Reference in New Issue
Block a user