Skip to content

Commit

Permalink
Merge pull request #293 from ornlneutronimaging/add_beamhardending_corr
Browse files Browse the repository at this point in the history
Add Beam hardening correction function to correction module
  • Loading branch information
KedoKudo authored Feb 19, 2024
2 parents 7c9dea7 + 750af05 commit 707c7f6
Show file tree
Hide file tree
Showing 13 changed files with 278 additions and 117 deletions.
1 change: 0 additions & 1 deletion .github/workflows/protected_branches.yml
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,6 @@ jobs:
- name: Build Conda Package
run: |
# boa uses mamba to resolve dependencies
conda install -y anaconda-client boa
cd conda.recipe
VERSION=$(versioningit ../) conda mambabuild --output-folder . -c conda-forge . || exit 1
conda verify noarch/imars3d*.tar.bz2 || exit 1
Expand Down
3 changes: 3 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
@@ -1,3 +1,6 @@
# direnv
.envrc

*.pyc
*~
.ipynb*
Expand Down
8 changes: 0 additions & 8 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -29,14 +29,6 @@ repos:
hooks:
- id: black
args: ['--line-length=119']
- repo: https://github.com/PyCQA/flake8
rev: 6.0.0
hooks:
- id: flake8
exclude: |
(?x)^(
^docs/conf.py
)$
- repo: https://github.com/adrienverge/yamllint.git
rev: v1.31.0
hooks:
Expand Down
17 changes: 10 additions & 7 deletions conda.recipe/meta.yaml
Original file line number Diff line number Diff line change
@@ -1,22 +1,25 @@
# load information from setup.cfg/setup.py
{% set data = load_setup_py_data() %}
{% set license = data.get('license') %}
{% set description = data.get('description') %}
{% set url = data.get('url') %}
{% set pyproject = load_file_data('pyproject.toml') %}
{% set project = pyproject.get('project', {}) %}
{% set license = project.get('license').get('text') %}
{% set description = project.get('description') %}
{% set project_url = pyproject.get('project', {}).get('urls') %}
{% set url = project_url.get('homepage') %}
# this will get the version set by environment variable
{% set version = environ.get('VERSION') %}
{% set version_number = environ.get('GIT_DESCRIBE_NUMBER', '0') | string %}
{% set version_number = version.split('+')[0] %}
{% set build_number = 0 %}

package:
name: imars3d
version: {{ version }}
version: {{ version_number }}

source:
path: ..

build:
noarch: python
number: {{ version_number }}
number: {{ build_number }}
string: py{{py}}
script: {{ PYTHON }} -m pip install . --no-deps --ignore-installed -vvv

Expand Down
43 changes: 29 additions & 14 deletions environment.yml
Original file line number Diff line number Diff line change
Expand Up @@ -2,29 +2,44 @@ name: imars3d
channels:
- conda-forge
dependencies:
# -- Runtime
# base
- python
- versioningit
- toml
# compute
- astropy
- tomopy
- dxchange
- jsonschema
- panel>=0.14.2
- param
- pyvista
- pydocstyle
- algotom
# plot
- holoviews
- bokeh
- datashader
- hvplot
- build
- toml
- conda-build
- conda-verify
- pytest
- pytest-cov
# GUI
- panel<1.3
- param<2
- pyvista
# IO
- dxchange
- jsonschema
# -- Development
# utils
- pre-commit
# packaging
- anaconda-client
- boa
- conda-build < 4
- conda-verify
- python-build
# doc
- pydocstyle
- sphinx
- sphinx_rtd_theme
- versioningit
- pip
# test
- pytest
- pytest-cov
# pip
- pip:
- check-wheel-contents
- pytest-playwright
Expand Down
76 changes: 65 additions & 11 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,13 +1,31 @@
[project]
name = "imars3d"
description = "Neutron imaging data analysis at ORNL"
dynamic = ["version"]
requires-python = ">=3.10"
license = { text = "BSD 3-Clause License" }
dependencies = [
"astropy",
]

[project.urls]
homepage = "https://github.com/ornlneutronimaging/iMars3D"

[project.scripts]
imars3dcli = "imars3d.backend.__main__:main"

[build-system]
requires = ["setuptools >= 40.6.0", "wheel", "toml", "versioningit"]
requires = [
"setuptools >= 40.6.0",
"wheel",
"toml",
"versioningit"
]
build-backend = "setuptools.build_meta"

[tool.black]
line-length = 119

[tool.versioningit.vcs]
method = "git"
default-tag = "5.0.0"
default-tag = "1.0.0"

[tool.versioningit.next-version]
method = "minor"
Expand All @@ -20,13 +38,49 @@ distance-dirty = "{next_version}.dev{distance}+d{build_date:%Y%m%d%H%M}"
[tool.versioningit.write]
file = "src/imars3d/_version.py"

[tool.setuptools.packages.find]
where = ["src"]
exclude = ["tests*", "scripts*", "docs*", "notebooks*"]

[tool.setuptools.package-data]
"*" = ["*.yml", "*.yaml", "*.ini", "schema.json"]

