trailofbits.python.pickles-in-pytorch-distributed.pickles-in-pytorch-distributed

profile photo of trailofbitstrailofbits
Author
unknown
Download Count*

Functions reliant on pickle can result in arbitrary code execution

Run Locally

Run in CI

Defintion

rules:
  - id: pickles-in-pytorch-distributed
    message: Functions reliant on pickle can result in arbitrary code execution
    languages:
      - python
    severity: ERROR
    metadata:
      category: security
      cwe: "CWE-502: Deserialization of Untrusted Data"
      subcategory:
        - vuln
      confidence: MEDIUM
      likelihood: MEDIUM
      impact: HIGH
      technology:
        - pytorch
      description: Potential arbitrary code execution from `PyTorch.Distributed`
        functions reliant on pickling
      references:
        - https://blog.trailofbits.com/2021/03/15/never-a-dill-moment-exploiting-machine-learning-pickle-files/
      license: AGPL-3.0 license
      vulnerability_class:
        - "Insecure Deserialization "
    patterns:
      - pattern-either:
          - pattern: torch.distributed.broadcast_object_list(...)
          - pattern: torch.distributed.all_gather_object(...)
          - pattern: torch.distributed.gather_object(...)
          - pattern: torch.distributed.scatter_object_list(...)

Examples

pickles-in-pytorch-distributed.py

import torch.distributed as dist

if dist.get_rank() == 0:
    objects = ["f", 1] 
else:
    objects = [None, None]
    
# ruleid: pickles-in-pytorch-distributed
dist.broadcast_object_list(objects, src=0)

# ruleid: pickles-in-pytorch-distributed
dist.all_gather_object(output, gather_objects[dist.get_rank()])

# ruleid: pickles-in-pytorch-distributed
dist.gather_object(
        gather_objects[dist.get_rank()],
        output if dist.get_rank() == 0 else None,
        dst=0
    )

# ruleid: pickles-in-pytorch-distributed
dist.scatter_object_list(output_list, objects, src=0)

# ok: pickles-in-pytorch-distributed
dist.scatter(output_list, objects, src=0)