Skip to content

Validation API

odibi.validation.engine

Optimized validation engine for executing declarative data quality tests.

Performance optimizations: - Fail-fast mode for early exit on first failure - DataFrame caching for Spark with many tests - Lazy evaluation for Polars (avoids early .collect()) - Batched null count aggregation (single scan for NOT_NULL) - Vectorized operations (no Python loops over rows) - Memory-efficient mask operations (no full DataFrame copies)

Validator

Validation engine for executing declarative data quality tests. Supports Spark, Pandas, and Polars engines with performance optimizations.

Source code in odibi/validation/engine.py
 23
 24
 25
 26
 27
 28
 29
 30
 31
 32
 33
 34
 35
 36
 37
 38
 39
 40
 41
 42
 43
 44
 45
 46
 47
 48
 49
 50
 51
 52
 53
 54
 55
 56
 57
 58
 59
 60
 61
 62
 63
 64
 65
 66
 67
 68
 69
 70
 71
 72
 73
 74
 75
 76
 77
 78
 79
 80
 81
 82
 83
 84
 85
 86
 87
 88
 89
 90
 91
 92
 93
 94
 95
 96
 97
 98
 99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
class Validator:
    """
    Validation engine for executing declarative data quality tests.
    Supports Spark, Pandas, and Polars engines with performance optimizations.
    """

    def validate(
        self, df: Any, config: ValidationConfig, context: Dict[str, Any] = None
    ) -> List[str]:
        """
        Run validation checks against a DataFrame.

        Args:
            df: Spark, Pandas, or Polars DataFrame
            config: Validation configuration
            context: Optional context (e.g. {'columns': ...}) for contracts

        Returns:
            List of error messages (empty if all checks pass)
        """
        ctx = get_logging_context()
        test_count = len(config.tests)
        failures = []
        is_spark = False
        is_polars = False
        engine_type = "pandas"

        try:
            import pyspark

            if isinstance(df, pyspark.sql.DataFrame):
                is_spark = True
                engine_type = "spark"
        except ImportError:
            pass

        if not is_spark:
            try:
                import polars as pl

                if isinstance(df, (pl.DataFrame, pl.LazyFrame)):
                    is_polars = True
                    engine_type = "polars"
            except ImportError:
                pass

        ctx.debug(
            "Starting validation",
            test_count=test_count,
            engine=engine_type,
            df_type=type(df).__name__,
            fail_fast=getattr(config, "fail_fast", False),
        )

        if is_spark:
            failures = self._validate_spark(df, config, context)
        elif is_polars:
            failures = self._validate_polars(df, config, context)
        else:
            failures = self._validate_pandas(df, config, context)

        tests_passed = test_count - len(failures)
        ctx.info(
            "Validation complete",
            total_tests=test_count,
            tests_passed=tests_passed,
            tests_failed=len(failures),
            engine=engine_type,
        )

        ctx.log_validation_result(
            passed=len(failures) == 0,
            rule_name="batch_validation",
            failures=failures[:5] if failures else None,
            total_tests=test_count,
            tests_passed=tests_passed,
            tests_failed=len(failures),
        )

        return failures

    def _handle_failure(self, message: str, test: Any) -> Optional[str]:
        """Handle failure based on severity."""
        ctx = get_logging_context()
        severity = getattr(test, "on_fail", ContractSeverity.FAIL)
        test_type = getattr(test, "type", "unknown")

        if severity == ContractSeverity.WARN:
            ctx.warning(
                f"Validation Warning: {message}",
                test_type=str(test_type),
                severity="warn",
            )
            return None

        ctx.error(
            f"Validation Failed: {message}",
            test_type=str(test_type),
            severity="fail",
            test_config=str(test),
        )
        return message

    def _validate_polars(
        self, df: Any, config: ValidationConfig, context: Dict[str, Any] = None
    ) -> List[str]:
        """
        Execute checks using Polars with lazy evaluation where possible.

        Optimization: Avoids collecting full LazyFrame. Uses lazy aggregations
        and only collects scalar results.
        """
        import polars as pl

        ctx = get_logging_context()
        fail_fast = getattr(config, "fail_fast", False)
        is_lazy = isinstance(df, pl.LazyFrame)

        if is_lazy:
            row_count = df.select(pl.len()).collect().item()
            columns = df.collect_schema().names()
        else:
            row_count = len(df)
            columns = df.columns

        ctx.debug("Validating Polars DataFrame", row_count=row_count, is_lazy=is_lazy)

        failures = []

        for test in config.tests:
            msg = None
            test_type = getattr(test, "type", "unknown")
            ctx.debug("Executing test", test_type=str(test_type))

            if test.type == TestType.SCHEMA:
                if context and "columns" in context:
                    expected = set(context["columns"].keys())
                    actual = set(columns)
                    if getattr(test, "strict", True):
                        if actual != expected:
                            msg = f"Schema mismatch. Expected {expected}, got {actual}"
                    else:
                        missing = expected - actual
                        if missing:
                            msg = f"Schema mismatch. Missing columns: {missing}"

            elif test.type == TestType.ROW_COUNT:
                if test.min is not None and row_count < test.min:
                    msg = f"Row count {row_count} < min {test.min}"
                elif test.max is not None and row_count > test.max:
                    msg = f"Row count {row_count} > max {test.max}"

            elif test.type == TestType.FRESHNESS:
                col = getattr(test, "column", "updated_at")
                if col in columns:
                    if is_lazy:
                        max_ts = df.select(pl.col(col).max()).collect().item()
                    else:
                        max_ts = df[col].max()
                    if max_ts:
                        from datetime import datetime, timedelta, timezone

                        duration_str = test.max_age
                        delta = None
                        if duration_str.endswith("h"):
                            delta = timedelta(hours=int(duration_str[:-1]))
                        elif duration_str.endswith("d"):
                            delta = timedelta(days=int(duration_str[:-1]))
                        elif duration_str.endswith("m"):
                            delta = timedelta(minutes=int(duration_str[:-1]))

                        if delta:
                            # Always compare in UTC — treat naive timestamps as UTC
                            now = datetime.now(timezone.utc)
                            if hasattr(max_ts, "tzinfo") and max_ts.tzinfo is None:
                                max_ts = max_ts.replace(tzinfo=timezone.utc)
                            if now - max_ts > delta:
                                msg = (
                                    f"Data too old. Max timestamp {max_ts} "
                                    f"is older than {test.max_age}"
                                )
                else:
                    msg = f"Freshness check failed: Column '{col}' not found"

            elif test.type == TestType.NOT_NULL:
                for col in test.columns:
                    if col in columns:
                        if is_lazy:
                            null_count = df.select(pl.col(col).is_null().sum()).collect().item()
                        else:
                            null_count = df[col].null_count()
                        if null_count > 0:
                            col_msg = f"Column '{col}' contains {null_count} NULLs"
                            ctx.debug(
                                "NOT_NULL check failed",
                                column=col,
                                null_count=null_count,
                                row_count=row_count,
                            )
                            res = self._handle_failure(col_msg, test)
                            if res:
                                failures.append(res)
                                if fail_fast:
                                    return [f for f in failures if f]
                continue

            elif test.type == TestType.UNIQUE:
                cols = [c for c in test.columns if c in columns]
                if len(cols) != len(test.columns):
                    msg = f"Unique check failed: Columns {set(test.columns) - set(cols)} not found"
                else:
                    if is_lazy:
                        dup_count = (
                            df.group_by(cols)
                            .agg(pl.len().alias("cnt"))
                            .filter(pl.col("cnt") > 1)
                            .select(pl.len())
                            .collect()
                            .item()
                        )
                    else:
                        dup_count = (
                            df.group_by(cols)
                            .agg(pl.len().alias("cnt"))
                            .filter(pl.col("cnt") > 1)
                            .height
                        )
                    if dup_count > 0:
                        msg = f"Column '{', '.join(cols)}' is not unique"
                        ctx.debug(
                            "UNIQUE check failed",
                            columns=cols,
                            duplicate_groups=dup_count,
                        )

            elif test.type == TestType.ACCEPTED_VALUES:
                col = test.column
                if col in columns:
                    if is_lazy:
                        invalid_count = (
                            df.filter(~pl.col(col).is_in(test.values))
                            .select(pl.len())
                            .collect()
                            .item()
                        )
                    else:
                        invalid_count = df.filter(~pl.col(col).is_in(test.values)).height
                    if invalid_count > 0:
                        if is_lazy:
                            examples = (
                                df.filter(~pl.col(col).is_in(test.values))
                                .select(pl.col(col))
                                .limit(3)
                                .collect()[col]
                                .to_list()
                            )
                        else:
                            invalid_rows = df.filter(~pl.col(col).is_in(test.values))
                            examples = invalid_rows[col].head(3).to_list()
                        msg = f"Column '{col}' contains invalid values. Found: {examples}"
                        ctx.debug(
                            "ACCEPTED_VALUES check failed",
                            column=col,
                            invalid_count=invalid_count,
                            examples=examples,
                        )
                else:
                    msg = f"Accepted values check failed: Column '{col}' not found"

            elif test.type == TestType.RANGE:
                col = test.column
                if col in columns:
                    cond = pl.lit(False)
                    if test.min is not None:
                        cond = cond | (pl.col(col) < test.min)
                    if test.max is not None:
                        cond = cond | (pl.col(col) > test.max)
                    if is_lazy:
                        invalid_count = df.filter(cond).select(pl.len()).collect().item()
                    else:
                        invalid_count = df.filter(cond).height
                    if invalid_count > 0:
                        msg = f"Column '{col}' contains {invalid_count} values out of range"
                        ctx.debug(
                            "RANGE check failed",
                            column=col,
                            invalid_count=invalid_count,
                            min=test.min,
                            max=test.max,
                        )
                else:
                    msg = f"Range check failed: Column '{col}' not found"

            elif test.type == TestType.REGEX_MATCH:
                col = test.column
                if col in columns:
                    regex_cond = pl.col(col).is_not_null() & ~pl.col(col).str.contains(test.pattern)
                    if is_lazy:
                        invalid_count = df.filter(regex_cond).select(pl.len()).collect().item()
                    else:
                        invalid_count = df.filter(regex_cond).height
                    if invalid_count > 0:
                        msg = (
                            f"Column '{col}' contains {invalid_count} values "
                            f"that does not match pattern '{test.pattern}'"
                        )
                        ctx.debug(
                            "REGEX_MATCH check failed",
                            column=col,
                            invalid_count=invalid_count,
                            pattern=test.pattern,
                        )
                else:
                    msg = f"Regex check failed: Column '{col}' not found"

            elif test.type == TestType.CUSTOM_SQL:
                ctx.warning(
                    "CUSTOM_SQL not fully supported in Polars; skipping",
                    test_name=getattr(test, "name", "custom_sql"),
                )
                continue

            if msg:
                res = self._handle_failure(msg, test)
                if res:
                    failures.append(res)
                    if fail_fast:
                        break

        return [f for f in failures if f]

    def _validate_spark(
        self, df: Any, config: ValidationConfig, context: Dict[str, Any] = None
    ) -> List[str]:
        """
        Execute checks using Spark SQL with optimizations.

        Optimizations:
        - Optional DataFrame caching when cache_df=True
        - Batched null count aggregation (single scan for all NOT_NULL columns)
        - Fail-fast mode to skip remaining tests
        - Reuses row_count instead of re-counting
        """
        from pyspark.sql import functions as F

        ctx = get_logging_context()
        failures = []
        fail_fast = getattr(config, "fail_fast", False)
        cache_df = getattr(config, "cache_df", False)

        df_work = df
        if cache_df:
            df_work = df.cache()
            ctx.debug("DataFrame cached for validation")

        row_count = df_work.count()
        ctx.debug("Validating Spark DataFrame", row_count=row_count)

        for test in config.tests:
            msg = None
            test_type = getattr(test, "type", "unknown")
            ctx.debug("Executing test", test_type=str(test_type))

            if test.type == TestType.ROW_COUNT:
                if test.min is not None and row_count < test.min:
                    msg = f"Row count {row_count} < min {test.min}"
                elif test.max is not None and row_count > test.max:
                    msg = f"Row count {row_count} > max {test.max}"

            elif test.type == TestType.SCHEMA:
                if context and "columns" in context:
                    expected = set(context["columns"].keys())
                    actual = set(df_work.columns)
                    if getattr(test, "strict", True):
                        if actual != expected:
                            msg = f"Schema mismatch. Expected {expected}, got {actual}"
                    else:
                        missing = expected - actual
                        if missing:
                            msg = f"Schema mismatch. Missing columns: {missing}"

            elif test.type == TestType.FRESHNESS:
                col = getattr(test, "column", "updated_at")
                if col in df_work.columns:
                    max_ts = df_work.agg(F.max(col)).collect()[0][0]
                    if max_ts:
                        from datetime import datetime, timedelta, timezone

                        duration_str = test.max_age
                        delta = None
                        if duration_str.endswith("h"):
                            delta = timedelta(hours=int(duration_str[:-1]))
                        elif duration_str.endswith("d"):
                            delta = timedelta(days=int(duration_str[:-1]))
                        elif duration_str.endswith("m"):
                            delta = timedelta(minutes=int(duration_str[:-1]))

                        if delta:
                            # Handle timezone-naive timestamps by using naive now()
                            if hasattr(max_ts, "tzinfo") and max_ts.tzinfo is not None:
                                now = datetime.now(timezone.utc)
                            else:
                                now = datetime.now()
                            if now - max_ts > delta:
                                msg = f"Data too old. Max timestamp {max_ts} is older than {test.max_age}"
                else:
                    msg = f"Freshness check failed: Column '{col}' not found"

            elif test.type == TestType.NOT_NULL:
                valid_cols = [c for c in test.columns if c in df_work.columns]
                if valid_cols:
                    null_aggs = [
                        F.sum(F.when(F.col(c).isNull(), 1).otherwise(0)).alias(c)
                        for c in valid_cols
                    ]
                    null_counts = df_work.agg(*null_aggs).collect()[0].asDict()
                    for col in valid_cols:
                        null_count = null_counts.get(col, 0) or 0
                        if null_count > 0:
                            col_msg = f"Column '{col}' contains {null_count} NULLs"
                            ctx.debug(
                                "NOT_NULL check failed",
                                column=col,
                                null_count=null_count,
                                row_count=row_count,
                            )
                            res = self._handle_failure(col_msg, test)
                            if res:
                                failures.append(res)
                                if fail_fast:
                                    if cache_df:
                                        df_work.unpersist()
                                    return failures
                continue

            elif test.type == TestType.UNIQUE:
                cols = [c for c in test.columns if c in df_work.columns]
                if len(cols) != len(test.columns):
                    msg = f"Unique check failed: Columns {set(test.columns) - set(cols)} not found"
                else:
                    dup_count = df_work.groupBy(*cols).count().filter("count > 1").count()
                    if dup_count > 0:
                        msg = f"Column '{', '.join(cols)}' is not unique"
                        ctx.debug(
                            "UNIQUE check failed",
                            columns=cols,
                            duplicate_groups=dup_count,
                        )

            elif test.type == TestType.ACCEPTED_VALUES:
                col = test.column
                if col in df_work.columns:
                    invalid_df = df_work.filter(~F.col(col).isin(*test.values))
                    invalid_count = invalid_df.count()
                    if invalid_count > 0:
                        examples_rows = invalid_df.select(col).limit(3).collect()
                        examples = [r[0] for r in examples_rows]
                        msg = f"Column '{col}' contains invalid values. Found: {examples}"
                        ctx.debug(
                            "ACCEPTED_VALUES check failed",
                            column=col,
                            invalid_count=invalid_count,
                            examples=examples,
                        )
                else:
                    msg = f"Accepted values check failed: Column '{col}' not found"

            elif test.type == TestType.RANGE:
                col = test.column
                if col in df_work.columns:
                    # Build range filter without Python | operator to avoid
                    # Py4J .or() dispatch issues across PySpark versions
                    range_parts = []
                    if test.min is not None:
                        range_parts.append(f"`{col}` < {test.min}")
                    if test.max is not None:
                        range_parts.append(f"`{col}` > {test.max}")

                    if not range_parts:
                        continue

                    invalid_count = df_work.filter(" OR ".join(range_parts)).count()
                    if invalid_count > 0:
                        msg = f"Column '{col}' contains {invalid_count} values out of range"
                        ctx.debug(
                            "RANGE check failed",
                            column=col,
                            invalid_count=invalid_count,
                            min=test.min,
                            max=test.max,
                        )
                else:
                    msg = f"Range check failed: Column '{col}' not found"

            elif test.type == TestType.REGEX_MATCH:
                col = test.column
                if col in df_work.columns:
                    # Use SQL string filter to avoid Py4J .and() dispatch issues
                    escaped_pattern = test.pattern.replace("'", "\\'")
                    invalid_count = df_work.filter(
                        f"`{col}` IS NOT NULL AND NOT `{col}` RLIKE '{escaped_pattern}'"
                    ).count()
                    if invalid_count > 0:
                        msg = (
                            f"Column '{col}' contains {invalid_count} values "
                            f"that does not match pattern '{test.pattern}'"
                        )
                        ctx.debug(
                            "REGEX_MATCH check failed",
                            column=col,
                            invalid_count=invalid_count,
                            pattern=test.pattern,
                        )
                else:
                    msg = f"Regex check failed: Column '{col}' not found"

            elif test.type == TestType.CUSTOM_SQL:
                try:
                    invalid_count = df_work.filter(f"NOT ({test.condition})").count()
                    if invalid_count > 0:
                        msg = (
                            f"Custom check '{getattr(test, 'name', 'custom_sql')}' failed. "
                            f"Found {invalid_count} invalid rows."
                        )
                        ctx.debug(
                            "CUSTOM_SQL check failed",
                            condition=test.condition,
                            invalid_count=invalid_count,
                        )
                except Exception as e:
                    msg = f"Failed to execute custom SQL '{test.condition}': {e}"
                    ctx.error(
                        "CUSTOM_SQL execution error",
                        condition=test.condition,
                        error=str(e),
                    )

            if msg:
                res = self._handle_failure(msg, test)
                if res:
                    failures.append(res)
                    if fail_fast:
                        break

        if cache_df:
            df_work.unpersist()

        return failures

    def _validate_pandas(
        self, df: Any, config: ValidationConfig, context: Dict[str, Any] = None
    ) -> List[str]:
        """
        Execute checks using Pandas with optimizations.

        Optimizations:
        - Single pass for UNIQUE (no double .duplicated() call)
        - Mask-based operations (no full DataFrame copies for invalid rows)
        - Memory-efficient example extraction
        - Fail-fast mode support
        """
        ctx = get_logging_context()
        failures = []
        row_count = len(df)
        fail_fast = getattr(config, "fail_fast", False)

        ctx.debug("Validating Pandas DataFrame", row_count=row_count)

        for test in config.tests:
            msg = None
            test_type = getattr(test, "type", "unknown")
            ctx.debug("Executing test", test_type=str(test_type))

            if test.type == TestType.SCHEMA:
                if context and "columns" in context:
                    expected = set(context["columns"].keys())
                    actual = set(df.columns)
                    if getattr(test, "strict", True):
                        if actual != expected:
                            msg = f"Schema mismatch. Expected {expected}, got {actual}"
                    else:
                        missing = expected - actual
                        if missing:
                            msg = f"Schema mismatch. Missing columns: {missing}"

            elif test.type == TestType.FRESHNESS:
                col = getattr(test, "column", "updated_at")
                if col in df.columns:
                    import pandas as pd

                    if not pd.api.types.is_datetime64_any_dtype(df[col]):
                        try:
                            s = pd.to_datetime(df[col])
                            max_ts = s.max()
                        except Exception as e:
                            logger = get_logging_context()
                            logger.debug(
                                f"Failed to convert column '{col}' to datetime for max_age check: {type(e).__name__}: {e}"
                            )
                            max_ts = None
                    else:
                        max_ts = df[col].max()

                    if max_ts is not None and max_ts is not pd.NaT:
                        from datetime import datetime, timedelta, timezone

                        duration_str = test.max_age
                        delta = None
                        if duration_str.endswith("h"):
                            delta = timedelta(hours=int(duration_str[:-1]))
                        elif duration_str.endswith("d"):
                            delta = timedelta(days=int(duration_str[:-1]))
                        elif duration_str.endswith("m"):
                            delta = timedelta(minutes=int(duration_str[:-1]))

                        if delta:
                            # Handle timezone-naive timestamps by using naive now()
                            if hasattr(max_ts, "tzinfo") and max_ts.tzinfo is not None:
                                now = datetime.now(timezone.utc)
                            else:
                                now = datetime.now()
                            if now - max_ts > delta:
                                msg = f"Data too old. Max timestamp {max_ts} is older than {test.max_age}"
                else:
                    msg = f"Freshness check failed: Column '{col}' not found"

            elif test.type == TestType.ROW_COUNT:
                if test.min is not None and row_count < test.min:
                    msg = f"Row count {row_count} < min {test.min}"
                elif test.max is not None and row_count > test.max:
                    msg = f"Row count {row_count} > max {test.max}"

            elif test.type == TestType.NOT_NULL:
                for col in test.columns:
                    if col in df.columns:
                        null_count = int(df[col].isnull().sum())
                        if null_count > 0:
                            col_msg = f"Column '{col}' contains {null_count} NULLs"
                            ctx.debug(
                                "NOT_NULL check failed",
                                column=col,
                                null_count=null_count,
                                row_count=row_count,
                            )
                            res = self._handle_failure(col_msg, test)
                            if res:
                                failures.append(res)
                                if fail_fast:
                                    return [f for f in failures if f]
                    else:
                        col_msg = f"Column '{col}' not found in DataFrame"
                        ctx.debug(
                            "NOT_NULL check failed - column missing",
                            column=col,
                        )
                        res = self._handle_failure(col_msg, test)
                        if res:
                            failures.append(res)
                            if fail_fast:
                                return [f for f in failures if f]
                continue

            elif test.type == TestType.UNIQUE:
                cols = [c for c in test.columns if c in df.columns]
                if len(cols) != len(test.columns):
                    msg = f"Unique check failed: Columns {set(test.columns) - set(cols)} not found"
                else:
                    dups = df.duplicated(subset=cols)
                    dup_count = int(dups.sum())
                    if dup_count > 0:
                        msg = f"Column '{', '.join(cols)}' is not unique"
                        ctx.debug(
                            "UNIQUE check failed",
                            columns=cols,
                            duplicate_rows=dup_count,
                        )

            elif test.type == TestType.ACCEPTED_VALUES:
                col = test.column
                if col in df.columns:
                    mask = ~df[col].isin(test.values)
                    invalid_count = int(mask.sum())
                    if invalid_count > 0:
                        examples = df.loc[mask, col].dropna().unique()[:3]
                        msg = f"Column '{col}' contains invalid values. Found: {list(examples)}"
                        ctx.debug(
                            "ACCEPTED_VALUES check failed",
                            column=col,
                            invalid_count=invalid_count,
                            examples=list(examples),
                        )
                else:
                    msg = f"Accepted values check failed: Column '{col}' not found"

            elif test.type == TestType.RANGE:
                col = test.column
                if col in df.columns:
                    invalid_count = 0
                    if test.min is not None:
                        invalid_count += int((df[col] < test.min).sum())
                    if test.max is not None:
                        invalid_count += int((df[col] > test.max).sum())

                    if invalid_count > 0:
                        msg = f"Column '{col}' contains {invalid_count} values out of range"
                        ctx.debug(
                            "RANGE check failed",
                            column=col,
                            invalid_count=invalid_count,
                            min=test.min,
                            max=test.max,
                        )
                else:
                    msg = f"Range check failed: Column '{col}' not found"

            elif test.type == TestType.REGEX_MATCH:
                col = test.column
                if col in df.columns:
                    valid_series = df[col].dropna().astype(str)
                    if not valid_series.empty:
                        matches = valid_series.str.match(test.pattern)
                        invalid_count = int((~matches).sum())
                        if invalid_count > 0:
                            msg = (
                                f"Column '{col}' contains {invalid_count} values "
                                f"that does not match pattern '{test.pattern}'"
                            )
                            ctx.debug(
                                "REGEX_MATCH check failed",
                                column=col,
                                invalid_count=invalid_count,
                                pattern=test.pattern,
                            )
                else:
                    msg = f"Regex check failed: Column '{col}' not found"

            elif test.type == TestType.CUSTOM_SQL:
                try:
                    mask = ~df.eval(test.condition)
                    invalid_count = int(mask.sum())
                    if invalid_count > 0:
                        msg = (
                            f"Custom check '{getattr(test, 'name', 'custom_sql')}' failed. "
                            f"Found {invalid_count} invalid rows."
                        )
                        ctx.debug(
                            "CUSTOM_SQL check failed",
                            condition=test.condition,
                            invalid_count=invalid_count,
                        )
                except Exception as e:
                    msg = f"Failed to execute custom SQL '{test.condition}': {e}"
                    ctx.error(
                        "CUSTOM_SQL execution error",
                        condition=test.condition,
                        error=str(e),
                    )

            if msg:
                res = self._handle_failure(msg, test)
                if res:
                    failures.append(res)
                    if fail_fast:
                        break

        return [f for f in failures if f]