[tool.pytest.ini_options]
pythonpath = [
".", "src", "scripts"
]
pythonpath = [".", "src", "scripts"]
testpaths = ["tests"]
python_files = ["test*.py"]
norecursedirs = [".git", "tmp*", "_tmp*", "__pycache__", "*dataset*", "*data_set*"]
markers = [
"datarepo: mark a test as using imars3d-data repository"
norecursedirs = [
".git", "tmp*", "_tmp*", "__pycache__",
"*dataset*", "*data_set*",
"*ui*"
]
markers = ["datarepo: mark a test as using imars3d-data repository"]

[tool.pylint]
max-line-length = 120
disable = ["too-many-locals",
"too-many-statements",
"too-many-instance-attributes",
"too-many-arguments",
"duplicate-code"
]

[tool.coverage.run]
source = [
"src/imars3d/backend"
]
omit = [
"*/tests/*",
"src/imars3d/__init__.py",
"src/imars3d/ui/*"
]

[tool.coverage.report]
fail_under = 60
exclude_lines = [
"pragma: no cover",
"def __repr__",
"def __str__",
"if TYPE_CHECKING:",
"if __name__ == .__main__.:"
]
67 changes: 0 additions & 67 deletions setup.cfg

This file was deleted.

6 changes: 0 additions & 6 deletions setup.py

This file was deleted.

97 changes: 97 additions & 0 deletions src/imars3d/backend/corrections/beam_hardening.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,97 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""Imaging correction for beam hardening."""
import logging
import param
import numpy as np
from imars3d.backend.util.functions import clamp_max_workers
from multiprocessing.managers import SharedMemoryManager
from functools import partial
from tqdm.contrib.concurrent import process_map
from algotom.prep.correction import beam_hardening_correction as algotom_beam_hardening_correction

logger = logging.getLogger(__name__)


class beam_hardening_correction(param.ParameterizedFunction):
"""Imaging correction for beam hardening.
Parameters
----------
arrays: np.ndarray
The image stack to be corrected for beam hardening, must be normalized to 0-1.
q: float
The beam hardening correction parameter, must be positive.
n: float
The beam hardening correction parameter, must be greater than 1.
opt: bool
If True, correction biased towards 1.0, else correction biased towards 0.0.
max_workers: int
The maximum number of workers to use for parallel processing.
tqdm_class: panel.widgets.Tqdm
Class to be used for rendering tqdm progress
Returns
-------
np.ndarray
The corrected image stack.
"""

arrays = param.Array(
doc="The image stack to be corrected for beam hardening, must be normalized to 0-1.",
default=None,
)
q = param.Number(
doc="The beam hardening correction parameter.",
default=0.005,
bounds=(0, None),
)
n = param.Number(
doc="The beam hardening correction parameter.",
default=20.0,
bounds=(1, None),
)
opt = param.Boolean(
doc="If True, correction biased towards 1.0, else correction biased towards 0.0.",
default=True,
)
max_workers = param.Integer(
doc="The maximum number of workers to use for parallel processing, default is 0, which means using all available cores.",
default=0,
bounds=(0, None),
)
tqdm_class = param.ClassSelector(class_=object, doc="Progress bar to render with")

def __call__(self, **params):
"""Perform the beam hardening correction."""
logger.info("Performing beam hardening correction.")
# type check & bounds check
_ = self.instance(**params)
# sanitize arguments
params = param.ParamOverrides(self, params)
# set max_workers
self.max_workers = clamp_max_workers(params.max_workers)
logger.debug(f"max_worker={self.max_workers}")

if params.arrays.ndim == 2:
return algotom_beam_hardening_correction(params.arrays, params.q, params.n, params.opt)
elif params.arrays.ndim == 3:
with SharedMemoryManager() as smm:
shm = smm.SharedMemory(params.arrays.nbytes)
shm_arrays = np.ndarray(params.arrays.shape, dtype=params.arrays.dtype, buffer=shm.buf)
np.copyto(shm_arrays, params.arrays)
# mp
kwargs = {
"max_workers": self.max_workers,
"desc": "denoise_by_bilateral",
}
if self.tqdm_class:
kwargs["tqdm_class"] = self.tqdm_class
rst = process_map(
partial(algotom_beam_hardening_correction, q=params.q, n=params.n, opt=params.opt),
[shm_arrays[i] for i in range(shm_arrays.shape[0])],
**kwargs,
)
return np.array(rst)
else:
raise ValueError("The input array must be either 2D or 3D.")
2 changes: 1 addition & 1 deletion src/imars3d/backend/corrections/ring_removal.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,7 +250,7 @@ def _remove_ring_artifact(
correction_range=correction_range,
),
[shm_arrays[:, sino_idx, :] for sino_idx in range(shm_arrays.shape[1])],
**kwargs
**kwargs,
)
rst = np.array(rst)
for i in range(arrays.shape[1]):
Expand Down
Loading

0 comments on commit 707c7f6

Please sign in to comment.