VyLala commited on
Commit
a48b71a
·
verified ·
1 Parent(s): 687a800

Update confidence_score.py

Browse files
Files changed (1) hide show
  1. confidence_score.py +262 -262
confidence_score.py CHANGED
@@ -1,263 +1,263 @@
1
- from typing import Dict, Any, Tuple, List, Optional
2
- import standardize_location
3
-
4
- def set_rules() -> Dict[str, Any]:
5
- """
6
- Define weights, penalties and thresholds for the confidence score.
7
-
8
- V1 principles:
9
- - Interpretability > mathematical purity
10
- - Conservative > aggressive
11
- - Explainable > comprehensive
12
- """
13
- return {
14
- "direct_evidence": {
15
- # Based on the table we discussed:
16
- # Accession explicitly linked to country in paper/supplement
17
- "explicit_geo_pubmed_text": 40,
18
- # PubMed ID exists AND geo_loc_name exists
19
- "geo_and_pubmed": 30,
20
- # geo_loc_name exists (GenBank only)
21
- "geo_only": 20,
22
- # accession appears in external text but no structured geo_loc_name
23
- "accession_in_text_only": 10,
24
- },
25
- "consistency": {
26
- # Predicted country matches GenBank field
27
- "match": 20,
28
- # No contradiction detected across sources (when some evidence exists)
29
- "no_contradiction": 10,
30
- # Clear contradiction detected between prediction and GenBank
31
- "contradiction": -30,
32
- },
33
- "evidence_density": {
34
- # ≥2 linked publications
35
- "two_or_more_pubs": 20,
36
- # 1 linked publication
37
- "one_pub": 10,
38
- # 0 publications
39
- "none": 0,
40
- },
41
- "risk_penalties": {
42
- # Missing key metadata fields (geo, host, collection_date, etc.)
43
- "missing_key_fields": -10,
44
- # Known failure accession pattern (from your existing bug list)
45
- "known_failure_pattern": -20,
46
- },
47
- "tiers": {
48
- # Confidence tiers (researchers think in categories, not decimals)
49
- "high_min": 70,
50
- "medium_min": 40, # < high_min and >= medium_min = medium; rest = low
51
- },
52
- }
53
-
54
-
55
- def normalize_country(name: Optional[str]) -> Optional[str]:
56
- """
57
- Normalize country names to improve simple equality checks.
58
-
59
- This is intentionally simple and rule-based.
60
- You can extend the mapping as you see real-world variants.
61
- """
62
- if not name:
63
- return None
64
- name = name.strip().lower()
65
-
66
- mapping = {
67
- "usa": "united states",
68
- "u.s.a.": "united states",
69
- "u.s.": "united states",
70
- "us": "united states",
71
- "united states of america": "united states",
72
- "uk": "united kingdom",
73
- "u.k.": "united kingdom",
74
- "england": "united kingdom",
75
- # Add more mappings here when encounter them in real data
76
- }
77
-
78
- return mapping.get(name, name)
79
-
80
-
81
- def compute_confidence_score_and_tier(
82
- signals: Dict[str, Any],
83
- rules: Optional[Dict[str, Any]] = None,
84
- ) -> Tuple[int, str, List[str]]:
85
- """
86
- Compute confidence score and tier for a single accession row.
87
-
88
- Input `signals` dict is expected to contain:
89
-
90
- has_geo_loc_name: bool
91
- has_pubmed: bool
92
- accession_found_in_text: bool # accession present in extracted external text
93
- predicted_country: str | None # final model label / country prediction
94
- genbank_country: str | None # from NCBI / GenBank metadata
95
- num_publications: int
96
- missing_key_fields: bool
97
- known_failure_pattern: bool
98
-
99
- Returns:
100
- score (0–100), tier ("high"/"medium"/"low"),
101
- explanations (list of short human-readable reasons)
102
- """
103
- if rules is None:
104
- rules = set_rules()
105
-
106
- score = 0
107
- explanations: List[str] = []
108
-
109
- # ---------- Signal 1: Direct evidence strength ----------
110
- has_geo = bool(signals.get("has_geo_loc_name"))
111
- has_pubmed = bool(signals.get("has_pubmed"))
112
- accession_in_text = bool(signals.get("accession_found_in_text"))
113
-
114
- direct_cfg = rules["direct_evidence"]
115
-
116
- # We pick the strongest applicable case.
117
- if has_geo and has_pubmed and accession_in_text:
118
- score += direct_cfg["explicit_geo_pubmed_text"]
119
- explanations.append(
120
- "Accession linked to a country in GenBank and associated publication text."
121
- )
122
- elif has_geo and has_pubmed:
123
- score += direct_cfg["geo_and_pubmed"]
124
- explanations.append(
125
- "GenBank geo_loc_name and linked publication found."
126
- )
127
- elif has_geo:
128
- score += direct_cfg["geo_only"]
129
- explanations.append("GenBank geo_loc_name present.")
130
- elif accession_in_text:
131
- score += direct_cfg["accession_in_text_only"]
132
- explanations.append("Accession keyword found in extracted external text.")
133
-
134
- # ---------- Signal 2: Cross-source consistency ----------
135
- pred_country = standardize_location.smart_country_lookup(signals.get("predicted_country").lower())
136
- if pred_country == "not found":
137
- pred_country = normalize_country(signals.get("predicted_country"))
138
- gb_country = standardize_location.smart_country_lookup(signals.get("genbank_country").lower())
139
- if gb_country == "not found":
140
- gb_country = normalize_country(signals.get("genbank_country"))
141
-
142
- cons_cfg = rules["consistency"]
143
-
144
- if gb_country is not None and pred_country is not None:
145
- if gb_country == pred_country:
146
- score += cons_cfg["match"]
147
- explanations.append(
148
- "Predicted country matches GenBank country metadata."
149
- )
150
- else:
151
- score += cons_cfg["contradiction"]
152
- explanations.append(
153
- "Conflict between predicted country and GenBank country metadata."
154
- )
155
- else:
156
- # Only give "no contradiction" bonus if there is at least some evidence
157
- if has_geo or has_pubmed or accession_in_text:
158
- score += cons_cfg["no_contradiction"]
159
- explanations.append(
160
- "No contradiction detected across available sources."
161
- )
162
-
163
- # ---------- Signal 3: Evidence density ----------
164
- num_pubs = int(signals.get("num_publications", 0))
165
- dens_cfg = rules["evidence_density"]
166
-
167
- if num_pubs >= 2:
168
- score += dens_cfg["two_or_more_pubs"]
169
- explanations.append("Multiple linked publications available.")
170
- elif num_pubs == 1:
171
- score += dens_cfg["one_pub"]
172
- explanations.append("One linked publication available.")
173
- # else: 0 publications → no extra score
174
-
175
- # ---------- Signal 4: Risk flags ----------
176
- risk_cfg = rules["risk_penalties"]
177
-
178
- if signals.get("missing_key_fields"):
179
- score += risk_cfg["missing_key_fields"]
180
- explanations.append(
181
- "Missing key metadata fields (higher uncertainty)."
182
- )
183
-
184
- if signals.get("known_failure_pattern"):
185
- score += risk_cfg["known_failure_pattern"]
186
- explanations.append(
187
- "Accession matches a known risky/failure pattern."
188
- )
189
-
190
- # ---------- Clamp score and determine tier ----------
191
- score = max(0, min(100, score))
192
-
193
- tiers = rules["tiers"]
194
- if score >= tiers["high_min"]:
195
- tier = "high"
196
- elif score >= tiers["medium_min"]:
197
- tier = "medium"
198
- else:
199
- tier = "low"
200
-
201
- # Keep explanations short and readable
202
- if len(explanations) > 3:
203
- explanations = explanations[:3]
204
-
205
- return score, tier, explanations
206
-
207
-
208
- if __name__ == "__main__":
209
- # Quick local sanity-check examples (manual smoke tests)
210
- rules = set_rules()
211
-
212
- examples = [
213
- {
214
- "name": "Strong, clean case",
215
- "signals": {
216
- "has_geo_loc_name": True,
217
- "has_pubmed": True,
218
- "accession_found_in_text": True,
219
- "predicted_country": "USA",
220
- "genbank_country": "United States of America",
221
- "num_publications": 3,
222
- "missing_key_fields": False,
223
- "known_failure_pattern": False,
224
- },
225
- },
226
- {
227
- "name": "Weak, conflicting case",
228
- "signals": {
229
- "has_geo_loc_name": True,
230
- "has_pubmed": False,
231
- "accession_found_in_text": False,
232
- "predicted_country": "Japan",
233
- "genbank_country": "France",
234
- "num_publications": 0,
235
- "missing_key_fields": True,
236
- "known_failure_pattern": True,
237
- },
238
- },
239
- {
240
- "name": "Medium, sparse but okay",
241
- "signals": {
242
- "has_geo_loc_name": False,
243
- "has_pubmed": True,
244
- "accession_found_in_text": False,
245
- "predicted_country": "United Kingdom",
246
- "genbank_country": None,
247
- "num_publications": 1,
248
- "missing_key_fields": False,
249
- "known_failure_pattern": False,
250
- },
251
- },
252
- ]
253
-
254
- for ex in examples:
255
- score, tier, expl = compute_confidence_score_and_tier(
256
- ex["signals"], rules
257
- )
258
- print("====", ex["name"], "====")
259
- print("Score:", score, "| Tier:", tier)
260
- print("Reasons:")
261
- for e in expl:
262
- print(" -", e)
263
  print()
 