validate(df, config, context=None)

Run validation checks against a DataFrame.

Parameters:

Name Type Description Default
df Any

Spark, Pandas, or Polars DataFrame

required
config ValidationConfig

Validation configuration

required
context Dict[str, Any]

Optional context (e.g. {'columns': ...}) for contracts

None

Returns:

Type Description
List[str]

List of error messages (empty if all checks pass)

Source code in odibi/validation/engine.py
def validate(
    self, df: Any, config: ValidationConfig, context: Dict[str, Any] = None
) -> List[str]:
    """
    Run validation checks against a DataFrame.

    Args:
        df: Spark, Pandas, or Polars DataFrame
        config: Validation configuration
        context: Optional context (e.g. {'columns': ...}) for contracts

    Returns:
        List of error messages (empty if all checks pass)
    """
    ctx = get_logging_context()
    test_count = len(config.tests)
    failures = []
    is_spark = False
    is_polars = False
    engine_type = "pandas"

    try:
        import pyspark

        if isinstance(df, pyspark.sql.DataFrame):
            is_spark = True
            engine_type = "spark"
    except ImportError:
        pass

    if not is_spark:
        try:
            import polars as pl

            if isinstance(df, (pl.DataFrame, pl.LazyFrame)):
                is_polars = True
                engine_type = "polars"
        except ImportError:
            pass

    ctx.debug(
        "Starting validation",
        test_count=test_count,
        engine=engine_type,
        df_type=type(df).__name__,
        fail_fast=getattr(config, "fail_fast", False),
    )

    if is_spark:
        failures = self._validate_spark(df, config, context)
    elif is_polars:
        failures = self._validate_polars(df, config, context)
    else:
        failures = self._validate_pandas(df, config, context)

    tests_passed = test_count - len(failures)
    ctx.info(
        "Validation complete",
        total_tests=test_count,
        tests_passed=tests_passed,
        tests_failed=len(failures),
        engine=engine_type,
    )

    ctx.log_validation_result(
        passed=len(failures) == 0,
        rule_name="batch_validation",
        failures=failures[:5] if failures else None,
        total_tests=test_count,
        tests_passed=tests_passed,
        tests_failed=len(failures),
    )

    return failures

