Spaces:
Running
on
CPU Upgrade
Running
on
CPU Upgrade
import os | |
import gradio as gr | |
import numpy as np | |
import pandas as pd | |
import periodictable | |
import plotly.graph_objs as go | |
import polars as pl | |
from datasets import concatenate_datasets, load_dataset | |
from pymatgen.analysis.phase_diagram import PDPlotter, PhaseDiagram | |
from pymatgen.core import Composition, Element, Structure | |
from pymatgen.core.composition import Composition | |
from pymatgen.entries.computed_entries import ( | |
ComputedStructureEntry, | |
GibbsComputedStructureEntry, | |
) | |
HF_TOKEN = os.environ.get("HF_TOKEN") | |
subsets = [ | |
"compatible_pbe", | |
"compatible_pbesol", | |
"compatible_scan", | |
] | |
# polars_dfs = { | |
# subset: pl.read_parquet( | |
# "hf://datasets/LeMaterial/LeMat1/{}/train-*.parquet".format(subset), | |
# storage_options={ | |
# "token": HF_TOKEN, | |
# }, | |
# ) | |
# for subset in subsets | |
# } | |
# # Load only the train split of the dataset | |
subsets_ds = {} | |
for subset in subsets: | |
dataset = load_dataset( | |
"LeMaterial/LeMat-Bulk", | |
subset, | |
token=HF_TOKEN, | |
columns=[ | |
"lattice_vectors", | |
"species_at_sites", | |
"cartesian_site_positions", | |
"energy", | |
"energy_corrected", | |
"immutable_id", | |
"elements", | |
"functional", | |
], | |
) | |
subsets_ds[subset] = dataset["train"] | |
elements_df = { | |
k: subset.select_columns("elements").to_pandas() for k, subset in subsets_ds.items() | |
} | |
from scipy.sparse import csr_matrix | |
all_elements = {str(el): i for i, el in enumerate(periodictable.elements)} | |
elements_indices = {} | |
for subset, df in elements_df.items(): | |
print("Processing subset: ", subset) | |
elements_indices[subset] = np.zeros((len(df), len(all_elements))) | |
def map_elements(row): | |
index, xs = row["index"], row["elements"] | |
for x in xs: | |
elements_indices[subset][index, all_elements[x]] = 1 | |
df = df.reset_index().apply(map_elements, axis=1) | |
elements_indices[subset] = csr_matrix(elements_indices[subset]) | |
map_functional = { | |
"PBE": "compatible_pbe", | |
"PBESol (No correction scheme)": "compatible_pbesol", | |
"SCAN (No correction scheme)": "compatible_scan", | |
} | |
def create_phase_diagram( | |
elements, | |
energy_correction, | |
plot_style, | |
functional, | |
finite_temp, | |
**kwargs, | |
): | |
# Split elements and remove any whitespace | |
element_list = [el.strip() for el in elements.split("-")] | |
subset_name = map_functional[functional] | |
element_list_vector = np.zeros(len(all_elements)) | |
for el in element_list: | |
element_list_vector[all_elements[el]] = 1 | |
n_elements = elements_indices[subset_name].sum(axis=1) | |
n_elements_query = elements_indices[subset_name][ | |
:, element_list_vector.astype(bool) | |
] | |
if n_elements_query.shape[1] == 0: | |
indices_with_only_elements = [] | |
else: | |
indices_with_only_elements = np.where( | |
n_elements_query.sum(axis=1) == n_elements | |
)[0] | |
print(indices_with_only_elements) | |
entries_df = subsets_ds[subset_name].select(indices_with_only_elements).to_pandas() | |
entries_df = entries_df[~entries_df["immutable_id"].isna()] | |
print(entries_df) | |
# Fetch all entries from the Materials Project database | |
def get_energy_correction(energy_correction, row): | |
if energy_correction == "Database specific, or MP2020" and functional == "PBE": | |
print("applying MP corrections") | |
return ( | |
row["energy_corrected"] - row["energy"] | |
if not np.isnan(row["energy_corrected"]) | |
else 0 | |
) | |
elif energy_correction == "The 110 PBE Method" and functional == "PBE": | |
print("applying PBE110 corrections") | |
return row["energy"] * 1.1 - row["energy"] | |
elif map_functional[functional] != "pbe": | |
print("not applying any corrections") | |
return 0 | |
entries = [ | |
ComputedStructureEntry( | |
Structure( | |
[x.tolist() for x in row["lattice_vectors"].tolist()], | |
row["species_at_sites"], | |
row["cartesian_site_positions"], | |
coords_are_cartesian=True, | |
), | |
energy=row["energy"], | |
correction=get_energy_correction(energy_correction, row), | |
entry_id=row["immutable_id"], | |
parameters={"run_type": row["functional"]}, | |
) | |
for n, row in entries_df.iterrows() | |
] | |
# TODO: Fetch elemental entries (they are usually GGA calculations) | |
# entries.extend([e for e in entries if e.composition.is_element]) | |
if finite_temp: | |
entries = GibbsComputedStructureEntry.from_entries(entries) | |
# Build the phase diagram | |
try: | |
phase_diagram = PhaseDiagram(entries) | |
except ValueError as e: | |
print(e) | |
return go.Figure().add_annotation(text=str(e)) | |
# Generate plotly figure | |
if plot_style == "2D": | |
plotter = PDPlotter(phase_diagram, show_unstable=True, backend="plotly") | |
fig = plotter.get_plot() | |
else: | |
# For 3D plots, limit to ternary systems | |
if len(element_list) == 3: | |
plotter = PDPlotter( | |
phase_diagram, show_unstable=True, backend="plotly", ternary_style="3d" | |
) | |
fig = plotter.get_plot() | |
else: | |
return go.Figure().add_annotation( | |
text="3D plots are only available for ternary systems." | |
) | |
# Adjust the maximum energy above hull | |
# (This is a placeholder as PDPlotter does not support direct filtering) | |
# Return the figure | |
return fig | |
# Define Gradio interface components | |
elements_input = gr.Textbox( | |
label="Elements (e.g., 'Li-Fe-O')", | |
placeholder="Enter elements separated by '-'", | |
value="Li-Fe-O", | |
) | |
# max_e_above_hull_slider = gr.Slider( | |
# minimum=0, maximum=1, value=0.1, label="Maximum Energy Above Hull (eV)" | |
# ) | |
energy_correction_dropdown = gr.Dropdown( | |
choices=[ | |
"The 110 PBE Method", | |
"Database specific, or MP2020", | |
], | |
label="Energy correction", | |
) | |
plot_style_dropdown = gr.Dropdown(choices=["2D", "3D"], label="Plot Style") | |
functional_dropdown = gr.Dropdown( | |
choices=["PBE", "PBESol (No correction scheme)", "SCAN (No correction scheme)"], | |
label="Functional", | |
) | |
finite_temp_toggle = gr.Checkbox(label="Enable Finite Temperature Estimation") | |
warning_message = "⚠️ This application uses energy correction schemes directly" | |
warning_message += " from the data providers (Alexandria, MP) and has the 2020 MP" | |
warning_message += " Compatibility scheme applied to OQMD. However, because we did" | |
warning_message += " not directly apply the compatibility schemes to Alexandria, MP" | |
warning_message += " we have noticed discrepencies in the data. While the correction" | |
warning_message += " scheme will be standardized in a soon to be released update, for" | |
warning_message += " now please take caution when analyzing the results of this" | |
warning_message += " application." | |
warning_message += "<br> Additionally, we have provided the 110 PBE correction method" | |
warning_message += " from <a href='https://chemrxiv.org/engage/api-gateway/chemrxiv/assets/orp/resource/item/67252d617be152b1d0b2c1ef/original/a-simple-linear-relation-solves-unphysical-dft-energy-corrections.pdf' target='_blank'>Rohr et al (2024)</a>." | |
message = "{} <br><br> Generate a phase diagram for a set of elements using LeMat-Bulk data.".format( | |
warning_message | |
) | |
message += """ | |
<div style="font-size: 14px; line-height: 1.6; padding: 20px 0"> | |
<p> | |
This web app is powered by | |
<a href="https://github.com/materialsproject/crystaltoolkit" target="_blank" style="text-decoration: none;">Crystal Toolkit</a>, | |
<a href="https://github.com/materialsproject/dash-mp-components" target="_blank" style="text-decoration: none;">MP Dash Components</a>, | |
and | |
<a href="https://pymatgen.org/" target="_blank" style="text-decoration: none;">Pymatgen</a>. | |
All tools are developed by the | |
<a href="https://next-gen.materialsproject.org/" target="_blank" style="text-decoration: none;">Materials Project</a>. | |
We are grateful for their open-source software packages. This app is intended for data exploration in LeMat-Bulk and is not affiliated with or endorsed by the Materials Project. | |
</p> | |
</div> | |
""" | |
footer_content = """ | |
<div style="font-size: 14px; line-height: 1.6; padding: 20px 0; text-align: center;"> | |
<hr style="border-top: 1px solid #ddd; margin: 10px 0;"> | |
<p> | |
<strong>CC-BY-4.0</strong> requires proper acknowledgement. If you use materials data with an immutable_id starting with <code>mp-</code>, please cite the | |
<a href="https://pubs.aip.org/aip/apm/article/1/1/011002/119685/Commentary-The-Materials-Project-A-materials" target="_blank" style="text-decoration: none;">Materials Project</a>. | |
If you use materials data with an immutable_id starting with <code>agm-</code>, please cite | |
<a href="https://www.science.org/doi/10.1126/sciadv.abi7948" target="_blank" style="text-decoration: none;">Alexandria, PBE</a> | |
or | |
<a href="https://hdl.handle.net/10.1038/s41597-022-01177-w" target="_blank" style="text-decoration: none;">Alexandria PBESol, SCAN</a>. | |
If you use materials data with an immutable_id starting with <code>oqmd-</code>, please cite | |
<a href="https://link.springer.com/article/10.1007/s11837-013-0755-4" target="_blank" style="text-decoration: none;">OQMD</a>. | |
</p> | |
<p> | |
If you use the Phase Diagram or Crystal Viewer, please acknowledge | |
<a href="https://github.com/materialsproject/crystaltoolkit" target="_blank" style="text-decoration: none;">Crystal Toolkit</a>. | |
</p> | |
</div> | |
""" | |
# Create Gradio interface | |
iface = gr.Interface( | |
fn=create_phase_diagram, | |
inputs=[ | |
elements_input, | |
# max_e_above_hull_slider, | |
energy_correction_dropdown, | |
plot_style_dropdown, | |
functional_dropdown, | |
finite_temp_toggle, | |
], | |
outputs=gr.Plot(label="Phase Diagram", elem_classes="plot-out"), | |
css=".plot-out {background-color: #ffffff;}", | |
title="MP Phase Diagram Viewer for LeMat-Bulk", | |
description=message, | |
article=footer_content, | |
) | |
# Launch the app | |
iface.launch() | |