퍼셉트론 구하기, 라이브러리 의존도 0% [ML with Python]

안녕하세요. 언제나 휴일에 언휴예요.

이번 강의는 퍼셉트론(Perceptron) 구하는 코드를 작성해 볼 거예요.

파이썬의 강력한 라이브러리를 이용하면 순식간에 만들겠죠. 눈 떴다 감을 시간도 없어요.

강력한 라이브러리 도움을 받아 기계 학습(Merchine Learning)을 할 수 있다는 것은 매우 매력적입니다.

하지만 처음 진입하는 이들에게 라이브러리에 의해 감쳐진 알고리즘 때문에 답답함을 주기도 합니다.

” 답답하지 않다.” 하시는 분들은 이번 강의는 가감하게 Pass~

== 다루는 내용 ==
단층 퍼셉트론(Single Layer Perceptron)
가중치, 임계치 설정 함수 정의
판별 함수 작성
테스트에 사용할 데이터 미리 살펴보기
예측 정확도 함수 작성
기계 학습 함수 작성
테스트 코드 작성

댠층 퍼셉트론(Single Layer Perceptron)

입력 인자로 구성한 평면에서 하나의 직선으로 참인 집단과 거짓인 집단을 구했을 때 퍼셉트론을 찾은 거예요.

[그림] or

위 그림은 OR 연산에 입력 인자 X1, X2를 축으로 하는 평면입니다.

네 개의 점은 X1, X2가 취할 수 있는 값이죠.

그리고 하나의 직선으로 결과가 참일 때와 거짓일 때를 구분하고 있습니다.

이 때의 w1, w2, b를 찾는 것이 선형 퍼셉트론입니다.

여기에서 w1,w2는 결과 (Y)에 입력 인자(x1, x2)가 어느 정도 영향을 주는지 나타내는 가중치입니다.

0이면 가중치가 없는 것이고 1이면 가중치가 100%예요.

그리고 b는 임계치입니다.

가중치, 임계치 설정 함수 정의

먼저 가중치와 임계치를 전역 변수로 선언할게요.

그리고 설정하는 함수도 정의합시다.

w1,w2,b=0,0,0
def setwb(wt1,wt2,bt):
    global w1,w2,b
    w1,w2,b = wt1,wt2,bt

판별 함수 작성

[그림] 판별식

위 그림은 판별식입니다.

이를 함수로 구현합시다.

반환 값은 참과 거짓 대신 0과 1을 반환하게 구현하였습니다.

data를 표현할 때 더 간단해서 이와 같이 표현한 곳이 많습니다.

def discriminate(x1,x2):
    if(w1*x1+w2*x2+b<=0):
        return 0
    else:
        return 1

테스트에 사용할 데이터 미리 살펴보기

기계 학습 부분을 구현하기 전에 어떠한 데이터를 전달하여 테스트를 할 것인지 데이터를 먼저 살펴볼게요.

여기에서는 AND 연산, OR 연산, XOR 연산 데이터를 가지고 테스트를 할 거예요.

AND연산을 예로 들면 x1=0,x2=0일 때 y=0이죠. 이를 [0,0,0]으로 표현할게요.

그리고 네 가지이므로 [ [0,0,0],[0,1,0],[1,0,0],[1,1,1] ] 처럼 표현할 거예요.

ds_and=[
    [0,0,0],[0,1,0],[1,0,0],[1,1,1]
    ]
ds_or=[
    [0,0,0],[0,1,1],[1,0,1],[1,1,1]
    ]
ds_xor=[
    [0,0,0],[0,1,1],[1,0,1],[1,1,0]
    ]

예측 정확도 함수 작성

데이터와 가중치, 임계치를 전달받아 얼마나 예측이 맞는지 계산하는 함수를 정의합시다.

먼저 입력 인자로 전달받은 가중치와 입계치를 설정합니다.(앞에서 정의한 setwb호출)

데이터에 각 항목을 판별식으로 확인합니다. (앞에서 정의한 discriminate 호출)

결과가 맞다면 성공을 1 증가합니다. (ok를 으로 초기화, 판별식이 참일 때 1 증가)

전체 데이터를 판별하였을 때 성공 확률을 반환합니다.

def test(ds,wt1,wt2,bt):
    setwb(wt1,wt2,bt)
    ok,total=0,0
    for x1,x2,y in ds:
        if(discriminate(x1,x2)==y):
