import umage as um
from math import sqrt, atan2, sin, cos, pi

def greyscale(mat_img):
    gray_img = []
    for ligne in mat_img:
        lig = []
        for r,g,b in ligne:
            v = int(r*0.2125 + g*0.7154 + b*0.0721)
            lig.append((v,)*3)
        gray_img.append(lig)
    return gray_img

def convolution(mat_img, mat):
    return_img = []
    for j in range(len(mat_img)):
        ligne = []
        for i in range(len(mat_img[0])):
            val = appliquer_convolution(mat_img, mat, i, j)
            ligne.append((val,)*3)
        return_img.append(ligne)
    return return_img

def filtre_sobel(img):

    def calcul_norme(pixel1, pixel2):
        valeur = pixel1[0]**2 + pixel2[0]**2
        norm = round(sqrt(valeur))
        norm = int(min(norm, 255))
        return norm

    def application_norme(im_x, im_y):
        result_image = []
        for j in range(len(im_x)):
            ligne = []
            for i in range(len(im_x[0])):
                pixel1 = im_x[j][i]
                pixel2 = im_y[j][i]
                norme = calcul_norme(pixel1, pixel2)
                ligne.append((norme,)*3)
            result_image.append(ligne)
        return result_image

    if not is_greyscale(img):
        img = greyscale(img)

    mat_x = [[-1,0,1],[-2,0,2],[-1,0,1]]
    mat_y = [[-1,-2,-1],[0,0,0],[1,2,1]]
    Gx = convolution(img, mat_x)
    Gy = convolution(img, mat_y)
    
    filtred_image = application_norme(Gx,Gy)
    return filtred_image



#########################################################################
########################Exercices Supplémentaires########################
#########################################################################

def is_greyscale(img):
    _greyscale = True
    for ligne in img:
        for r,g,b in ligne:
            if not (r==g and g==b):
                _greyscale = False
                break
        if not _greyscale:
            break
    return _greyscale

def invert(img):
    result_image = []
    for ligne in img:
        result_ligne = []
        for r,g,b in ligne:
            result_ligne.append((255-r, 255-g, 255-b))
        result_image.append(result_ligne)
    return result_image

def pixel(img, i, j, default=(0,0,0)):
    #i la colone et j la ligne
    if 0 <= i < len(img[0]) and 0 <= j < len(img):
        return img[j][i]
    else:
        return default

