trailofbits.python.numpy-in-pytorch-modules.numpy-in-pytorch-modules

Author
unknown
Download Count*
License
Usage of NumPy library inside PyTorch $MODULE
module was found. Avoid mixing these libraries for efficiency and proper ONNX loading
Run Locally
Run in CI
Defintion
rules:
- id: numpy-in-pytorch-modules
message: Usage of NumPy library inside PyTorch `$MODULE` module was found. Avoid
mixing these libraries for efficiency and proper ONNX loading
languages:
- python
severity: WARNING
metadata:
category: performance
subcategory:
- audit
confidence: MEDIUM
technology:
- pytorch
- numpy
description: Uses of `NumPy` functions inside `PyTorch` modules
references:
- https://tanelp.github.io/posts/a-bug-that-plagues-thousands-of-open-source-ml-projects
license: CC-BY-NC-SA-4.0
patterns:
- pattern: $RESULT = numpy.$FUNCTION(...)
- pattern-inside: |
class $MODULE(torch.nn.Module):
...
Examples
numpy-in-pytorch-modules.py
import torch.nn as nn
import numpy as np
class MyModule(nn.Module):
def __init__(self):
self.dropout = nn.Dropout(0.5)
def forward(self, x, y):
x = self.dropout(x)
# ruleid: numpy-in-pytorch-modules
y = np.concatenate((x, y), axis=1)
def forward_correct(self, x, y):
x = self.dropout(x)
# ok: numpy-in-pytorch-modules
y = nn.cat((x, y), 1)
Short Link: https://sg.run/9vxr