Skip to content

API Reference

sanitongo

Sanitongo - Modern MongoDB Query Sanitizer with Layered Security Protection.

A comprehensive security library for sanitizing MongoDB queries with multiple layers of protection against NoSQL injection attacks and malicious queries.

ComplexityError

Bases: SanitizerError

Raised when query exceeds complexity limits.

Source code in src/sanitongo/exceptions.py
class ComplexityError(SanitizerError):
    """Raised when query exceeds complexity limits."""

    def __init__(
        self,
        message: str,
        limit_type: str,
        current_value: int,
        max_allowed: int,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize complexity error with limit information."""
        super().__init__(message, context)
        self.limit_type = limit_type
        self.current_value = current_value
        self.max_allowed = max_allowed

__init__(message, limit_type, current_value, max_allowed, context=None)

Initialize complexity error with limit information.

Source code in src/sanitongo/exceptions.py
def __init__(
    self,
    message: str,
    limit_type: str,
    current_value: int,
    max_allowed: int,
    context: dict[str, Any] | None = None,
) -> None:
    """Initialize complexity error with limit information."""
    super().__init__(message, context)
    self.limit_type = limit_type
    self.current_value = current_value
    self.max_allowed = max_allowed

PatternError

Bases: SecurityError

Raised when dangerous patterns are detected in query values.

Source code in src/sanitongo/exceptions.py
class PatternError(SecurityError):
    """Raised when dangerous patterns are detected in query values."""

    def __init__(
        self,
        message: str,
        pattern_type: str,
        field_path: str | None = None,
        pattern_value: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize pattern error."""
        super().__init__(
            message, pattern_type, [pattern_value] if pattern_value else [], context
        )
        self.pattern_type = pattern_type
        self.field_path = field_path
        self.pattern_value = pattern_value

__init__(message, pattern_type, field_path=None, pattern_value=None, context=None)

Initialize pattern error.

Source code in src/sanitongo/exceptions.py
def __init__(
    self,
    message: str,
    pattern_type: str,
    field_path: str | None = None,
    pattern_value: str | None = None,
    context: dict[str, Any] | None = None,
) -> None:
    """Initialize pattern error."""
    super().__init__(
        message, pattern_type, [pattern_value] if pattern_value else [], context
    )
    self.pattern_type = pattern_type
    self.field_path = field_path
    self.pattern_value = pattern_value

SanitizerError

Bases: Exception

Base exception for all sanitizer-related errors.

Source code in src/sanitongo/exceptions.py
class SanitizerError(Exception):
    """Base exception for all sanitizer-related errors."""

    def __init__(self, message: str, context: dict[str, Any] | None = None) -> None:
        """Initialize the exception with message and optional context."""
        super().__init__(message)
        self.context = context or {}

__init__(message, context=None)

Initialize the exception with message and optional context.

Source code in src/sanitongo/exceptions.py
def __init__(self, message: str, context: dict[str, Any] | None = None) -> None:
    """Initialize the exception with message and optional context."""
    super().__init__(message)
    self.context = context or {}

SchemaViolationError

Bases: SanitizerError

Raised when query violates the defined schema.

Source code in src/sanitongo/exceptions.py
class SchemaViolationError(SanitizerError):
    """Raised when query violates the defined schema."""

    def __init__(
        self,
        message: str,
        field_path: str,
        schema_rule: str | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize schema violation error."""
        super().__init__(message, context)
        self.field_path = field_path
        self.schema_rule = schema_rule

__init__(message, field_path, schema_rule=None, context=None)

Initialize schema violation error.

Source code in src/sanitongo/exceptions.py
def __init__(
    self,
    message: str,
    field_path: str,
    schema_rule: str | None = None,
    context: dict[str, Any] | None = None,
) -> None:
    """Initialize schema violation error."""
    super().__init__(message, context)
    self.field_path = field_path
    self.schema_rule = schema_rule

SecurityError

Bases: SanitizerError

Raised when potentially malicious content is detected.

Source code in src/sanitongo/exceptions.py
class SecurityError(SanitizerError):
    """Raised when potentially malicious content is detected."""

    def __init__(
        self,
        message: str,
        threat_type: str,
        detected_patterns: list[str] | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize security error with threat information."""
        super().__init__(message, context)
        self.threat_type = threat_type
        self.detected_patterns = detected_patterns or []

__init__(message, threat_type, detected_patterns=None, context=None)

Initialize security error with threat information.

Source code in src/sanitongo/exceptions.py
def __init__(
    self,
    message: str,
    threat_type: str,
    detected_patterns: list[str] | None = None,
    context: dict[str, Any] | None = None,
) -> None:
    """Initialize security error with threat information."""
    super().__init__(message, context)
    self.threat_type = threat_type
    self.detected_patterns = detected_patterns or []

ValidationError

Bases: SanitizerError

Raised when input validation fails.

Source code in src/sanitongo/exceptions.py
class ValidationError(SanitizerError):
    """Raised when input validation fails."""

    def __init__(
        self,
        message: str,
        field_path: str | None = None,
        invalid_value: Any | None = None,
        context: dict[str, Any] | None = None,
    ) -> None:
        """Initialize validation error with field information."""
        super().__init__(message, context)
        self.field_path = field_path
        self.invalid_value = invalid_value

__init__(message, field_path=None, invalid_value=None, context=None)

Initialize validation error with field information.

Source code in src/sanitongo/exceptions.py
def __init__(
    self,
    message: str,
    field_path: str | None = None,
    invalid_value: Any | None = None,
    context: dict[str, Any] | None = None,
) -> None:
    """Initialize validation error with field information."""
    super().__init__(message, context)
    self.field_path = field_path
    self.invalid_value = invalid_value

ComplexityLimiter

Layer 5: Limits query complexity to prevent DoS attacks.

Source code in src/sanitongo/layers.py
class ComplexityLimiter:
    """Layer 5: Limits query complexity to prevent DoS attacks."""

    def __init__(
        self,
        max_depth: int = 10,
        max_keys: int = 100,
        max_array_length: int = 1000,
        max_string_length: int = 10000,
    ) -> None:
        """Initialize complexity limiter."""
        self.max_depth = max_depth
        self.max_keys = max_keys
        self.max_array_length = max_array_length
        self.max_string_length = max_string_length

    def validate(self, query: dict[str, Any]) -> LayerResult:
        """Check query complexity limits."""
        warnings = []

        # Check depth
        depth = self._calculate_depth(query)
        if depth > self.max_depth:
            raise ComplexityError(
                f"Query depth exceeds limit: {depth} > {self.max_depth}",
                limit_type="depth",
                current_value=depth,
                max_allowed=self.max_depth,
            )

        # Check key count
        key_count = self._count_keys(query)
        if key_count > self.max_keys:
            raise ComplexityError(
                f"Query key count exceeds limit: {key_count} > {self.max_keys}",
                limit_type="keys",
                current_value=key_count,
                max_allowed=self.max_keys,
            )

        # Check arrays and strings
        self._check_arrays_and_strings(query, "")

        return LayerResult(success=True, modified_query=query, warnings=warnings)

    def _calculate_depth(self, obj: Any, current_depth: int = 0) -> int:
        """Calculate maximum nesting depth."""
        if not isinstance(obj, (dict, list)):
            return current_depth

        max_child_depth = current_depth
        if isinstance(obj, dict):
            for value in obj.values():
                child_depth = self._calculate_depth(value, current_depth + 1)
                max_child_depth = max(max_child_depth, child_depth)
        elif isinstance(obj, list):
            for item in obj:
                child_depth = self._calculate_depth(item, current_depth + 1)
                max_child_depth = max(max_child_depth, child_depth)

        return max_child_depth

    def _count_keys(self, obj: Any) -> int:
        """Count total number of dictionary keys."""
        if isinstance(obj, dict):
            count = len(obj)
            for value in obj.values():
                count += self._count_keys(value)
            return count
        elif isinstance(obj, list):
            count = 0
            for item in obj:
                count += self._count_keys(item)
            return count
        return 0

    def _check_arrays_and_strings(self, obj: Any, path: str) -> None:
        """Check array lengths and string lengths."""
        if isinstance(obj, list):
            if len(obj) > self.max_array_length:
                raise ComplexityError(
                    f"Array at '{path}' exceeds length limit: {len(obj)} > {self.max_array_length}",
                    limit_type="array_length",
                    current_value=len(obj),
                    max_allowed=self.max_array_length,
                )
            for i, item in enumerate(obj):
                self._check_arrays_and_strings(item, f"{path}[{i}]")
        elif isinstance(obj, str):
            if len(obj) > self.max_string_length:
                raise ComplexityError(
                    f"String at '{path}' exceeds length limit: {len(obj)} > {self.max_string_length}",
                    limit_type="string_length",
                    current_value=len(obj),
                    max_allowed=self.max_string_length,
                )
        elif isinstance(obj, dict):
            for key, value in obj.items():
                current_path = f"{path}.{key}" if path else key
                self._check_arrays_and_strings(value, current_path)

__init__(max_depth=10, max_keys=100, max_array_length=1000, max_string_length=10000)

Initialize complexity limiter.

Source code in src/sanitongo/layers.py
def __init__(
    self,
    max_depth: int = 10,
    max_keys: int = 100,
    max_array_length: int = 1000,
    max_string_length: int = 10000,
) -> None:
    """Initialize complexity limiter."""
    self.max_depth = max_depth
    self.max_keys = max_keys
    self.max_array_length = max_array_length
    self.max_string_length = max_string_length

validate(query)

Check query complexity limits.

Source code in src/sanitongo/layers.py
def validate(self, query: dict[str, Any]) -> LayerResult:
    """Check query complexity limits."""
    warnings = []

    # Check depth
    depth = self._calculate_depth(query)
    if depth > self.max_depth:
        raise ComplexityError(
            f"Query depth exceeds limit: {depth} > {self.max_depth}",
            limit_type="depth",
            current_value=depth,
            max_allowed=self.max_depth,
        )

    # Check key count
    key_count = self._count_keys(query)
    if key_count > self.max_keys:
        raise ComplexityError(
            f"Query key count exceeds limit: {key_count} > {self.max_keys}",
            limit_type="keys",
            current_value=key_count,
            max_allowed=self.max_keys,
        )

    # Check arrays and strings
    self._check_arrays_and_strings(query, "")

    return LayerResult(success=True, modified_query=query, warnings=warnings)

OperatorFilter

Layer 3: Filters and validates MongoDB operators.

Source code in src/sanitongo/layers.py
class OperatorFilter:
    """Layer 3: Filters and validates MongoDB operators."""

    def __init__(
        self,
        allowed_operators: set[str] | None = None,
        dangerous_operators: set[str] | None = None,
        strict_mode: bool = True,
    ) -> None:
        """Initialize operator filter."""
        self.allowed_operators = allowed_operators or self._get_safe_operators()
        self.dangerous_operators = (
            dangerous_operators or self._get_dangerous_operators()
        )
        self.strict_mode = strict_mode

    def validate(self, query: dict[str, Any]) -> LayerResult:
        """Filter MongoDB operators from the query."""
        modified_query = {}
        removed_items = {}
        warnings = []

        self._process_dict(query, modified_query, removed_items, warnings, "")

        return LayerResult(
            success=True,
            modified_query=modified_query,
            warnings=warnings,
            removed_items=removed_items,
        )

    def _process_dict(
        self,
        source: dict[str, Any],
        target: dict[str, Any],
        removed: dict[str, Any],
        warnings: list[str],
        path: str,
    ) -> None:
        """Process dictionary removing dangerous operators."""
        for key, value in source.items():
            current_path = f"{path}.{key}" if path else key

            if key.startswith("$"):
                if key in self.dangerous_operators:
                    removed[current_path] = value
                    warnings.append(f"Removed dangerous operator: {key}")
                    if self.strict_mode:
                        raise SecurityError(
                            f"Dangerous operator detected: {key}",
                            threat_type="dangerous_operator",
                            detected_patterns=[key],
                        )
                    continue
                elif key not in self.allowed_operators:
                    removed[current_path] = value
                    warnings.append(f"Removed unknown operator: {key}")
                    continue

            if isinstance(value, dict):
                target[key] = {}
                self._process_dict(value, target[key], removed, warnings, current_path)
            elif isinstance(value, list):
                target[key] = []
                self._process_list(value, target[key], removed, warnings, current_path)
            else:
                target[key] = value

    def _process_list(
        self,
        source: list[Any],
        target: list[Any],
        removed: dict[str, Any],
        warnings: list[str],
        path: str,
    ) -> None:
        """Process list items recursively."""
        for i, item in enumerate(source):
            item_path = f"{path}[{i}]"
            if isinstance(item, dict):
                processed_item = {}
                self._process_dict(item, processed_item, removed, warnings, item_path)
                target.append(processed_item)
            elif isinstance(item, list):
                processed_list = []
                self._process_list(item, processed_list, removed, warnings, item_path)
                target.append(processed_list)
            else:
                target.append(item)

    def _get_safe_operators(self) -> set[str]:
        """Get set of safe MongoDB operators."""
        return {
            # Comparison
            "$eq",
            "$ne",
            "$gt",
            "$gte",
            "$lt",
            "$lte",
            "$in",
            "$nin",
            # Logical
            "$and",
            "$or",
            "$not",
            "$nor",
            # Element
            "$exists",
            "$type",
            # Array (limited set)
            "$all",
            "$size",
            # Text search (controlled)
            "$text",
        }

    def _get_dangerous_operators(self) -> set[str]:
        """Get set of dangerous MongoDB operators."""
        return {
            # JavaScript execution
            "$where",
            "$function",
            # Regex with potential for ReDoS
            "$regex",
            # Aggregation operators in wrong context
            "$expr",
            "$jsonSchema",
            # Modification operators
            "$set",
            "$unset",
            "$inc",
            "$push",
            "$pull",
            # Advanced array operators
            "$elemMatch",
            "$slice",
            "$position",
        }

__init__(allowed_operators=None, dangerous_operators=None, strict_mode=True)

Initialize operator filter.

Source code in src/sanitongo/layers.py
def __init__(
    self,
    allowed_operators: set[str] | None = None,
    dangerous_operators: set[str] | None = None,
    strict_mode: bool = True,
) -> None:
    """Initialize operator filter."""
    self.allowed_operators = allowed_operators or self._get_safe_operators()
    self.dangerous_operators = (
        dangerous_operators or self._get_dangerous_operators()
    )
    self.strict_mode = strict_mode

validate(query)

Filter MongoDB operators from the query.

Source code in src/sanitongo/layers.py
def validate(self, query: dict[str, Any]) -> LayerResult:
    """Filter MongoDB operators from the query."""
    modified_query = {}
    removed_items = {}
    warnings = []

    self._process_dict(query, modified_query, removed_items, warnings, "")

    return LayerResult(
        success=True,
        modified_query=modified_query,
        warnings=warnings,
        removed_items=removed_items,
    )

PatternValidator

Layer 4: Validates patterns in string values.

Source code in src/sanitongo/layers.py
class PatternValidator:
    """Layer 4: Validates patterns in string values."""

    def __init__(
        self,
        custom_patterns: dict[str, Pattern[str]] | None = None,
        fail_on_dangerous_patterns: bool = True,
    ) -> None:
        """Initialize pattern validator."""
        self.dangerous_patterns = self._get_dangerous_patterns()
        if custom_patterns:
            self.dangerous_patterns.update(custom_patterns)
        self.fail_on_dangerous_patterns = fail_on_dangerous_patterns

    def validate(self, query: dict[str, Any]) -> LayerResult:
        """Validate string patterns in the query."""
        warnings = []
        self._check_patterns(query, warnings, "")
        return LayerResult(success=True, modified_query=query, warnings=warnings)

    def _check_patterns(self, obj: Any, warnings: list[str], path: str) -> None:
        """Recursively check for dangerous patterns."""
        if isinstance(obj, str):
            for pattern_name, pattern in self.dangerous_patterns.items():
                if pattern.search(obj):
                    warning_msg = (
                        f"Dangerous pattern '{pattern_name}' detected at '{path}'"
                    )
                    warnings.append(warning_msg)
                    if self.fail_on_dangerous_patterns:
                        raise PatternError(
                            f"Dangerous pattern detected: {pattern_name}",
                            pattern_type=pattern_name,
                            field_path=path,
                            pattern_value=obj,
                        )
        elif isinstance(obj, dict):
            for key, value in obj.items():
                current_path = f"{path}.{key}" if path else key
                # Check the key itself for dangerous patterns, but skip MongoDB operators
                if not key.startswith("$"):
                    self._check_patterns(key, warnings, f"{current_path}#key")
                # Check the value
                self._check_patterns(value, warnings, current_path)
        elif isinstance(obj, list):
            for i, item in enumerate(obj):
                self._check_patterns(item, warnings, f"{path}[{i}]")

    def _get_dangerous_patterns(self) -> dict[str, Pattern[str]]:
        """Get dangerous regex patterns to detect."""
        return {
            "javascript": re.compile(
                r"(?i)(function\s*\(|eval\s*\(|setTimeout|setInterval)", re.IGNORECASE
            ),
            "script_tags": re.compile(
                r"<script[^>]*>.*?</script>", re.IGNORECASE | re.DOTALL
            ),
            "sql_injection": re.compile(
                r"(?i)(union\s+select|drop\s+table|insert\s+into)", re.IGNORECASE
            ),
            "command_injection": re.compile(r"[;&|`$()]", re.MULTILINE),
            "prototype_pollution": re.compile(
                r"__proto__|constructor|prototype", re.IGNORECASE
            ),
            "redos_suspicious": re.compile(
                r"(\+|\*|\{[\d,]*\})\+|\*\*|\+\+", re.MULTILINE
            ),
        }

__init__(custom_patterns=None, fail_on_dangerous_patterns=True)

Initialize pattern validator.

Source code in src/sanitongo/layers.py
def __init__(
    self,
    custom_patterns: dict[str, Pattern[str]] | None = None,
    fail_on_dangerous_patterns: bool = True,
) -> None:
    """Initialize pattern validator."""
    self.dangerous_patterns = self._get_dangerous_patterns()
    if custom_patterns:
        self.dangerous_patterns.update(custom_patterns)
    self.fail_on_dangerous_patterns = fail_on_dangerous_patterns

validate(query)

Validate string patterns in the query.

Source code in src/sanitongo/layers.py
def validate(self, query: dict[str, Any]) -> LayerResult:
    """Validate string patterns in the query."""
    warnings = []
    self._check_patterns(query, warnings, "")
    return LayerResult(success=True, modified_query=query, warnings=warnings)

SchemaEnforcer

Layer 2: Enforces field schema and validation rules.

Source code in src/sanitongo/layers.py
class SchemaEnforcer:
    """Layer 2: Enforces field schema and validation rules."""

    def __init__(
        self,
        schema_validator: SchemaValidator | None = None,
        fail_on_violation: bool = True,
    ) -> None:
        """Initialize schema enforcer."""
        self.schema_validator = schema_validator
        self.fail_on_violation = fail_on_violation

    def validate(self, query: dict[str, Any]) -> LayerResult:
        """Enforce schema rules on the query."""
        if not self.schema_validator:
            # No schema defined, allow all fields
            return LayerResult(
                success=True, modified_query=query, warnings=["No schema defined"]
            )

        try:
            self.schema_validator.validate_query(query)
            return LayerResult(success=True, modified_query=query)
        except Exception as e:
            if self.fail_on_violation:
                raise ValidationError(f"Schema validation failed: {e}") from e
            else:
                # In lenient mode, return success with warnings
                warning_msg = f"Schema validation warning: {e}"
                return LayerResult(
                    success=True, modified_query=query, warnings=[warning_msg]
                )

__init__(schema_validator=None, fail_on_violation=True)

Initialize schema enforcer.

Source code in src/sanitongo/layers.py
def __init__(
    self,
    schema_validator: SchemaValidator | None = None,
    fail_on_violation: bool = True,
) -> None:
    """Initialize schema enforcer."""
    self.schema_validator = schema_validator
    self.fail_on_violation = fail_on_violation

validate(query)

Enforce schema rules on the query.

Source code in src/sanitongo/layers.py
def validate(self, query: dict[str, Any]) -> LayerResult:
    """Enforce schema rules on the query."""
    if not self.schema_validator:
        # No schema defined, allow all fields
        return LayerResult(
            success=True, modified_query=query, warnings=["No schema defined"]
        )

    try:
        self.schema_validator.validate_query(query)
        return LayerResult(success=True, modified_query=query)
    except Exception as e:
        if self.fail_on_violation:
            raise ValidationError(f"Schema validation failed: {e}") from e
        else:
            # In lenient mode, return success with warnings
            warning_msg = f"Schema validation warning: {e}"
            return LayerResult(
                success=True, modified_query=query, warnings=[warning_msg]
            )

TypeValidator

Layer 1: Validates input types and basic structure.

Source code in src/sanitongo/layers.py
class TypeValidator:
    """Layer 1: Validates input types and basic structure."""

    def __init__(self, strict_mode: bool = True) -> None:
        """Initialize type validator."""
        self.strict_mode = strict_mode

    def validate(self, query: Any) -> LayerResult:
        """Validate query types and structure."""
        warnings = []

        # Basic type check
        if not isinstance(query, dict):
            if self.strict_mode:
                raise ValidationError(
                    f"Query must be a dictionary, got {type(query).__name__}"
                )
            return LayerResult(success=False)

        # Check for empty query
        if not query:
            warnings.append("Empty query detected")

        # Validate nested structure
        try:
            self._validate_nested_types(query, "")
        except ValidationError:
            raise
        except Exception as e:
            raise ValidationError(f"Type validation failed: {e}") from e

        return LayerResult(success=True, modified_query=query, warnings=warnings)

    def _validate_nested_types(self, obj: Any, path: str) -> None:
        """Recursively validate nested object types."""
        if isinstance(obj, dict):
            for key, value in obj.items():
                if not isinstance(key, str):
                    raise ValidationError(
                        f"Dictionary key at '{path}' must be string, got {type(key).__name__}"
                    )
                self._validate_nested_types(value, f"{path}.{key}" if path else key)
        elif isinstance(obj, list):
            for i, item in enumerate(obj):
                self._validate_nested_types(item, f"{path}[{i}]")
        elif obj is not None and not isinstance(obj, (str, int, float, bool)):
            # Allow None and basic types, but warn about complex objects
            raise ValidationError(f"Unsupported type at '{path}': {type(obj).__name__}")

__init__(strict_mode=True)

Initialize type validator.

Source code in src/sanitongo/layers.py
def __init__(self, strict_mode: bool = True) -> None:
    """Initialize type validator."""
    self.strict_mode = strict_mode

validate(query)

Validate query types and structure.

Source code in src/sanitongo/layers.py
def validate(self, query: Any) -> LayerResult:
    """Validate query types and structure."""
    warnings = []

    # Basic type check
    if not isinstance(query, dict):
        if self.strict_mode:
            raise ValidationError(
                f"Query must be a dictionary, got {type(query).__name__}"
            )
        return LayerResult(success=False)

    # Check for empty query
    if not query:
        warnings.append("Empty query detected")

    # Validate nested structure
    try:
        self._validate_nested_types(query, "")
    except ValidationError:
        raise
    except Exception as e:
        raise ValidationError(f"Type validation failed: {e}") from e

    return LayerResult(success=True, modified_query=query, warnings=warnings)

MongoSanitizer

Main MongoDB query sanitizer with layered protection.

Implements a five-layer protection system: 1. Type Validation 2. Schema Enforcement 3. Operator Filtering 4. Pattern Validation 5. Complexity Limiting

Source code in src/sanitongo/sanitizer.py
class MongoSanitizer:
    """
    Main MongoDB query sanitizer with layered protection.

    Implements a five-layer protection system:
    1. Type Validation
    2. Schema Enforcement
    3. Operator Filtering
    4. Pattern Validation
    5. Complexity Limiting
    """

    def __init__(self, config: SanitizerConfig | None = None) -> None:
        """Initialize the sanitizer with configuration."""
        self.config = config or SanitizerConfig()
        self.logger = self._setup_logging()

        # Initialize protection layers
        self._init_layers()

    def sanitize(self, query: Any) -> SanitizationReport:
        """
        Sanitize a MongoDB query through all protection layers.

        Args:
            query: The MongoDB query to sanitize

        Returns:
            SanitizationReport with detailed results
        """
        import time

        start_time = time.time()

        # Create initial report
        report = SanitizationReport(
            original_query=query.copy() if isinstance(query, dict) else query,
            sanitized_query=query,
            success=False,
        )

        try:
            current_query = query

            # Layer 1: Type Validation
            result = self._run_layer(
                "Type Validation", self.type_validator, current_query, report
            )
            current_query = result.modified_query or current_query

            # Layer 2: Schema Enforcement
            if isinstance(current_query, dict):
                result = self._run_layer(
                    "Schema Enforcement", self.schema_enforcer, current_query, report
                )
                current_query = result.modified_query or current_query

            # Layer 3: Operator Filtering
            if isinstance(current_query, dict):
                result = self._run_layer(
                    "Operator Filtering", self.operator_filter, current_query, report
                )
                current_query = result.modified_query or current_query
                if result.removed_items:
                    report.removed_items.update(result.removed_items)

            # Layer 4: Pattern Validation
            if isinstance(current_query, dict):
                result = self._run_layer(
                    "Pattern Validation", self.pattern_validator, current_query, report
                )
                current_query = result.modified_query or current_query

            # Layer 5: Complexity Limiting
            if isinstance(current_query, dict):
                result = self._run_layer(
                    "Complexity Limiting",
                    self.complexity_limiter,
                    current_query,
                    report,
                )
                current_query = result.modified_query or current_query

            # Finalize report
            report.sanitized_query = current_query
            report.success = True

            # Log results
            if self.config.enable_logging:
                self._log_sanitization_results(report)

        except Exception as e:
            report.error = e
            report.success = False
            if self.config.enable_logging:
                self.logger.error(f"Sanitization failed: {e}")

            # Re-raise based on configuration
            if self._should_reraise_error(e):
                raise

        # Performance metrics
        report.performance_metrics = {
            "processing_time_ms": round((time.time() - start_time) * 1000, 2),
            "layers_processed": len(report.layers_processed),
        }

        return report

    def sanitize_query(self, query: dict[str, Any]) -> dict[str, Any]:
        """
        Sanitize a query and return only the cleaned query.

        This is a convenience method that returns just the sanitized query
        without the full report.
        """
        report = self.sanitize(query)
        if not report.success:
            raise SanitizerError(f"Query sanitization failed: {report.error}")
        return report.sanitized_query

    def is_query_safe(self, query: Any) -> bool:
        """
        Check if a query is safe without modifying it.

        Returns True if the query passes all validation layers.
        """
        try:
            report = self.sanitize(query)
            return report.success and not report.has_security_issues()
        except Exception:
            return False

    def _init_layers(self) -> None:
        """Initialize all protection layers."""
        self.type_validator = TypeValidator(strict_mode=self.config.strict_types)

        self.schema_enforcer = SchemaEnforcer(
            schema_validator=self.config.schema_validator,
            fail_on_violation=self.config.fail_on_schema_violation,
        )

        self.operator_filter = OperatorFilter(
            allowed_operators=self.config.allowed_operators,
            dangerous_operators=self.config.dangerous_operators,
            strict_mode=self.config.strict_operators,
        )

        if self.config.enable_pattern_validation:
            custom_patterns = {}
            if self.config.custom_dangerous_patterns:
                import re

                custom_patterns = {
                    name: re.compile(pattern)
                    for name, pattern in self.config.custom_dangerous_patterns.items()
                }
            self.pattern_validator = PatternValidator(
                custom_patterns=custom_patterns,
                fail_on_dangerous_patterns=self.config.fail_on_dangerous_patterns,
            )
        else:
            self.pattern_validator = None

        self.complexity_limiter = ComplexityLimiter(
            max_depth=self.config.max_depth,
            max_keys=self.config.max_keys,
            max_array_length=self.config.max_array_length,
            max_string_length=self.config.max_string_length,
        )

    def _run_layer(
        self,
        layer_name: str,
        layer_instance: Any,
        query: Any,
        report: SanitizationReport,
    ) -> LayerResult:
        """Run a single protection layer and update the report."""
        if layer_instance is None:
            return LayerResult(success=True, modified_query=query)

        try:
            result = layer_instance.validate(query)
            report.layers_processed.append(layer_name)
            report.warnings.extend(result.warnings)

            if self.config.enable_logging and result.warnings:
                for warning in result.warnings:
                    self.logger.warning(f"{layer_name}: {warning}")

            return result

        except Exception as e:
            if self.config.enable_logging:
                self.logger.error(f"{layer_name} failed: {e}")
            raise

    def _should_reraise_error(self, error: Exception) -> bool:
        """Determine if an error should be re-raised based on config."""
        from .exceptions import (
            ComplexityError,
            PatternError,
            SchemaViolationError,
            SecurityError,
        )

        if isinstance(error, SchemaViolationError):
            return self.config.fail_on_schema_violation
        elif isinstance(error, SecurityError):
            return self.config.fail_on_dangerous_operators
        elif isinstance(error, PatternError):
            return self.config.fail_on_dangerous_patterns
        elif isinstance(error, ComplexityError):
            return self.config.fail_on_complexity_exceeded

        return True  # Re-raise unexpected errors

    def _setup_logging(self) -> logging.Logger:
        """Set up logging for the sanitizer."""
        logger = logging.getLogger("sanitongo")

        if not logger.handlers and self.config.enable_logging:
            handler = logging.StreamHandler()
            formatter = logging.Formatter(
                "%(asctime)s - %(name)s - %(levelname)s - %(message)s"
            )
            handler.setFormatter(formatter)
            logger.addHandler(handler)
            logger.setLevel(getattr(logging, self.config.log_level.upper()))

        return logger

    def _log_sanitization_results(self, report: SanitizationReport) -> None:
        """Log the results of sanitization."""
        if report.success:
            self.logger.info(f"Sanitization completed: {report.get_summary()}")

            if self.config.log_removed_items and report.removed_items:
                self.logger.warning(f"Removed items: {report.removed_items}")

            if report.has_security_issues():
                for issue in report.security_issues:
                    self.logger.warning(f"Security issue: {issue}")
        else:
            self.logger.error(f"Sanitization failed: {report.error}")

    def get_config(self) -> SanitizerConfig:
        """Get the current configuration."""
        return self.config

    def update_config(self, **kwargs: Any) -> None:
        """Update sanitizer configuration."""
        for key, value in kwargs.items():
            if hasattr(self.config, key):
                setattr(self.config, key, value)
            else:
                raise ValueError(f"Unknown configuration option: {key}")

        # Re-initialize layers with new config
        self._init_layers()

__init__(config=None)

Initialize the sanitizer with configuration.

Source code in src/sanitongo/sanitizer.py
def __init__(self, config: SanitizerConfig | None = None) -> None:
    """Initialize the sanitizer with configuration."""
    self.config = config or SanitizerConfig()
    self.logger = self._setup_logging()

    # Initialize protection layers
    self._init_layers()

sanitize(query)

Sanitize a MongoDB query through all protection layers.

Parameters:

Name Type Description Default
query Any

The MongoDB query to sanitize

required

Returns:

Type Description
SanitizationReport

SanitizationReport with detailed results

Source code in src/sanitongo/sanitizer.py
def sanitize(self, query: Any) -> SanitizationReport:
    """
    Sanitize a MongoDB query through all protection layers.

    Args:
        query: The MongoDB query to sanitize

    Returns:
        SanitizationReport with detailed results
    """
    import time

    start_time = time.time()

    # Create initial report
    report = SanitizationReport(
        original_query=query.copy() if isinstance(query, dict) else query,
        sanitized_query=query,
        success=False,
    )

    try:
        current_query = query

        # Layer 1: Type Validation
        result = self._run_layer(
            "Type Validation", self.type_validator, current_query, report
        )
        current_query = result.modified_query or current_query

        # Layer 2: Schema Enforcement
        if isinstance(current_query, dict):
            result = self._run_layer(
                "Schema Enforcement", self.schema_enforcer, current_query, report
            )
            current_query = result.modified_query or current_query

        # Layer 3: Operator Filtering
        if isinstance(current_query, dict):
            result = self._run_layer(
                "Operator Filtering", self.operator_filter, current_query, report
            )
            current_query = result.modified_query or current_query
            if result.removed_items:
                report.removed_items.update(result.removed_items)

        # Layer 4: Pattern Validation
        if isinstance(current_query, dict):
            result = self._run_layer(
                "Pattern Validation", self.pattern_validator, current_query, report
            )
            current_query = result.modified_query or current_query

        # Layer 5: Complexity Limiting
        if isinstance(current_query, dict):
            result = self._run_layer(
                "Complexity Limiting",
                self.complexity_limiter,
                current_query,
                report,
            )
            current_query = result.modified_query or current_query

        # Finalize report
        report.sanitized_query = current_query
        report.success = True

        # Log results
        if self.config.enable_logging:
            self._log_sanitization_results(report)

    except Exception as e:
        report.error = e
        report.success = False
        if self.config.enable_logging:
            self.logger.error(f"Sanitization failed: {e}")

        # Re-raise based on configuration
        if self._should_reraise_error(e):
            raise

    # Performance metrics
    report.performance_metrics = {
        "processing_time_ms": round((time.time() - start_time) * 1000, 2),
        "layers_processed": len(report.layers_processed),
    }

    return report

sanitize_query(query)

Sanitize a query and return only the cleaned query.

This is a convenience method that returns just the sanitized query without the full report.

Source code in src/sanitongo/sanitizer.py
def sanitize_query(self, query: dict[str, Any]) -> dict[str, Any]:
    """
    Sanitize a query and return only the cleaned query.

    This is a convenience method that returns just the sanitized query
    without the full report.
    """
    report = self.sanitize(query)
    if not report.success:
        raise SanitizerError(f"Query sanitization failed: {report.error}")
    return report.sanitized_query

is_query_safe(query)

Check if a query is safe without modifying it.

Returns True if the query passes all validation layers.

Source code in src/sanitongo/sanitizer.py
def is_query_safe(self, query: Any) -> bool:
    """
    Check if a query is safe without modifying it.

    Returns True if the query passes all validation layers.
    """
    try:
        report = self.sanitize(query)
        return report.success and not report.has_security_issues()
    except Exception:
        return False

get_config()

Get the current configuration.

Source code in src/sanitongo/sanitizer.py
def get_config(self) -> SanitizerConfig:
    """Get the current configuration."""
    return self.config

update_config(**kwargs)

Update sanitizer configuration.

Source code in src/sanitongo/sanitizer.py
def update_config(self, **kwargs: Any) -> None:
    """Update sanitizer configuration."""
    for key, value in kwargs.items():
        if hasattr(self.config, key):
            setattr(self.config, key, value)
        else:
            raise ValueError(f"Unknown configuration option: {key}")

    # Re-initialize layers with new config
    self._init_layers()

SanitizationReport dataclass

Detailed report of the sanitization process.

Source code in src/sanitongo/sanitizer.py
@dataclass
class SanitizationReport:
    """Detailed report of the sanitization process."""

    original_query: dict[str, Any]
    sanitized_query: dict[str, Any]
    success: bool
    layers_processed: list[str] = field(default_factory=list)
    warnings: list[str] = field(default_factory=list)
    removed_items: dict[str, Any] = field(default_factory=dict)
    security_issues: list[str] = field(default_factory=list)
    performance_metrics: dict[str, Any] = field(default_factory=dict)
    error: Exception | None = None

    def has_warnings(self) -> bool:
        """Check if there are any warnings."""
        return bool(self.warnings)

    def has_security_issues(self) -> bool:
        """Check if there are any security issues."""
        return bool(self.security_issues)

    def has_modifications(self) -> bool:
        """Check if the query was modified."""
        return self.original_query != self.sanitized_query

    def get_summary(self) -> str:
        """Get a summary of the sanitization results."""
        if not self.success:
            return f"Sanitization failed: {self.error}"

        parts = []
        if self.has_modifications():
            parts.append(f"Query modified ({len(self.removed_items)} items removed)")
        if self.has_warnings():
            parts.append(f"{len(self.warnings)} warnings")
        if self.has_security_issues():
            parts.append(f"{len(self.security_issues)} security issues")

        if not parts:
            return "Query passed sanitization without issues"

        return "Sanitization completed with: " + ", ".join(parts)

has_warnings()

Check if there are any warnings.

Source code in src/sanitongo/sanitizer.py
def has_warnings(self) -> bool:
    """Check if there are any warnings."""
    return bool(self.warnings)

has_security_issues()

Check if there are any security issues.

Source code in src/sanitongo/sanitizer.py
def has_security_issues(self) -> bool:
    """Check if there are any security issues."""
    return bool(self.security_issues)

has_modifications()

Check if the query was modified.

Source code in src/sanitongo/sanitizer.py
def has_modifications(self) -> bool:
    """Check if the query was modified."""
    return self.original_query != self.sanitized_query

get_summary()

Get a summary of the sanitization results.

Source code in src/sanitongo/sanitizer.py
def get_summary(self) -> str:
    """Get a summary of the sanitization results."""
    if not self.success:
        return f"Sanitization failed: {self.error}"

    parts = []
    if self.has_modifications():
        parts.append(f"Query modified ({len(self.removed_items)} items removed)")
    if self.has_warnings():
        parts.append(f"{len(self.warnings)} warnings")
    if self.has_security_issues():
        parts.append(f"{len(self.security_issues)} security issues")

    if not parts:
        return "Query passed sanitization without issues"

    return "Sanitization completed with: " + ", ".join(parts)

SanitizerConfig dataclass

Configuration for the MongoDB sanitizer.

Source code in src/sanitongo/sanitizer.py
@dataclass
class SanitizerConfig:
    """Configuration for the MongoDB sanitizer."""

    # Schema validation
    schema_validator: SchemaValidator | None = None

    # Type validation
    strict_types: bool = True

    # Operator filtering
    allowed_operators: set[str] | None = None
    dangerous_operators: set[str] | None = None
    strict_operators: bool = True

    # Pattern validation
    enable_pattern_validation: bool = True
    custom_dangerous_patterns: dict[str, str] | None = None

    # Complexity limits
    max_depth: int = 10
    max_keys: int = 100
    max_array_length: int = 1000
    max_string_length: int = 10000

    # Logging
    enable_logging: bool = True
    log_level: str = "INFO"
    log_removed_items: bool = True

    # Error handling
    fail_on_schema_violation: bool = True
    fail_on_dangerous_operators: bool = True
    fail_on_dangerous_patterns: bool = True
    fail_on_complexity_exceeded: bool = True

FieldType

Bases: Enum

Supported field types for schema validation.

Source code in src/sanitongo/schema.py
class FieldType(Enum):
    """Supported field types for schema validation."""

    STRING = "string"
    INTEGER = "integer"
    FLOAT = "float"
    BOOLEAN = "boolean"
    OBJECT_ID = "objectid"
    DATETIME = "datetime"
    ARRAY = "array"
    OBJECT = "object"
    ANY = "any"

SchemaValidator

Validates MongoDB queries against a predefined schema.

Source code in src/sanitongo/schema.py
class SchemaValidator:
    """Validates MongoDB queries against a predefined schema."""

    def __init__(self, schema: dict[str, FieldRule]) -> None:
        """Initialize validator with field schema."""
        self.schema = schema
        self._allowed_fields = set(schema.keys())

    def validate_query(self, query: dict[str, Any], path_prefix: str = "") -> None:
        """Validate a query dictionary against the schema."""
        if not isinstance(query, dict):
            raise ValidationError("Query must be a dictionary")

        # Check for unknown fields
        query_fields = set(query.keys())
        unknown_fields = query_fields - self._allowed_fields

        if unknown_fields:
            raise SchemaViolationError(
                f"Unknown fields in query: {sorted(unknown_fields)}",
                field_path=path_prefix,
                schema_rule="allowed_fields",
            )

        # Validate each field
        for field_name, field_rule in self.schema.items():
            field_path = f"{path_prefix}.{field_name}" if path_prefix else field_name
            value = query.get(field_name)

            try:
                field_rule.validate_value(value, field_path)
            except ValidationError:
                raise
            except Exception as e:
                raise ValidationError(
                    f"Validation failed for field '{field_path}': {e}"
                ) from e

    def get_allowed_fields(self) -> set[str]:
        """Get the set of allowed field names."""
        return self._allowed_fields.copy()

    def is_field_allowed(self, field_name: str) -> bool:
        """Check if a field name is allowed by the schema."""
        return field_name in self._allowed_fields

    def get_field_rule(self, field_name: str) -> FieldRule | None:
        """Get the validation rule for a specific field."""
        return self.schema.get(field_name)

__init__(schema)

Initialize validator with field schema.

Source code in src/sanitongo/schema.py
def __init__(self, schema: dict[str, FieldRule]) -> None:
    """Initialize validator with field schema."""
    self.schema = schema
    self._allowed_fields = set(schema.keys())

validate_query(query, path_prefix='')

Validate a query dictionary against the schema.

Source code in src/sanitongo/schema.py
def validate_query(self, query: dict[str, Any], path_prefix: str = "") -> None:
    """Validate a query dictionary against the schema."""
    if not isinstance(query, dict):
        raise ValidationError("Query must be a dictionary")

    # Check for unknown fields
    query_fields = set(query.keys())
    unknown_fields = query_fields - self._allowed_fields

    if unknown_fields:
        raise SchemaViolationError(
            f"Unknown fields in query: {sorted(unknown_fields)}",
            field_path=path_prefix,
            schema_rule="allowed_fields",
        )

    # Validate each field
    for field_name, field_rule in self.schema.items():
        field_path = f"{path_prefix}.{field_name}" if path_prefix else field_name
        value = query.get(field_name)

        try:
            field_rule.validate_value(value, field_path)
        except ValidationError:
            raise
        except Exception as e:
            raise ValidationError(
                f"Validation failed for field '{field_path}': {e}"
            ) from e

get_allowed_fields()

Get the set of allowed field names.

Source code in src/sanitongo/schema.py
def get_allowed_fields(self) -> set[str]:
    """Get the set of allowed field names."""
    return self._allowed_fields.copy()

is_field_allowed(field_name)

Check if a field name is allowed by the schema.

Source code in src/sanitongo/schema.py
def is_field_allowed(self, field_name: str) -> bool:
    """Check if a field name is allowed by the schema."""
    return field_name in self._allowed_fields

get_field_rule(field_name)

Get the validation rule for a specific field.

Source code in src/sanitongo/schema.py
def get_field_rule(self, field_name: str) -> FieldRule | None:
    """Get the validation rule for a specific field."""
    return self.schema.get(field_name)

create_sanitizer(schema=None, strict_mode=True, **config_kwargs)

Create a MongoDB sanitizer with common configuration.

Parameters:

Name Type Description Default
schema dict[str, Any] | None

Optional schema definition for field validation

None
strict_mode bool

Whether to use strict validation mode

True
**config_kwargs Any

Additional configuration options

{}

Returns:

Type Description
MongoSanitizer

Configured MongoSanitizer instance

Source code in src/sanitongo/sanitizer.py
def create_sanitizer(
    schema: dict[str, Any] | None = None,
    strict_mode: bool = True,
    **config_kwargs: Any,
) -> MongoSanitizer:
    """
    Create a MongoDB sanitizer with common configuration.

    Args:
        schema: Optional schema definition for field validation
        strict_mode: Whether to use strict validation mode
        **config_kwargs: Additional configuration options

    Returns:
        Configured MongoSanitizer instance
    """
    config = SanitizerConfig(
        strict_types=strict_mode,
        strict_operators=strict_mode,
        fail_on_dangerous_patterns=strict_mode,
        **config_kwargs,
    )

    if schema:
        from .schema import FieldRule, FieldType, SchemaValidator

        # Convert simple schema to FieldRule objects if needed
        schema_rules = {}
        for field_name, field_config in schema.items():
            if isinstance(field_config, FieldRule):
                schema_rules[field_name] = field_config
            elif isinstance(field_config, dict):
                # Create FieldRule from dict config
                field_type = FieldType(field_config.get("type", "any"))
                schema_rules[field_name] = FieldRule(
                    field_type=field_type,
                    required=field_config.get("required", False),
                    allowed_values=field_config.get("allowed_values"),
                    min_length=field_config.get("min_length"),
                    max_length=field_config.get("max_length"),
                    pattern=field_config.get("pattern"),
                )
            else:
                # Assume it's a field type string
                schema_rules[field_name] = FieldRule(FieldType(field_config))

        config.schema_validator = SchemaValidator(schema_rules)

    return MongoSanitizer(config)