Bedingte Logik als Array-Operationen – where#

Die Funktion numpy.where ist eine vektorisierte Version von if und else.

Im folgenden Beispiel erzeugen wir zunächst ein boolesches Array und zwei Arrays mit Werten:

[1]:
import numpy as np
[2]:
cond = ([False,  True, False,  True, False, False, False])
data1 = np.random.randn(1, 7)
data2 = np.random.randn(1, 7)

Nun wollen wir Nehmen wir die Werte aus data1 übernehmen, wenn der entsprechende Wert in cond True ist und ansonsten den Wert aus data2 übernommen wird. Mit Pythons if-else könnte das wie folgt aussehen:

[3]:
result = [(x if c else y) for x, y, c in zip(data1, data2, cond)]

result
[3]:
[array([-0.37431791,  0.11490952, -0.24917534,  0.35700256, -0.3293716 ,
        -1.51677151,  0.351892  ])]

Dies hat jedoch die folgenden beiden Probleme:

  • bei großen Arrays wird die Funktion nicht sehr schnell sein

  • dies funktioniert nicht mit mehrdimensionalen Arrays

Mit np.where könnt ihr diese Probleme in einem einzigen Funktionsaufruf umgehen:

[4]:
result = np.where(cond, data1, data2)

result
[4]:
array([[-0.37431791,  0.90681988, -0.24917534,  0.0425698 , -0.3293716 ,
        -1.51677151,  0.351892  ]])

Das zweite und dritte Argument von np.where müssen keine Arrays sein; eines oder beide können auch Skalare sein. Eine typische Anwendung von where in der Datenanalyse besteht darin, ein neues Array von Werten auf der Grundlage eines anderen Arrays zu erzeugen. Angenommen, ihr habt eine Matrix mit zufällig generierten Daten und möchtet alle negativen Werte zu positiven Werten machen:

[5]:
data = np.random.randn(4, 4)

data
[5]:
array([[-1.52714845, -0.17217264,  0.48149727,  0.18465047],
       [ 0.02691677, -0.39642089, -1.54266224,  1.40343846],
       [-0.1541781 , -1.94429536, -1.55113023, -1.27231227],
       [ 0.44520634,  1.17590632,  1.30634966, -1.8479735 ]])
[6]:
data < 0
[6]:
array([[ True,  True, False, False],
       [False,  True,  True, False],
       [ True,  True,  True,  True],
       [False, False, False,  True]])
[7]:
np.where(data < 0, data * -1, data)
[7]:
array([[1.52714845, 0.17217264, 0.48149727, 0.18465047],
       [0.02691677, 0.39642089, 1.54266224, 1.40343846],
       [0.1541781 , 1.94429536, 1.55113023, 1.27231227],
       [0.44520634, 1.17590632, 1.30634966, 1.8479735 ]])