def appliquer_convolution(img, mat, i, j):
    somme = 0
    for y in range(len(mat)):
        for x in range(len(mat[0])):
            pixel_i = i - (len(mat[0]) // 2) + x
            pixel_j = j - (len(mat) // 2) + y
            pix = pixel(img, pixel_i, pixel_j)
            somme += pix[0]*mat[y][x]
    return min(max(somme,0), 255)



######################################################################
########################Exercices personnelles########################
######################################################################
def convolution_gauss(mat_img):
    mat_gauss = [
        [2/159, 4/159, 5/159, 4/159,2/159],
        [4/159, 9/159,12/159, 9/159,4/159],
        [5/159,12/159,15/159,12/159,5/159],
        [4/159, 9/159,12/159, 9/159,4/159],
        [2/159, 4/159, 5/159, 4/159,2/159]
    ]

    return_img = []
    for j in range(len(mat_img)):
        ligne = []
        for i in range(len(mat_img[0])):
            val = reduction_bruit(mat_img, mat_gauss, i, j)
            ligne.append((val,)*3)
        return_img.append(ligne)
    return return_img

def reduction_bruit(img, mat, i, j):
    somme = 0
    for y in range(len(mat)):
        for x in range(len(mat[0])):
            pixel_i = i - (len(mat[0]) // 2) + x
            pixel_j = j - (len(mat) // 2) + y
            pix = pixel(img, pixel_i, pixel_j)
            somme += pix[0]*mat[y][x]
    normalise = round(somme)
    return normalise

def filtre_canny(img):

    def norme_gradient(pixel1, pixel2):
        color_x = pixel1[0]
        color_y = pixel2[0]
        
        norm = round(sqrt(color_x**2 + color_y**2))
        norm = min(norm, 255)

        grad = atan2(color_y, color_x)
        return norm, grad

    def liste_normGrad(im1, im2):
        liste = []
        for j in range(len(im1)):
            ligne = []
            for i in range(len(im1[0])):
                normGrad = norme_gradient(im1[j][i], im2[j][i])
                ligne.append(normGrad)
            liste.append(ligne)
        return liste

    if not is_greyscale(img):
        img = greyscale(img)
    
    mat_x = [[-1,0,1]]
    mat_y = [[1],[0],[-1]]

    #lissage/suppression des bri
    img_no_bruit = convolution_gauss(img)
    Jx = convolution(img, mat_x)
    Jy = convolution(img, mat_y)
    normGrad = liste_normGrad(Jx, Jy)

    #Suppresion des non-maximum


#temp
def norme_gradient(pixel1, pixel2):
        color_x = pixel1[0]
        color_y = pixel2[0]
        
        norm = round(sqrt(color_x**2 + color_y**2))
        norm = min(norm, 255)

        grad = atan2(color_y, color_x)
        return norm, grad

#temp
def liste_normGrad(im1, im2):
    liste = []
    for j in range(len(im1)):
        ligne = []
        for i in range(len(im1[0])):
            normGrad = norme_gradient(im1[j][i], im2[j][i])
            ligne.append(normGrad)
        liste.append(ligne)
    return liste

mat_x = [[-1,0,1]]
mat_y = [[1],[0],[-1]]
#temp
#lissage
img = um.load("imageEngine\\images\\valve.png")
img = convolution_gauss(img)
Jx = convolution(img, mat_x)
Jy = convolution(img, mat_y)
normGrad = liste_normGrad(Jx, Jy)
###########



def find_neighbord_norm(mat, i, j, rad):
    x = 0
    y = 0
    if sin(pi/8) <= abs(sin(rad)):
        y = 1
    if cos(3*pi/8)>abs(cos(rad)):
        x = 1

    norm_pix1 = -1
    norm_pix2 = -1
    if 0 <= j-y < len(mat):
        if 0 <= i-x < len(mat[0]):
            norm_pix1 = mat[j-y][i-x][0]
    if 0 <= j+y < len(mat):
        if 0 <= i+x < len(mat[0]):
            norm_pix2 = mat[j+y][i+x][0]
    
    return norm_pix1, norm_pix2

def delete_pixel(mat_img, mat):
    img_to_return = []
    for j in range(len(mat)):
        ligne = []
        for i in range(len(mat[0])):
            rad = mat[j][i][1]
            norms = find_neighbord_norm(mat, i, j, rad)
            if rad < norms[0] or rad < norms[1]:
                ligne.append((0,)*3)
            else:
                ligne.append(mat_img[j][i])
        img_to_return.append(ligne)
    return img_to_return



def hysteresis(mat_img, mat_norm, Th):
    Tl = Th / 2
    mat_img = yesOrNo(mat_img, Th, Tl)
    result_image = []
    for j in range(len(mat_img)):
        ligne = []
        for i in range(len(mat_img[0])):
            rad = mat_norm[j][i][1]
            color1, color2 = find_neighbord_pixel(mat_img, i, j, rad+(pi/2))
            if color1 == 255 or color2 == 255:
                ligne.append((255,)*3)
            else:
                ligne.append((0,)*3)
        result_image.append(ligne)
    return result_image

def find_neighbord_pixel(mat_image, i, j, rad):
    x = 0
    y = 0
    if sin(pi/8) <= abs(sin(rad)):
        y = 1
    if cos(3*pi/8)>abs(cos(rad)):
        x = 1

    color_pix1 = 0
    color_pix2 = 0
    if 0 <= j-y < len(mat_image):
        if 0 <= i-x < len(mat_image[0]):
            color_pix1 = mat_image[j-y][i-x][0]
    if 0 <= j+y < len(mat_image):
        if 0 <= i+x < len(mat_image[0]):
            color_pix2 = mat_image[j+y][i+x][0]
    
    return color_pix1, color_pix2

def yesOrNo(mat_img, Th, Tl):
    result_image = []
    for j in range(len(mat_img)):
        ligne = []
        for i in range(len(mat_img[0])):
            pix = mat_img[j][i]
            if Th <= pix[0]:
                ligne.append((255,)*3)
            elif pix[0] < Tl:
                ligne.append((0,)*3)
            else:
                ligne.append(pix)
        result_image.append(ligne)
    return result_image


zt_no_maxima = delete_pixel(img, normGrad)
zt_hysteresis = hysteresis(zt_no_maxima, normGrad, 200)

um.save(zt_hysteresis, "imageEngine\\test\\valve", "png")