Why your LoRA is not memory efficient?
LoRA (Low-Rank Adaptation) is a technique used in finetuning to reduce the memory footprint of models by approximating weight matrices with low-rank representations. However, there are some subtle reasons in your LoRA implementation that might lead to increased computational overhead and memory usage. In this blog, we will explore these reasons and provide insights on how to optimize your LoRA implementation for better memory efficiency.
Background on LoRA
LoRA works by decomposing the weight matrices of a neural network into low-rank components. Assume the pre-trained weight matrix is denoted as $W$ with shape of $(o,i)$, where $o$ is the output dimension and $i$ is the input dimension. LoRA approximates this weight matrix as follows: \(W' = W + \Delta W\) where $\Delta W$ is the low-rank adaptation term, which can be expressed as:
\[\Delta W = B \cdot A\]Here, $A$ is a matrix of shape $(r,i)$ and $B$ is a matrix of shape $(o,r)$, where $r$ is the rank of the adaptation. Usually, $r$ is much smaller than both $o$ and $i$ thus $o \cdot r + r \cdot i$ is much smaller than $o \cdot i$, which leads to reduced number of parameters.
During fine-tuning, we zero initialize $B$ and randomly initialize $A$, giving a global zero initialization for $\Delta W$. This allows the model to start with the pre-trained weights and gradually adapt to the new task without significantly altering the original weights. After fine-tuning, we can merge the low-rank adaptation term $\Delta W$ back into the original weight matrix $W$ to obtain the adapted weights $W’$. This merging process can be done by simply adding $\Delta W$ to $W$. This allows us inference without additional computational overhead.
Below, we can give a simple implementation of LoRA in PyTorch:
from torch.nn import Module
from torch import nn
from typing import Optional
import math,torch
import torch.nn.functional as F
class LoRALinear(Module):
"""
Base class of LoRA fine-tuning layers.
Args:
rank (int): The rank of the LoRALayer. Must be greater than 0.
lora_alpha (Optional[int]): The alpha value for LoRA. Defaults to None.
In the LoRA paper, the alpha is used to rescale the LoRA weight :math:`BA` as math:`\frac{\alpha}{r}BA`.
If None, the alpha value is set to the rank. So the final scaling factor is always 1, which is the same as the LoRA paper.
"""
def __init__(
self,
parent_module: nn.Module,
rank: int,
lora_alpha: Optional[int] = None,
train_bias: bool = False
):
super().__init__()
if rank < 0:
raise ValueError('Rank must be greater than 0')
self.rank = rank
if lora_alpha is not None:
self.lora_alpha = lora_alpha
else:
self.lora_alpha = rank
self.merged = False
self.scaling = self.lora_alpha / self.rank
self.parent_module = parent_module
# Actual trainable parameters
self.lora_B = nn.Parameter(self.parent_module.weight.new_zeros((parent_module.out_features, rank)))
self.lora_A = nn.Parameter(self.parent_module.weight.new_zeros((rank, parent_module.in_features)))
self._lora_B_initialization(self.lora_B)
self._lora_A_initialization(self.lora_A)
self.parent_module.weight.requires_grad = False
if not train_bias and self.parent_module.bias is not None:
self.parent_module.bias.requires_grad = False
def _lora_B_initialization(self,B_para):
nn.init.zeros_(B_para)
def _lora_A_initialization(self,A_para):
nn.init.kaiming_uniform_(A_para, a=math.sqrt(5))
def lora_weight(self):
return self.lora_B @ self.lora_A * self.scaling
def merge_weight(self):
if self.merged:
raise ValueError('The weight is already merged')
self.parent_module.weight.data += self.lora_weight()
self.merged = True
def unmerge_weight(self):
if not self.merged:
raise ValueError('The weight is already unmerged')
self.parent_module.weight.data -= self.lora_weight()
self.merged = False
@property
def weight(self):
if self.merged:
return self.parent_module.weight
else:
return self.parent_module.weight + self.lora_weight()
@property
def bias(self):
return self.parent_module.bias
def forward(self, x: torch.Tensor):
if self.training:
if self.merged:
self.unmerge_weight()
else:
return F.linear(x,
self.parent_module.weight+self.lora_weight(),
self.parent_module.bias)
else:
if not self.merged:
self.merge_weight()
return self.parent_module(x)
Evaluating the Memory Efficiency of the LoRA Implementation
Now, we can evaluate the LoRA implementation we have above. First, let’s check the trainable parameters:
from torch.nn import Linear
linear=Linear(128, 256)
linear.train()
linear_lora=LoRALinear(Linear(128, 256), rank=4)
linear_lora.train()
def num_trainable_parameters(model):
return sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f'Number of trainable parameters in the original linear layer: {num_trainable_parameters(linear)}')
print(f'Number of trainable parameters in the LoRA linear layer: {num_trainable_parameters(linear_lora)}')
Number of trainable parameters in the original linear layer: 33024
Number of trainable parameters in the LoRA linear layer: 1536
It is clear that the number of trainable parameters in the LoRA linear layer is significantly less than that of the original linear layer. Then, what is the effect of less trainable parameters on the performance? First, the forward pass of the network will be slower than the original linear layer, because we need to compute the LoRA weight and add it to the original weight. However, the backward pass should be faster than the original linear layer, because we only need to compute the gradients for the LoRA parameters, which are much less than the original parameters. Moreover, the memory usage during training should also be reduced, because we only need to store the gradients for the LoRA parameters, which are much less than the original parameters. Let’s check the FLOPs and peak memory usage of the network now:
from torch.profiler import ProfilerActivity, profile
import torch
def measure_flops(function, *args, **kwargs):
activities = [ProfilerActivity.CPU,ProfilerActivity.CUDA]
with profile(activities=activities, record_shapes=False, with_flops=True) as prof:
_ = function(*args, **kwargs)
flops = sum(evt.flops for evt in prof.key_averages() if evt.flops is not None)
return flops
def measure_forward_flops(model,input_shape,device='cuda:0'):
model=model.to(device)
inputs=torch.randn(input_shape,device=device)
return measure_flops(model,inputs)
def measure_backward_flops(model,input_shape,device='cuda:0'):
model=model.to(device)
inputs=torch.randn(input_shape,device=device,requires_grad=False)
outputs=model(inputs)
loss=((outputs)**2).mean()
return measure_flops(loss.backward)
def measure_vram(function, device_index=0, *args, **kwargs):
torch.cuda.empty_cache()
torch.cuda.synchronize(device_index)
peak_bytes_before = torch.cuda.memory_allocated(device_index)
torch.cuda.reset_peak_memory_stats(device_index)
_ = function(*args, **kwargs)
torch.cuda.synchronize(device_index)
peak_bytes = torch.cuda.max_memory_allocated(device_index)
return (peak_bytes - peak_bytes_before) / 1024 # Convert bytes to kilobytes
def measure_forward_vram(model,input_shape,device_index=0):
model=model.to(f"cuda:{device_index}")
inputs=torch.randn(input_shape,device=f"cuda:{device_index}")
return measure_vram(model,device_index,inputs)
def measure_backward_vram(model,input_shape,device_index=0):
model=model.to(f"cuda:{device_index}")
inputs=torch.randn(input_shape,device=f"cuda:{device_index}",requires_grad=True)
outputs=model(inputs)
loss=((outputs)**2).mean()
return measure_vram(loss.backward,device_index=device_index)
linear_performance = [
measure_forward_flops(linear, (1, 128)),
measure_backward_flops(linear, (1, 128)),
measure_forward_vram(linear, (1, 128)),
measure_backward_vram(linear, (1, 128)),
]
lora_linear_performance = [
measure_forward_flops(linear_lora, (1, 128)),
measure_backward_flops(linear_lora, (1, 128)),
measure_forward_vram(linear_lora, (1, 128)),
measure_backward_vram(linear_lora, (1, 128)),
]
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
improvements = [f"{(lora-linear) / linear *100:.2f}%" for lora, linear in zip(lora_linear_performance, linear_performance)]
print("network forward_flops backward_flops forward_vram backward_vram")
print("linear", *linear_performance)
print("lora_linear", *lora_linear_performance)
print("improvements", *improvements)
network forward_flops backward_flops forward_vram backward_vram
linear 65536 66048 1025.0 131.0
lora_linear 393216 623104 1153.0 130.0
improvements 500.00% 843.41% 12.49% -0.76%
Surprisingly, both the forward and backward the FLOPs of the LoRA linear layer is much higher than that of the original linear layer. Meanwhile, although the backward pass of the layer requires slightly less peak memory than the original linear layer, the forward pass requires more peak memory. At a word, it seems that the LoRA implemenation is not efficient compared to the standard linear layer at all. Why?
Before giving the answer, we can first show the correct implementation of LoRA in PyTorch:
class LoRALinearFixed(LoRALinear):
def forward(self, x: torch.Tensor):
if self.training:
if self.merged:
self.unmerge_weight()
# Previous code: F.linear(x,self.parent_module.weight+self.lora_weight(),self.parent_module.bias)
return (
self.parent_module(x)
+ (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1))
* self.scaling
)
else:
if not self.merged:
self.merge_weight()
return self.parent_module(x)
fixed_linear_lora = LoRALinearFixed(Linear(128, 256), rank=4)
fixed_linear_lora.train()
fixed_linear_lora_performance = [
measure_forward_flops(fixed_linear_lora, (1, 128)),
measure_backward_flops(fixed_linear_lora, (1, 128)),
measure_forward_vram(fixed_linear_lora, (1, 128)),
measure_backward_vram(fixed_linear_lora, (1, 128)),
]
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2026-02-26 12:21:33 962224:962224 ActivityProfilerController.cpp:324] Completed Stage: Post Processing
fixed_improvements = [f"{(fixed-linear) / linear *100:.2f}%" for fixed, linear in zip(fixed_linear_lora_performance, linear_performance)]
print("network forward_flops backward_flops forward_vram backward_vram")
print("linear", *linear_performance)
print("lora_linear", *lora_linear_performance)
print("lora_linear_fixed", *fixed_linear_lora_performance)
print("fixed_improvements", *fixed_improvements)
network forward_flops backward_flops forward_vram backward_vram
linear 65536 66048 1025.0 131.0
lora_linear 393216 623104 1153.0 130.0
lora_linear_fixed 69120 5888 1025.0 7.0
fixed_improvements 5.47% -91.09% 0.00% -94.66%
With the correct implementation, we can see that only the forward pass of the LoRA linear layer has higher FLOPs than the original linear layer, while the backward pass has much lower FLOPs than the original linear layer. And the peak memory usage stays the same as non-LoRA linear layer while the backward pass of the LoRA linear layer requires much less peak memory than the original linear layer.
Compared to the previous code, we only change
F.linear(x,self.parent_module.weight+self.lora_weight(),self.parent_module.bias)
to
self.parent_module(x) + (x @ self.lora_A.transpose(0, 1) @ self.lora_B.transpose(0, 1)) * self.scaling
Mathematically, they differences are: \(\text{Previous code: } y = x \cdot (W + B\cdot A) + b)\) \(\text{Correct code: } y = x \cdot W + b + (x \cdot A^T \cdot B^T)\)
Although mathematically they are equivalent, the previous code will lead to much higher FLOPs and peak memory usage than the correct code. The reason is that in the previous code, we construct a full-size weight matrix $B\cdot A$ and thus it will add a computation graph of the full-size weight matrix, i.e., “A @ B → delta → weight → linear”. While in the correct code, the full-size weight matrix is never constructed and thus it will only add a computation graph of the intermediate results, i.e., “x @ A^T → xA → xA @ B^T → xAB”. Therefore, the previous code will lead to much higher FLOPs and peak memory usage than the correct code.
Therefore, when implementing LoRA, we should always be careful about the computational graph and try to avoid constructing full-size weight matrices, which can lead to much higher FLOPs and peak memory usage than the original linear layer.
Enjoy Reading This Article?
Here are some more articles you might like to read next: