diff --git a/diamondSearch.py b/diamondSearch.py index 44b3120..b901309 100644 --- a/diamondSearch.py +++ b/diamondSearch.py @@ -57,9 +57,9 @@ def image_area_filter(picture, x): class DiamondSearch(search.Search): - def __init__(self, current_picture, referenced_picture, n=2, p=2): - super(DiamondSearch, self).__init__(current_picture, referenced_picture, n=2, p=2) - self.numOfcomparedMacroblocks=0 + # def __init__(self, current_picture, referenced_picture, n=2, p=2): + # super(DiamondSearch, self).__init__(current_picture, referenced_picture, n=2, p=2) + # self.numOfcomparedMacroblocks=0 def image_area_filter2(self, curr, ref): return image_area_filter(self.current_picture, curr) and image_area_filter(self.referenced_picture, ref) @@ -90,8 +90,8 @@ def blockSAD(self, ref_center): _sum /= cnt #print "Compute SAD for macroblock ({0[0]}, {0[1]}): {1}".format(ref_center, _sum) - return _sum - + return _sum + def motionVector(self,isInterpolated=False): self.first = ( self.x + self.n/2, self.y + self.n/2) ldsp = LDSPGenerator() @@ -101,8 +101,21 @@ def motionVector(self,isInterpolated=False): ldsp.setOrigin(origin) filtered_pattern = filter(lambda x: self.search_area_filter(x), ldsp.generate()) mblock_cmp = lambda x: self.blockSAD(x) - next_origin = min(filtered_pattern, key=mblock_cmp) - if next_origin[0] == origin[0] and next_origin[1] == origin[1]: + try: + next_origin = min(filtered_pattern, key=mblock_cmp) + if next_origin[0] == origin[0] and next_origin[1] == origin[1]: + # SDSP + sdsp = [(origin[0] + x, origin[1] + y) for x, y in [[0, 0], [1,0], [0, 1], [-1, 0], [0, -1]]] + filtered_pattern = filter(lambda x: self.search_area_filter(x), sdsp) + last = min(filtered_pattern, key=mblock_cmp) + vec = [ -last[0]+self.first[0], last[1]-self.first[1] ] + #print "vector ({0[0]}, {0[1]}) macroblock {1[0]}, {1[1]}".format(vec, self.first) + return [ vec[1], vec[0] ] + origin = next_origin + except ValueError as e: + print e + print "filtered_pattern", filtered_pattern + print "mblock_cmp",mblock_cmp # SDSP sdsp = [(origin[0] + x, origin[1] + y) for x, y in [[0, 0], [1,0], [0, 1], [-1, 0], [0, -1]]] filtered_pattern = filter(lambda x: self.search_area_filter(x), sdsp) @@ -110,4 +123,75 @@ def motionVector(self,isInterpolated=False): vec = [ -last[0]+self.first[0], last[1]-self.first[1] ] #print "vector ({0[0]}, {0[1]}) macroblock {1[0]}, {1[1]}".format(vec, self.first) return [ vec[1], vec[0] ] - origin = next_origin \ No newline at end of file + + # below we use methods from first version + + #remember that y is reversed + # use __makroBlock__ to get value + def __position___(self,i,j): + y=self.y+i + x=self.x+j + return y,x + + # macroblock is defined by top left corner and size self.N + # returns a cut part of picture + def __makroBlock__(self,i,j,isCurrent=True): + y, x=self.__position___(i,j) + if x>=len(self.current_picture[0]) or x<0: # python allows negative index but I don't + raise IndexError('x out of list') + if y>=len(self.current_picture) or y<0: + raise IndexError('y out of list') + if isCurrent: + # print "ref_pict[",y,"]","[",x,"]= ",list(reversed(self.current_picture))[y][x] + # print "row",list(reversed(self.current_picture))[y] + return list(reversed(self.current_picture))[y][x] + else: + return list(reversed(self.referenced_picture))[y][x] + + def motionEstimation(self): + num_of_macroblocs_in_y=len(self.current_picture)/self.n + num_of_macroblocs_in_x=len(self.current_picture[0])/self.n + result=[] + for y in range(num_of_macroblocs_in_y): + row=[] + for x in range(num_of_macroblocs_in_x): + self.x=x*self.n + self.y=y*self.n + row.append(self.motionVector()) + result.append(row) + return result + + def createCompressedImage(self): + num_of_macroblocs_in_y=len(self.current_picture)/self.n + num_of_macroblocs_in_x=len(self.current_picture[0])/self.n + + mE=self.motionEstimation() + compressedImage = [0] * len(self.current_picture) + for i in range(len(compressedImage)): + compressedImage[i] = [0] * len(self.current_picture[0]) + try : + for y_ in range(num_of_macroblocs_in_y): + for x_ in range(num_of_macroblocs_in_x): + x = x_*self.n + y = y_*self.n + for n1 in range(self.n): + for m1 in range(self.n): + offset_y = mE[y_][x_][0] + offset_x = mE[y_][x_][1] + ey = y + n1 + offset_y + ex = x + m1 + offset_x + if ey < 0 or len(self.current_picture) <= ey or ex < 0 or len(self.current_picture[0]) <= ex: + continue + cy = y + n1 + cx = x + m1 + compressedImage[ey][ex] = self.current_picture[cy][cx] + except IndexError as e: + print "Error: ",e.message,e.args + print "compressedImage size: [",len(compressedImage),"][",y + n1 + offset_y*self.n,"]" + print "value: [",y + n1 + offset_y*self.n,"][",x + m1 + offset_x*self.n,"]" + print "y=",y," n1=",n1," offset_y",offset_y," n=",self.n," x=",x," m1=",m1," offset_x=",offset_x + print "current_picture[",y+n1,"][",x+m1,"]= ",self.current_picture[y+n1][x+m1] + raise e + exit() + + return compressedImage diff --git a/main.py b/main.py index 3914918..e808352 100644 --- a/main.py +++ b/main.py @@ -8,7 +8,7 @@ import diamondSearch __name__='MotionVectorEstimator' -__version__='1.1.0' +__version__='1.2.0' def raw_input_with_default(text,default): input = raw_input(text+'['+default +']'+ chr(8)*4) @@ -59,7 +59,7 @@ def readStandardSequence(filename): comparations=None start = time.time() end = time.time() -useInterpolation=True +useInterpolation=False if "full" in feed_in.lower() : @@ -70,6 +70,9 @@ def readStandardSequence(filename): print full_.motionEstimation() comparations=full_.numOfcomparedMacroblocks elif "diamond" in feed_in.lower(): + if useInterpolation: + print "---Diamond search doesn't support interpolation---" + useInterpolation=False ds = diamondSearch.DiamondSearch(current_picture, referenced_picture, n, p) compressedImage = ds.createCompressedImage() end = time.time() @@ -89,7 +92,7 @@ def readStandardSequence(filename): running_time=(end - start) print "it took: ",running_time, "s"," Number of comparitions: ",comparations -print psnr(referenced_picture,compressedImage),"[dB] - bigger value is better" +print "psnr= ",psnr(referenced_picture,compressedImage),"[dB] - bigger value is better" im = Image.new("L", (len(compressedImage[0]), len(compressedImage)), "white") img_list=[]