Engineering python pytorch 機械学習

Pytorch入門:torch.whereの使い方と条件に応じた要素の選択方法を徹底解説

pytorchで少し細かい実装をしていると、torch.whereをよく使います。torch.whereを使うと、特定の条件に当てはまる値を変更したりでき非常に便利なのですが、引数の与え方などによって、挙動が変わってきます。ここの挙動を正しく理解しないと、実装ミスにもつながるかなと思いましたので、今回はパターン別にtorch.whereの動きを整理してみました。

torch.whereとは?

公式docs (https://pytorch.org/docs/stable/generated/torch.where.html)によると、以下のように書かれています。基本的には条件に合致した場合、inputを返し、それ以外はotherの値もしくはテンソルを返す関数になります。

- torch.where(condition, input, other, *, out=None) → Tensor
Return a tensor of elements selected from either input or other, depending on condition.
(条件に応じてinputまたはotherの選択された要素のtensorを返す)

引数

torch.whereの引数は、以下のように定義されています。

引数説明
conditionBoolTensorTrueの時inputを出力し、Falseの時otherを出力する
inputTensor or ScalarTrueの時の値またはTensor
otherTensor or ScalarFalseの時の値またはTensor
torch.whereの引数

実際の動き

それではtorch.whereの実際の動きを見てみましょう。

conditionのみ

conditionのみを与えると、TrueになったTorchのインデックスが返ってきます。

# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
# only condition
In [2]: b = torch.where(a > 1)
In [3]: b
Out[3]: (tensor([0, 0, 1, 1, 1, 1]), tensor([2, 3, 0, 1, 2, 3]))

上記のように(2, 4)のテンソルを入力にすると、tupleで2つテンソルが返ってきます。それぞれのテンソルが、各軸のインデックスに対応していて、それぞれ順番に取り出してみると、

  • [0, 2] -> 2
  • [0. 3] -> 3
  • [1, 0] -> 4
  • [1, 1] -> 5
  • [1, 2] -> 6
  • [1, 3] -> 7

とすべて1より大きい値のインデックスが返っていることがわかります。ここで帰ってくるインデックスのサイズは、入力で使用されるテンソルのサイズに合わせて変わります。

スカラをinput/otherに入れる

inputotherにスカラを入れると、そこで渡した値に入れ替わったテンソルが返ってきます。

# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
# use scalar
In [2]: b = torch.where(a > 1, 10, 0)
In [3]: b
Out[3]: tensor([[0, 0, 10, 10], 
               [10, 10, 10, 10]])

上記のように、1よりも大きいインデックスの値が10, それ以外が0になったテンソルが返ってきます。これが割と一般的なtorch.whereの使い方かなと思います。

テンソルをinput/otherに入れる

テンソルをinputoutputに入れることもできます。conditionで使用するテンソルと同じサイズである必要がありますが、それぞれ対応するインデックスの値に置き換えたテンソルが返ってきます。

# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
In [2]: in = torch.Tensor([[10, 11, 12, 13], [14, 15, 16, 17]])
In [3]: other = torch.Tensor([[20, 21, 22, 23], [24, 25, 26, 27]])
# use tensor
In [4]: b = torch.where(a > 1, in, other)
In [5]: b
Out[5]: tensor([[20, 21, 12, 13], 
               [14, 15, 16, 17]])

上記のように、1よりも大きいインデックスの値がinのテンソルの値に置き換わり、それ以外はotherのテンソルの値に置き換わっています。

テンソルとスカラを同時に入れる

スカラとテンソルを同時に使用することも、torch.whereでは可能です。

# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
In [2]: in = torch.Tensor([[10, 11, 12, 13], [14, 15, 16, 17]])
# use tensor and scalar
In [3]: b = torch.where(a > 1, in, -1)
In [4]: b
Out[4]: tensor([[-1, -1, 12, 13], 
               [14, 15, 16, 17]])

上記のように、1よりも大きいインデックスの値がinのテンソルの値に置き換わり、それ以外は-1に置き換わっています。今のテンソルの値を維持しつつ、特定の値に置換をしたい時などに便利です。

まとめ

今回はtorch.whereの挙動について整理してみました。与える引数の型や、数によって、挙動が全然異なるため、使用時にはそこら辺を意識して使用できればと思います。

-Engineering, python, pytorch, 機械学習
-, ,