Skip to content

JJ-Han/nli-consensus-selection

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Premise-Direction NLI Consensus for Response Selection in LLMs

Selecting the most reliable response from multiple stochastic LLM samples using directional NLI consensus — evaluated on TruthfulQA.

CS 9860: Advanced Machine Learning · University of Western Ontario

📄 Full Report (PDF)


TL;DR

Instead of trying to reduce an LLM's randomness, this project asks: given several responses sampled at temperature 1.0, can we select the most truthful one? The approach builds a pairwise Natural Language Inference (NLI) matrix over candidates and picks the response that most strongly entails the rest in the premise direction.

Three findings:

  1. Direction is the driver. Premise-direction scoring beats greedy decoding (p < 10⁻⁵); reversing to the hypothesis direction erases the gain entirely (p > 0.13). An ablation confirms this asymmetry is the core mechanism, not an incidental effect.
  2. Simplicity wins. A plain row-sum baseline matches or beats more structured methods (clustering, directed graph) under an N=5 sampling budget — no statistically significant difference between them.
  3. Evaluation signal must be filtered. A reliability filter isolates the 74.3% of cases where the NLI signal is discriminative; outside that regime the consensus signal is indistinguishable from noise.

Key Results

On a filtered subset of TruthfulQA (325 questions, 13 categories), within the reliable evaluation tier:

Method Correctness vs. Greedy p-value
Oracle (upper bound) 0.6452 +0.299
Baseline (ours) 0.3979 +0.052 1.5×10⁻⁶
Clustering (ours) 0.3962 +0.050 3.2×10⁻⁶
DLG (ours) 0.3950 +0.049 6.8×10⁻⁶
Greedy decoding 0.3459

The Baseline closes roughly ~17% of the gap between greedy decoding and the oracle upper bound. All three proposed methods improve over greedy with high significance, and cross-model validation with BART shows directionally consistent gains — suggesting the improvements are not solely attributable to using a single NLI model.


Method

The pipeline runs in four stages:

  1. Generate — sample N=5 responses per question from Llama 3.1 8B (4-bit) at temperature 1.0, across 10 runs.
  2. Score — build an N×N pairwise NLI matrix M[i][j] = P(entail) − P(contradict) with DeBERTa-v3, using R_i as premise and R_j as hypothesis.
  3. Select — pick the most consensual response via one of three methods:
    • Baseline — highest row-sum (most strongly entails the pool)
    • Clustering — average-linkage on the symmetrized matrix
    • DLG — highest out-degree in a thresholded directed graph
  4. Evaluate — score the selected response against reference answers, partitioned by a reliability filter.

The key idea is directional asymmetry: NLI is not symmetric, and scoring in the premise direction (which response best entails the others) carries the signal that drives selection quality.

A reliability filter partitions evaluation instances by the oracle's score margin |Δ|, isolating the 74.3% of cases where the NLI signal is discriminative enough to trust:

Reliability tiers

Oracle score margin distribution. Instances with |Δ| ≥ 0.5 (the Reliable tier, green) carry a clear NLI signal; below that, the signal is too weak to distinguish correct from incorrect.


Repository Structure

.
├── report.pdf                 # Full CS9860 report
├── requirements.txt
├── assets/                    # Figures for this README
├── src/                       # Core pipeline
│   ├── config.py              # Central configuration
│   ├── download_all.py        # Model download
│   ├── generator.py           # LLM response generation
│   ├── generate_dataset.py    # Stage 1: generate
│   ├── nli_scorer.py          # Stage 2: pairwise NLI matrix
│   ├── selector.py            # Stage 3: Baseline / Clustering / DLG
│   ├── evaluate.py            # Stage 4: evaluate stochastic
│   └── evaluate_greedy.py     # Stage 4: evaluate greedy baseline
├── analysis/                  # Reproduce paper tables & figures
│   ├── analyze_final_result.py    # Main results (Table 4)
│   ├── analyze_thresholds.py      # NLI distributions (Figures 1–3)
│   ├── analyze_win_rate.py        # Win/Tie/Loss (Table 5)
│   ├── analyze_question_level.py  # Question-level (Table 6)
│   ├── analyze_comparison.py      # Method comparison (§6.1)
│   ├── analyze_bart.py            # Cross-model validation (Appendix C)
│   └── analyze_fallback_ratio.py  # Threshold selection support
└── results/                   # Pre-computed data (~25 MB)
    ├── temp1.0/               # Stochastic: responses, NLI matrices, eval
    └── temp0.1/               # Greedy baseline

Reproducing the Results

The repository ships with pre-computed data, so the headline numbers can be reproduced without a GPU:

pip install -r requirements.txt
python analysis/analyze_final_result.py

This reads results/temp1.0/evaluated/eval_final.json and reproduces the main results table.

On torch: requirements.txt pins torch>=2.5 without a CUDA build tag, so the right wheel for your environment is installed. The analysis scripts run fine on CPU. The full generation pipeline below requires a CUDA GPU — if torch.cuda.is_available() returns False, install a CUDA build from pytorch.org matching your driver.

Running the full pipeline from scratch

Requires a CUDA GPU (tested on RTX 4080 Super, 16 GB).

# 1. Download models (Llama 3.1 8B + DeBERTa-v3 NLI)
python src/download_all.py

# 2. Generate responses (TruthfulQA downloads automatically)
python src/generate_dataset.py

# 3. Compute pairwise NLI matrices
python src/nli_scorer.py

# 4. Evaluate
python src/evaluate_greedy.py
python src/evaluate.py

Key parameters live in src/config.py:

Parameter Default Description
N_SAMPLES 5 Responses per query per run
ROUND 10 Number of runs
TEMPERATURE 1.0 Sampling temperature
CLUSTER_AVG_THRESHOLD 0.45 Clustering threshold (τ_c)
DLG_THRESHOLD 0.8 DLG threshold (τ_d)
DIFF_THRESHOLD 0.5 Reliability filter threshold

Models

Role Model
Generation meta-llama/Llama-3.1-8B-Instruct (4-bit NF4)
NLI scoring cross-encoder/nli-deberta-v3-large
Cross-validation facebook/bart-large-mnli

Notes

AI assistance was used for coding support (generation and debugging) during this project. All experimental design, analysis, and interpretation are the author's own.

About

Premise-direction NLI consensus for response selection in LLMs

Topics

Resources

Stars

Watchers

Forks

Contributors

Languages