odibi.validation.gate

Quality Gate support for batch-level validation.

Gates evaluate the entire batch before writing, ensuring data quality thresholds are met at the aggregate level.

GateResult dataclass

Result of gate evaluation.

Source code in odibi/validation/gate.py
@dataclass
class GateResult:
    """Result of gate evaluation."""

    passed: bool
    pass_rate: float
    total_rows: int
    passed_rows: int
    failed_rows: int
    details: Dict[str, Any] = field(default_factory=dict)
    action: GateOnFail = GateOnFail.ABORT
    failure_reasons: List[str] = field(default_factory=list)

evaluate_gate(df, validation_results, gate_config, engine, catalog=None, node_name=None)

Evaluate quality gate on validation results.

Parameters:

Name Type Description Default
df Any

DataFrame being validated

required
validation_results Dict[str, List[bool]]

Dict of test_name -> per-row boolean results (True=passed)

required
gate_config GateConfig

Gate configuration

required
engine Any

Engine instance

required
catalog Optional[Any]

Optional CatalogManager for historical row count checks

None
node_name Optional[str]

Optional node name for historical lookups

None

Returns:

Type Description
GateResult

GateResult with pass/fail status and action to take

Source code in odibi/validation/gate.py
def evaluate_gate(
    df: Any,
    validation_results: Dict[str, List[bool]],
    gate_config: GateConfig,
    engine: Any,
    catalog: Optional[Any] = None,
    node_name: Optional[str] = None,
) -> GateResult:
    """
    Evaluate quality gate on validation results.

    Args:
        df: DataFrame being validated
        validation_results: Dict of test_name -> per-row boolean results (True=passed)
        gate_config: Gate configuration
        engine: Engine instance
        catalog: Optional CatalogManager for historical row count checks
        node_name: Optional node name for historical lookups

    Returns:
        GateResult with pass/fail status and action to take
    """
    is_spark = False

    try:
        import pyspark

        if hasattr(engine, "spark") or isinstance(df, pyspark.sql.DataFrame):
            is_spark = True
    except ImportError:
        pass

    if is_spark:
        total_rows = df.count()
    elif hasattr(engine, "count_rows"):
        total_rows = engine.count_rows(df)
    else:
        total_rows = len(df)

    if total_rows == 0:
        return GateResult(
            passed=True,
            pass_rate=1.0,
            total_rows=0,
            passed_rows=0,
            failed_rows=0,
            action=gate_config.on_fail,
            details={"message": "Empty dataset - gate passed by default"},
        )

    passed_rows = total_rows
    if validation_results:
        all_pass_mask = None
        for test_name, results in validation_results.items():
            if len(results) == total_rows:
                if all_pass_mask is None:
                    all_pass_mask = results.copy()
                else:
                    all_pass_mask = [a and b for a, b in zip(all_pass_mask, results)]

        if all_pass_mask:
            passed_rows = sum(all_pass_mask)

    pass_rate = round(passed_rows / total_rows, 10) if total_rows > 0 else 1.0
    failed_rows = total_rows - passed_rows

    details: Dict[str, Any] = {
        "overall_pass_rate": pass_rate,
        "per_test_rates": {},
        "row_count_check": None,
    }

    gate_passed = True
    failure_reasons: List[str] = []

    if pass_rate < gate_config.require_pass_rate:
        gate_passed = False
        failure_reasons.append(
            f"Overall pass rate {pass_rate:.1%} < required {gate_config.require_pass_rate:.1%}"
        )

    for threshold in gate_config.thresholds:
        test_results = validation_results.get(threshold.test)
        if test_results:
            test_total = len(test_results)
            test_passed = sum(test_results)
            test_pass_rate = test_passed / test_total if test_total > 0 else 1.0
            details["per_test_rates"][threshold.test] = test_pass_rate

            if test_pass_rate < threshold.min_pass_rate:
                gate_passed = False
                failure_reasons.append(
                    f"Test '{threshold.test}' pass rate {test_pass_rate:.1%} "
                    f"< required {threshold.min_pass_rate:.1%}"
                )

    if gate_config.row_count:
        row_check = _check_row_count(
            total_rows,
            gate_config.row_count,
            catalog,
            node_name,
        )
        details["row_count_check"] = row_check

        if not row_check["passed"]:
            gate_passed = False
            failure_reasons.append(row_check["reason"])

    details["failure_reasons"] = failure_reasons

    if gate_passed:
        logger.info(f"Gate passed: {pass_rate:.1%} pass rate ({passed_rows}/{total_rows} rows)")
    else:
        logger.warning(f"Gate failed: {', '.join(failure_reasons)}")

    return GateResult(
        passed=gate_passed,
        pass_rate=pass_rate,
        total_rows=total_rows,
        passed_rows=passed_rows,
        failed_rows=failed_rows,
        details=details,
        action=gate_config.on_fail,
        failure_reasons=failure_reasons,
    )

