Python で正規分布の乱数を生成
Python では numpy を利用することで、正規分布の乱数を容易に生成することができます。
正規分布とは、平均値の付近に乱数の生成が集中しやすい特性があり、様々な数値解析モデルで利用されることが多いです。
動作環境
- Windows 11
- WSL Ubuntu 20.04
- Python 3.8.10
事前準備
各種動作を WSL で確認しています。
事前準備として、必要なパッケージ等をインストールしておきます。
$ sudo apt install python3-pip
$ pip install numpy
標準正規分布
正規分布の中でも、平均値が 0、標準偏差が 1 のものを標準正規分布と呼びます。
これは numpy.random.randn()
で生成可能です。
import numpy
print(numpy.random.randn())
print(numpy.random.randn())
print(numpy.random.randn())
$ python3 main.py
-1.286112660617798
-0.6865218203081558
0.7401399832873927
randn
の引数に、生成する乱数の数を指定することも可能です。
import numpy
print(numpy.random.randn(10))
$ python3 main.py
[-0.6141697 -0.17554056 0.49666993 0.87636681 -0.67411954 0.461894
0.16993512 0.12006013 -0.13954982 0.2314704 ]
ヒストグラムを表示
数値を見ただけでは標準正規分布かどうかわかりづらいので、大量の乱数を生成してヒストグラムを表示します。
以下のサンプルコードでは、1億の正規分布乱数を生成して、これを matplotlib でヒストグラム表示しました。
import matplotlib.pyplot
import numpy
xs = numpy.random.randn(100_000_000)
matplotlib.pyplot.hist(xs, bins=100, range=(-4, 4))
matplotlib.pyplot.show()

平均値である 0 近傍に乱数が集中していることがわかります。
【関連記事】
WSL で matplotlib を利用する記事はこちらです。
任意パラメータの正規分布
実用上では、正規分布の平均や標準偏差を指定したいこともあると思います。
そのような場合は、np.random.normal
を利用します。
第1引数に平均値、第2引数に標準偏差を指定します。
また、randn
同様に第3引数で乱数の数を指定することもできます。
import numpy
print(numpy.random.normal(1, 5, 10))
$ python3 main.py
[-1.87434209 8.23852896 5.00024207 4.0433546 8.15219528 -1.6738393
-2.50390046 -5.14845033 -6.24863085 2.85018825]
平均をずらした正規分布
平均値を 1.0 とした正規分布をヒストグラム表示します。
標準偏差は 1.0 としているため、グラフの形状は標準正規分布と同じです。
x 軸の数値だけシフトしたようになります。
import matplotlib.pyplot
import numpy
xs = numpy.random.normal(1, 1, 100_000_000)
matplotlib.pyplot.hist(xs, bins=100, range=(-4, 4))
matplotlib.pyplot.show()

標準偏差を大きくした正規分布
次に、平均値を 0 に戻して、標準偏差を 2 に大きくしたグラフを描画します。
import matplotlib.pyplot
import numpy
xs = numpy.random.normal(0, 2, 100_000_000)
matplotlib.pyplot.hist(xs, bins=100, range=(-4, 4))
matplotlib.pyplot.show()

標準正規分布よりも、横長のヒストグラムが出力されました。
レンジが狭かったので、もう少し幅を広くして再描画します。
import matplotlib.pyplot
import numpy
xs = numpy.random.normal(0, 2, 100_000_000)
matplotlib.pyplot.hist(xs, bins=100, range=(-8, 8))
matplotlib.pyplot.show()
