Reasoning Joint Embedding Predictive Architecture
An experimental World Model for Text Reasoning that operates on latent representations instead of tokens. Adapts Meta AI's V-JEPA (video prediction) approach to textual reasoning sequences.
R-JEPA operates on LLM hidden states to understand and improve reasoning chains
Extract hidden states from LLM reasoning steps (layer -2)
Mask reasoning steps and predict their latent representations
Learn stable relationships between concepts (like physics of thought)
Use learned world model to correct, complete, or re-rank reasoning
Operates on LLM hidden states (4096-dim vectors from layer -2), not tokens. Uses contiguous masking (30-70% of steps) to predict reasoning step representations.
Tested with Qwen3, Llama3, Mistral, DeepSeek, Phi families. Includes projection adapters (W_in/W_out) for calibrating to new LLMs without full retraining.
Trains on chains verified by automatic validators: sympy for math, sandboxed execution for code. Uses is_valid=True filtering - only correct reasoning paths.
Includes continuous learning module: user interaction logging with PII filtering, feedback validation, and scheduled retraining support via Prefect flows.
Includes loaders for GSM8K, MATH, HumanEval, MMLU (57 subjects), Big-Bench Hard (23 tasks), ARC, and HellaSwag. Correlation analysis between JEPA-loss and correctness.
7 Docker services via docker-compose: student-llm, rjepa-service, teacher-orch, data-pipeline, prefect-server, ui-backend, ui-frontend.
Full open-source flexibility or managed cloud service with advanced features. Your data stays private either way.
With R-JEPA Cloud, your text data never leaves your infrastructure. Here's how it works:
1. You extract latent vectors locally from your LLM (4096-dim abstract representations)
2. Only these latent vectors are sent to our API — they act as pointers
3. Your text stays in your local database, indexed by these latent references
4. We return memory indices + guidance — you reconstruct context locally
Latent vectors are mathematically irreversible — impossible to reconstruct original text.
Compatible with GDPR, HIPAA, and enterprise security requirements.
Latents act as pointers to your local database. Store experiences, retrieve by semantic similarity. Your text stays local — we only see abstract vectors. Cross-session context without privacy compromise.
LLM adapts to R-JEPA guidance — never the reverse. R-JEPA learns from validated reasoning only. Prevents hallucination contamination through strict unidirectional flow.
Cognitive firewall: Only mathematically or logically verified CoTs enter memory. Sympy for math, sandboxed execution for code. Prevents learning from incorrect reasoning paths.
Dual-encoder architecture with EMA target encoder, inspired by V-JEPA
flowchart TB
subgraph INPUT["📥 Input: LLM Latents"]
H["H [B, S, 4096]
from Qwen3-8B layer -2"]
end
subgraph MASKING["🎭 Masking Strategy"]
M["Contiguous Masking
30-70% of steps"]
end
subgraph CONTEXT["🟢 Context Encoder"]
CE["StepTransformer
6 layers, 2048-dim, 16 heads
TRAINABLE"]
end
subgraph TARGET["🟡 Target Encoder"]
TE["StepTransformer
EMA copy (τ = 0.996→0.9999)
FROZEN"]
end
subgraph PREDICTOR["🔵 Predictor"]
P["StepPredictor
4 layers
TRAINABLE"]
end
subgraph OUTPUTS["📤 Outputs"]
Z_PRED["z_pred
Predicted Latents"]
Z_TARGET["z_target
Ground Truth"]
end
subgraph LOSS["📉 Loss Function"]
L["Loss = L1(z_pred, z_target)
+ variance_reg (0.01)
+ contrastive (0.1)"]
end
H --> M
M -->|visible steps| CE
M -->|all steps| TE
CE --> P
P --> Z_PRED
TE --> Z_TARGET
Z_PRED --> L
Z_TARGET --> L
style INPUT fill:#1e293b,stroke:#6366f1,stroke-width:2px,color:#e2e8f0
style MASKING fill:#1e293b,stroke:#22d3ee,stroke-width:2px,color:#e2e8f0
style CONTEXT fill:#134e4a,stroke:#22c55e,stroke-width:3px,color:#e2e8f0
style TARGET fill:#422006,stroke:#f59e0b,stroke-width:3px,color:#e2e8f0
style PREDICTOR fill:#1e1b4b,stroke:#6366f1,stroke-width:3px,color:#e2e8f0
style OUTPUTS fill:#1e293b,stroke:#0ea5e9,stroke-width:2px,color:#e2e8f0
style LOSS fill:#4c1d95,stroke:#a78bfa,stroke-width:2px,color:#e2e8f0
Three ways R-JEPA can help improve reasoning. Read the full explanation →
Generate K=4 reasoning chain candidates, compute JEPA-loss for each, select the lowest-loss candidate. Optionally combines with logprob scoring: score = α×logprob + β×(-JEPA_loss).
Token-by-token guidance using predicted latents. Projects predicted h_next to vocabulary space via learned MLP. Formula: logits_final = logits_llm + α×guidance_bias. Requires calibration step.
Predict latents for missing reasoning steps based on context. Decodes predicted vectors to text via separate decoder module or LLM prompting. Useful for gap completion.
Architecture specifications and training configuration
Context Encoder: 6-layer Transformer, 2048-dim, 16 heads
Predictor: 4-layer Transformer
Target Encoder: EMA copy (momentum 0.996→0.9999)
Input: 4096-dim latents from LLM layer -2
Primary: L1 reconstruction on masked steps
Regularization: Variance loss (prevents collapse)
Optional: InfoNCE contrastive (weight 0.1)
Temperature: 0.07 for contrastive
Metadata: Parquet (zstd compression)
Tensors: SafeTensors format
Indexing: DuckDB for SQL queries
Training data: GSM8K, MATH, HumanEval
Optimizer: AdamW, lr=3e-4
Precision: bfloat16 AMP
Gradient clip: 1.0
Masking: Contiguous 30-70% of steps
Based on published work in self-supervised learning
Self-supervised method that learns representations by predicting masked regions in latent space rather than pixel space.
Read Paper →Vision for AI systems that learn world models through self-supervised prediction in latent space - the theoretical foundation for JEPA architectures.
Read Paper →