odibi.validation.quarantine

Optimized quarantine table support for routing failed validation rows.

Performance optimizations: - Removed per-row test_results lists (O(N*tests) memory savings) - Added sampling/limiting for large invalid sets - Single pass for combined mask evaluation - No unnecessary Python list conversions

This module provides functionality to: 1. Split DataFrames into valid and invalid portions based on test results 2. Add metadata columns to quarantined rows 3. Write quarantined rows to a dedicated table (with optional sampling)

QuarantineResult dataclass

Result of quarantine operation.

Source code in odibi/validation/quarantine.py
@dataclass
class QuarantineResult:
    """Result of quarantine operation."""

    valid_df: Any
    invalid_df: Any
    rows_quarantined: int
    rows_valid: int
    test_results: Dict[str, Dict[str, int]] = field(default_factory=dict)
    failed_test_details: Dict[int, List[str]] = field(default_factory=dict)

add_quarantine_metadata(invalid_df, test_results, config, engine, node_name, run_id, tests)

Add metadata columns to quarantined rows.

Parameters:

Name Type Description Default
invalid_df Any

DataFrame of invalid rows

required
test_results Dict[str, Any]

Dict of test_name -> aggregate results (not per-row)

required
config QuarantineColumnsConfig

QuarantineColumnsConfig specifying which columns to add

required
engine Any

Engine instance

required
node_name str

Name of the originating node

required
run_id str

Current run ID

required
tests List[TestConfig]

List of test configurations (for building failure reasons)

required

Returns:

Type Description
Any

DataFrame with added metadata columns

