
pytorchで少し細かい実装をしていると、torch.whereをよく使います。torch.whereを使うと、特定の条件に当てはまる値を変更したりでき非常に便利なのですが、引数の与え方などによって、挙動が変わってきます。ここの挙動を正しく理解しないと、実装ミスにもつながるかなと思いましたので、今回はパターン別にtorch.whereの動きを整理してみました。
torch.whereとは?
公式docs (https://pytorch.org/docs/stable/generated/torch.where.html)によると、以下のように書かれています。基本的には条件に合致した場合、inputを返し、それ以外はotherの値もしくはテンソルを返す関数になります。
- torch.where(condition, input, other, *, out=None) → Tensor
Return a tensor of elements selected from either input or other, depending on condition.
(条件に応じてinputまたはotherの選択された要素のtensorを返す)
引数
torch.whereの引数は、以下のように定義されています。
| 引数 | 型 | 説明 |
condition | BoolTensor | Trueの時inputを出力し、Falseの時otherを出力する |
input | Tensor or Scalar | Trueの時の値またはTensor |
other | Tensor or Scalar | Falseの時の値またはTensor |
torch.whereの引数実際の動き
それではtorch.whereの実際の動きを見てみましょう。
conditionのみ
conditionのみを与えると、TrueになったTorchのインデックスが返ってきます。
# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
# only condition
In [2]: b = torch.where(a > 1)
In [3]: b
Out[3]: (tensor([0, 0, 1, 1, 1, 1]), tensor([2, 3, 0, 1, 2, 3]))
上記のように(2, 4)のテンソルを入力にすると、tupleで2つテンソルが返ってきます。それぞれのテンソルが、各軸のインデックスに対応していて、それぞれ順番に取り出してみると、
- [0, 2] -> 2
- [0. 3] -> 3
- [1, 0] -> 4
- [1, 1] -> 5
- [1, 2] -> 6
- [1, 3] -> 7
とすべて1より大きい値のインデックスが返っていることがわかります。ここで帰ってくるインデックスのサイズは、入力で使用されるテンソルのサイズに合わせて変わります。
スカラをinput/otherに入れる
inputやotherにスカラを入れると、そこで渡した値に入れ替わったテンソルが返ってきます。
# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
# use scalar
In [2]: b = torch.where(a > 1, 10, 0)
In [3]: b
Out[3]: tensor([[0, 0, 10, 10],
[10, 10, 10, 10]])
上記のように、1よりも大きいインデックスの値が10, それ以外が0になったテンソルが返ってきます。これが割と一般的なtorch.whereの使い方かなと思います。
テンソルをinput/otherに入れる
テンソルをinputやoutputに入れることもできます。conditionで使用するテンソルと同じサイズである必要がありますが、それぞれ対応するインデックスの値に置き換えたテンソルが返ってきます。
# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
In [2]: in = torch.Tensor([[10, 11, 12, 13], [14, 15, 16, 17]])
In [3]: other = torch.Tensor([[20, 21, 22, 23], [24, 25, 26, 27]])
# use tensor
In [4]: b = torch.where(a > 1, in, other)
In [5]: b
Out[5]: tensor([[20, 21, 12, 13],
[14, 15, 16, 17]])
上記のように、1よりも大きいインデックスの値がinのテンソルの値に置き換わり、それ以外はotherのテンソルの値に置き換わっています。
テンソルとスカラを同時に入れる
スカラとテンソルを同時に使用することも、torch.whereでは可能です。
# Create Tensor
In [1]: a = torch.Tensor([[0, 1, 2, 3], [4, 5, 6, 7]])
In [2]: in = torch.Tensor([[10, 11, 12, 13], [14, 15, 16, 17]])
# use tensor and scalar
In [3]: b = torch.where(a > 1, in, -1)
In [4]: b
Out[4]: tensor([[-1, -1, 12, 13],
[14, 15, 16, 17]])
上記のように、1よりも大きいインデックスの値がinのテンソルの値に置き換わり、それ以外は-1に置き換わっています。今のテンソルの値を維持しつつ、特定の値に置換をしたい時などに便利です。
まとめ
今回はtorch.whereの挙動について整理してみました。与える引数の型や、数によって、挙動が全然異なるため、使用時にはそこら辺を意識して使用できればと思います。