blinx is a lightweight, high-performance JAX/Flax library designed for Batched LoRA (Low-Rank Adaptation). It enables serving multiple independent, fine-tuned adapters simultaneously within the same batch during inference with minimal overhead, making it ideal for high-throughput Multi-LoRA model serving.
- Dynamic Adapter Batching: Slice and map different LoRA adapters to different sequences in the same input batch.
- Universal Flax Wrapper: Wrap around any linear operations like
nn.Denseor customEinsumsubmodules, or use it directly on raw parameters. - Support for Scanned Layers: Out-of-the-box support for models with scanned/stacked parameters along a scan dimension (e.g., layers grouped via
jax.lax.scan). - Zero-overhead Inference: Uses JAX's native
dot_generalbatching dimensions to perform low-rank operations efficiently. - Simple Disk Serialization: Serializes adapter checkpoints and configurations using
orbaxandflax.serialization.
blinx is inspired by the design of lorax by Predibase, bringing flexible, multi-adapter inference capabilities to JAX-based architectures.
Adding LoRA capabilities to your JAX model using blinx consists of three main steps:
- Configure rules using
LoraRuleandLoraConfig. - Initialize the adapter parameters.
- Wrap your model layers using
BatchedLora.
To adapt a model, you specify rules mapping to layers using regex patterns on their Flax module path:
from blinx import LoraConfig, LoraConfigManager, LoraRule
# Configure rules matching target layers in your architecture
lora_config = LoraConfig(
model_name="my_model",
lora_rank=16,
alpha=16,
rules=[
# Rule for standard MLP Block layers
LoraRule(
pattern=".*MlpBlock.*",
kernel_name="kernel",
in_posn=[1],
out_posn=[2],
scan_posn=0
),
# Rule for custom Einsum attention projections
LoraRule(
pattern=".*q_einsum.*",
kernel_name="w",
in_posn=[2],
out_posn=[1, 3],
scan_posn=0
),
]
)
# Register configuration globally
LoraConfigManager.configs.append(lora_config)Using the existing base model parameter tree, initialize matching LoRA parameters:
from blinx import init_lora_params
# Initialize matching LoRA weights (will return a dictionary tree of weights)
# conforming to your LoraConfig rules
lora_params = init_lora_params(base_params, lora_config)Note
init_lora_params creates a single adapter's weight tree. To serve multiple adapters in a batch, you can stack the individual adapter dictionaries along a batch axis using jax.tree.map(lambda *x: jnp.stack(x), ...) to produce a batched lora parameters tree of shape (batch_size, ...).
There are two primary ways to apply BatchedLora within a Flax module:
You can directly wrap any nn.Module subclass (e.g., standard nn.Dense or custom Einsum layers). BatchedLora automatically intercepts the module calls, retrieves the relevant LoRA weights, computes the adapter output, and adds it:
from blinx import BatchedLora
import flax.linen as nn
class MyAttention(nn.Module):
num_heads: int
head_dim: int
@nn.compact
def __call__(self, x, lora):
# Wrap the projection submodule.
# x_arg=1 indicates that 'x' is the 1st positional argument passed to q_einsum
q = BatchedLora(self.q_einsum, x_arg=1)("BTD,NDH->BTNH", x, lora=lora)
return qFor layers where weights are defined directly as parameters using self.param (rather than submodules), you can query rules and apply the adapter delta manually:
from blinx import BatchedLora
import flax.linen as nn
import jax.numpy as jnp
class CustomMLP(nn.Module):
features: int
hidden_dim: int
@nn.compact
def __call__(self, x, lora):
w_gating = self.param("gating_einsum", init_fn, (2, self.features, self.hidden_dim))
# Instantiate BatchedLora on self
bl = BatchedLora(self)
# Query rules and retrieve lora parameter matrices
lora_rule = bl.find_lora_rule(suffix="/gating_einsum")
lora_a, lora_b = bl.find_lora_params(lora, lora_rule)
# Apply the adapter delta manually and combine with the base layer output
ff_gate = jnp.dot(x, w_gating[0]) + bl.compute_lora_delta(
lora_rule, x, lora_a, lora_b.transpose(2, 0, 1, 3)[0]
)
return nn.gelu(ff_gate)blinx provides built-in utilities using orbax to serialize adapters to disk:
from blinx import save_lora_adapter, load_lora_adapter
# Save adapter weights and configuration
save_lora_adapter("path/to/adapter_directory", lora_params, lora_config)
# Restore adapter weights and configuration
loaded_params, loaded_config = load_lora_adapter("path/to/adapter_directory")| Model | Repository | Description |
|---|---|---|
| PaliGemma / PaliGemma 2 | google/big_vision | Verified compatibility with Vision-Language Models (VLM) using JAX/Flax implementation. |
This project is licensed under the MIT License. See below for details.
Copyright (c) 2026 AlphaDeep Inc.
Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including its software documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions:
The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software.
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE.