Source code in odibi/validation/quarantine.py
def add_quarantine_metadata(
    invalid_df: Any,
    test_results: Dict[str, Any],
    config: QuarantineColumnsConfig,
    engine: Any,
    node_name: str,
    run_id: str,
    tests: List[TestConfig],
) -> Any:
    """
    Add metadata columns to quarantined rows.

    Args:
        invalid_df: DataFrame of invalid rows
        test_results: Dict of test_name -> aggregate results (not per-row)
        config: QuarantineColumnsConfig specifying which columns to add
        engine: Engine instance
        node_name: Name of the originating node
        run_id: Current run ID
        tests: List of test configurations (for building failure reasons)

    Returns:
        DataFrame with added metadata columns
    """
    is_spark = False
    is_polars = False

    try:
        import pyspark

        if hasattr(engine, "spark") or isinstance(invalid_df, pyspark.sql.DataFrame):
            is_spark = True
    except ImportError:
        pass

    if not is_spark:
        try:
            import polars as pl

            if isinstance(invalid_df, (pl.DataFrame, pl.LazyFrame)):
                is_polars = True
        except ImportError:
            pass

    rejected_at = datetime.now(timezone.utc).isoformat()

    quarantine_tests = [t for t in tests if t.on_fail == ContractSeverity.QUARANTINE]
    test_names = [t.name or f"{t.type.value}" for t in quarantine_tests]
    failed_tests_str = ",".join(test_names)
    all_tests_reason = f"Failed tests: {failed_tests_str}"

    if is_spark:
        from pyspark.sql import functions as F

        result_df = invalid_df

        if config.rejection_reason:
            result_df = result_df.withColumn("_rejection_reason", F.lit(all_tests_reason))

        if config.rejected_at:
            result_df = result_df.withColumn("_rejected_at", F.lit(rejected_at))

        if config.source_batch_id:
            result_df = result_df.withColumn("_source_batch_id", F.lit(run_id))

        if config.failed_tests:
            result_df = result_df.withColumn("_failed_tests", F.lit(failed_tests_str))

        if config.original_node:
            result_df = result_df.withColumn("_original_node", F.lit(node_name))

        return result_df

    elif is_polars:
        import polars as pl

        result_df = invalid_df

        if config.rejection_reason:
            result_df = result_df.with_columns(pl.lit(all_tests_reason).alias("_rejection_reason"))

        if config.rejected_at:
            result_df = result_df.with_columns(pl.lit(rejected_at).alias("_rejected_at"))

        if config.source_batch_id:
            result_df = result_df.with_columns(pl.lit(run_id).alias("_source_batch_id"))

        if config.failed_tests:
            result_df = result_df.with_columns(pl.lit(failed_tests_str).alias("_failed_tests"))

        if config.original_node:
            result_df = result_df.with_columns(pl.lit(node_name).alias("_original_node"))

        return result_df

    else:
        import pandas as pd

        result_df = invalid_df.copy()

        if config.rejection_reason:
            # Build per-row rejection reasons
            reasons = pd.Series([""] * len(result_df), index=result_df.index)
            for qt in quarantine_tests:
                mask = _evaluate_test_mask(result_df, qt, is_spark=False, is_polars=False)
                failed = ~mask
                name = qt.name or qt.type.value
                reasons = reasons.where(~failed, reasons + name + ",")
            result_df["_rejection_reason"] = "Failed: " + reasons.str.rstrip(",")

        if config.rejected_at:
            result_df["_rejected_at"] = rejected_at

        if config.source_batch_id:
            result_df["_source_batch_id"] = run_id

        if config.failed_tests:
            result_df["_failed_tests"] = failed_tests_str

        if config.original_node:
            result_df["_original_node"] = node_name

        return result_df

has_quarantine_tests(tests)

Check if any tests use quarantine severity.

Source code in odibi/validation/quarantine.py
def has_quarantine_tests(tests: List[TestConfig]) -> bool:
    """Check if any tests use quarantine severity."""
    return any(t.on_fail == ContractSeverity.QUARANTINE for t in tests)

split_valid_invalid(df, tests, engine)

Split DataFrame into valid and invalid portions based on quarantine tests.

Only tests with on_fail == QUARANTINE are evaluated for splitting. A row is invalid if it fails ANY quarantine test.

Performance: Removed per-row test_results lists to save O(N*tests) memory. Now stores only aggregate counts per test.

Parameters:

Name Type Description Default
df Any

DataFrame to split

required
tests List[TestConfig]

List of test configurations

required
engine Any

Engine instance (Spark, Pandas, or Polars)

required

Returns:

Type Description
QuarantineResult

QuarantineResult with valid_df, invalid_df, and test metadata

Source code in odibi/validation/quarantine.py
def split_valid_invalid(
    df: Any,
    tests: List[TestConfig],
    engine: Any,
) -> QuarantineResult:
    """
    Split DataFrame into valid and invalid portions based on quarantine tests.

    Only tests with on_fail == QUARANTINE are evaluated for splitting.
    A row is invalid if it fails ANY quarantine test.

    Performance: Removed per-row test_results lists to save O(N*tests) memory.
    Now stores only aggregate counts per test.

    Args:
        df: DataFrame to split
        tests: List of test configurations
        engine: Engine instance (Spark, Pandas, or Polars)

    Returns:
        QuarantineResult with valid_df, invalid_df, and test metadata
    """
    is_spark = False
    is_polars = False

    try:
        import pyspark

        if hasattr(engine, "spark") or isinstance(df, pyspark.sql.DataFrame):
            is_spark = True
    except ImportError:
        pass

    if not is_spark:
        try:
            import polars as pl

            if isinstance(df, (pl.DataFrame, pl.LazyFrame)):
                is_polars = True
        except ImportError:
            pass

    quarantine_tests = [t for t in tests if t.on_fail == ContractSeverity.QUARANTINE]

    if not quarantine_tests:
        if is_spark:
            from pyspark.sql import functions as F

            empty_df = df.filter(F.lit(False))
        elif is_polars:
            import polars as pl

            empty_df = df.filter(pl.lit(False))
        else:
            empty_df = df.iloc[0:0].copy()

        row_count = engine.count_rows(df) if hasattr(engine, "count_rows") else len(df)
        return QuarantineResult(
            valid_df=df,
            invalid_df=empty_df,
            rows_quarantined=0,
            rows_valid=row_count,
            test_results={},
            failed_test_details={},
        )

    test_masks = {}
    test_names = []

    for idx, test in enumerate(quarantine_tests):
        base_name = test.name or f"{test.type.value}"
        test_name = base_name if base_name not in test_masks else f"{base_name}_{idx}"
        test_names.append(test_name)
        mask = _evaluate_test_mask(df, test, is_spark, is_polars)
        test_masks[test_name] = mask

    if is_spark:
        from pyspark.sql import functions as F

        combined_valid_mask = F.lit(True)
        for mask in test_masks.values():
            combined_valid_mask = combined_valid_mask & mask

        df_cached = df.cache()

        valid_df = df_cached.filter(combined_valid_mask)
        invalid_df = df_cached.filter(~combined_valid_mask)

        valid_df = valid_df.cache()
        invalid_df = invalid_df.cache()

        rows_valid = valid_df.count()
        rows_quarantined = invalid_df.count()
        total = rows_valid + rows_quarantined

        test_results = {}
        for name, mask in test_masks.items():
            pass_count = df_cached.filter(mask).count()
            fail_count = total - pass_count
            test_results[name] = {"pass_count": pass_count, "fail_count": fail_count}

        df_cached.unpersist()

    elif is_polars:
        import polars as pl

        combined_valid_mask = pl.lit(True)
        for mask in test_masks.values():
            combined_valid_mask = combined_valid_mask & mask

        valid_df = df.filter(combined_valid_mask)
        invalid_df = df.filter(~combined_valid_mask)

        rows_valid = len(valid_df)
        rows_quarantined = len(invalid_df)

        test_results = {}

    else:
        import pandas as pd

        combined_valid_mask = pd.Series([True] * len(df), index=df.index)
        for mask in test_masks.values():
            combined_valid_mask = combined_valid_mask & mask

        valid_df = df[combined_valid_mask].copy()
        invalid_df = df[~combined_valid_mask].copy()

        rows_valid = len(valid_df)
        rows_quarantined = len(invalid_df)

        test_results = {}
        for name, mask in test_masks.items():
            pass_count = int(mask.sum())
            fail_count = len(df) - pass_count
            test_results[name] = {"pass_count": pass_count, "fail_count": fail_count}

    logger.info(f"Quarantine split: {rows_valid} valid, {rows_quarantined} invalid")

    return QuarantineResult(
        valid_df=valid_df,
        invalid_df=invalid_df,
        rows_quarantined=rows_quarantined,
        rows_valid=rows_valid,
        test_results=test_results,
        failed_test_details={},
    )

write_quarantine(invalid_df, config, engine, connections)

Write quarantined rows to destination (always append mode).

Supports optional sampling/limiting via config.max_rows and config.sample_fraction.

Parameters:

Name Type Description Default
invalid_df Any

DataFrame of invalid rows with metadata

required
config QuarantineConfig

QuarantineConfig specifying destination and sampling options

required
engine Any

Engine instance

required
connections Dict[str, Any]

Dict of connection configurations

required

Returns:

Type Description
Dict[str, Any]

