Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
95 changes: 95 additions & 0 deletions B_mean.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
import numpy as np
import matplotlib.pyplot as plt
from scipy.linalg import norm
from scipy.ndimage import map_coordinates
from typing import Tuple


def linear_extrapolate(data_, points):
return map_coordinates(data_, points.T, order=1, mode='nearest')


data = np.load("C:/Users/user/Downloads/vtk_field/npy_field/Bnlfffe_NORH_NLFFFE_170904_055842.npy")

nx, ny, nz = data.shape[:3]
x, y, z = np.arange(nx), np.arange(ny), np.arange(nz)
X, Y, Z = np.meshgrid(x, y, z, indexing='ij')
coordinates = np.stack((X, Y, Z), axis=-1)


def calculate_b_mean(data_, coordinates_, center_, radius_):
squared_dist = np.sum((coordinates_ - np.array(center_)) ** 2, axis=-1)
mask = np.array(squared_dist <= radius_**2)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This already should be a np.array, no need to convert again

return np.mean(data_[mask], axis=0), np.where(mask.ravel())[0]


# noinspection PyUnreachableCode
def create_local_frame(b_mean_: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
b_n_ = b_mean_ / norm(b_mean_)
basis = np.eye(3)
x_prime = np.cross(basis[np.argmin(np.abs(basis @ b_n_))], b_n_)
x_prime /= norm(x_prime)
y_prime = np.cross(b_n_, x_prime)
y_prime /= norm(y_prime)
return np.vstack([x_prime, y_prime, b_n_]), b_n_


def transform_field(b_, rotation_matrix_):
return np.dot(b_, rotation_matrix_.T)


def plot_all_2d_components(b_local_, center_, radius_, rotation_matrix_, title_prefix=""):
fig, axes = plt.subplots(1, 3, figsize=(18, 6))
x_prime, y_prime = rotation_matrix_[0], rotation_matrix_[1]

u = v = np.linspace(-radius_, radius_, int(2 * radius_))
u_grid, v_grid = np.meshgrid(u, v)
x_plane = center_[0] + u_grid * x_prime[0] + v_grid * y_prime[0]
y_plane = center_[1] + u_grid * x_prime[1] + v_grid * y_prime[1]
z_plane = center_[2] + u_grid * x_prime[2] + v_grid * y_prime[2]
sample_points = np.stack((x_plane, y_plane, z_plane), axis=-1)

b_x = linear_extrapolate(b_local_[..., 0], sample_points)
b_y = linear_extrapolate(b_local_[..., 1], sample_points)
b_z = linear_extrapolate(b_local_[..., 2], sample_points)

phi = np.arctan2(v_grid, u_grid)
b_r = b_x * np.cos(phi) + b_y * np.sin(phi)
b_phi = -b_x * np.sin(phi) + b_y * np.cos(phi)

vmax = max(np.max(np.abs(b_r)), np.max(np.abs(b_phi)), np.max(np.abs(b_z)), 1e-6)
for ax, component in zip(axes, [b_r, b_phi, b_z]):
im = ax.imshow(component, cmap='bwr', vmin=-vmax, vmax=vmax,
extent=[u[0], u[-1], v[-1], v[0]], origin='lower')
plt.colorbar(im, ax=ax)
ax.add_patch(plt.Circle((0, 0), radius, color='k', fill=False, linestyle='--'))
ax.scatter(0, 0, c='k', s=100, marker='*')
ax.grid(True, linestyle=':', alpha=0.5)

axes[0].set_title(fr"{title_prefix}$B_r$")
axes[1].set_title(fr"{title_prefix}$B_\phi$")
axes[2].set_title(fr"{title_prefix}$B_n$")
plt.tight_layout()
plt.show()


def calculate_curl(b_):
gradients = np.stack(np.gradient(b_, axis=(0, 1, 2)), axis=-1)
curl = np.stack([gradients[..., 2, 1] - gradients[..., 1, 2],
gradients[..., 0, 2] - gradients[..., 2, 0],
gradients[..., 1, 0] - gradients[..., 0, 1]], axis=-1)
return curl


center = (203, 10, 1)
radius = 15

b_mean, indices = calculate_b_mean(data, coordinates, center, radius)
rotation_matrix, B_n = create_local_frame(b_mean)
b_local = np.apply_along_axis(transform_field, 3, data, rotation_matrix)
Copy link
Copy Markdown
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It is better to apply_along_axis when there are no other ways.

You can try options like:
reshape B:

‘Nx Ny Nz 3 -> -1, 3
B@rotation.T
-1, 3 -> Nx Ny Nz 3

Or you can also try np.einsum


plot_all_2d_components(b_local, center, radius, rotation_matrix, "Local components: ")

curl_global = calculate_curl(data)
curl_local = np.apply_along_axis(transform_field, 3, curl_global, rotation_matrix)
plot_all_2d_components(curl_local, center, radius, rotation_matrix, "Curl ")
1 change: 1 addition & 0 deletions requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ idna==3.10
isodate==0.7.2
joblib==1.5.0
lxml==5.4.0
matplotlib==3.10.3
multidict==6.4.3
numpy==2.2.5
packaging==25.0
Expand Down