trailofbits.python.numpy-in-pytorch-modules.numpy-in-pytorch-modules

profile photo of trailofbitstrailofbits
Author
unknown
Download Count*

Usage of NumPy library inside PyTorch $MODULE module was found. Avoid mixing these libraries for efficiency and proper ONNX loading

Run Locally

Run in CI

Defintion

rules:
  - id: numpy-in-pytorch-modules
    message: Usage of NumPy library inside PyTorch `$MODULE` module was found. Avoid
      mixing these libraries for efficiency and proper ONNX loading
    languages:
      - python
    severity: WARNING
    metadata:
      category: performance
      subcategory:
        - audit
      confidence: MEDIUM
      technology:
        - pytorch
        - numpy
      description: Uses of `NumPy` functions inside `PyTorch` modules
      references:
        - https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects
      license: AGPL-3.0 license
    patterns:
      - pattern-either:
          - pattern: numpy.$FN(...)
          - pattern: numpy. ... .$FN(...)
      - pattern-inside: |
          class $MODULE(torch.nn.Module):
              ...

Examples

numpy-in-pytorch-modules.py

import torch.nn as nn
import numpy as np

class MyModule(nn.Module):
    def __init__(self):
        self.dropout = nn.Dropout(0.5)

    def forward(self, x, y):
        x = self.dropout(x)
        # ruleid: numpy-in-pytorch-modules
        y = np.concatenate((x, y), axis=1)

        # ruleid: numpy-in-pytorch-modules
        np.ndarray.sort(y)

    def forward_correct(self, x, y):
        x = self.dropout(x)
        # ok: numpy-in-pytorch-modules
        y = nn.cat((x, y), 1)