Dict with write result metadata

Source code in odibi/validation/quarantine.py
def write_quarantine(
    invalid_df: Any,
    config: QuarantineConfig,
    engine: Any,
    connections: Dict[str, Any],
) -> Dict[str, Any]:
    """
    Write quarantined rows to destination (always append mode).

    Supports optional sampling/limiting via config.max_rows and config.sample_fraction.

    Args:
        invalid_df: DataFrame of invalid rows with metadata
        config: QuarantineConfig specifying destination and sampling options
        engine: Engine instance
        connections: Dict of connection configurations

    Returns:
        Dict with write result metadata
    """
    is_spark = False
    is_polars = False

    try:
        import pyspark

        if hasattr(engine, "spark") or isinstance(invalid_df, pyspark.sql.DataFrame):
            is_spark = True
    except ImportError:
        pass

    if not is_spark:
        try:
            import polars as pl

            if isinstance(invalid_df, (pl.DataFrame, pl.LazyFrame)):
                is_polars = True
        except ImportError:
            pass

    invalid_df = _apply_sampling(invalid_df, config, is_spark, is_polars)

    if is_spark:
        row_count = invalid_df.count()
    elif is_polars:
        row_count = len(invalid_df)
    else:
        row_count = len(invalid_df)

    if row_count == 0:
        return {
            "rows_quarantined": 0,
            "quarantine_path": config.path or config.table,
            "write_info": None,
        }

    connection = connections.get(config.connection)
    if connection is None:
        raise ValueError(
            f"Quarantine connection '{config.connection}' not found. "
            f"Available: {', '.join(connections.keys())}"
        )

    try:
        write_result = engine.write(
            invalid_df,
            connection=connection,
            format="delta" if config.table else "parquet",
            path=config.path,
            table=config.table,
            mode="append",
        )
    except Exception as e:
        logger.error(f"Failed to write quarantine data: {e}")
        raise

    logger.info(f"Wrote {row_count} rows to quarantine: {config.path or config.table}")

    return {
        "rows_quarantined": row_count,
        "quarantine_path": config.path or config.table,
        "write_info": write_result,
    }

odibi.validation.fk

Foreign Key Validation Module

Declare and validate referential integrity between fact and dimension tables.

Features: - Declare relationships in YAML - Validate referential integrity on fact load - Detect orphan records - Generate lineage from relationships - Integration with FactPattern

Example Config
  • name: orders_to_customers fact: fact_orders dimension: dim_customer fact_key: customer_sk dimension_key: customer_sk

  • name: orders_to_products fact: fact_orders dimension: dim_product fact_key: product_sk dimension_key: product_sk

FKValidationReport dataclass

Complete FK validation report for a fact table.

Source code in odibi/validation/fk.py
@dataclass
class FKValidationReport:
    """Complete FK validation report for a fact table."""

    fact_table: str
    all_valid: bool
    total_relationships: int
    valid_relationships: int
    results: List[FKValidationResult] = field(default_factory=list)
    orphan_records: List[OrphanRecord] = field(default_factory=list)
    elapsed_ms: float = 0.0

FKValidationResult dataclass

Result of FK validation.

Source code in odibi/validation/fk.py
@dataclass
class FKValidationResult:
    """Result of FK validation."""

    relationship_name: str
    valid: bool
    total_rows: int
    orphan_count: int
    null_count: int
    orphan_values: List[Any] = field(default_factory=list)
    elapsed_ms: float = 0.0
    error: Optional[str] = None

FKValidator

Validate foreign key relationships between fact and dimension tables.

Usage

registry = RelationshipRegistry(relationships=[...]) validator = FKValidator(registry) report = validator.validate_fact(fact_df, "fact_orders", context)

Source code in odibi/validation/fk.py
class FKValidator:
    """
    Validate foreign key relationships between fact and dimension tables.

    Usage:
        registry = RelationshipRegistry(relationships=[...])
        validator = FKValidator(registry)
        report = validator.validate_fact(fact_df, "fact_orders", context)
    """

    def __init__(self, registry: RelationshipRegistry):
        """
        Initialize with relationship registry.

        Args:
            registry: RelationshipRegistry with relationship definitions
        """
        self.registry = registry

    def validate_relationship(
        self,
        fact_df: Any,
        relationship: RelationshipConfig,
        context: EngineContext,
    ) -> FKValidationResult:
        """
        Validate a single FK relationship.

        Args:
            fact_df: Fact DataFrame to validate
            relationship: Relationship configuration
            context: EngineContext with dimension data

        Returns:
            FKValidationResult with validation details
        """
        ctx = get_logging_context()
        start_time = time.time()

        ctx.debug(
            "Validating FK relationship",
            relationship=relationship.name,
            fact=relationship.fact,
            dimension=relationship.dimension,
        )

        try:
            dim_df = context.get(relationship.dimension)
        except KeyError:
            elapsed_ms = (time.time() - start_time) * 1000
            return FKValidationResult(
                relationship_name=relationship.name,
                valid=False,
                total_rows=0,
                orphan_count=-1,
                null_count=0,
                elapsed_ms=elapsed_ms,
                error=f"Dimension table '{relationship.dimension}' not found",
            )

        try:
            if context.engine_type == EngineType.SPARK:
                result = self._validate_spark(fact_df, dim_df, relationship)
            else:
                result = self._validate_pandas(fact_df, dim_df, relationship)

            elapsed_ms = (time.time() - start_time) * 1000
            result.elapsed_ms = elapsed_ms

            if result.valid:
                ctx.debug(
                    "FK validation passed",
                    relationship=relationship.name,
                    total_rows=result.total_rows,
                )
            else:
                ctx.warning(
                    "FK validation failed",
                    relationship=relationship.name,
                    orphan_count=result.orphan_count,
                    null_count=result.null_count,
                )

            return result

        except Exception as e:
            elapsed_ms = (time.time() - start_time) * 1000
            ctx.error(
                f"FK validation error: {e}",
                relationship=relationship.name,
            )
            return FKValidationResult(
                relationship_name=relationship.name,
                valid=False,
                total_rows=0,
                orphan_count=0,
                null_count=0,
                elapsed_ms=elapsed_ms,
                error=str(e),
            )

    def _validate_spark(
        self,
        fact_df: Any,
        dim_df: Any,
        relationship: RelationshipConfig,
    ) -> FKValidationResult:
        """Validate using Spark."""
        from pyspark.sql import functions as F

        fk_col = relationship.fact_key
        dk_col = relationship.dimension_key

        total_rows = fact_df.count()

        null_count = fact_df.filter(F.col(fk_col).isNull()).count()

        dim_keys = dim_df.select(F.col(dk_col).alias("_dim_key")).distinct()

        non_null_facts = fact_df.filter(F.col(fk_col).isNotNull())
        orphans = non_null_facts.join(
            dim_keys,
            non_null_facts[fk_col] == dim_keys["_dim_key"],
            "left_anti",
        )

        orphan_count = orphans.count()

        orphan_values = []
        if orphan_count > 0 and orphan_count <= 100:
            orphan_values = [
                row[fk_col] for row in orphans.select(fk_col).distinct().limit(100).collect()
            ]

        is_valid = orphan_count == 0 and (relationship.nullable or null_count == 0)

        return FKValidationResult(
            relationship_name=relationship.name,
            valid=is_valid,
            total_rows=total_rows,
            orphan_count=orphan_count,
            null_count=null_count,
            orphan_values=orphan_values,
        )

    def _validate_pandas(
        self,
        fact_df: Any,
        dim_df: Any,
        relationship: RelationshipConfig,
    ) -> FKValidationResult:
        """Validate using Pandas."""

        fk_col = relationship.fact_key
        dk_col = relationship.dimension_key

        total_rows = len(fact_df)

        null_count = int(fact_df[fk_col].isna().sum())

        dim_keys = set(dim_df[dk_col].dropna().unique())

        non_null_fks = fact_df[fk_col].dropna()
        orphan_mask = ~non_null_fks.isin(dim_keys)
        orphan_count = int(orphan_mask.sum())

        orphan_values = []
        if orphan_count > 0:
            orphan_values = list(non_null_fks[orphan_mask].unique()[:100])

        is_valid = orphan_count == 0 and (relationship.nullable or null_count == 0)

        return FKValidationResult(
            relationship_name=relationship.name,
            valid=is_valid,
            total_rows=total_rows,
            orphan_count=orphan_count,
            null_count=null_count,
            orphan_values=orphan_values,
        )

    def validate_fact(
        self,
        fact_df: Any,
        fact_table: str,
        context: EngineContext,
    ) -> FKValidationReport:
        """
        Validate all FK relationships for a fact table.

        Args:
            fact_df: Fact DataFrame to validate
            fact_table: Fact table name
            context: EngineContext with dimension data

        Returns:
            FKValidationReport with all validation results
        """
        ctx = get_logging_context()
        start_time = time.time()

        ctx.info("Starting FK validation", fact_table=fact_table)

        relationships = self.registry.get_fact_relationships(fact_table)

        if not relationships:
            ctx.warning(
                "No FK relationships defined",
                fact_table=fact_table,
            )
            return FKValidationReport(
                fact_table=fact_table,
                all_valid=True,
                total_relationships=0,
                valid_relationships=0,
                elapsed_ms=(time.time() - start_time) * 1000,
            )

        results = []
        all_orphans = []

        for relationship in relationships:
            result = self.validate_relationship(fact_df, relationship, context)
            results.append(result)

            if result.orphan_count > 0:
                for orphan_val in result.orphan_values:
                    all_orphans.append(
                        OrphanRecord(
                            fact_key_value=orphan_val,
                            fact_key_column=relationship.fact_key,
                            dimension_table=relationship.dimension,
                        )
                    )

        all_valid = all(r.valid for r in results)
        valid_count = sum(1 for r in results if r.valid)
        elapsed_ms = (time.time() - start_time) * 1000

        if all_valid:
            ctx.info(
                "FK validation passed",
                fact_table=fact_table,
                relationships=len(relationships),
            )
        else:
            ctx.warning(
                "FK validation failed",
                fact_table=fact_table,
                valid=valid_count,
                total=len(relationships),
            )

        return FKValidationReport(
            fact_table=fact_table,
            all_valid=all_valid,
            total_relationships=len(relationships),
            valid_relationships=valid_count,
            results=results,
            orphan_records=all_orphans,
            elapsed_ms=elapsed_ms,
        )

