From ca94f1794a3143ae433c19141c67cc4191ff6535 Mon Sep 17 00:00:00 2001 From: Clare72 Date: Wed, 10 Jun 2026 17:08:58 +0100 Subject: [PATCH] show subclass results in class connectivity queries --- .../test_downstream_class_connectivity.py | 79 +++- src/test/test_upstream_class_connectivity.py | 80 +++- src/vfbquery/vfb_queries.py | 377 ++++++++++++------ 3 files changed, 398 insertions(+), 138 deletions(-) diff --git a/src/test/test_downstream_class_connectivity.py b/src/test/test_downstream_class_connectivity.py index df7286b..ce701d6 100644 --- a/src/test/test_downstream_class_connectivity.py +++ b/src/test/test_downstream_class_connectivity.py @@ -39,8 +39,9 @@ def test_row_has_expected_keys(self): assert result["rows"], "Expected at least one row" row = result["rows"][0] expected_keys = { - "id", "downstream_class", "total_n", "connected_n", - "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + "id", "query_id", "upstream_class", "downstream_class", + "total_n", "connected_n", "percent_connected", + "pairwise_connections", "total_weight", "avg_weight", } assert expected_keys.issubset(row.keys()) @@ -87,8 +88,9 @@ def test_dataframe_has_expected_columns(self): TEST_CLASS, return_dataframe=True, limit=1, force_refresh=True ) expected_cols = { - "id", "downstream_class", "total_n", "connected_n", - "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + "id", "query_id", "upstream_class", "downstream_class", + "total_n", "connected_n", "percent_connected", + "pairwise_connections", "total_weight", "avg_weight", } assert expected_cols.issubset(set(df.columns)) @@ -129,7 +131,7 @@ def test_parent_class_appears_with_sensible_counts(self, result): """ from vfbquery.vfb_queries import vc, get_dict_cursor - rows = result["rows"] + rows = [r for r in result["rows"] if r["query_id"] == TEST_CLASS] ids = [r["id"] for r in rows] assert ids, "Expected at least one row to test against" @@ -147,7 +149,7 @@ def test_parent_class_appears_with_sensible_counts(self, result): parent_id = pairs[0]["parent"] child_id = pairs[0]["child"] parent_row = next(r for r in rows if r["id"] == parent_id) - # Sum connected_n across all descendant rows (not just the one returned). + # Sum connected_n across all descendant rows. desc_q = ( "MATCH (p:Class {short_form: '%s'})<-[:SUBCLASSOF*1..]-(c:Class) " "WHERE c.short_form IN %s " @@ -169,18 +171,67 @@ def test_parent_class_appears_with_sensible_counts(self, result): ) @pytest.mark.integration - def test_total_n_is_constant_across_rows(self, result): - """`total_n` is the queried-side instance count and must be the same - for every output row (regression for the previous summed-across- - subclasses value). + def test_total_n_constant_within_each_query_class(self, result): + """In the downstream direction the presynaptic side is the queried + class, so (matching VFB_connect's normalization) `total_n` is the + queried (sub)class instance count: constant within each query block (it + varies between blocks), and `connected_n` never exceeds it. """ + from collections import defaultdict + rows = result["rows"] assert rows, "Expected at least one row" - total_ns = {r["total_n"] for r in rows} - assert len(total_ns) == 1, ( - f"Expected total_n to be constant across rows, got: {total_ns}" + by_query = defaultdict(set) + for r in rows: + assert r["connected_n"] <= r["total_n"], ( + f"connected_n={r['connected_n']} > total_n={r['total_n']} " + f"for {r['id']}" + ) + by_query[r["query_id"]].add(r["total_n"]) + for qid, totals in by_query.items(): + assert len(totals) == 1, ( + f"Expected total_n constant within block {qid}, got: {totals}" + ) + assert next(iter(totals)) > 0 + + @pytest.mark.integration + def test_includes_subclass_breakdown(self, result): + """The result should contain the input term's own rows plus a block of + rows for each subclass that has connectivity instances. Any non-input + query_id must be a genuine subclass of the input term. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor + + rows = result["rows"] + query_ids = {r["query_id"] for r in rows} + assert TEST_CLASS in query_ids, "Expected the input term's own rows" + + # Full subclass closure (incl. the input term itself). + q = ( + "MATCH (sub:Class)-[:SUBCLASSOF*0..]->(:Class {short_form: '%s'}) " + "RETURN collect(DISTINCT sub.short_form) AS ids" % TEST_CLASS + ) + subtree_rows = get_dict_cursor()(vc.nc.commit_list([q])) + subtree = set(subtree_rows[0]["ids"]) if subtree_rows else set() + offenders = [q for q in query_ids if q not in subtree] + assert not offenders, ( + f"query_id(s) not in the input term's subclass closure: {offenders}" + ) + + # Subclasses of the input term that have connectivity instances. + sub_q = ( + "MATCH (sub:Class)-[:SUBCLASSOF*1..]->(:Class {short_form: '%s'}) " + "WHERE (sub)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-" + "(:Individual:has_neuron_connectivity) " + "RETURN collect(DISTINCT sub.short_form) AS ids" % TEST_CLASS + ) + sub_rows = get_dict_cursor()(vc.nc.commit_list([sub_q])) + connected_subclasses = set(sub_rows[0]["ids"]) if sub_rows else set() + if not connected_subclasses: + pytest.skip("Input term has no connectivity-bearing subclasses") + assert query_ids & connected_subclasses, ( + "Expected subclass breakdown rows but none were present" ) - assert next(iter(total_ns)) > 0 @pytest.mark.integration def test_no_rows_above_neuron_root(self, result): diff --git a/src/test/test_upstream_class_connectivity.py b/src/test/test_upstream_class_connectivity.py index 7cc538b..f6e64b5 100644 --- a/src/test/test_upstream_class_connectivity.py +++ b/src/test/test_upstream_class_connectivity.py @@ -39,8 +39,9 @@ def test_row_has_expected_keys(self): assert result["rows"], "Expected at least one row" row = result["rows"][0] expected_keys = { - "id", "upstream_class", "total_n", "connected_n", - "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + "id", "query_id", "upstream_class", "downstream_class", + "total_n", "connected_n", "percent_connected", + "pairwise_connections", "total_weight", "avg_weight", } assert expected_keys.issubset(row.keys()) @@ -87,8 +88,9 @@ def test_dataframe_has_expected_columns(self): TEST_CLASS, return_dataframe=True, limit=1, force_refresh=True ) expected_cols = { - "id", "upstream_class", "total_n", "connected_n", - "percent_connected", "pairwise_connections", "total_weight", "avg_weight", + "id", "query_id", "upstream_class", "downstream_class", + "total_n", "connected_n", "percent_connected", + "pairwise_connections", "total_weight", "avg_weight", } assert expected_cols.issubset(set(df.columns)) @@ -126,10 +128,13 @@ def test_parent_class_appears_with_sensible_counts(self, result): """A row keyed on a parent class should have connected_n at least as large as any of its descendant rows (set-union semantics) and at most the sum of descendant connected_n. + + Restricted to the input term's own block so partner rows are not mixed + across queried (sub)classes. """ from vfbquery.vfb_queries import vc, get_dict_cursor - rows = result["rows"] + rows = [r for r in result["rows"] if r["query_id"] == TEST_CLASS] ids = [r["id"] for r in rows] assert ids, "Expected at least one row to test against" @@ -166,17 +171,68 @@ def test_parent_class_appears_with_sensible_counts(self, result): ) @pytest.mark.integration - def test_total_n_is_constant_across_rows(self, result): - """`total_n` is the queried-side instance count and must be the same - for every output row. + def test_total_n_is_per_partner(self, result): + """In the upstream direction the presynaptic side is the partner, so + (matching VFB_connect's normalization) `total_n` describes the partner + (`upstream_class`): it must be constant across every row referencing the + same partner id, regardless of which queried (sub)class block it is in, + and `connected_n` must never exceed it. """ + from collections import defaultdict + rows = result["rows"] assert rows, "Expected at least one row" - total_ns = {r["total_n"] for r in rows} - assert len(total_ns) == 1, ( - f"Expected total_n to be constant across rows, got: {total_ns}" + by_partner = defaultdict(set) + for r in rows: + assert r["connected_n"] <= r["total_n"], ( + f"connected_n={r['connected_n']} > total_n={r['total_n']} " + f"for {r['id']}" + ) + by_partner[r["id"]].add(r["total_n"]) + for pid, totals in by_partner.items(): + assert len(totals) == 1, ( + f"total_n varies for partner {pid}: {totals}" + ) + assert next(iter(totals)) > 0 + + @pytest.mark.integration + def test_includes_subclass_breakdown(self, result): + """The result should contain the input term's own rows plus a block of + rows for each subclass that has connectivity instances. Any non-input + query_id must be a genuine subclass of the input term. + """ + from vfbquery.vfb_queries import vc, get_dict_cursor + + rows = result["rows"] + query_ids = {r["query_id"] for r in rows} + assert TEST_CLASS in query_ids, "Expected the input term's own rows" + + # Full subclass closure (incl. the input term itself). + q = ( + "MATCH (sub:Class)-[:SUBCLASSOF*0..]->(:Class {short_form: '%s'}) " + "RETURN collect(DISTINCT sub.short_form) AS ids" % TEST_CLASS + ) + subtree_rows = get_dict_cursor()(vc.nc.commit_list([q])) + subtree = set(subtree_rows[0]["ids"]) if subtree_rows else set() + offenders = [q for q in query_ids if q not in subtree] + assert not offenders, ( + f"query_id(s) not in the input term's subclass closure: {offenders}" + ) + + # Subclasses of the input term that have connectivity instances. + sub_q = ( + "MATCH (sub:Class)-[:SUBCLASSOF*1..]->(:Class {short_form: '%s'}) " + "WHERE (sub)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-" + "(:Individual:has_neuron_connectivity) " + "RETURN collect(DISTINCT sub.short_form) AS ids" % TEST_CLASS + ) + sub_rows = get_dict_cursor()(vc.nc.commit_list([sub_q])) + connected_subclasses = set(sub_rows[0]["ids"]) if sub_rows else set() + if not connected_subclasses: + pytest.skip("Input term has no connectivity-bearing subclasses") + assert query_ids & connected_subclasses, ( + "Expected subclass breakdown rows but none were present" ) - assert next(iter(total_ns)) > 0 @pytest.mark.integration def test_no_rows_above_neuron_root(self, result): diff --git a/src/vfbquery/vfb_queries.py b/src/vfbquery/vfb_queries.py index b32f337..8657b35 100644 --- a/src/vfbquery/vfb_queries.py +++ b/src/vfbquery/vfb_queries.py @@ -1597,14 +1597,18 @@ def DownstreamClassConnectivity_to_schema(name, take_default): Matching criteria: Class + Neuron Implementation: multi-step aggregation, not a single Solr lookup. - 1. Neo4j: instances in the SUBCLASSOF closure of the queried class. + 1. Neo4j: the queried class plus each subclass that has connectivity + instances, with the instances in each one's SUBCLASSOF closure. 2. Solr cache (batched): per-instance synaptic partners. 3. Solr: direct partner classes from the downstream_connectivity_query field (seed set for the partner-side ancestor walk). 4. Neo4j: walk SUBCLASSOF up from each direct partner to the neuron root. 5. Neo4j (batched): partner_instance -> {class_ids} membership map. 6. In-memory aggregation with set-union semantics to handle FBbt - multi-inheritance without double-counting. + multi-inheritance without double-counting, emitted as a separate row + block per queried (sub)class (input term first); the queried (sub)class + fills the ``upstream_class`` slot (downstream query) or + ``downstream_class`` slot (upstream query) of the v2 layout. Results are cached server-side (@with_solr_cache) per queried class, so repeat calls return in milliseconds, but cold calls on broad classes can @@ -1618,7 +1622,7 @@ def DownstreamClassConnectivity_to_schema(name, take_default): "default": take_default, } preview = 5 - preview_columns = ["downstream_class", "total_n", "connected_n", "percent_connected", "pairwise_connections", "total_weight", "avg_weight"] + preview_columns = ["upstream_class", "downstream_class", "total_n", "connected_n", "percent_connected", "pairwise_connections", "total_weight", "avg_weight"] return Query(query=query, label=label, function=function, takes=takes, preview=preview, preview_columns=preview_columns) @@ -1641,7 +1645,7 @@ def UpstreamClassConnectivity_to_schema(name, take_default): "default": take_default, } preview = 5 - preview_columns = ["upstream_class", "total_n", "connected_n", "percent_connected", "pairwise_connections", "total_weight", "avg_weight"] + preview_columns = ["upstream_class", "downstream_class", "total_n", "connected_n", "percent_connected", "pairwise_connections", "total_weight", "avg_weight"] return Query(query=query, label=label, function=function, takes=takes, preview=preview, preview_columns=preview_columns) @@ -3350,24 +3354,31 @@ def get_neuron_region_connectivity(short_form: str, return_dataframe=True, limit } -def _fetch_connectivity_entries(short_form: str, solr_field: str): +def _fetch_connectivity_entries(short_form: str, solr_field: str, subclass_ids=None): """Fetch connectivity entries from Solr for a neuron class and all its OWLERY subclasses. Returns a flat list of parsed JSON entries (dicts) from the Solr connectivity field, collected across every subclass doc. + + ``subclass_ids`` may be a pre-resolved subclass set (the queried class plus + its Owlery subclass closure) to avoid re-querying Owlery when the caller has + already computed it; when omitted it is fetched here. """ - # Step 1: OWLERY subclass expansion (includes the class itself) - owl_query = f"<{short_form}>" - try: - subclass_ids = vc.vfb.oc.get_subclasses( - query=owl_query, query_by_label=False, verbose=False - ) - except Exception as e: - print(f"Owlery subclass query failed for {short_form}: {e}") - subclass_ids = [] + # Step 1: OWLERY subclass expansion (includes the class itself). Use the + # caller-supplied set when given, otherwise resolve it via Owlery. + if subclass_ids is None: + owl_query = f"<{short_form}>" + try: + subclass_ids = vc.vfb.oc.get_subclasses( + query=owl_query, query_by_label=False, verbose=False + ) + except Exception as e: + print(f"Owlery subclass query failed for {short_form}: {e}") + subclass_ids = [] - # Always include the queried class itself + # Always include the queried class itself; normalise to a list. + subclass_ids = list(subclass_ids) if short_form not in subclass_ids: subclass_ids.insert(0, short_form) @@ -3548,43 +3559,84 @@ def _bulk_fetch_per_instance_connectivity(instance_ids): def _aggregate_class_connectivity(short_form, direction, neuron_root=NEURON_ROOT_SHORT_FORM): - """Aggregate class-level partner connectivity correctly under FBbt - multi-inheritance, using set-union over instance memberships. + """Aggregate class-level partner connectivity for the queried class AND + each of its subclasses individually, correctly under FBbt + multi-inheritance using set-union over instance memberships. ``direction`` is ``'downstream'`` (partner = downstream of queried class) - or ``'upstream'``. Returns a list of row dicts with the same fields the - previous summation-based implementation produced. + or ``'upstream'``. Returns a flat list of row dicts; every row is tagged + with the queried (sub)class it belongs to via ``query_id`` / + ``_query_label``. The input term's own rows (aggregated over its full + instance population, exactly as before) come first, followed by a block of + rows for each subclass that has connectivity instances, ordered by class id. + + The expensive pieces (per-instance edges, partner-side hierarchy and + membership) are computed once for the whole subtree and instances are then + partitioned by queried (sub)class, so cost is roughly independent of the + number of subclasses. """ from collections import defaultdict - # 1. Queried-side instances (subclass closure via Neo4j — Owlery's - # get_instances has been observed to hang for some classes, while a - # SUBCLASSOF traversal in Cypher is fast and equivalent here). - queried_q = ( - "MATCH (n:Individual:has_neuron_connectivity)-[:INSTANCEOF]->" - "(:Class)-[:SUBCLASSOF*0..]->(:Class {short_form: '%s'}) " - "RETURN DISTINCT n.short_form AS sf" % short_form + # 1a. Queried (sub)classes in scope: the input term plus every subclass. + # Reuse Owlery's reasoner subclass closure (the canonical subclass set + # used throughout VFBquery — get_instances, _fetch_connectivity_entries + # — and effectively cached) rather than a fresh Neo4j SUBCLASSOF + # traversal. The same set seeds the partner-fetch in step 3 below, so + # it is computed once here. Owlery excludes the queried class itself, so + # add it back. + try: + owl_query = f"<{short_form}>" + subclass_ids = vc.vfb.oc.get_subclasses( + query=owl_query, query_by_label=False, verbose=False + ) + except Exception as e: + print(f"Owlery subclass query failed for {short_form}: {e}") + subclass_ids = [] + query_class_ids = {short_form, *(subclass_ids or [])} + + # 1b. queried (sub)class -> its instances (SUBCLASSOF closure), with labels. + # The proven anchored membership query (single variable-length walk + # bounded by ``WHERE ... IN [ids]``) returns the instances AND the label + # for every queried (sub)class that actually has connectivity instances + # — which is exactly the set of blocks we emit — so no separate label + # lookup or subtree query is needed. Classes with no instances simply + # don't come back. + membership_q = ( + "MATCH (c:Class)<-[:SUBCLASSOF*0..]-(:Class)<-[:INSTANCEOF]-" + "(n:Individual:has_neuron_connectivity) " + "WHERE c.short_form IN %s " + "RETURN c.short_form AS cid, c.label AS label, " + "collect(DISTINCT n.short_form) AS iids" % sorted(query_class_ids) ) try: - results = vc.nc.commit_list([queried_q]) - rows = get_dict_cursor()(results) - queried_instances = [r['sf'] for r in rows if r.get('sf')] + rows = get_dict_cursor()(vc.nc.commit_list([membership_q])) except Exception as e: - print(f"Queried-side instance query failed for {short_form}: {e}") + print(f"Queried-side membership query failed for {short_form}: {e}") return [] - if not queried_instances: + query_class_to_instances = defaultdict(set) + query_labels = {} + all_instances = set() + for r in rows: + cid = r.get('cid') + iids = set(r.get('iids') or []) + if not cid or not iids: + continue + query_class_to_instances[cid] = iids + query_labels[cid] = r.get('label') or cid + all_instances.update(iids) + if not query_class_to_instances: return [] - queried_instance_set = set(queried_instances) - total_n_queried = len(queried_instance_set) + query_labels.setdefault(short_form, short_form) - # 2. Per-instance edges from cache. Cache misses are skipped with a warning; - # the resulting connected_n / pairwise / total_weight will be a slight - # underestimate when this happens. - found_edges, missing = _bulk_fetch_per_instance_connectivity(queried_instances) + # 2. Per-instance edges from cache (once for the whole subtree). Cache + # misses are skipped with a warning; the resulting connected_n / + # pairwise / total_weight will be a slight underestimate when this + # happens. + found_edges, missing = _bulk_fetch_per_instance_connectivity(all_instances) if missing: print( f"Warning: per-instance connectivity cache missing for " - f"{len(missing)}/{total_n_queried} instances of {short_form}; " + f"{len(missing)}/{len(all_instances)} instances under {short_form}; " f"those will be skipped (results may be a slight underestimate)." ) if not found_edges: @@ -3593,13 +3645,15 @@ def _aggregate_class_connectivity(short_form, direction, weight_key = 'outputs' if direction == 'downstream' else 'inputs' # 3. Direct partner classes from the existing class-level connectivity - # field (already cached) — used as the seed set for the partner-side - # ancestor walk. + # field (already cached, unioned across the input term's subclass docs) + # — used as the seed set for the partner-side ancestor walk. Reuse the + # subclass set already resolved in step 1a rather than re-querying Owlery. solr_field = ( 'downstream_connectivity_query' if direction == 'downstream' else 'upstream_connectivity_query' ) - class_entries = _fetch_connectivity_entries(short_form, solr_field) + class_entries = _fetch_connectivity_entries( + short_form, solr_field, subclass_ids=query_class_ids) direct_partner_ids = set() for entry in class_entries: obj = entry.get('object', {}) @@ -3615,63 +3669,123 @@ def _aggregate_class_connectivity(short_form, direction, return [] # 5. Build partner_instance_id -> {class_ids it belongs to}, restricted - # to in-scope partner classes. - instance_to_classes = _build_partner_instance_class_membership(partner_class_ids) - - # 6. Aggregate edges into per-class buckets via set-union semantics. - buckets = defaultdict(lambda: { - 'edges': set(), 'weight_sum': 0.0, 'connected_n1': set(), - }) - for n1, partner_rows in found_edges.items(): - if n1 not in queried_instance_set: - continue - for prow in partner_rows or []: - n2 = prow.get('id') - w = prow.get(weight_key) - if not n2 or not w: - continue - try: - w_num = float(w) - except (TypeError, ValueError): - continue - if w_num <= 0: - continue - for c in instance_to_classes.get(n2, ()): - b = buckets[c] - b['edges'].add((n1, n2)) - b['weight_sum'] += w_num - b['connected_n1'].add(n1) - - # 7. Emit one row per partner class that received at least one edge. - rows = [] - for cid, b in buckets.items(): - pw = len(b['edges']) - cn = len(b['connected_n1']) - tw = b['weight_sum'] - pct = round((cn / total_n_queried) * 100) if total_n_queried else 0 - avg = (tw / pw) if pw else 0 - label = class_labels.get(cid, cid) - rows.append({ - 'id': cid, - '_label': label, - 'total_n': total_n_queried, - 'connected_n': cn, - 'percent_connected': pct, - 'pairwise_connections': pw, - 'total_weight': tw, - 'avg_weight': avg, + # to in-scope partner classes. The helper already returns this + # instance -> {classes} mapping, so it is used directly. From it we also + # derive the total instance count per partner class (with SUBCLASSOF + # closure), which is the denominator when the partner is the presynaptic + # side (the upstream direction — see VFB_connect parity below). + instance_to_partner_classes = _build_partner_instance_class_membership(partner_class_ids) + partner_class_total = defaultdict(int) + _partner_class_members = defaultdict(set) + for iid, classes in instance_to_partner_classes.items(): + for c in classes: + _partner_class_members[c].add(iid) + for c, members in _partner_class_members.items(): + partner_class_total[c] = len(members) + + # 6. Aggregate edges into per-(partner-class) buckets via set-union + # semantics, separately for each queried (sub)class. + # + # Normalization matches VFB_connect's ``get_connected_neurons_by_type``: + # ``total_n`` / ``connected_n`` describe the PRESYNAPTIC (source) side of + # each connection (the column names stay as the v2 frontend expects). + # - downstream direction: queried class -> partner, so the presynaptic + # side is the queried (sub)class. ``total_n`` is the queried-class + # instance count (constant within the block) and ``connected_n`` + # counts queried instances that connect. + # - upstream direction: partner -> queried class, so the presynaptic + # side is the partner class. ``total_n`` is the partner class + # instance count (varies per partner row) and ``connected_n`` counts + # partner instances that connect. + queried_is_presynaptic = (direction == 'downstream') + + def block_for(query_id): + instances = query_class_to_instances.get(query_id) or set() + if not instances: + return [] + total_queried = len(instances) + buckets = defaultdict(lambda: { + 'edges': set(), 'weight_sum': 0.0, + 'connected_queried': set(), 'connected_partner': set(), }) + for n1 in instances: + for prow in found_edges.get(n1) or []: + n2 = prow.get('id') + w = prow.get(weight_key) + if not n2 or not w: + continue + try: + w_num = float(w) + except (TypeError, ValueError): + continue + if w_num <= 0: + continue + for c in instance_to_partner_classes.get(n2, ()): + b = buckets[c] + b['edges'].add((n1, n2)) + b['weight_sum'] += w_num + b['connected_queried'].add(n1) + b['connected_partner'].add(n2) + block = [] + for cid, b in buckets.items(): + pw = len(b['edges']) + tw = b['weight_sum'] + if queried_is_presynaptic: + total = total_queried + connected = len(b['connected_queried']) + else: + total = partner_class_total.get(cid, 0) + connected = len(b['connected_partner']) + pct = round((connected / total) * 100) if total else 0 + avg = (tw / pw) if pw else 0 + block.append({ + 'id': cid, + '_label': class_labels.get(cid, cid), + 'query_id': query_id, + '_query_label': query_labels.get(query_id, query_id), + 'total_n': total, + 'connected_n': connected, + 'percent_connected': pct, + 'pairwise_connections': pw, + 'total_weight': tw, + 'avg_weight': avg, + }) + block.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) + return block + + # 7. Input term first, then one block per subclass (ordered by class id). + rows = block_for(short_form) + subclass_ids = sorted( + cid for cid in query_class_to_instances if cid != short_form + ) + for cid in subclass_ids: + rows.extend(block_for(cid)) return rows -def _format_class_connectivity_rows(rows, partner_key): - """Add the markdown-link partner column expected by callers and drop the - internal ``_label`` field.""" +def _format_class_connectivity_rows(rows, partner_key, query_key): + """Populate both markdown-link class columns expected by the v2 layout and + drop the internal ``_label`` / ``_query_label`` fields. + + ``partner_key`` (``'downstream_class'`` or ``'upstream_class'``) receives the + partner class; ``query_key`` (the other of the two) receives the queried + (sub)class this row belongs to. Reusing the existing + ``upstream_class`` / ``downstream_class`` slots avoids adding a column and + lets the per-subclass breakdown show the actual (sub)class per row instead + of the constant the processor used to synthesise. + + ``query_id`` is retained (not displayed) so callers can group rows by the + queried (sub)class without parsing the markdown link. + """ out = [] for r in rows: formatted = dict(r) - label = formatted.pop('_label', formatted['id']) - formatted[partner_key] = f"[{label}]({formatted['id']})" + partner_label = formatted.pop('_label', formatted['id']) + formatted[partner_key] = f"[{partner_label}]({formatted['id']})" + query_id = formatted.get('query_id') + query_label = formatted.pop('_query_label', query_id) + if query_id is not None: + formatted[query_key] = f"[{query_label or query_id}]({query_id})" out.append(formatted) return out @@ -3679,7 +3793,8 @@ def _format_class_connectivity_rows(rows, partner_key): @with_solr_cache('downstream_class_connectivity_query') def get_downstream_class_connectivity(short_form: str, return_dataframe=True, limit: int = -1): """ - Retrieves downstream connectivity classes for the specified neuron class. + Retrieves downstream connectivity classes for the specified neuron class + AND, as separate row blocks, for each of its subclasses. Uses a Neo4j SUBCLASSOF traversal to enumerate instances of the queried class (Owlery's get_instances was observed to hang for some classes; @@ -3690,6 +3805,20 @@ class (Owlery's get_instances was observed to hang for some classes; to a child class also count toward each ancestor class's row, without double-counting under FBbt multi-inheritance. + Every row carries both the ``upstream_class`` and ``downstream_class`` + columns of the v2 layout: ``downstream_class`` is the partner and + ``upstream_class`` is the queried (sub)class this row belongs to. The input + term's rows come first, followed by a block of rows for each subclass that + has connectivity instances, so the queried-side column shows the actual + (sub)class per row rather than a single constant. + + Counts use VFB_connect's normalization (``get_connected_neurons_by_type``): + ``total_n`` / ``connected_n`` describe the PRESYNAPTIC (source) side. For the + downstream direction the queried class is presynaptic, so ``total_n`` is the + queried (sub)class instance count (constant within each block) and + ``connected_n`` is the number of those instances that connect to the + partner. ``percent_connected`` = connected_n / total_n. + Server-side cached via ``@with_solr_cache``; cold calls on broad classes can take tens of seconds because of the aggregation work (already batched across Solr/Neo4j round-trips). @@ -3707,8 +3836,11 @@ class (Owlery's get_instances was observed to hang for some classes; return pd.DataFrame() return {'headers': {}, 'rows': [], 'count': 0} - rows = _format_class_connectivity_rows(rows, partner_key='downstream_class') - rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) + # Rows arrive grouped by queried (sub)class (input term first) and sorted by + # pairwise_connections within each group; preserve that order. The partner is + # the downstream class; the queried (sub)class fills the upstream_class slot. + rows = _format_class_connectivity_rows( + rows, partner_key='downstream_class', query_key='upstream_class') total_count = len(rows) if limit != -1: @@ -3716,18 +3848,19 @@ class (Owlery's get_instances was observed to hang for some classes; if return_dataframe: df = pd.DataFrame(rows) - df = encode_markdown_links(df, ['downstream_class']) + df = encode_markdown_links(df, ['upstream_class', 'downstream_class']) return df headers = { 'id': {'title': 'ID', 'type': 'selection_id', 'order': -1}, - 'downstream_class': {'title': 'Downstream Class', 'type': 'markdown', 'order': 0}, - 'total_n': {'title': 'Total N', 'type': 'number', 'order': 1}, - 'connected_n': {'title': 'Connected N', 'type': 'number', 'order': 2}, - 'percent_connected': {'title': '% Connected', 'type': 'number', 'order': 3}, - 'pairwise_connections': {'title': 'Pairwise Connections', 'type': 'number', 'order': 4}, - 'total_weight': {'title': 'Total Weight', 'type': 'number', 'order': 5}, - 'avg_weight': {'title': 'Avg Weight', 'type': 'number', 'order': 6}, + 'upstream_class': {'title': 'Upstream Class', 'type': 'markdown', 'order': 0}, + 'downstream_class': {'title': 'Downstream Class', 'type': 'markdown', 'order': 1}, + 'total_n': {'title': 'Total N', 'type': 'number', 'order': 2}, + 'connected_n': {'title': 'Connected N', 'type': 'number', 'order': 3}, + 'percent_connected': {'title': '% Connected', 'type': 'number', 'order': 4}, + 'pairwise_connections': {'title': 'Pairwise Connections', 'type': 'number', 'order': 5}, + 'total_weight': {'title': 'Total Weight', 'type': 'number', 'order': 6}, + 'avg_weight': {'title': 'Avg Weight', 'type': 'number', 'order': 7}, } return {'headers': headers, 'rows': rows, 'count': total_count} @@ -3735,7 +3868,8 @@ class (Owlery's get_instances was observed to hang for some classes; @with_solr_cache('upstream_class_connectivity_query') def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limit: int = -1): """ - Retrieves upstream connectivity classes for the specified neuron class. + Retrieves upstream connectivity classes for the specified neuron class + AND, as separate row blocks, for each of its subclasses. Same multi-step aggregation as ``get_downstream_class_connectivity`` but walking the upstream side: Neo4j SUBCLASSOF enumerates queried-class @@ -3743,6 +3877,21 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi partner-side hierarchy is walked up to ``NEURON_ROOT_SHORT_FORM`` with set-union semantics to avoid double-counting under FBbt multi-inheritance. + Every row carries both the ``upstream_class`` and ``downstream_class`` + columns of the v2 layout: ``upstream_class`` is the partner and + ``downstream_class`` is the queried (sub)class this row belongs to. The + input term's rows come first, followed by a block of rows for each subclass + that has connectivity instances, so the queried-side column shows the actual + (sub)class per row rather than a single constant. + + Counts use VFB_connect's normalization (``get_connected_neurons_by_type``): + ``total_n`` / ``connected_n`` describe the PRESYNAPTIC (source) side. For the + upstream direction the partner is presynaptic, so ``total_n`` is the partner + (``upstream_class``) instance count — it varies per partner row, NOT per + queried (sub)class block — and ``connected_n`` is the number of partner + instances that connect to the queried (sub)class. ``percent_connected`` = + connected_n / total_n. + Server-side cached via ``@with_solr_cache``; cold calls on broad classes can take tens of seconds because of the aggregation work (already batched across Solr/Neo4j round-trips). @@ -3760,8 +3909,11 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi return pd.DataFrame() return {'headers': {}, 'rows': [], 'count': 0} - rows = _format_class_connectivity_rows(rows, partner_key='upstream_class') - rows.sort(key=lambda r: r.get('pairwise_connections', 0), reverse=True) + # Rows arrive grouped by queried (sub)class (input term first) and sorted by + # pairwise_connections within each group; preserve that order. The partner is + # the upstream class; the queried (sub)class fills the downstream_class slot. + rows = _format_class_connectivity_rows( + rows, partner_key='upstream_class', query_key='downstream_class') total_count = len(rows) if limit != -1: @@ -3769,18 +3921,19 @@ def get_upstream_class_connectivity(short_form: str, return_dataframe=True, limi if return_dataframe: df = pd.DataFrame(rows) - df = encode_markdown_links(df, ['upstream_class']) + df = encode_markdown_links(df, ['upstream_class', 'downstream_class']) return df headers = { 'id': {'title': 'ID', 'type': 'selection_id', 'order': -1}, 'upstream_class': {'title': 'Upstream Class', 'type': 'markdown', 'order': 0}, - 'total_n': {'title': 'Total N', 'type': 'number', 'order': 1}, - 'connected_n': {'title': 'Connected N', 'type': 'number', 'order': 2}, - 'percent_connected': {'title': '% Connected', 'type': 'number', 'order': 3}, - 'pairwise_connections': {'title': 'Pairwise Connections', 'type': 'number', 'order': 4}, - 'total_weight': {'title': 'Total Weight', 'type': 'number', 'order': 5}, - 'avg_weight': {'title': 'Avg Weight', 'type': 'number', 'order': 6}, + 'downstream_class': {'title': 'Downstream Class', 'type': 'markdown', 'order': 1}, + 'total_n': {'title': 'Total N', 'type': 'number', 'order': 2}, + 'connected_n': {'title': 'Connected N', 'type': 'number', 'order': 3}, + 'percent_connected': {'title': '% Connected', 'type': 'number', 'order': 4}, + 'pairwise_connections': {'title': 'Pairwise Connections', 'type': 'number', 'order': 5}, + 'total_weight': {'title': 'Total Weight', 'type': 'number', 'order': 6}, + 'avg_weight': {'title': 'Avg Weight', 'type': 'number', 'order': 7}, } return {'headers': headers, 'rows': rows, 'count': total_count}