#            print("T", end=' ')  #테스트 목적일 때는 잠시 주석을 풀어서 확인하세요.
            ok+=1
#        else:                    #테스트 목적일 때는 잠시 주석을 풀어서 확인하세요.
#            print("F",end=' ')   #테스트 목적일 때는 잠시 주석을 풀어서 확인하세요.
        total+=1
    return ok/total

기계 학습 함수 작성

선형 퍼셉트론을 구하는 기계 학습이라 무리한 반복문을 수행할게요.

가중치 wt1을 0~1까지 0.1씩 증가시키면서 테스트 합니다. (가중치는 0일 때 무관, 1일 때 100% 영향)

가중치 wt2을 0~1까지 0.1씩 증가시키면서 테스트 합니다.

임계치 bt를 -1~1까지 0.1씩 증가시키면서 테스트 합니다. (임계치는 편향으로 직선 이동 정도입니다.)

def myr(s,e,st):   #range와 같은 목적, step이 실수
    r=s
    while(r<e):
        yield r
        r+=st
def find_wb(ds): #기계 학습
    for wt1 in myr(0,1,0.1):
        for wt2 in myr(0,1,0.1):
            for bt in myr(-1,1,0.1):
                if(test(ds,wt1,wt2,bt)==1.0):
                    return True
    return False

테스트 코드 작성

이제 and, or, xor 테스트를 위한 코드를 작성하세요.

ds_and=[
    [0,0,0],[0,1,0],[1,0,0],[1,1,1]
    ]
if find_wb(ds_and):
    print("w1:{0:.1f} w2:{1:.1f} b:{2:.1f} ## and".format(w1,w2,b))
else:
    print("not founded ## and")
ds_or=[
    [0,0,0],[0,1,1],[1,0,1],[1,1,1]
    ]
if find_wb(ds_or):
    print("w1:{0:.1f} w2:{1:.1f} b:{2:.1f} ## or".format(w1,w2,b))
else:
    print("not founded ## or")
ds_xor=[
    [0,0,0],[0,1,1],[1,0,1],[1,1,0]
    ]
if find_wb(ds_xor):
    print("w1:{0:.1f} w2:{1:.1f} b:{2:.1f} ## xor".format(w1,w2,b))
else:
    print("not founded ## xor")

전체 소스 코드

# https://ehpub.co.kr
# 머신 러닝 with pYTHON
# 선형 퍼셉트론 구하기 - 라이브러리 의존도 0%

w1,w2,b=0,0,0 #가중치1,2와 임계치
def setwb(wt1,wt2,bt): #가중치, 임계치 설정 함수
    global w1,w2,b
    w1,w2,b = wt1,wt2,bt
def discriminate(x1,x2): #판별 함수
    if(w1*x1+w2*x2+b<=0):
        return 0
    else:
        return 1
def test(ds,wt1,wt2,bt): #예측 정확도 계산 함수
    setwb(wt1,wt2,bt)
    ok,total=0,0
    for x1,x2,y in ds:
        if(discriminate(x1,x2)==y):
#            print("T", end=' ')
            ok+=1
#        else:
#            print("F",end=' ')
        total+=1
    return ok/total
def myr(s,e,st): #range와 같은 목적, step이 실수
    r=s
    while(r<e):
        yield r
        r+=st
def find_wb(ds): #기계 학습
    for wt1 in myr(0,1,0.1):
        for wt2 in myr(0,1,0.1):
            for bt in myr(-1,1,0.1):
                if(test(ds,wt1,wt2,bt)==1.0):
                    return True
    return False

#테스트 코드
ds_and=[
    [0,0,0],[0,1,0],[1,0,0],[1,1,1]
    ]
if find_wb(ds_and):
    print("w1:{0:.1f} w2:{1:.1f} b:{2:.1f} ## and".format(w1,w2,b))
else:
    print("not founded ## and")
ds_or=[
    [0,0,0],[0,1,1],[1,0,1],[1,1,1]
    ]
if find_wb(ds_or):
    print("w1:{0:.1f} w2:{1:.1f} b:{2:.1f} ## or".format(w1,w2,b))
else:
    print("not founded ## or")
ds_xor=[
    [0,0,0],[0,1,1],[1,0,1],[1,1,0]
    ]
if find_wb(ds_xor):
    print("w1:{0:.1f} w2:{1:.1f} b:{2:.1f} ## xor".format(w1,w2,b))
else:
    print("not founded ## xor")