__init__(registry)

Initialize with relationship registry.

Parameters:

Name Type Description Default
registry RelationshipRegistry

RelationshipRegistry with relationship definitions

required
Source code in odibi/validation/fk.py
def __init__(self, registry: RelationshipRegistry):
    """
    Initialize with relationship registry.

    Args:
        registry: RelationshipRegistry with relationship definitions
    """
    self.registry = registry

validate_fact(fact_df, fact_table, context)

Validate all FK relationships for a fact table.

Parameters:

Name Type Description Default
fact_df Any

Fact DataFrame to validate

required
fact_table str

Fact table name

required
context EngineContext

EngineContext with dimension data

required

Returns:

Type Description
FKValidationReport

FKValidationReport with all validation results

Source code in odibi/validation/fk.py
def validate_fact(
    self,
    fact_df: Any,
    fact_table: str,
    context: EngineContext,
) -> FKValidationReport:
    """
    Validate all FK relationships for a fact table.

    Args:
        fact_df: Fact DataFrame to validate
        fact_table: Fact table name
        context: EngineContext with dimension data

    Returns:
        FKValidationReport with all validation results
    """
    ctx = get_logging_context()
    start_time = time.time()

    ctx.info("Starting FK validation", fact_table=fact_table)

    relationships = self.registry.get_fact_relationships(fact_table)

    if not relationships:
        ctx.warning(
            "No FK relationships defined",
            fact_table=fact_table,
        )
        return FKValidationReport(
            fact_table=fact_table,
            all_valid=True,
            total_relationships=0,
            valid_relationships=0,
            elapsed_ms=(time.time() - start_time) * 1000,
        )

    results = []
    all_orphans = []

    for relationship in relationships:
        result = self.validate_relationship(fact_df, relationship, context)
        results.append(result)

        if result.orphan_count > 0:
            for orphan_val in result.orphan_values:
                all_orphans.append(
                    OrphanRecord(
                        fact_key_value=orphan_val,
                        fact_key_column=relationship.fact_key,
                        dimension_table=relationship.dimension,
                    )
                )

    all_valid = all(r.valid for r in results)
    valid_count = sum(1 for r in results if r.valid)
    elapsed_ms = (time.time() - start_time) * 1000

    if all_valid:
        ctx.info(
            "FK validation passed",
            fact_table=fact_table,
            relationships=len(relationships),
        )
    else:
        ctx.warning(
            "FK validation failed",
            fact_table=fact_table,
            valid=valid_count,
            total=len(relationships),
        )

    return FKValidationReport(
        fact_table=fact_table,
        all_valid=all_valid,
        total_relationships=len(relationships),
        valid_relationships=valid_count,
        results=results,
        orphan_records=all_orphans,
        elapsed_ms=elapsed_ms,
    )

validate_relationship(fact_df, relationship, context)

Validate a single FK relationship.

Parameters:

Name Type Description Default
fact_df Any

Fact DataFrame to validate

required
relationship RelationshipConfig

Relationship configuration

required
context EngineContext

EngineContext with dimension data

required

Returns:

Type Description
FKValidationResult

FKValidationResult with validation details

Source code in odibi/validation/fk.py
def validate_relationship(
    self,
    fact_df: Any,
    relationship: RelationshipConfig,
    context: EngineContext,
) -> FKValidationResult:
    """
    Validate a single FK relationship.

    Args:
        fact_df: Fact DataFrame to validate
        relationship: Relationship configuration
        context: EngineContext with dimension data

    Returns:
        FKValidationResult with validation details
    """
    ctx = get_logging_context()
    start_time = time.time()

    ctx.debug(
        "Validating FK relationship",
        relationship=relationship.name,
        fact=relationship.fact,
        dimension=relationship.dimension,
    )

    try:
        dim_df = context.get(relationship.dimension)
    except KeyError:
        elapsed_ms = (time.time() - start_time) * 1000
        return FKValidationResult(
            relationship_name=relationship.name,
            valid=False,
            total_rows=0,
            orphan_count=-1,
            null_count=0,
            elapsed_ms=elapsed_ms,
            error=f"Dimension table '{relationship.dimension}' not found",
        )

    try:
        if context.engine_type == EngineType.SPARK:
            result = self._validate_spark(fact_df, dim_df, relationship)
        else:
            result = self._validate_pandas(fact_df, dim_df, relationship)

        elapsed_ms = (time.time() - start_time) * 1000
        result.elapsed_ms = elapsed_ms

        if result.valid:
            ctx.debug(
                "FK validation passed",
                relationship=relationship.name,
                total_rows=result.total_rows,
            )
        else:
            ctx.warning(
                "FK validation failed",
                relationship=relationship.name,
                orphan_count=result.orphan_count,
                null_count=result.null_count,
            )

        return result

    except Exception as e:
        elapsed_ms = (time.time() - start_time) * 1000
        ctx.error(
            f"FK validation error: {e}",
            relationship=relationship.name,
        )
        return FKValidationResult(
            relationship_name=relationship.name,
            valid=False,
            total_rows=0,
            orphan_count=0,
            null_count=0,
            elapsed_ms=elapsed_ms,
            error=str(e),
        )

OrphanRecord dataclass

Details of an orphan record.

Source code in odibi/validation/fk.py
@dataclass
class OrphanRecord:
    """Details of an orphan record."""

    fact_key_value: Any
    fact_key_column: str
    dimension_table: str
    row_index: Optional[int] = None

RelationshipConfig

Bases: BaseModel

Configuration for a foreign key relationship.

Attributes:

Name Type Description
name str

Unique relationship identifier

fact str

Fact table name

dimension str

Dimension table name

fact_key str

Foreign key column in fact table

dimension_key str

Primary/surrogate key column in dimension

nullable bool

Whether nulls are allowed in fact_key

on_violation str

Action on violation ("warn", "error", "quarantine")

Source code in odibi/validation/fk.py
class RelationshipConfig(BaseModel):
    """
    Configuration for a foreign key relationship.

    Attributes:
        name: Unique relationship identifier
        fact: Fact table name
        dimension: Dimension table name
        fact_key: Foreign key column in fact table
        dimension_key: Primary/surrogate key column in dimension
        nullable: Whether nulls are allowed in fact_key
        on_violation: Action on violation ("warn", "error", "quarantine")
    """

    name: str = Field(..., description="Unique relationship identifier")
    fact: str = Field(..., description="Fact table name")
    dimension: str = Field(..., description="Dimension table name")
    fact_key: str = Field(..., description="FK column in fact table")
    dimension_key: str = Field(..., description="PK/SK column in dimension")
    nullable: bool = Field(default=False, description="Allow nulls in fact_key")
    on_violation: str = Field(default="error", description="Action on violation")

    @field_validator("name", "fact", "dimension", "fact_key", "dimension_key")
    @classmethod
    def validate_not_empty(cls, v: str, info) -> str:
        if not v or not v.strip():
            raise ValueError(
                f"RelationshipConfig.{info.field_name} cannot be empty. "
                f"Got: {v!r}. Provide a non-empty string value."
            )
        return v.strip()

    @field_validator("on_violation")
    @classmethod
    def validate_on_violation(cls, v: str) -> str:
        valid = ("warn", "error", "quarantine")
        if v.lower() not in valid:
            raise ValueError(f"Invalid on_violation value. Expected one of {valid}, got: {v!r}.")
        return v.lower()

