Skip to content

Commit

Permalink
Merge pull request #50 from AllenInstitute/kt_stitch
Browse files Browse the repository at this point in the history
Kt stitch
  • Loading branch information
RussTorres authored Apr 10, 2024
2 parents aace92c + 5338431 commit 56dc5c5
Show file tree
Hide file tree
Showing 7 changed files with 634 additions and 0 deletions.
25 changes: 25 additions & 0 deletions acpreprocessing/stitching_modules/acstitch/ccorr_stitch.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
import numpy
from acpreprocessing.stitching_modules.acstitch.rtccorr import get_point_correspondence


def get_correspondences(A1_ds,A2_ds,A1_pts,A2_pts,w,r=1,pad=False,cc_threshold=0.8,min_value=0):
cc_threshold /= r # TODO: is this the correct rescaling accounting for expanding reference?
w = numpy.asarray(w,dtype=int)
if len(A1_pts.shape)<2:
A1_pts = numpy.array([A1_pts])
A2_pts = numpy.array([A2_pts])
pm1 = []
pm2 = []
for p,q in zip(A1_pts.astype(int),A2_pts.astype(int)):
A2sub = A2_ds[0,0,(q-r*w)[0]:(q+r*w)[0],(q-r*w)[1]:(q+r*w)[1],(q-r*w)[2]:(q+r*w)[2]]
A1sub = A1_ds[0,0,(p-w)[0]:(p+w)[0],(p-w)[1]:(p+w)[1],(p-w)[2]:(p+w)[2]]
if r > 1:
pw = numpy.asarray([((r-1)*wi,(r-1)*wi) for wi in w],dtype=int)
A1sub = numpy.pad(A1sub,pw)
p1,p2 = get_point_correspondence(p,q,A1sub,A2sub,autocorrelation_threshold=cc_threshold,padarray=pad,value_threshold=min_value)
if not p1 is None:
pm1.append(p1)
pm2.append(p2)
if pm1:
return numpy.asarray(pm1),numpy.asarray(pm2)
return None,None
32 changes: 32 additions & 0 deletions acpreprocessing/stitching_modules/acstitch/io.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
# -*- coding: utf-8 -*-
"""
Created on Mon Jul 31 13:33:46 2023
@author: kevint
"""

import json
import gzip
import numpy

class NumpyEncoder(json.JSONEncoder):
def default(self, obj):
if isinstance(obj, numpy.ndarray):
return obj.tolist()
return json.JSONEncoder.default(self, obj)


def save_pointmatch_file(pmdata,jsonpath):
with gzip.open(jsonpath, 'w') as fout:
fout.write(json.dumps(pmdata,cls=NumpyEncoder).encode('utf-8'))


def read_pointmatch_file(jsonpath):
with gzip.open(jsonpath, 'r') as fin:
data = json.loads(fin.read().decode('utf-8'))
if data:
for tspec in data:
for key in ["p_pts","q_pts"]:
if not tspec[key] is None:
tspec[key] = numpy.asarray(tspec[key])
return data
80 changes: 80 additions & 0 deletions acpreprocessing/stitching_modules/acstitch/rtccorr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@

import numpy
import scipy.ndimage

def correlate_fftns(fft1, fft2):
prod = fft1 * fft2.conj()
res = numpy.fft.ifftn(prod)

corr = numpy.fft.fftshift(res).real
return corr


def ccorr_fftn(img1, img2):
# TODO do we want to pad this?
fft1 = numpy.fft.fftn(img1)
fft2 = numpy.fft.fftn(img2)

return correlate_fftns(fft1, fft2)


def autocorr_fftn(img):
fft = numpy.fft.fftn(img)
return correlate_fftns(fft, fft)


def ccorr_and_autocorr_fftn(img1, img2):
# TODO do we want to pad this?
fft1 = numpy.fft.fftn(img1)
fft2 = numpy.fft.fftn(img2)
ccorr = correlate_fftns(fft1, fft2)
acorr1 = correlate_fftns(fft1, fft1)
acorr2 = correlate_fftns(fft2, fft2)
return ccorr, acorr1, acorr2


def subpixel_maximum(arr):
max_loc = numpy.unravel_index(numpy.argmax(arr), arr.shape)

sub_arr = arr[
tuple(slice(ml-1, ml+2) for ml in max_loc)
]

# get center of mass of sub_arr
subpixel_max_loc = numpy.array(scipy.ndimage.center_of_mass(sub_arr)) - 1
return subpixel_max_loc + max_loc


def ccorr_disp(img1, img2, autocorrelation_threshold=None, padarray=False, value_threshold=0):
if padarray:
d = numpy.ceil(numpy.array(img1.shape) / 2)
pw = numpy.asarray([(di,di) for di in d],dtype=int)
img1 = numpy.pad(img1,pw)
img2 = numpy.pad(img2,pw)
if value_threshold:
img1[img1<value_threshold] = 0
img2[img2<value_threshold] = 0
if autocorrelation_threshold is not None:
cc, ac1, ac2 = ccorr_and_autocorr_fftn(img1, img2)
ac1max = ac1.max()
ac2max = ac2.max()
if (not numpy.isnan(ac1max) and ac1max > 0) and (not numpy.isnan(ac2max) and ac2max > 0):
autocorrelation_ratio = cc.max() / (numpy.sqrt(ac1max*ac2max))
if autocorrelation_ratio < autocorrelation_threshold:
# what to do here?
print("ratio below threshold: " + str(autocorrelation_ratio))
return None
else:
return None
else:
cc = ccorr_fftn(img1, img2)
max_loc = subpixel_maximum(cc)
mid_point = numpy.array(img1.shape) // 2
return max_loc - mid_point


def get_point_correspondence(src_pt, dst_pt, src_patch, dst_patch, autocorrelation_threshold=0.8,padarray=False,value_threshold=0):
disp = ccorr_disp(src_patch, dst_patch, autocorrelation_threshold, padarray,value_threshold)
if disp is not None:
return src_pt, dst_pt - disp
return None,None
Loading

0 comments on commit 56dc5c5

Please sign in to comment.