1
+ from typing import Dict, Any, Tuple, List, Optional
2
+ import standardize_location
3
+
4
+ def set_rules() -> Dict[str, Any]:
5
+ """
6
+ Define weights, penalties and thresholds for the confidence score.
7
+
8
+ V1 principles:
9
+ - Interpretability > mathematical purity
10
+ - Conservative > aggressive
11
+ - Explainable > comprehensive
12
+ """
13
+ return {
14
+ "direct_evidence": {
15
+ # Based on the table we discussed:
16
+ # Accession explicitly linked to country in paper/supplement
17
+ "explicit_geo_pubmed_text": 40,
18
+ # PubMed ID exists AND geo_loc_name exists
19
+ "geo_and_pubmed": 30,
20
+ # geo_loc_name exists (GenBank only)
21
+ "geo_only": 20,
22
+ # accession appears in external text but no structured geo_loc_name
23
+ "accession_in_text_only": 10,
24
+ },
25
+ "consistency": {
26
+ # Predicted country matches GenBank field
27
+ "match": 20,
28
+ # No contradiction detected across sources (when some evidence exists)
29
+ "no_contradiction": 10,
30
+ # Clear contradiction detected between prediction and GenBank
31
+ "contradiction": -30,
32
+ },
33
+ "evidence_density": {
34
+ # ≥2 linked publications
35
+ "two_or_more_pubs": 20,
36
+ # 1 linked publication
37
+ "one_pub": 10,
38
+ # 0 publications
39
+ "none": 0,
40
+ },
41
+ "risk_penalties": {
42
+ # Missing key metadata fields (geo, host, collection_date, etc.)
43
+ "missing_key_fields": -10,
44
+ # Known failure accession pattern (from your existing bug list)
45
+ "known_failure_pattern": -20,
46
+ },
47
+ "tiers": {
48
+ # Confidence tiers (researchers think in categories, not decimals)
49
+ "high_min": 70,
50
+ "medium_min": 40, # < high_min and >= medium_min = medium; rest = low
51
+ },
52
+ }
53
+
54
+
55
+ def normalize_country(name: Optional[str]) -> Optional[str]:
56
+ """
57
+ Normalize country names to improve simple equality checks.
58
+
59
+ This is intentionally simple and rule-based.
60
+ You can extend the mapping as you see real-world variants.
61
+ """
62
+ if not name:
63
+ return None
64
+ name = name.strip().lower()
65
+
66
+ mapping = {
67
+ "usa": "united states",
68
+ "u.s.a.": "united states",
69
+ "u.s.": "united states",
70
+ "us": "united states",
71
+ "united states of america": "united states",
72
+ "uk": "united kingdom",
73
+ "u.k.": "united kingdom",
74
+ "england": "united kingdom",
75
+ # Add more mappings here when encounter them in real data
76
+ }
77
+
78
+ return mapping.get(name, name)
79
+
80
+
81
+ def compute_confidence_score_and_tier(
82
+ signals: Dict[str, Any],
83
+ rules: Optional[Dict[str, Any]] = None,
84
+ ) -> Tuple[int, str, List[str]]:
85
+ """
86
+ Compute confidence score and tier for a single accession row.
87
+
88
+ Input `signals` dict is expected to contain:
89
+
90
+ has_geo_loc_name: bool
91
+ has_pubmed: bool
92
+ accession_found_in_text: bool # accession present in extracted external text
93
+ predicted_country: str | None # final model label / country prediction
94
+ genbank_country: str | None # from NCBI / GenBank metadata
95
+ num_publications: int
96
+ missing_key_fields: bool
97
+ known_failure_pattern: bool
98
+
99
+ Returns:
100
+ score (0–100), tier ("high"/"medium"/"low"),
101
+ explanations (list of short human-readable reasons)
102
+ """
103
+ if rules is None:
104
+ rules = set_rules()
105
+
106
+ score = 0
107
+ explanations: List[str] = []
108
+
109
+ # ---------- Signal 1: Direct evidence strength ----------
110
+ has_geo = bool(signals.get("has_geo_loc_name"))
111
+ has_pubmed = bool(signals.get("has_pubmed"))
112
+ accession_in_text = bool(signals.get("accession_found_in_text"))
113
+
114
+ direct_cfg = rules["direct_evidence"]
115
+
116
+ # We pick the strongest applicable case.
117
+ if has_geo and has_pubmed and accession_in_text:
118
+ score += direct_cfg["explicit_geo_pubmed_text"]
119
+ explanations.append(
120
+ "Accession linked to a country in GenBank and associated publication text."
121
+ )
122
+ elif has_geo and has_pubmed:
123
+ score += direct_cfg["geo_and_pubmed"]
124
+ explanations.append(
125
+ "GenBank geo_loc_name and linked publication found."
126
+ )
127
+ elif has_geo:
128
+ score += direct_cfg["geo_only"]
129
+ explanations.append("GenBank geo_loc_name present.")
130
+ elif accession_in_text:
131
+ score += direct_cfg["accession_in_text_only"]
132
+ explanations.append("Accession keyword found in extracted external text.")
133
+
134
+ # ---------- Signal 2: Cross-source consistency ----------
135
+ pred_country = standardize_location.smart_country_lookup(signals.get("predicted_country").lower())
136
+ if pred_country == "not found":
137
+ pred_country = normalize_country(signals.get("predicted_country"))
138
+ gb_country = standardize_location.smart_country_lookup(signals.get("genbank_country").lower())
139
+ if gb_country == "not found":
140
+ gb_country = normalize_country(signals.get("genbank_country"))
141
+
142
+ cons_cfg = rules["consistency"]
143
+
144
+ if gb_country is not None and pred_country is not None:
145
+ if gb_country == pred_country:
146
+ score += cons_cfg["match"]
147
+ explanations.append(
148
+ "Predicted country matches GenBank country metadata."
149
+ )
150
+ else:
151
+ score += cons_cfg["contradiction"]
152
+ explanations.append(
153
+ "Conflict between predicted country and GenBank country metadata."
154
+ )
155
+ else:
156
+ # Only give "no contradiction" bonus if there is at least some evidence
157
+ if has_geo or has_pubmed or accession_in_text:
158
+ score += cons_cfg["no_contradiction"]
159
+ explanations.append(
160
+ "No contradiction detected across available sources."
161
+ )
162
+
163
+ # ---------- Signal 3: Evidence density ----------
164
+ num_pubs = int(signals.get("num_publications", 0))
165
+ dens_cfg = rules["evidence_density"]
166
+
167
+ if num_pubs >= 2:
168
+ score += dens_cfg["two_or_more_pubs"]
169
+ explanations.append("Multiple linked publications available.")
170
+ elif num_pubs == 1:
171
+ score += dens_cfg["one_pub"]
172
+ explanations.append("One linked publication available.")
173
+ # else: 0 publications → no extra score
174
+
175
+ # ---------- Signal 4: Risk flags ----------
176
+ risk_cfg = rules["risk_penalties"]
177
+
178
+ if signals.get("missing_key_fields"):
179
+ score += risk_cfg["missing_key_fields"]
180
+ explanations.append(
181
+ "Missing key metadata fields (higher uncertainty)."
182
+ )
183
+
184
+ if signals.get("known_failure_pattern"):
185
+ score += risk_cfg["known_failure_pattern"]
186
+ explanations.append(
187
+ "Accession matches a known risky/failure pattern."
188
+ )
189
+
190
+ # ---------- Clamp score and determine tier ----------
191
+ score = max(0, min(100, score))
192
+
193
+ tiers = rules["tiers"]
194
+ if score >= tiers["high_min"]:
195
+ tier = "high"
196
+ elif score >= tiers["medium_min"]:
197
+ tier = "medium"
198
+ else:
199
+ tier = "low"
200
+
201
+ # Keep explanations short and readable
202
+ if len(explanations) > 3:
203
+ explanations = explanations[:3]
204
+
205
+ return score, tier, explanations
206
+
207
+
208
+ if __name__ == "__main__":
209
+ # Quick local sanity-check examples (manual smoke tests)
210
+ rules = set_rules()
211
+
212
+ examples = [
213
+ {
214
+ "name": "Strong, clean case",
215
+ "signals": {
216
+ "has_geo_loc_name": True,
217
+ "has_pubmed": True,
218
+ "accession_found_in_text": True,
219
+ "predicted_country": "USA",
220
+ "genbank_country": "United States of America",
221
+ "num_publications": 3,
222
+ "missing_key_fields": False,
223
+ "known_failure_pattern": False,
224
+ },
225
+ },
226
+ {
227
+ "name": "Weak, conflicting case",
228
+ "signals": {
229
+ "has_geo_loc_name": True,
230
+ "has_pubmed": False,
231
+ "accession_found_in_text": False,
232
+ "predicted_country": "Japan",
233
+ "genbank_country": "France",
234
+ "num_publications": 0,
235
+ "missing_key_fields": True,
236
+ "known_failure_pattern": True,
237
+ },
238
+ },
239
+ {
240
+ "name": "Medium, sparse but okay",
241
+ "signals": {
242
+ "has_geo_loc_name": False,
243
+ "has_pubmed": True,
244
+ "accession_found_in_text": False,
245
+ "predicted_country": "United Kingdom",
246
+ "genbank_country": None,
247
+ "num_publications": 1,
248
+ "missing_key_fields": False,
249
+ "known_failure_pattern": False,
250
+ },
251
+ },
252
+ ]
253
+
254
+ for ex in examples:
255
+ score, tier, expl = compute_confidence_score_and_tier(
256
+ ex["signals"], rules
257
+ )
258
+ print("====", ex["name"], "====")
259
+ print("Score:", score, "| Tier:", tier)
260
+ print("Reasons:")
261
+ for e in expl:
262
+ print(" -", e)
263
  print()