RelationshipRegistry

Bases: BaseModel

Registry of all declared relationships.

Attributes:

Name Type Description
relationships List[RelationshipConfig]

List of relationship configurations

Source code in odibi/validation/fk.py
class RelationshipRegistry(BaseModel):
    """
    Registry of all declared relationships.

    Attributes:
        relationships: List of relationship configurations
    """

    relationships: List[RelationshipConfig] = Field(
        default_factory=list, description="Relationship definitions"
    )

    def get_relationship(self, name: str) -> Optional[RelationshipConfig]:
        """Get a relationship by name."""
        for rel in self.relationships:
            if rel.name.lower() == name.lower():
                return rel
        return None

    def get_fact_relationships(self, fact_table: str) -> List[RelationshipConfig]:
        """Get all relationships for a fact table."""
        return [rel for rel in self.relationships if rel.fact.lower() == fact_table.lower()]

    def get_dimension_relationships(self, dim_table: str) -> List[RelationshipConfig]:
        """Get all relationships referencing a dimension."""
        return [rel for rel in self.relationships if rel.dimension.lower() == dim_table.lower()]

    def generate_lineage(self) -> Dict[str, List[str]]:
        """
        Generate lineage map from relationships.

        Returns:
            Dict mapping fact tables to their dimension dependencies
        """
        lineage: Dict[str, List[str]] = {}
        for rel in self.relationships:
            if rel.fact not in lineage:
                lineage[rel.fact] = []
            if rel.dimension not in lineage[rel.fact]:
                lineage[rel.fact].append(rel.dimension)
        return lineage

generate_lineage()

Generate lineage map from relationships.

Returns:

Type Description
Dict[str, List[str]]

Dict mapping fact tables to their dimension dependencies

Source code in odibi/validation/fk.py
def generate_lineage(self) -> Dict[str, List[str]]:
    """
    Generate lineage map from relationships.

    Returns:
        Dict mapping fact tables to their dimension dependencies
    """
    lineage: Dict[str, List[str]] = {}
    for rel in self.relationships:
        if rel.fact not in lineage:
            lineage[rel.fact] = []
        if rel.dimension not in lineage[rel.fact]:
            lineage[rel.fact].append(rel.dimension)
    return lineage

get_dimension_relationships(dim_table)

Get all relationships referencing a dimension.

Source code in odibi/validation/fk.py
def get_dimension_relationships(self, dim_table: str) -> List[RelationshipConfig]:
    """Get all relationships referencing a dimension."""
    return [rel for rel in self.relationships if rel.dimension.lower() == dim_table.lower()]

get_fact_relationships(fact_table)

Get all relationships for a fact table.

Source code in odibi/validation/fk.py
def get_fact_relationships(self, fact_table: str) -> List[RelationshipConfig]:
    """Get all relationships for a fact table."""
    return [rel for rel in self.relationships if rel.fact.lower() == fact_table.lower()]

get_relationship(name)

Get a relationship by name.

Source code in odibi/validation/fk.py
def get_relationship(self, name: str) -> Optional[RelationshipConfig]:
    """Get a relationship by name."""
    for rel in self.relationships:
        if rel.name.lower() == name.lower():
            return rel
    return None

get_orphan_records(fact_df, relationship, dim_df, engine_type)

Extract orphan records from a fact table.

Parameters:

Name Type Description Default
fact_df Any

Fact DataFrame

required
relationship RelationshipConfig

Relationship configuration

required
dim_df Any

Dimension DataFrame

required
engine_type EngineType

Engine type (SPARK or PANDAS)

required

Returns:

Type Description
Any

DataFrame containing orphan records

Source code in odibi/validation/fk.py
def get_orphan_records(
    fact_df: Any,
    relationship: RelationshipConfig,
    dim_df: Any,
    engine_type: EngineType,
) -> Any:
    """
    Extract orphan records from a fact table.

    Args:
        fact_df: Fact DataFrame
        relationship: Relationship configuration
        dim_df: Dimension DataFrame
        engine_type: Engine type (SPARK or PANDAS)

    Returns:
        DataFrame containing orphan records
    """
    fk_col = relationship.fact_key
    dk_col = relationship.dimension_key

    if engine_type == EngineType.SPARK:
        from pyspark.sql import functions as F

        dim_keys = dim_df.select(F.col(dk_col).alias("_dim_key")).distinct()
        non_null_facts = fact_df.filter(F.col(fk_col).isNotNull())
        orphans = non_null_facts.join(
            dim_keys,
            non_null_facts[fk_col] == dim_keys["_dim_key"],
            "left_anti",
        )
        return orphans
    else:
        dim_keys = set(dim_df[dk_col].dropna().unique())
        non_null_mask = fact_df[fk_col].notna()
        orphan_mask = ~fact_df[fk_col].isin(dim_keys) & non_null_mask
        return fact_df[orphan_mask].copy()

parse_relationships_config(config_dict)

Parse relationships from a configuration dictionary.

Parameters:

Name Type Description Default
config_dict Dict[str, Any]

Config dict with "relationships" key

required

Returns:

Type Description
RelationshipRegistry

RelationshipRegistry instance

Source code in odibi/validation/fk.py
def parse_relationships_config(config_dict: Dict[str, Any]) -> RelationshipRegistry:
    """
    Parse relationships from a configuration dictionary.

    Args:
        config_dict: Config dict with "relationships" key

    Returns:
        RelationshipRegistry instance
    """
    relationships = []
    for rel_dict in config_dict.get("relationships", []):
        relationships.append(RelationshipConfig(**rel_dict))
    return RelationshipRegistry(relationships=relationships)

validate_fk_on_load(fact_df, relationships, context, on_failure='error')

Validate FK constraints and optionally filter orphans.

This is a convenience function for use in FactPattern.

Parameters:

Name Type Description Default
fact_df Any

Fact DataFrame to validate

required
relationships List[RelationshipConfig]

List of relationship configs

required
context EngineContext

EngineContext with dimension data

required
on_failure str

Action on failure ("error", "warn", "filter")

'error'

Returns:

Type Description
Any

fact_df (possibly filtered if on_failure="filter")

Raises:

Type Description
ValueError

If on_failure="error" and validation fails

Source code in odibi/validation/fk.py
def validate_fk_on_load(
    fact_df: Any,
    relationships: List[RelationshipConfig],
    context: EngineContext,
    on_failure: str = "error",
) -> Any:
    """
    Validate FK constraints and optionally filter orphans.

    This is a convenience function for use in FactPattern.

    Args:
        fact_df: Fact DataFrame to validate
        relationships: List of relationship configs
        context: EngineContext with dimension data
        on_failure: Action on failure ("error", "warn", "filter")

    Returns:
        fact_df (possibly filtered if on_failure="filter")

    Raises:
        ValueError: If on_failure="error" and validation fails
    """
    ctx = get_logging_context()

    registry = RelationshipRegistry(relationships=relationships)
    validator = FKValidator(registry)

    for rel in relationships:
        result = validator.validate_relationship(fact_df, rel, context)

        if not result.valid:
            if on_failure == "error":
                raise ValueError(
                    f"FK validation failed for '{rel.name}': "
                    f"{result.orphan_count} orphans, {result.null_count} nulls. "
                    f"Sample orphan values: {result.orphan_values[:5]}"
                )
            elif on_failure == "warn":
                ctx.warning(
                    f"FK validation warning for '{rel.name}': "
                    f"{result.orphan_count} orphans, {result.null_count} nulls"
                )
            elif on_failure == "filter":
                try:
                    dim_df = context.get(rel.dimension)
                except KeyError:
                    continue

                if context.engine_type == EngineType.SPARK:
                    from pyspark.sql import functions as F

                    dim_keys = dim_df.select(F.col(rel.dimension_key).alias("_fk_key")).distinct()
                    fact_df = fact_df.join(
                        dim_keys,
                        fact_df[rel.fact_key] == dim_keys["_fk_key"],
                        "inner",
                    ).drop("_fk_key")
                else:
                    dim_keys = set(dim_df[rel.dimension_key].dropna().unique())
                    fact_df = fact_df[fact_df[rel.fact_key].isin(dim_keys)].copy()

                ctx.info(
                    f"Filtered orphans for '{rel.name}'",
                    remaining_rows=len(fact_df) if hasattr(fact_df, "__len__") else "N/A",
                )

    return fact_df