Skip to content

2:4 sparsity with to_sparse_semi_structured method from pytorch results in memory issue #28

Description

@Ahmed-Roushdy

I am trying to reduce the memory footprint of the 2:4 sparsegpt pruned LLaMA2 model using to_sparse_semi_structured method from PyTorch. However, when I apply this to modify the way the sparse parameters are stored, I got out of memory. Please note that I did not get out of memory for the original dense model.
Below is the code I was running, where model_path is the path to the pruned model.

from torch.sparse import to_sparse_semi_structured, SparseSemiStructuredTensor
model = AutoModelForCausalLM.from_pretrained(model_path)
model = model.to(device).half()

for fqn, module in model.named_modules():
    # print(fqn)
    if isinstance(module, nn.Linear):
        module.weight = nn.Parameter(to_sparse_semi_structured(module.weight))

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type
    No fields configured for issues without a type.

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions