Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor volume consensus #928

Draft
wants to merge 3 commits into
base: devel
Choose a base branch
from
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 83 additions & 27 deletions src/xmipp/applications/scripts/volume_consensus/volume_consensus.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@
*
**************************************************************************
"""
from os.path import splitext
import math
import os
import itertools
import numpy as np
import pywt
import pywt.data

from scipy.ndimage import zoom
from xmipp_base import XmippScript
import xmippLib
Expand All @@ -54,51 +56,105 @@ def run(self):
outVolFn = self.getParam('-o')
self.computeVolumeConsensus(inputFile, outVolFn)

def resize(self, image, dim):
imageFt = np.fft.rfftn(image)
resultFt = np.zeros(dim[:-1] + (dim[-1]//2+1,), dtype=imageFt.dtype)

copyExtent = np.minimum(image.shape, dim) // 2
srcCornerStart = image.shape-copyExtent
dstCornerStart = dim-copyExtent
for corners in itertools.product(range(2), repeat=len(dim)-1):
corners = np.array(corners + (0, ))
srcStart = np.where(corners, srcCornerStart, 0)
srcEnd = srcStart + copyExtent
dstStart = np.where(corners, dstCornerStart, 0)
dstEnd = dstStart + copyExtent
srcSlices = [slice(s, e) for s, e in zip(srcStart, srcEnd)]
dstSlices = [slice(s, e) for s, e in zip(dstStart, dstEnd)]
resultFt[tuple(dstSlices)] = imageFt[tuple(srcSlices)]

return np.fft.irfftn(resultFt)

def computeVolumeConsensus(self, inputFile, outVolFn, wavelet='sym11'):
outputWt = None
outputMin = None
xdim2 = None
xdimOrig = None
image = xmippLib.Image()
with open(inputFile) as f:
for line in f:
fileName = line.split()[0]
if fileName.endswith('.mrc'):
fileName += ':mrc'
V = xmippLib.Image(fileName)
vol = V.getData()

image.read(fileName)
volume = image.getData()

if xdimOrig is None:
xdimOrig = vol.shape[0]
xdim2 = 2**(math.ceil(math.log(xdimOrig, 2))) # Next power of 2
ydimOrig = vol.shape[1]
ydim2 = 2 ** (math.ceil(math.log(ydimOrig, 2))) # Next power of 2
zdimOrig = vol.shape[2]
zdim2 = 2 ** (math.ceil(math.log(zdimOrig, 2))) # Next power of 2
xdimOrig = volume.shape[0]
xdim2 = 2**(math.ceil(math.log2(xdimOrig))) # Next power of 2
ydimOrig = volume.shape[1]
ydim2 = 2**(math.ceil(math.log2(ydimOrig))) # Next power of 2
zdimOrig = volume.shape[2]
zdim2 = 2**(math.ceil(math.log2(zdimOrig))) # Next power of 2

if xdimOrig!=xdim2 or ydimOrig!=ydim2 or zdimOrig!=zdim2:
vol = zoom(vol, (xdim2/xdimOrig,ydim2/ydimOrig,zdim2/zdimOrig))
nlevel = pywt.swt_max_level(len(vol))
wt = pywt.swtn(vol, wavelet, nlevel, 0)
#volume = zoom(volume, (xdim2/xdimOrig,ydim2/ydimOrig,zdim2/zdimOrig))
volume = self.resize(volume, (zdim2, ydim2, xdim2))

nlevel = pywt.dwtn_max_level(volume.shape, wavelet=wavelet)
wt = pywt.wavedecn(
data=volume,
wavelet=wavelet,
level=nlevel
)

if outputWt == None:
outputWt = wt
outputMin = wt[0]['aaa']*0
outputMin = np.zeros_like(volume)
else:
for level in range(0, nlevel):
diff = np.abs(np.abs(wt[0]) - np.abs(outputWt[0]))
diff = self.resize(diff, outputMin.shape)
np.maximum(
diff, outputMin,
out=outputMin
)
outputWt[0] = np.where(
np.abs(wt[0]) > np.abs(outputWt[0]),
wt[0], outputWt[0]
)

for level in range(1, nlevel+1):
wtLevel = wt[level]
outputWtLevel = outputWt[level]
for key in wtLevel:
outputWtLevel[key] = np.where(np.abs(outputWtLevel[key]) > np.abs(wtLevel[key]),
outputWtLevel[key], wtLevel[key])
diff = np.abs(np.abs(outputWtLevel[key]) - np.abs(wtLevel[key]))
outputMin = np.where(outputMin > diff, outputMin, diff)
for detail in wtLevel:
wtLevelDetail = wtLevel[detail]
outputWtLevelDetail = outputWtLevel[detail]

diff = np.abs(np.abs(wtLevelDetail) - np.abs(outputWtLevelDetail))
diff = self.resize(diff, outputMin.shape)
np.maximum(
diff, outputMin,
out=outputMin
)

outputWtLevelDetail[...] = np.where(
np.abs(wtLevelDetail) > np.abs(outputWtLevelDetail),
wtLevelDetail, outputWtLevelDetail
)


f.close()
consensus = pywt.iswtn(outputWt, wavelet)
consensus = pywt.waverecn(outputWt, wavelet)
if xdimOrig!=xdim2 or ydimOrig!=ydim2 or zdimOrig!=zdim2:
consensus = self.resize(consensus, (zdimOrig, ydimOrig, xdimOrig))
image.setData(consensus)
image.write(outVolFn)
if xdimOrig!=xdim2 or ydimOrig!=ydim2 or zdimOrig!=zdim2:
consensus = zoom(consensus, (xdimOrig/xdim2,ydimOrig/ydim2,zdimOrig/zdim2))
V = xmippLib.Image()
V.setData(consensus)
V.write(outVolFn)
V.setData(outputMin)
outVolFn2 = splitext(outVolFn)[0] + '_diff.mrc'
V.write(outVolFn2)
outputMin = self.resize(outputMin, (zdimOrig, ydimOrig, xdimOrig))
image.setData(outputMin)
outVolFn2 = os.path.splitext(outVolFn)[0] + '_diff.mrc'
image.write(outVolFn2)


if __name__=="__main__":
Expand Down