trailofbits.python.pickles-in-pytorch.pickles-in-pytorch

profile photo of trailofbitstrailofbits
Author
unknown
Download Count*

Functions reliant on pickle can result in arbitrary code execution. Consider loading from state_dict, using fickling, or switching to a safer serialization method like ONNX

Run Locally

Run in CI

Defintion

rules:
  - id: pickles-in-pytorch
    message: Functions reliant on pickle can result in arbitrary code
      execution.  Consider loading from `state_dict`, using fickling, or
      switching to a safer serialization method like ONNX
    languages:
      - python
    severity: ERROR
    metadata:
      category: security
      cwe: "CWE-502: Deserialization of Untrusted Data"
      subcategory:
        - vuln
      confidence: MEDIUM
      likelihood: MEDIUM
      impact: HIGH
      technology:
        - pytorch
      description: Potential arbitrary code execution from `PyTorch` functions reliant
        on pickling
      references:
        - https://blog.trailofbits.com/2021/03/15/never-a-dill-moment-exploiting-machine-learning-pickle-files/
      license: AGPL-3.0 license
      vulnerability_class:
        - "Insecure Deserialization "
    patterns:
      - pattern-either:
          - pattern: torch.save(...)
          - pattern: torch.load(...)
      - pattern-not: torch.load("...")
      - pattern-not: torch.save(..., "...")
      - pattern-not: torch.save($M.state_dict(), ...)
      - pattern-not-inside: $M.load_state_dict(...)
      - pattern-not:
          patterns:
            - pattern: torch.save($STATE_DICT, ...)
            - pattern-inside: |
                $STATE_DICT = $M.state_dict()
                ...

Examples

pickles-in-pytorch.py

from torch import nn, optim
import torch.nn.functional as F
import torch

PATH = "x"

# ok: pickles-in-pytorch
model = torch.load(PATH)

# ok: pickles-in-pytorch
torch.save(model, PATH)

# ok: pickles-in-pytorch
torch.save(model.state_dict(), PATH)

# ok: pickles-in-pytorch
model.load_state_dict(torch.load(PATH))


def test(arg):
    # ruleid: pickles-in-pytorch
    model = torch.load(arg)

    # ruleid: pickles-in-pytorch
    torch.save(model, arg)

    # ok: pickles-in-pytorch
    torch.save(model.state_dict(), arg)

    # ok: pickles-in-pytorch
    model.load_state_dict(torch.load(arg))
    
    state_dict = model.state_dict()
    # ok: pickles-in-pytorch
    torch.save(state_dict, arg)