Hi All,
I’m using scib_matrics (Benchmarking large-scale integration — scib-metrics) and I can get results to all but the PC regression.
The command I run is:
bm = Benchmarker(
adata,
batch_key="sample",
label_key="cell_type",
embedding_obsm_keys=["Unintegrated", "scANVI", "scVI"],
pre_integrated_embedding_obsm_key="X_pca",
bio_conservation_metrics=biocons,
n_jobs=-1,
)
bm.prepare(neighbor_computer=faiss_brute_force_nn)
bm.benchmark()
The error I get is:
ValueError Traceback (most recent call last)
Cell In[45], line 34
24 bm = Benchmarker(
25 adata = adata_scanvi,
26 batch_key=batch_key,
(...)
31 n_jobs=-1,
32 )
33 bm.prepare(neighbor_computer=faiss_brute_force_nn)
---> 34 bm.benchmark()
35 end = time.time()
36 print(f"Time: {int((end - start) / 60)} min {int((end - start) % 60)} sec")
File /opt/conda/lib/python3.10/site-packages/scib_metrics/benchmark/_core.py:226, in Benchmarker.benchmark(self)
223 if isinstance(use_metric_or_kwargs, dict):
224 # Kwargs in this case
225 metric_fn = partial(metric_fn, **use_metric_or_kwargs)
--> 226 metric_value = getattr(MetricAnnDataAPI, metric_name)(ad, metric_fn)
227 # nmi/ari metrics return a dict
228 if isinstance(metric_value, dict):
File /opt/conda/lib/python3.10/site-packages/scib_metrics/benchmark/_core.py:91, in MetricAnnDataAPI.<lambda>(ad, fn)
89 graph_connectivity = lambda ad, fn: fn(ad.obsp["15_distances"], ad.obs[_LABELS])
90 silhouette_batch = lambda ad, fn: fn(ad.X, ad.obs[_LABELS], ad.obs[_BATCH])
---> 91 pcr_comparison = lambda ad, fn: fn(ad.obsm[_X_PRE], ad.X, ad.obs[_BATCH], categorical=True)
92 ilisi_knn = lambda ad, fn: fn(ad.obsp["90_distances"], ad.obs[_BATCH])
93 kbet_per_label = lambda ad, fn: fn(ad.obsp["50_connectivities"], ad.obs[_BATCH], ad.obs[_LABELS])
File /opt/conda/lib/python3.10/site-packages/scib_metrics/_pcr_comparison.py:43, in pcr_comparison(X_pre, X_post, covariate, scale, **kwargs)
40 if covariate.shape[0] != X_pre.shape[0]:
41 raise ValueError("Dimension mismatch: `X_pre` and `covariate` must have the same number of samples.")
---> 43 pcr_pre = principal_component_regression(X_pre, covariate, **kwargs)
44 pcr_post = principal_component_regression(X_post, covariate, **kwargs)
46 if scale:
File /opt/conda/lib/python3.10/site-packages/scib_metrics/utils/_pcr.py:49, in principal_component_regression(X, covariate, categorical, n_components)
46 else:
47 covariate = np.asarray(covariate)
---> 49 covariate = one_hot(covariate) if categorical else covariate.reshape((covariate.shape[0], 1))
51 pca_results = pca(X, n_components=n_components)
53 # Center inputs for no intercept
File /opt/conda/lib/python3.10/site-packages/scib_metrics/utils/_utils.py:37, in one_hot(y, n_classes)
22 """One-hot encode an array. Wrapper around :func:`~jax.nn.one_hot`.
23
24 Parameters
(...)
34 Array of shape (n_cells, n_classes).
35 """
36 n_classes = n_classes or jnp.max(y) + 1
---> 37 return nn.one_hot(jnp.ravel(y), n_classes)
File /opt/conda/lib/python3.10/site-packages/jax/_src/nn/functions.py:464, in one_hot(x, num_classes, dtype, axis)
438 """One-hot encodes the given indices.
439
440 Each index in the input ``x`` is encoded as a vector of zeros of length
(...)
459 computed.
460 """
461 num_classes = core.concrete_dim_or_error(
462 num_classes,
463 "The error arose in jax.nn.one_hot argument `num_classes`.")
--> 464 return _one_hot(x, num_classes, dtype=dtype, axis=axis)
ValueError: Non-hashable static arguments are not supported. An error occurred during a call to '_one_hot' while trying to hash an object of type <class 'jaxlib.xla_extension.ArrayImpl'>, 9. The error was:
TypeError: unhashable type: 'ArrayImpl'
did anyone else encounter it? and found a fix they can share?
thanks!