mumin.dataset
Script containing the main dataset class
View Source
"""Script containing the main dataset class""" import io import logging import multiprocessing as mp import os import warnings import zipfile from functools import partial from pathlib import Path from shutil import rmtree from typing import Dict, List, Optional, Tuple, Union import numpy as np import pandas as pd import requests from dotenv import load_dotenv from tqdm.auto import tqdm from .data_extractor import DataExtractor from .dgl import build_dgl_dataset from .embedder import Embedder from .id_updator import IdUpdator from .twitter import Twitter # Load environment variables load_dotenv() # Set up logging logger = logging.getLogger(__name__) # Disable tokenizer parallelism os.environ["TOKENIZERS_PARALLELISM"] = "false" # Prevents crashes during data extraction on MacOS os.environ["OBJC_DISABLE_INITIALIZE_FORK_SAFETY"] = "YES" # Allows progress bars with `pd.DataFrame.progress_apply` tqdm.pandas() class MuminDataset: """The MuMiN misinformation dataset, from [1]. Args: twitter_bearer_token (str or None, optional): The Twitter bearer token. If None then the the bearer token must be stored in the environment variable `TWITTER_API_KEY`, or placed in a file named `.env` in the working directory, formatted as "TWITTER_API_KEY=xxxxx". Defaults to None. size (str, optional): The size of the dataset. Can be either 'small', 'medium' or 'large'. Defaults to 'small'. include_replies (bool, optional): Whether to include replies and quote tweets in the dataset. Defaults to True. include_articles (bool, optional): Whether to include articles in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to True. include_tweet_images (bool, optional): Whether to include images from the tweets in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to True. include_extra_images (bool, optional): Whether to include images from the articles and users in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to False. include_hashtags (bool, optional): Whether to include hashtags in the dataset. Defaults to True. include_mentions (bool, optional): Whether to include mentions in the dataset. Defaults to True. include_timelines (bool, optional): Whether to include timelines in the dataset. Defaults to False. text_embedding_model_id (str, optional): The HuggingFace Hub model ID to use when embedding texts. Defaults to 'xlm-roberta-base'. image_embedding_model_id (str, optional): The HuggingFace Hub model ID to use when embedding images. Defaults to 'google/vit-base-patch16-224-in21k'. dataset_path (str, pathlib Path or None, optional): The path to the file where the dataset should be stored. If None then the dataset will be stored at './mumin-<size>.zip'. Defaults to None. n_jobs (int, optional): The number of jobs to use for parallel processing. Defaults to the number of available CPU cores minus one. chunksize (int, optional): The number of articles/images to process in each job. This speeds up processing time, but also increases memory load. Defaults to 10. verbose (bool, optional): Whether extra information should be outputted. Defaults to True. Attributes: include_replies (bool): Whether to include replies. include_articles (bool): Whether to include articles. include_tweet_images (bool): Whether to include tweet images. include_extra_images (bool): Whether to include user/article images. include_hashtags (bool): Whether to include hashtags. include_mentions (bool): Whether to include mentions. include_timelines (bool): Whether to include timelines. size (str): The size of the dataset. dataset_path (pathlib Path): The dataset file. text_embedding_model_id (str): The model ID used for embedding text. image_embedding_model_id (str): The model ID used for embedding images. nodes (dict): The nodes of the dataset. rels (dict): The relations of the dataset. rehydrated (bool): Whether the tweets and/or replies have been rehydrated. compiled (bool): Whether the dataset has been compiled. n_jobs (int): The number of jobs to use for parallel processing. chunksize (int): The number of articles/images to process in each job. verbose (bool): Whether extra information should be outputted. download_url (str): The URL to download the dataset from. Raises: ValueError: If `twitter_bearer_token` is None and the environment variable `TWITTER_API_KEY` is not set. References: - [1] Nielsen and McConville: MuMiN: A Large-Scale Multilingual Multimodal Fact-Checked Misinformation Dataset with Linked Social Network Posts (2021) """ download_url: str = ( "https://data.bris.ac.uk/datasets/23yv276we2mll25f" "jakkfim2ml/23yv276we2mll25fjakkfim2ml.zip" ) _node_dump: List[str] = [ "claim", "tweet", "user", "image", "article", "hashtag", "reply", ] _rel_dump: List[Tuple[str, str, str]] = [ ("tweet", "discusses", "claim"), ("tweet", "mentions", "user"), ("tweet", "has_image", "image"), ("tweet", "has_hashtag", "hashtag"), ("tweet", "has_article", "article"), ("reply", "reply_to", "tweet"), ("reply", "quote_of", "tweet"), ("user", "posted", "tweet"), ("user", "posted", "reply"), ("user", "mentions", "user"), ("user", "has_hashtag", "hashtag"), ("user", "has_profile_picture", "image"), ("user", "retweeted", "tweet"), ("user", "follows", "user"), ("article", "has_top_image", "image"), ] def __init__( self, twitter_bearer_token: Optional[str] = None, size: str = "small", include_replies: bool = True, include_articles: bool = True, include_tweet_images: bool = True, include_extra_images: bool = False, include_hashtags: bool = True, include_mentions: bool = True, include_timelines: bool = False, text_embedding_model_id: str = "xlm-roberta-base", image_embedding_model_id: str = "google/vit-base-patch16-224-in21k", dataset_path: Optional[Union[str, Path]] = None, n_jobs: int = mp.cpu_count() - 1, chunksize: int = 10, verbose: bool = True, ): self.size = size self.include_replies = include_replies self.include_articles = include_articles self.include_tweet_images = include_tweet_images self.include_extra_images = include_extra_images self.include_hashtags = include_hashtags self.include_mentions = include_mentions self.include_timelines = include_timelines self.text_embedding_model_id = text_embedding_model_id self.image_embedding_model_id = image_embedding_model_id self.verbose = verbose self.compiled: bool = False self.rehydrated: bool = False self.nodes: Dict[str, pd.DataFrame] = dict() self.rels: Dict[Tuple[str, str, str], pd.DataFrame] = dict() # Load the bearer token if it is not provided if twitter_bearer_token is None: twitter_bearer_token = os.environ.get("TWITTER_API_KEY") # If no bearer token is available, raise a warning and set the `_twitter` # attribute to None. Otherwise, set the `_twitter` attribute to a `Twitter` # instance. self._twitter: Twitter if twitter_bearer_token is None: warnings.warn( "Twitter bearer token not provided, so rehydration can not be " "performed. This is fine if you are using a pre-compiled MuMiN, but " "if this is not the case then you will need to either specify the " "`twitter_bearer_token` argument or set the environment variable " "`TWITTER_API_KEY`." ) else: self._twitter = Twitter(twitter_bearer_token=twitter_bearer_token) self._extractor = DataExtractor( include_replies=include_replies, include_articles=include_articles, include_tweet_images=include_tweet_images, include_extra_images=include_extra_images, include_hashtags=include_hashtags, include_mentions=include_mentions, n_jobs=n_jobs, chunksize=chunksize, ) self._updator = IdUpdator() self._embedder = Embedder( text_embedding_model_id=text_embedding_model_id, image_embedding_model_id=image_embedding_model_id, include_articles=include_articles, include_tweet_images=include_tweet_images, include_extra_images=include_extra_images, ) if dataset_path is None: dataset_path = f"./mumin-{size}.zip" self.dataset_path = Path(dataset_path) # Set up logging verbosity if self.verbose: logger.setLevel(logging.INFO) else: logger.setLevel(logging.WARNING) def compile(self, overwrite: bool = False): """Compiles the dataset. This entails downloading the dataset, rehydrating the Twitter data and downloading the relevant associated data, such as articles and images. Args: overwrite (bool, optional): Whether the dataset directory should be overwritten, in case it already exists. Defaults to False. Raises: RuntimeError: If the dataset needs to be compiled and a Twitter bearer token has not been provided. """ self._download(overwrite=overwrite) self._load_dataset() # Variable to check if dataset has been rehydrated and/or compiled if "text" in self.nodes["tweet"].columns: self.rehydrated = True if self.rehydrated and str(self.nodes["tweet"].tweet_id.dtype) == "uint64": self.compiled = True # Only compile the dataset if it has not already been compiled if not self.compiled: # If the dataset has not already been rehydrated, rehydrate it if not self.rehydrated: # Shrink dataset to the correct size self._shrink_dataset() # If the bearer token is not available then raise an error if not isinstance(self._twitter, Twitter): raise RuntimeError( "Twitter bearer token not provided. You need to either specify " "the `twitter_bearer_token` argument in the `MuminDataset` " "constructor or set the environment variable `TWITTER_API_KEY`." ) # Rehydrate the tweets self._rehydrate(node_type="tweet") self._rehydrate(node_type="reply") # Update the IDs of the data that was there pre-hydration self.nodes, self.rels = self._updator.update_all( nodes=self.nodes, rels=self.rels ) # Save dataset self._dump_dataset() # Set the rehydrated flag to True self.rehydrated = True # Extract data from the rehydrated tweets self.nodes, self.rels = self._extractor.extract_all( nodes=self.nodes, rels=self.rels ) # Filter the data self._filter_node_features() self._filter_relations() # Set datatypes self._set_datatypes() # Remove unnecessary bits self._remove_auxilliaries() self._remove_islands() # Save dataset if not self.compiled: self._dump_dataset() # Mark dataset as compiled self.compiled = True return self def _download(self, overwrite: bool = False): """Downloads the dataset. Args: overwrite (bool, optional): Whether the dataset directory should be overwritten, in case it already exists. Defaults to False. """ if not self.dataset_path.exists() or (self.dataset_path.exists() and overwrite): logger.info("Downloading dataset") # Remove existing directory if we are overwriting if self.dataset_path.exists() and overwrite: self.dataset_path.unlink() # Set up download stream of dataset with requests.get(self.download_url, stream=True) as response: # If the response was unsuccessful then raise an error if response.status_code != 200: raise RuntimeError(f"[{response.status_code}] {response.content!r}") # Download dataset with progress bar total = int(response.headers["Content-Length"]) with tqdm( total=total, unit="iB", unit_scale=True, desc="Downloading MuMiN" ) as pbar: with Path(self.dataset_path).open("wb") as f: for data in response.iter_content(1024): pbar.update(len(data)) f.write(data) # The data.bris zip file contains two files: `mumin.zip` and # `readme.txt`. We only want the first one, so we extract that and # replace the original file with it. with zipfile.ZipFile( self.dataset_path, mode="r", compression=zipfile.ZIP_DEFLATED ) as zipf: zipdata = zipf.read("23yv276we2mll25fjakkfim2ml/mumin.zip") with Path(self.dataset_path).open("wb") as f: f.write(zipdata) logger.info("Converting dataset to less compressed format") # Open the zip file containing the dataset data_dict = dict() with zipfile.ZipFile( self.dataset_path, mode="r", compression=zipfile.ZIP_DEFLATED ) as zip_file: # Loop over all the files in the zipped file for name in zip_file.namelist(): # Extract the dataframe in the file byte_data = zip_file.read(name=name) df = pd.read_pickle(io.BytesIO(byte_data), compression="xz") data_dict[name] = df # Overwrite the zip file in a less compressed way, to make io operations # faster with zipfile.ZipFile( self.dataset_path, mode="w", compression=zipfile.ZIP_STORED ) as zip_file: for name, df in data_dict.items(): buffer = io.BytesIO() df.to_pickle(buffer, protocol=4) zip_file.writestr(name, data=buffer.getvalue()) return self def _load_dataset(self): """Loads the dataset files into memory. Raises: RuntimeError: If the dataset has not been downloaded yet. """ # Raise error if the dataset has not been downloaded yet if not self.dataset_path.exists(): raise RuntimeError("Dataset has not been downloaded yet!") logger.info("Loading dataset") # Reset `nodes` and `relations` to ensure a fresh start self.nodes = dict() self.rels = dict() # Open the zip file containing the dataset with zipfile.ZipFile( self.dataset_path, mode="r", compression=zipfile.ZIP_STORED ) as zip_file: # Loop over all the files in the zipped file for name in zip_file.namelist(): # Extract the dataframe in the file byte_data = zip_file.read(name=name) df = pd.read_pickle(io.BytesIO(byte_data)) # If there are no underscores in the filename then we assume that it # contains node data if "_" not in name: self.nodes[name.replace(".pickle", "")] = df.copy() # Otherwise, with underscores in the filename then we assume it # contains relation data else: splits = name.replace(".pickle", "").split("_") src = splits[0] tgt = splits[-1] rel = "_".join(splits[1:-1]) self.rels[(src, rel, tgt)] = df.copy() # Ensure that claims are present in the dataset if "claim" not in self.nodes.keys(): raise RuntimeError("No claims are present in the file!") # Ensure that tweets are present in the dataset, and also that the tweet # IDs are unique if "tweet" not in self.nodes.keys(): raise RuntimeError("No tweets are present in the file!") else: tweet_df = self.nodes["tweet"] duplicated = tweet_df[tweet_df.tweet_id.duplicated()].tweet_id.tolist() if len(duplicated) > 0: raise RuntimeError( f"The tweet IDs {duplicated} are " f"duplicate in the dataset!" ) return self def _shrink_dataset(self): """Shrink dataset if `size` is 'small' or 'medium""" logger.info("Shrinking dataset") # Define the `relevance` threshold if self.size == "small": threshold = 0.80 # noqa elif self.size == "medium": threshold = 0.75 # noqa elif self.size == "large": threshold = 0.70 # noqa elif self.size == "test": threshold = 0.995 # noqa # Filter nodes ntypes = ["tweet", "reply", "user", "article"] for ntype in ntypes: self.nodes[ntype] = ( self.nodes[ntype] .query("relevance > @threshold") .drop(columns=["relevance"]) .reset_index(drop=True) ) # Filter relations etypes = [ ("reply", "reply_to", "tweet"), ("reply", "quote_of", "tweet"), ("user", "retweeted", "tweet"), ("user", "follows", "user"), ("tweet", "discusses", "claim"), ("article", "discusses", "claim"), ] for etype in etypes: self.rels[etype] = ( self.rels[etype] .query("relevance > @threshold") .drop(columns=["relevance"]) .reset_index(drop=True) ) # Filter claims claim_df = self.nodes["claim"] discusses_rel = self.rels[("tweet", "discusses", "claim")] include_claim = claim_df.id.isin(discusses_rel.tgt.tolist()) self.nodes["claim"] = claim_df[include_claim].reset_index(drop=True) # Filter timeline tweets if not self.include_timelines: src_tweet_ids = ( self.rels[("tweet", "discusses", "claim")] .src.astype(np.uint64) .tolist() ) is_src = self.nodes["tweet"].tweet_id.isin(src_tweet_ids) self.nodes["tweet"] = self.nodes["tweet"].loc[is_src] return self def _rehydrate(self, node_type: str): """Rehydrate the tweets and users in the dataset. Args: node_type (str): The type of node to rehydrate. """ if node_type in self.nodes.keys() and ( node_type != "reply" or self.include_replies ): logger.info(f"Rehydrating {node_type} nodes") # Get the tweet IDs, and if the node type is a tweet then separate these # into source tweets and the rest (i.e., timeline tweets) if node_type == "tweet": source_tweet_ids = ( self.rels[("tweet", "discusses", "claim")] .src.astype(np.uint64) .tolist() ) tweet_ids = [ tweet_id for tweet_id in self.nodes[node_type].tweet_id.astype(np.uint64) if tweet_id not in source_tweet_ids ] else: source_tweet_ids = list() tweet_ids = self.nodes[node_type].tweet_id.astype(np.uint64).tolist() # Store any features the nodes might have had before hydration prehydration_df = self.nodes[node_type].copy() # Rehydrate the source tweets if len(source_tweet_ids) > 0: params = dict(tweet_ids=source_tweet_ids) source_tweet_dfs = self._twitter.rehydrate_tweets(**params) # Return error if there are no tweets were rehydrated. This is probably # because the bearer token is wrong if len(source_tweet_dfs) == 0: raise RuntimeError( "No tweets were rehydrated. Check if the bearer token is " "correct." ) if len(tweet_ids) == 0: tweet_dfs = {key: pd.DataFrame() for key in source_tweet_dfs.keys()} # Rehydrate the other tweets if len(tweet_ids) > 0: params = dict(tweet_ids=tweet_ids) tweet_dfs = self._twitter.rehydrate_tweets(**params) # Return error if there are no tweets were rehydrated. This is # probably because the bearer token is wrong if len(tweet_dfs) == 0: raise RuntimeError( "No tweets were rehydrated. Check if " "the bearer token is correct." ) if len(source_tweet_ids) == 0: source_tweet_dfs = {key: pd.DataFrame() for key in tweet_dfs.keys()} # Extract and store tweets and users tweet_df = pd.concat( [source_tweet_dfs["tweets"], tweet_dfs["tweets"]], ignore_index=True ) self.nodes[node_type] = tweet_df.drop_duplicates( subset="tweet_id" ).reset_index(drop=True) user_df = pd.concat( [source_tweet_dfs["users"], tweet_dfs["users"]], ignore_index=True ) if "user" in self.nodes.keys() and "username" in self.nodes["user"].columns: user_df = ( pd.concat((self.nodes["user"], user_df), axis=0) .drop_duplicates(subset="user_id") .reset_index(drop=True) ) self.nodes["user"] = user_df # Add prehydration tweet features back to the tweets self.nodes[node_type] = ( self.nodes[node_type] .merge(prehydration_df, on="tweet_id", how="outer") .reset_index(drop=True) ) # Extract and store images # Note: This will store `self.nodes['image']`, but this is only to enable # extraction of URLs later on. The `self.nodes['image']` will be # overwritten later on. if ( node_type == "tweet" and self.include_tweet_images and len(source_tweet_dfs["media"]) ): image_df = ( source_tweet_dfs["media"] .query('type == "photo"') .drop_duplicates(subset="media_key") .reset_index(drop=True) ) if "image" in self.nodes.keys(): image_df = ( pd.concat((self.nodes["image"], image_df), axis=0) .drop_duplicates(subset="media_key") .reset_index(drop=True) ) self.nodes["image"] = image_df return self def add_embeddings( self, nodes_to_embed: List[str] = [ "tweet", "reply", "user", "claim", "article", "image", ], ): """Computes, stores and dumps embeddings of node features. Args: nodes_to_embed (list of str): The node types which needs to be embedded. If a node type does not exist in the graph it will be ignored. Defaults to ['tweet', 'reply', 'user', 'claim', 'article', 'image']. """ # Compute the embeddings self.nodes, embeddings_added = self._embedder.embed_all( nodes=self.nodes, nodes_to_embed=nodes_to_embed ) # Store dataset if any embeddings were added if embeddings_added: self._dump_dataset() return self def _filter_node_features(self): """Filters the node features to avoid redundancies and noise""" logger.info("Filters node features") # Set up the node features that should be kept size = "small" if self.size == "test" else self.size node_feats = dict( claim=[ "embedding", "label", "reviewers", "date", "language", "keywords", "cluster_keywords", "cluster", f"{size}_train_mask", f"{size}_val_mask", f"{size}_test_mask", ], tweet=[ "tweet_id", "text", "created_at", "lang", "source", "public_metrics.retweet_count", "public_metrics.reply_count", "public_metrics.quote_count", "label", f"{size}_train_mask", f"{size}_val_mask", f"{size}_test_mask", ], reply=[ "tweet_id", "text", "created_at", "lang", "source", "public_metrics.retweet_count", "public_metrics.reply_count", "public_metrics.quote_count", ], user=[ "user_id", "verified", "protected", "created_at", "username", "description", "url", "name", "public_metrics.followers_count", "public_metrics.following_count", "public_metrics.tweet_count", "public_metrics.listed_count", "location", ], image=["url", "pixels", "width", "height"], article=["url", "title", "content"], place=[ "place_id", "name", "full_name", "country_code", "country", "place_type", "lat", "lng", ], hashtag=["tag"], poll=[ "poll_id", "labels", "votes", "end_datetime", "voting_status", "duration_minutes", ], ) # Set up renaming of node features that should be kept node_feat_renaming = { "public_metrics.retweet_count": "num_retweets", "public_metrics.reply_count": "num_replies", "public_metrics.quote_count": "num_quote_tweets", "public_metrics.followers_count": "num_followers", "public_metrics.following_count": "num_followees", "public_metrics.tweet_count": "num_tweets", "public_metrics.listed_count": "num_listed", f"{size}_train_mask": "train_mask", f"{size}_val_mask": "val_mask", f"{size}_test_mask": "test_mask", } # Filter and rename the node features for node_type, features in node_feats.items(): if node_type in self.nodes.keys(): filtered_feats = [ feat for feat in features if feat in self.nodes[node_type].columns ] renaming_dict = { old: new for old, new in node_feat_renaming.items() if old in features } self.nodes[node_type] = self.nodes[node_type][filtered_feats].rename( columns=renaming_dict ) return self def _filter_relations(self): """Filters the relations to only include node IDs that exist""" logger.info("Filters relations") # Remove article relations if they are not included if not self.include_articles: rels_to_pop = list() for rel_type in self.rels.keys(): src, _, tgt = rel_type if src == "article" or tgt == "article": rels_to_pop.append(rel_type) for rel_type in rels_to_pop: self.rels.pop(rel_type) # Remove reply relations if they are not included if not self.include_replies: rels_to_pop = list() for rel_type in self.rels.keys(): src, _, tgt = rel_type if src == "reply" or tgt == "reply": rels_to_pop.append(rel_type) for rel_type in rels_to_pop: self.rels.pop(rel_type) # Remove mention relations if they are not included if not self.include_mentions: rels_to_pop = list() for rel_type in self.rels.keys(): _, rel, _ = rel_type if rel == "mentions": rels_to_pop.append(rel_type) for rel_type in rels_to_pop: self.rels.pop(rel_type) # Loop over the relations, extract the associated node IDs and filter the # relation dataframe to only include relations between nodes that exist rels_to_pop = list() for rel_type, rel_df in self.rels.items(): # Pop the relation if the dataframe does not exist if rel_df is None or len(rel_df) == 0: rels_to_pop.append(rel_type) continue # Pop the relation if the source or target node does not exist src, _, tgt = rel_type if src not in self.nodes.keys() or tgt not in self.nodes.keys(): rels_to_pop.append(rel_type) # Otherwise filter the relation dataframe to only include nodes that exist else: src_ids = self.nodes[src].index.tolist() tgt_ids = self.nodes[tgt].index.tolist() rel_df = rel_df[rel_df.src.isin(src_ids)] rel_df = rel_df[rel_df.tgt.isin(tgt_ids)] self.rels[rel_type] = rel_df # Pop the relations that has been assigned to be popped for rel_type in rels_to_pop: self.rels.pop(rel_type) def _set_datatypes(self): """Set datatypes in the dataframes, to use less memory""" # Set up all the dtypes of the columns dtypes = dict( tweet=dict( tweet_id="uint64", text="str", created_at={"created_at": "datetime64[ns]"}, lang="category", source="str", num_retweets="uint64", num_replies="uint64", num_quote_tweets="uint64", ), user=dict( user_id="uint64", verified="bool", protected="bool", created_at={"created_at": "datetime64[ns]"}, username="str", description="str", url="str", name="str", num_followers="uint64", num_followees="uint64", num_tweets="uint64", num_listed="uint64", location="category", ), ) if self.include_hashtags: dtypes["hashtag"] = dict(tag="str") if self.include_replies: dtypes["reply"] = dict( tweet_id="uint64", text="str", created_at={"created_at": "datetime64[ns]"}, lang="category", source="str", num_retweets="uint64", num_replies="uint64", num_quote_tweets="uint64", ) if self.include_tweet_images or self.include_extra_images: dtypes["image"] = dict( url="str", pixels="numpy:uint8", width="uint64", height="uint64" ) if self.include_articles: dtypes["article"] = dict(url="str", title="str", content="str") # Create conversion function for missing values def fill_na_values(dtype: Union[str, dict]): if dtype == "uint64": return 0 elif dtype == "bool": return False elif dtype == dict(created_at="datetime64[ns]"): return np.datetime64("NaT") elif dtype == "category": return "NaN" elif dtype == "str": return "" else: return np.nan # Loop over all nodes for ntype, dtype_dict in dtypes.items(): if ntype in self.nodes.keys(): # Set the dtypes for non-numpy columns dtype_dict_no_numpy = { col: dtype for col, dtype in dtype_dict.items() if not isinstance(dtype, str) or not dtype.startswith("numpy") } for col, dtype in dtype_dict_no_numpy.items(): if col in self.nodes[ntype].columns: # Fill NaN values with canonical values in accordance with the # datatype self.nodes[ntype][col].fillna( fill_na_values(dtype), inplace=True ) # Set the dtype self.nodes[ntype][col] = self.nodes[ntype][col].astype(dtype) # For numpy columns, set the type manually def numpy_fn(x, dtype: str): return np.asarray(x, dtype=dtype) for col, dtype in dtype_dict.items(): if ( isinstance(dtype, str) and dtype.startswith("numpy") and col in self.nodes[ntype].columns ): # Fill NaN values with canonical values in accordance with the # datatype self.nodes[ntype][col].fillna( fill_na_values(dtype), inplace=True ) # Extract the NumPy datatype numpy_dtype = dtype.split(":")[-1] # Tweak the numpy function to include the datatype fn = partial(numpy_fn, dtype=numpy_dtype) # Set the dtype self.nodes[ntype][col] = self.nodes[ntype][col].map(fn) def _remove_auxilliaries(self): """Removes node types that are not in use anymore""" # Remove auxilliary node types nodes_to_remove = [ node_type for node_type in self.nodes.keys() if node_type not in self._node_dump ] for node_type in nodes_to_remove: self.nodes.pop(node_type) # Remove auxilliary relation types rels_to_remove = [ rel_type for rel_type in self.rels.keys() if rel_type not in self._rel_dump ] for rel_type in rels_to_remove: self.rels.pop(rel_type) return self def _remove_islands(self): """Removes nodes and relations that are not connected to anything""" # Loop over all the node types for node_type, node_df in self.nodes.items(): # For each node type, loop over all the relations, to see what nodes of # that node type does not appear in any of the relations for rel_type, rel_df in self.rels.items(): src, _, tgt = rel_type # If the node is the source of the relation if node_type == src: # Store all the nodes connected to the relation (or any of the # previously checked relations) connected = node_df.index.isin(rel_df.src.tolist()) if "connected" in node_df.columns: connected = node_df.connected | connected node_df["connected"] = connected # If the node is the source of the relation if node_type == tgt: # Store all the nodes connected to the relation (or any of the # previously checked relations) connected = node_df.index.isin(rel_df.tgt.tolist()) if "connected" in node_df.columns: connected = node_df.connected | connected node_df["connected"] = connected # Filter the node dataframe to only keep the connected ones if "connected" in node_df.columns and "index" not in node_df.columns: self.nodes[node_type] = ( node_df.query("connected == True") .drop(columns="connected") .reset_index() ) # Update the relevant relations for rel_type, rel_df in self.rels.items(): src, _, tgt = rel_type # If islands have been removed from the source, then update those # indices if node_type == src and "index" in self.nodes[node_type]: node_df = ( self.nodes[node_type] .rename(columns=dict(index="old_idx")) .reset_index() ) rel_df = ( rel_df.merge( node_df[["index", "old_idx"]], left_on="src", right_on="old_idx", ) .drop(columns=["src", "old_idx"]) .rename(columns=dict(index="src")) ) self.rels[rel_type] = rel_df[["src", "tgt"]] self.nodes[node_type] = self.nodes[node_type] # If islands have been removed from the target, then update those # indices if node_type == tgt and "index" in self.nodes[node_type]: node_df = ( self.nodes[node_type] .rename(columns=dict(index="old_idx")) .reset_index() ) rel_df = ( rel_df.merge( node_df[["index", "old_idx"]], left_on="tgt", right_on="old_idx", ) .drop(columns=["tgt", "old_idx"]) .rename(columns=dict(index="tgt")) ) self.rels[rel_type] = rel_df[["src", "tgt"]] self.nodes[node_type] = self.nodes[node_type] if "index" in self.nodes[node_type]: self.nodes[node_type] = self.nodes[node_type].drop(columns="index") return self def _dump_dataset(self): """Dumps the dataset to a zip file""" logger.info("Dumping dataset") # Create a temporary pickle folder temp_pickle_folder = Path("temp_pickle_folder") if not temp_pickle_folder.exists(): temp_pickle_folder.mkdir() # Make temporary pickle list pickle_list = list() # Create progress bar total = len(self._node_dump) + len(self._rel_dump) + 1 pbar = tqdm(total=total) # Store the nodes for node_type in self._node_dump: pbar.set_description(f"Storing {node_type} nodes") if node_type in self.nodes.keys(): pickle_list.append(node_type) pickle_path = temp_pickle_folder / f"{node_type}.pickle" self.nodes[node_type].to_pickle(pickle_path, protocol=4) pbar.update() # Store the relations for rel_type in self._rel_dump: pbar.set_description(f"Storing {rel_type} relations") if rel_type in self.rels.keys(): name = "_".join(rel_type) pickle_list.append(name) pickle_path = temp_pickle_folder / f"{name}.pickle" self.rels[rel_type].to_pickle(pickle_path, protocol=4) pbar.update() # Zip the nodes and relations, and save the zip file with zipfile.ZipFile( self.dataset_path, mode="w", compression=zipfile.ZIP_STORED ) as zip_file: pbar.set_description("Dumping dataset") for name in pickle_list: fname = f"{name}.pickle" zip_file.write(temp_pickle_folder / fname, arcname=fname) # Remove the temporary pickle folder rmtree(temp_pickle_folder) # Final progress bar update and close it pbar.update() pbar.close() return self def to_dgl(self): """Convert the dataset to a DGL dataset. Returns: DGLHeteroGraph: The graph in DGL format. """ logger.info("Outputting to DGL") return build_dgl_dataset(nodes=self.nodes, relations=self.rels) def __repr__(self) -> str: """A string representation of the dataset. Returns: str: The representation of the dataset. """ bearer_token_available = self._twitter is not None if len(self.nodes) == 0 or len(self.rels) == 0: return ( f"MuminDataset(size={self.size}, " f"rehydrated={self.rehydrated}, " f"compiled={self.compiled}, " f"bearer_token_available={bearer_token_available})" ) else: num_nodes = sum([len(df) for df in self.nodes.values()]) num_rels = sum([len(df) for df in self.rels.values()]) return ( f"MuminDataset(num_nodes={num_nodes:,}, " f"num_relations={num_rels:,}, " f"size='{self.size}', " f"rehydrated={self.rehydrated}, " f"compiled={self.compiled}, " f"bearer_token_available={bearer_token_available})" )
View Source
class MuminDataset: """The MuMiN misinformation dataset, from [1]. Args: twitter_bearer_token (str or None, optional): The Twitter bearer token. If None then the the bearer token must be stored in the environment variable `TWITTER_API_KEY`, or placed in a file named `.env` in the working directory, formatted as "TWITTER_API_KEY=xxxxx". Defaults to None. size (str, optional): The size of the dataset. Can be either 'small', 'medium' or 'large'. Defaults to 'small'. include_replies (bool, optional): Whether to include replies and quote tweets in the dataset. Defaults to True. include_articles (bool, optional): Whether to include articles in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to True. include_tweet_images (bool, optional): Whether to include images from the tweets in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to True. include_extra_images (bool, optional): Whether to include images from the articles and users in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to False. include_hashtags (bool, optional): Whether to include hashtags in the dataset. Defaults to True. include_mentions (bool, optional): Whether to include mentions in the dataset. Defaults to True. include_timelines (bool, optional): Whether to include timelines in the dataset. Defaults to False. text_embedding_model_id (str, optional): The HuggingFace Hub model ID to use when embedding texts. Defaults to 'xlm-roberta-base'. image_embedding_model_id (str, optional): The HuggingFace Hub model ID to use when embedding images. Defaults to 'google/vit-base-patch16-224-in21k'. dataset_path (str, pathlib Path or None, optional): The path to the file where the dataset should be stored. If None then the dataset will be stored at './mumin-<size>.zip'. Defaults to None. n_jobs (int, optional): The number of jobs to use for parallel processing. Defaults to the number of available CPU cores minus one. chunksize (int, optional): The number of articles/images to process in each job. This speeds up processing time, but also increases memory load. Defaults to 10. verbose (bool, optional): Whether extra information should be outputted. Defaults to True. Attributes: include_replies (bool): Whether to include replies. include_articles (bool): Whether to include articles. include_tweet_images (bool): Whether to include tweet images. include_extra_images (bool): Whether to include user/article images. include_hashtags (bool): Whether to include hashtags. include_mentions (bool): Whether to include mentions. include_timelines (bool): Whether to include timelines. size (str): The size of the dataset. dataset_path (pathlib Path): The dataset file. text_embedding_model_id (str): The model ID used for embedding text. image_embedding_model_id (str): The model ID used for embedding images. nodes (dict): The nodes of the dataset. rels (dict): The relations of the dataset. rehydrated (bool): Whether the tweets and/or replies have been rehydrated. compiled (bool): Whether the dataset has been compiled. n_jobs (int): The number of jobs to use for parallel processing. chunksize (int): The number of articles/images to process in each job. verbose (bool): Whether extra information should be outputted. download_url (str): The URL to download the dataset from. Raises: ValueError: If `twitter_bearer_token` is None and the environment variable `TWITTER_API_KEY` is not set. References: - [1] Nielsen and McConville: MuMiN: A Large-Scale Multilingual Multimodal Fact-Checked Misinformation Dataset with Linked Social Network Posts (2021) """ download_url: str = ( "https://data.bris.ac.uk/datasets/23yv276we2mll25f" "jakkfim2ml/23yv276we2mll25fjakkfim2ml.zip" ) _node_dump: List[str] = [ "claim", "tweet", "user", "image", "article", "hashtag", "reply", ] _rel_dump: List[Tuple[str, str, str]] = [ ("tweet", "discusses", "claim"), ("tweet", "mentions", "user"), ("tweet", "has_image", "image"), ("tweet", "has_hashtag", "hashtag"), ("tweet", "has_article", "article"), ("reply", "reply_to", "tweet"), ("reply", "quote_of", "tweet"), ("user", "posted", "tweet"), ("user", "posted", "reply"), ("user", "mentions", "user"), ("user", "has_hashtag", "hashtag"), ("user", "has_profile_picture", "image"), ("user", "retweeted", "tweet"), ("user", "follows", "user"), ("article", "has_top_image", "image"), ] def __init__( self, twitter_bearer_token: Optional[str] = None, size: str = "small", include_replies: bool = True, include_articles: bool = True, include_tweet_images: bool = True, include_extra_images: bool = False, include_hashtags: bool = True, include_mentions: bool = True, include_timelines: bool = False, text_embedding_model_id: str = "xlm-roberta-base", image_embedding_model_id: str = "google/vit-base-patch16-224-in21k", dataset_path: Optional[Union[str, Path]] = None, n_jobs: int = mp.cpu_count() - 1, chunksize: int = 10, verbose: bool = True, ): self.size = size self.include_replies = include_replies self.include_articles = include_articles self.include_tweet_images = include_tweet_images self.include_extra_images = include_extra_images self.include_hashtags = include_hashtags self.include_mentions = include_mentions self.include_timelines = include_timelines self.text_embedding_model_id = text_embedding_model_id self.image_embedding_model_id = image_embedding_model_id self.verbose = verbose self.compiled: bool = False self.rehydrated: bool = False self.nodes: Dict[str, pd.DataFrame] = dict() self.rels: Dict[Tuple[str, str, str], pd.DataFrame] = dict() # Load the bearer token if it is not provided if twitter_bearer_token is None: twitter_bearer_token = os.environ.get("TWITTER_API_KEY") # If no bearer token is available, raise a warning and set the `_twitter` # attribute to None. Otherwise, set the `_twitter` attribute to a `Twitter` # instance. self._twitter: Twitter if twitter_bearer_token is None: warnings.warn( "Twitter bearer token not provided, so rehydration can not be " "performed. This is fine if you are using a pre-compiled MuMiN, but " "if this is not the case then you will need to either specify the " "`twitter_bearer_token` argument or set the environment variable " "`TWITTER_API_KEY`." ) else: self._twitter = Twitter(twitter_bearer_token=twitter_bearer_token) self._extractor = DataExtractor( include_replies=include_replies, include_articles=include_articles, include_tweet_images=include_tweet_images, include_extra_images=include_extra_images, include_hashtags=include_hashtags, include_mentions=include_mentions, n_jobs=n_jobs, chunksize=chunksize, ) self._updator = IdUpdator() self._embedder = Embedder( text_embedding_model_id=text_embedding_model_id, image_embedding_model_id=image_embedding_model_id, include_articles=include_articles, include_tweet_images=include_tweet_images, include_extra_images=include_extra_images, ) if dataset_path is None: dataset_path = f"./mumin-{size}.zip" self.dataset_path = Path(dataset_path) # Set up logging verbosity if self.verbose: logger.setLevel(logging.INFO) else: logger.setLevel(logging.WARNING) def compile(self, overwrite: bool = False): """Compiles the dataset. This entails downloading the dataset, rehydrating the Twitter data and downloading the relevant associated data, such as articles and images. Args: overwrite (bool, optional): Whether the dataset directory should be overwritten, in case it already exists. Defaults to False. Raises: RuntimeError: If the dataset needs to be compiled and a Twitter bearer token has not been provided. """ self._download(overwrite=overwrite) self._load_dataset() # Variable to check if dataset has been rehydrated and/or compiled if "text" in self.nodes["tweet"].columns: self.rehydrated = True if self.rehydrated and str(self.nodes["tweet"].tweet_id.dtype) == "uint64": self.compiled = True # Only compile the dataset if it has not already been compiled if not self.compiled: # If the dataset has not already been rehydrated, rehydrate it if not self.rehydrated: # Shrink dataset to the correct size self._shrink_dataset() # If the bearer token is not available then raise an error if not isinstance(self._twitter, Twitter): raise RuntimeError( "Twitter bearer token not provided. You need to either specify " "the `twitter_bearer_token` argument in the `MuminDataset` " "constructor or set the environment variable `TWITTER_API_KEY`." ) # Rehydrate the tweets self._rehydrate(node_type="tweet") self._rehydrate(node_type="reply") # Update the IDs of the data that was there pre-hydration self.nodes, self.rels = self._updator.update_all( nodes=self.nodes, rels=self.rels ) # Save dataset self._dump_dataset() # Set the rehydrated flag to True self.rehydrated = True # Extract data from the rehydrated tweets self.nodes, self.rels = self._extractor.extract_all( nodes=self.nodes, rels=self.rels ) # Filter the data self._filter_node_features() self._filter_relations() # Set datatypes self._set_datatypes() # Remove unnecessary bits self._remove_auxilliaries() self._remove_islands() # Save dataset if not self.compiled: self._dump_dataset() # Mark dataset as compiled self.compiled = True return self def _download(self, overwrite: bool = False): """Downloads the dataset. Args: overwrite (bool, optional): Whether the dataset directory should be overwritten, in case it already exists. Defaults to False. """ if not self.dataset_path.exists() or (self.dataset_path.exists() and overwrite): logger.info("Downloading dataset") # Remove existing directory if we are overwriting if self.dataset_path.exists() and overwrite: self.dataset_path.unlink() # Set up download stream of dataset with requests.get(self.download_url, stream=True) as response: # If the response was unsuccessful then raise an error if response.status_code != 200: raise RuntimeError(f"[{response.status_code}] {response.content!r}") # Download dataset with progress bar total = int(response.headers["Content-Length"]) with tqdm( total=total, unit="iB", unit_scale=True, desc="Downloading MuMiN" ) as pbar: with Path(self.dataset_path).open("wb") as f: for data in response.iter_content(1024): pbar.update(len(data)) f.write(data) # The data.bris zip file contains two files: `mumin.zip` and # `readme.txt`. We only want the first one, so we extract that and # replace the original file with it. with zipfile.ZipFile( self.dataset_path, mode="r", compression=zipfile.ZIP_DEFLATED ) as zipf: zipdata = zipf.read("23yv276we2mll25fjakkfim2ml/mumin.zip") with Path(self.dataset_path).open("wb") as f: f.write(zipdata) logger.info("Converting dataset to less compressed format") # Open the zip file containing the dataset data_dict = dict() with zipfile.ZipFile( self.dataset_path, mode="r", compression=zipfile.ZIP_DEFLATED ) as zip_file: # Loop over all the files in the zipped file for name in zip_file.namelist(): # Extract the dataframe in the file byte_data = zip_file.read(name=name) df = pd.read_pickle(io.BytesIO(byte_data), compression="xz") data_dict[name] = df # Overwrite the zip file in a less compressed way, to make io operations # faster with zipfile.ZipFile( self.dataset_path, mode="w", compression=zipfile.ZIP_STORED ) as zip_file: for name, df in data_dict.items(): buffer = io.BytesIO() df.to_pickle(buffer, protocol=4) zip_file.writestr(name, data=buffer.getvalue()) return self def _load_dataset(self): """Loads the dataset files into memory. Raises: RuntimeError: If the dataset has not been downloaded yet. """ # Raise error if the dataset has not been downloaded yet if not self.dataset_path.exists(): raise RuntimeError("Dataset has not been downloaded yet!") logger.info("Loading dataset") # Reset `nodes` and `relations` to ensure a fresh start self.nodes = dict() self.rels = dict() # Open the zip file containing the dataset with zipfile.ZipFile( self.dataset_path, mode="r", compression=zipfile.ZIP_STORED ) as zip_file: # Loop over all the files in the zipped file for name in zip_file.namelist(): # Extract the dataframe in the file byte_data = zip_file.read(name=name) df = pd.read_pickle(io.BytesIO(byte_data)) # If there are no underscores in the filename then we assume that it # contains node data if "_" not in name: self.nodes[name.replace(".pickle", "")] = df.copy() # Otherwise, with underscores in the filename then we assume it # contains relation data else: splits = name.replace(".pickle", "").split("_") src = splits[0] tgt = splits[-1] rel = "_".join(splits[1:-1]) self.rels[(src, rel, tgt)] = df.copy() # Ensure that claims are present in the dataset if "claim" not in self.nodes.keys(): raise RuntimeError("No claims are present in the file!") # Ensure that tweets are present in the dataset, and also that the tweet # IDs are unique if "tweet" not in self.nodes.keys(): raise RuntimeError("No tweets are present in the file!") else: tweet_df = self.nodes["tweet"] duplicated = tweet_df[tweet_df.tweet_id.duplicated()].tweet_id.tolist() if len(duplicated) > 0: raise RuntimeError( f"The tweet IDs {duplicated} are " f"duplicate in the dataset!" ) return self def _shrink_dataset(self): """Shrink dataset if `size` is 'small' or 'medium""" logger.info("Shrinking dataset") # Define the `relevance` threshold if self.size == "small": threshold = 0.80 # noqa elif self.size == "medium": threshold = 0.75 # noqa elif self.size == "large": threshold = 0.70 # noqa elif self.size == "test": threshold = 0.995 # noqa # Filter nodes ntypes = ["tweet", "reply", "user", "article"] for ntype in ntypes: self.nodes[ntype] = ( self.nodes[ntype] .query("relevance > @threshold") .drop(columns=["relevance"]) .reset_index(drop=True) ) # Filter relations etypes = [ ("reply", "reply_to", "tweet"), ("reply", "quote_of", "tweet"), ("user", "retweeted", "tweet"), ("user", "follows", "user"), ("tweet", "discusses", "claim"), ("article", "discusses", "claim"), ] for etype in etypes: self.rels[etype] = ( self.rels[etype] .query("relevance > @threshold") .drop(columns=["relevance"]) .reset_index(drop=True) ) # Filter claims claim_df = self.nodes["claim"] discusses_rel = self.rels[("tweet", "discusses", "claim")] include_claim = claim_df.id.isin(discusses_rel.tgt.tolist()) self.nodes["claim"] = claim_df[include_claim].reset_index(drop=True) # Filter timeline tweets if not self.include_timelines: src_tweet_ids = ( self.rels[("tweet", "discusses", "claim")] .src.astype(np.uint64) .tolist() ) is_src = self.nodes["tweet"].tweet_id.isin(src_tweet_ids) self.nodes["tweet"] = self.nodes["tweet"].loc[is_src] return self def _rehydrate(self, node_type: str): """Rehydrate the tweets and users in the dataset. Args: node_type (str): The type of node to rehydrate. """ if node_type in self.nodes.keys() and ( node_type != "reply" or self.include_replies ): logger.info(f"Rehydrating {node_type} nodes") # Get the tweet IDs, and if the node type is a tweet then separate these # into source tweets and the rest (i.e., timeline tweets) if node_type == "tweet": source_tweet_ids = ( self.rels[("tweet", "discusses", "claim")] .src.astype(np.uint64) .tolist() ) tweet_ids = [ tweet_id for tweet_id in self.nodes[node_type].tweet_id.astype(np.uint64) if tweet_id not in source_tweet_ids ] else: source_tweet_ids = list() tweet_ids = self.nodes[node_type].tweet_id.astype(np.uint64).tolist() # Store any features the nodes might have had before hydration prehydration_df = self.nodes[node_type].copy() # Rehydrate the source tweets if len(source_tweet_ids) > 0: params = dict(tweet_ids=source_tweet_ids) source_tweet_dfs = self._twitter.rehydrate_tweets(**params) # Return error if there are no tweets were rehydrated. This is probably # because the bearer token is wrong if len(source_tweet_dfs) == 0: raise RuntimeError( "No tweets were rehydrated. Check if the bearer token is " "correct." ) if len(tweet_ids) == 0: tweet_dfs = {key: pd.DataFrame() for key in source_tweet_dfs.keys()} # Rehydrate the other tweets if len(tweet_ids) > 0: params = dict(tweet_ids=tweet_ids) tweet_dfs = self._twitter.rehydrate_tweets(**params) # Return error if there are no tweets were rehydrated. This is # probably because the bearer token is wrong if len(tweet_dfs) == 0: raise RuntimeError( "No tweets were rehydrated. Check if " "the bearer token is correct." ) if len(source_tweet_ids) == 0: source_tweet_dfs = {key: pd.DataFrame() for key in tweet_dfs.keys()} # Extract and store tweets and users tweet_df = pd.concat( [source_tweet_dfs["tweets"], tweet_dfs["tweets"]], ignore_index=True ) self.nodes[node_type] = tweet_df.drop_duplicates( subset="tweet_id" ).reset_index(drop=True) user_df = pd.concat( [source_tweet_dfs["users"], tweet_dfs["users"]], ignore_index=True ) if "user" in self.nodes.keys() and "username" in self.nodes["user"].columns: user_df = ( pd.concat((self.nodes["user"], user_df), axis=0) .drop_duplicates(subset="user_id") .reset_index(drop=True) ) self.nodes["user"] = user_df # Add prehydration tweet features back to the tweets self.nodes[node_type] = ( self.nodes[node_type] .merge(prehydration_df, on="tweet_id", how="outer") .reset_index(drop=True) ) # Extract and store images # Note: This will store `self.nodes['image']`, but this is only to enable # extraction of URLs later on. The `self.nodes['image']` will be # overwritten later on. if ( node_type == "tweet" and self.include_tweet_images and len(source_tweet_dfs["media"]) ): image_df = ( source_tweet_dfs["media"] .query('type == "photo"') .drop_duplicates(subset="media_key") .reset_index(drop=True) ) if "image" in self.nodes.keys(): image_df = ( pd.concat((self.nodes["image"], image_df), axis=0) .drop_duplicates(subset="media_key") .reset_index(drop=True) ) self.nodes["image"] = image_df return self def add_embeddings( self, nodes_to_embed: List[str] = [ "tweet", "reply", "user", "claim", "article", "image", ], ): """Computes, stores and dumps embeddings of node features. Args: nodes_to_embed (list of str): The node types which needs to be embedded. If a node type does not exist in the graph it will be ignored. Defaults to ['tweet', 'reply', 'user', 'claim', 'article', 'image']. """ # Compute the embeddings self.nodes, embeddings_added = self._embedder.embed_all( nodes=self.nodes, nodes_to_embed=nodes_to_embed ) # Store dataset if any embeddings were added if embeddings_added: self._dump_dataset() return self def _filter_node_features(self): """Filters the node features to avoid redundancies and noise""" logger.info("Filters node features") # Set up the node features that should be kept size = "small" if self.size == "test" else self.size node_feats = dict( claim=[ "embedding", "label", "reviewers", "date", "language", "keywords", "cluster_keywords", "cluster", f"{size}_train_mask", f"{size}_val_mask", f"{size}_test_mask", ], tweet=[ "tweet_id", "text", "created_at", "lang", "source", "public_metrics.retweet_count", "public_metrics.reply_count", "public_metrics.quote_count", "label", f"{size}_train_mask", f"{size}_val_mask", f"{size}_test_mask", ], reply=[ "tweet_id", "text", "created_at", "lang", "source", "public_metrics.retweet_count", "public_metrics.reply_count", "public_metrics.quote_count", ], user=[ "user_id", "verified", "protected", "created_at", "username", "description", "url", "name", "public_metrics.followers_count", "public_metrics.following_count", "public_metrics.tweet_count", "public_metrics.listed_count", "location", ], image=["url", "pixels", "width", "height"], article=["url", "title", "content"], place=[ "place_id", "name", "full_name", "country_code", "country", "place_type", "lat", "lng", ], hashtag=["tag"], poll=[ "poll_id", "labels", "votes", "end_datetime", "voting_status", "duration_minutes", ], ) # Set up renaming of node features that should be kept node_feat_renaming = { "public_metrics.retweet_count": "num_retweets", "public_metrics.reply_count": "num_replies", "public_metrics.quote_count": "num_quote_tweets", "public_metrics.followers_count": "num_followers", "public_metrics.following_count": "num_followees", "public_metrics.tweet_count": "num_tweets", "public_metrics.listed_count": "num_listed", f"{size}_train_mask": "train_mask", f"{size}_val_mask": "val_mask", f"{size}_test_mask": "test_mask", } # Filter and rename the node features for node_type, features in node_feats.items(): if node_type in self.nodes.keys(): filtered_feats = [ feat for feat in features if feat in self.nodes[node_type].columns ] renaming_dict = { old: new for old, new in node_feat_renaming.items() if old in features } self.nodes[node_type] = self.nodes[node_type][filtered_feats].rename( columns=renaming_dict ) return self def _filter_relations(self): """Filters the relations to only include node IDs that exist""" logger.info("Filters relations") # Remove article relations if they are not included if not self.include_articles: rels_to_pop = list() for rel_type in self.rels.keys(): src, _, tgt = rel_type if src == "article" or tgt == "article": rels_to_pop.append(rel_type) for rel_type in rels_to_pop: self.rels.pop(rel_type) # Remove reply relations if they are not included if not self.include_replies: rels_to_pop = list() for rel_type in self.rels.keys(): src, _, tgt = rel_type if src == "reply" or tgt == "reply": rels_to_pop.append(rel_type) for rel_type in rels_to_pop: self.rels.pop(rel_type) # Remove mention relations if they are not included if not self.include_mentions: rels_to_pop = list() for rel_type in self.rels.keys(): _, rel, _ = rel_type if rel == "mentions": rels_to_pop.append(rel_type) for rel_type in rels_to_pop: self.rels.pop(rel_type) # Loop over the relations, extract the associated node IDs and filter the # relation dataframe to only include relations between nodes that exist rels_to_pop = list() for rel_type, rel_df in self.rels.items(): # Pop the relation if the dataframe does not exist if rel_df is None or len(rel_df) == 0: rels_to_pop.append(rel_type) continue # Pop the relation if the source or target node does not exist src, _, tgt = rel_type if src not in self.nodes.keys() or tgt not in self.nodes.keys(): rels_to_pop.append(rel_type) # Otherwise filter the relation dataframe to only include nodes that exist else: src_ids = self.nodes[src].index.tolist() tgt_ids = self.nodes[tgt].index.tolist() rel_df = rel_df[rel_df.src.isin(src_ids)] rel_df = rel_df[rel_df.tgt.isin(tgt_ids)] self.rels[rel_type] = rel_df # Pop the relations that has been assigned to be popped for rel_type in rels_to_pop: self.rels.pop(rel_type) def _set_datatypes(self): """Set datatypes in the dataframes, to use less memory""" # Set up all the dtypes of the columns dtypes = dict( tweet=dict( tweet_id="uint64", text="str", created_at={"created_at": "datetime64[ns]"}, lang="category", source="str", num_retweets="uint64", num_replies="uint64", num_quote_tweets="uint64", ), user=dict( user_id="uint64", verified="bool", protected="bool", created_at={"created_at": "datetime64[ns]"}, username="str", description="str", url="str", name="str", num_followers="uint64", num_followees="uint64", num_tweets="uint64", num_listed="uint64", location="category", ), ) if self.include_hashtags: dtypes["hashtag"] = dict(tag="str") if self.include_replies: dtypes["reply"] = dict( tweet_id="uint64", text="str", created_at={"created_at": "datetime64[ns]"}, lang="category", source="str", num_retweets="uint64", num_replies="uint64", num_quote_tweets="uint64", ) if self.include_tweet_images or self.include_extra_images: dtypes["image"] = dict( url="str", pixels="numpy:uint8", width="uint64", height="uint64" ) if self.include_articles: dtypes["article"] = dict(url="str", title="str", content="str") # Create conversion function for missing values def fill_na_values(dtype: Union[str, dict]): if dtype == "uint64": return 0 elif dtype == "bool": return False elif dtype == dict(created_at="datetime64[ns]"): return np.datetime64("NaT") elif dtype == "category": return "NaN" elif dtype == "str": return "" else: return np.nan # Loop over all nodes for ntype, dtype_dict in dtypes.items(): if ntype in self.nodes.keys(): # Set the dtypes for non-numpy columns dtype_dict_no_numpy = { col: dtype for col, dtype in dtype_dict.items() if not isinstance(dtype, str) or not dtype.startswith("numpy") } for col, dtype in dtype_dict_no_numpy.items(): if col in self.nodes[ntype].columns: # Fill NaN values with canonical values in accordance with the # datatype self.nodes[ntype][col].fillna( fill_na_values(dtype), inplace=True ) # Set the dtype self.nodes[ntype][col] = self.nodes[ntype][col].astype(dtype) # For numpy columns, set the type manually def numpy_fn(x, dtype: str): return np.asarray(x, dtype=dtype) for col, dtype in dtype_dict.items(): if ( isinstance(dtype, str) and dtype.startswith("numpy") and col in self.nodes[ntype].columns ): # Fill NaN values with canonical values in accordance with the # datatype self.nodes[ntype][col].fillna( fill_na_values(dtype), inplace=True ) # Extract the NumPy datatype numpy_dtype = dtype.split(":")[-1] # Tweak the numpy function to include the datatype fn = partial(numpy_fn, dtype=numpy_dtype) # Set the dtype self.nodes[ntype][col] = self.nodes[ntype][col].map(fn) def _remove_auxilliaries(self): """Removes node types that are not in use anymore""" # Remove auxilliary node types nodes_to_remove = [ node_type for node_type in self.nodes.keys() if node_type not in self._node_dump ] for node_type in nodes_to_remove: self.nodes.pop(node_type) # Remove auxilliary relation types rels_to_remove = [ rel_type for rel_type in self.rels.keys() if rel_type not in self._rel_dump ] for rel_type in rels_to_remove: self.rels.pop(rel_type) return self def _remove_islands(self): """Removes nodes and relations that are not connected to anything""" # Loop over all the node types for node_type, node_df in self.nodes.items(): # For each node type, loop over all the relations, to see what nodes of # that node type does not appear in any of the relations for rel_type, rel_df in self.rels.items(): src, _, tgt = rel_type # If the node is the source of the relation if node_type == src: # Store all the nodes connected to the relation (or any of the # previously checked relations) connected = node_df.index.isin(rel_df.src.tolist()) if "connected" in node_df.columns: connected = node_df.connected | connected node_df["connected"] = connected # If the node is the source of the relation if node_type == tgt: # Store all the nodes connected to the relation (or any of the # previously checked relations) connected = node_df.index.isin(rel_df.tgt.tolist()) if "connected" in node_df.columns: connected = node_df.connected | connected node_df["connected"] = connected # Filter the node dataframe to only keep the connected ones if "connected" in node_df.columns and "index" not in node_df.columns: self.nodes[node_type] = ( node_df.query("connected == True") .drop(columns="connected") .reset_index() ) # Update the relevant relations for rel_type, rel_df in self.rels.items(): src, _, tgt = rel_type # If islands have been removed from the source, then update those # indices if node_type == src and "index" in self.nodes[node_type]: node_df = ( self.nodes[node_type] .rename(columns=dict(index="old_idx")) .reset_index() ) rel_df = ( rel_df.merge( node_df[["index", "old_idx"]], left_on="src", right_on="old_idx", ) .drop(columns=["src", "old_idx"]) .rename(columns=dict(index="src")) ) self.rels[rel_type] = rel_df[["src", "tgt"]] self.nodes[node_type] = self.nodes[node_type] # If islands have been removed from the target, then update those # indices if node_type == tgt and "index" in self.nodes[node_type]: node_df = ( self.nodes[node_type] .rename(columns=dict(index="old_idx")) .reset_index() ) rel_df = ( rel_df.merge( node_df[["index", "old_idx"]], left_on="tgt", right_on="old_idx", ) .drop(columns=["tgt", "old_idx"]) .rename(columns=dict(index="tgt")) ) self.rels[rel_type] = rel_df[["src", "tgt"]] self.nodes[node_type] = self.nodes[node_type] if "index" in self.nodes[node_type]: self.nodes[node_type] = self.nodes[node_type].drop(columns="index") return self def _dump_dataset(self): """Dumps the dataset to a zip file""" logger.info("Dumping dataset") # Create a temporary pickle folder temp_pickle_folder = Path("temp_pickle_folder") if not temp_pickle_folder.exists(): temp_pickle_folder.mkdir() # Make temporary pickle list pickle_list = list() # Create progress bar total = len(self._node_dump) + len(self._rel_dump) + 1 pbar = tqdm(total=total) # Store the nodes for node_type in self._node_dump: pbar.set_description(f"Storing {node_type} nodes") if node_type in self.nodes.keys(): pickle_list.append(node_type) pickle_path = temp_pickle_folder / f"{node_type}.pickle" self.nodes[node_type].to_pickle(pickle_path, protocol=4) pbar.update() # Store the relations for rel_type in self._rel_dump: pbar.set_description(f"Storing {rel_type} relations") if rel_type in self.rels.keys(): name = "_".join(rel_type) pickle_list.append(name) pickle_path = temp_pickle_folder / f"{name}.pickle" self.rels[rel_type].to_pickle(pickle_path, protocol=4) pbar.update() # Zip the nodes and relations, and save the zip file with zipfile.ZipFile( self.dataset_path, mode="w", compression=zipfile.ZIP_STORED ) as zip_file: pbar.set_description("Dumping dataset") for name in pickle_list: fname = f"{name}.pickle" zip_file.write(temp_pickle_folder / fname, arcname=fname) # Remove the temporary pickle folder rmtree(temp_pickle_folder) # Final progress bar update and close it pbar.update() pbar.close() return self def to_dgl(self): """Convert the dataset to a DGL dataset. Returns: DGLHeteroGraph: The graph in DGL format. """ logger.info("Outputting to DGL") return build_dgl_dataset(nodes=self.nodes, relations=self.rels) def __repr__(self) -> str: """A string representation of the dataset. Returns: str: The representation of the dataset. """ bearer_token_available = self._twitter is not None if len(self.nodes) == 0 or len(self.rels) == 0: return ( f"MuminDataset(size={self.size}, " f"rehydrated={self.rehydrated}, " f"compiled={self.compiled}, " f"bearer_token_available={bearer_token_available})" ) else: num_nodes = sum([len(df) for df in self.nodes.values()]) num_rels = sum([len(df) for df in self.rels.values()]) return ( f"MuminDataset(num_nodes={num_nodes:,}, " f"num_relations={num_rels:,}, " f"size='{self.size}', " f"rehydrated={self.rehydrated}, " f"compiled={self.compiled}, " f"bearer_token_available={bearer_token_available})" )
The MuMiN misinformation dataset, from [1].
Args
- twitter_bearer_token (str or None, optional): The Twitter bearer token. If None then the the bearer token must be stored
in the environment variable
TWITTER_API_KEY
, or placed in a file named.env
in the working directory, formatted as "TWITTER_API_KEY=xxxxx". Defaults to None. - size (str, optional): The size of the dataset. Can be either 'small', 'medium' or 'large'. Defaults to 'small'.
- include_replies (bool, optional): Whether to include replies and quote tweets in the dataset. Defaults to True.
- include_articles (bool, optional): Whether to include articles in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to True.
- include_tweet_images (bool, optional): Whether to include images from the tweets in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to True.
- include_extra_images (bool, optional): Whether to include images from the articles and users in the dataset. This will mean that compilation of the dataset will take a bit longer, as these need to be downloaded and parsed. Defaults to False.
- include_hashtags (bool, optional): Whether to include hashtags in the dataset. Defaults to True.
- include_mentions (bool, optional): Whether to include mentions in the dataset. Defaults to True.
- include_timelines (bool, optional): Whether to include timelines in the dataset. Defaults to False.
- text_embedding_model_id (str, optional): The HuggingFace Hub model ID to use when embedding texts. Defaults to 'xlm-roberta-base'.
- image_embedding_model_id (str, optional): The HuggingFace Hub model ID to use when embedding images. Defaults to 'google/vit-base-patch16-224-in21k'.
- dataset_path (str, pathlib Path or None, optional): The path to the file where the dataset should be stored. If None then the
dataset will be stored at './mumin-
.zip'. Defaults to None. - n_jobs (int, optional): The number of jobs to use for parallel processing. Defaults to the number of available CPU cores minus one.
- chunksize (int, optional): The number of articles/images to process in each job. This speeds up processing time, but also increases memory load. Defaults to 10.
- verbose (bool, optional): Whether extra information should be outputted. Defaults to True.
Attributes
- include_replies (bool): Whether to include replies.
- include_articles (bool): Whether to include articles.
- include_tweet_images (bool): Whether to include tweet images.
- include_extra_images (bool): Whether to include user/article images.
- include_hashtags (bool): Whether to include hashtags.
- include_mentions (bool): Whether to include mentions.
- include_timelines (bool): Whether to include timelines.
- size (str): The size of the dataset.
- dataset_path (pathlib Path): The dataset file.
- text_embedding_model_id (str): The model ID used for embedding text.
- image_embedding_model_id (str): The model ID used for embedding images.
- nodes (dict): The nodes of the dataset.
- rels (dict): The relations of the dataset.
- rehydrated (bool): Whether the tweets and/or replies have been rehydrated.
- compiled (bool): Whether the dataset has been compiled.
- n_jobs (int): The number of jobs to use for parallel processing.
- chunksize (int): The number of articles/images to process in each job.
- verbose (bool): Whether extra information should be outputted.
- download_url (str): The URL to download the dataset from.
Raises
- ValueError: If
twitter_bearer_token
is None and the environment variableTWITTER_API_KEY
is not set.
References
- [1] Nielsen and McConville: MuMiN: A Large-Scale Multilingual Multimodal Fact-Checked Misinformation Dataset with Linked Social Network Posts (2021)
#  
MuminDataset(
twitter_bearer_token: Union[str, NoneType] = None,
size: str = 'small',
include_replies: bool = True,
include_articles: bool = True,
include_tweet_images: bool = True,
include_extra_images: bool = False,
include_hashtags: bool = True,
include_mentions: bool = True,
include_timelines: bool = False,
text_embedding_model_id: str = 'xlm-roberta-base',
image_embedding_model_id: str = 'google/vit-base-patch16-224-in21k',
dataset_path: Union[str, pathlib.Path, NoneType] = None,
n_jobs: int = 1,
chunksize: int = 10,
verbose: bool = True
)
View Source
def __init__( self, twitter_bearer_token: Optional[str] = None, size: str = "small", include_replies: bool = True, include_articles: bool = True, include_tweet_images: bool = True, include_extra_images: bool = False, include_hashtags: bool = True, include_mentions: bool = True, include_timelines: bool = False, text_embedding_model_id: str = "xlm-roberta-base", image_embedding_model_id: str = "google/vit-base-patch16-224-in21k", dataset_path: Optional[Union[str, Path]] = None, n_jobs: int = mp.cpu_count() - 1, chunksize: int = 10, verbose: bool = True, ): self.size = size self.include_replies = include_replies self.include_articles = include_articles self.include_tweet_images = include_tweet_images self.include_extra_images = include_extra_images self.include_hashtags = include_hashtags self.include_mentions = include_mentions self.include_timelines = include_timelines self.text_embedding_model_id = text_embedding_model_id self.image_embedding_model_id = image_embedding_model_id self.verbose = verbose self.compiled: bool = False self.rehydrated: bool = False self.nodes: Dict[str, pd.DataFrame] = dict() self.rels: Dict[Tuple[str, str, str], pd.DataFrame] = dict() # Load the bearer token if it is not provided if twitter_bearer_token is None: twitter_bearer_token = os.environ.get("TWITTER_API_KEY") # If no bearer token is available, raise a warning and set the `_twitter` # attribute to None. Otherwise, set the `_twitter` attribute to a `Twitter` # instance. self._twitter: Twitter if twitter_bearer_token is None: warnings.warn( "Twitter bearer token not provided, so rehydration can not be " "performed. This is fine if you are using a pre-compiled MuMiN, but " "if this is not the case then you will need to either specify the " "`twitter_bearer_token` argument or set the environment variable " "`TWITTER_API_KEY`." ) else: self._twitter = Twitter(twitter_bearer_token=twitter_bearer_token) self._extractor = DataExtractor( include_replies=include_replies, include_articles=include_articles, include_tweet_images=include_tweet_images, include_extra_images=include_extra_images, include_hashtags=include_hashtags, include_mentions=include_mentions, n_jobs=n_jobs, chunksize=chunksize, ) self._updator = IdUpdator() self._embedder = Embedder( text_embedding_model_id=text_embedding_model_id, image_embedding_model_id=image_embedding_model_id, include_articles=include_articles, include_tweet_images=include_tweet_images, include_extra_images=include_extra_images, ) if dataset_path is None: dataset_path = f"./mumin-{size}.zip" self.dataset_path = Path(dataset_path) # Set up logging verbosity if self.verbose: logger.setLevel(logging.INFO) else: logger.setLevel(logging.WARNING)
#  
download_url: str = 'https://data.bris.ac.uk/datasets/23yv276we2mll25fjakkfim2ml/23yv276we2mll25fjakkfim2ml.zip'
View Source
def compile(self, overwrite: bool = False): """Compiles the dataset. This entails downloading the dataset, rehydrating the Twitter data and downloading the relevant associated data, such as articles and images. Args: overwrite (bool, optional): Whether the dataset directory should be overwritten, in case it already exists. Defaults to False. Raises: RuntimeError: If the dataset needs to be compiled and a Twitter bearer token has not been provided. """ self._download(overwrite=overwrite) self._load_dataset() # Variable to check if dataset has been rehydrated and/or compiled if "text" in self.nodes["tweet"].columns: self.rehydrated = True if self.rehydrated and str(self.nodes["tweet"].tweet_id.dtype) == "uint64": self.compiled = True # Only compile the dataset if it has not already been compiled if not self.compiled: # If the dataset has not already been rehydrated, rehydrate it if not self.rehydrated: # Shrink dataset to the correct size self._shrink_dataset() # If the bearer token is not available then raise an error if not isinstance(self._twitter, Twitter): raise RuntimeError( "Twitter bearer token not provided. You need to either specify " "the `twitter_bearer_token` argument in the `MuminDataset` " "constructor or set the environment variable `TWITTER_API_KEY`." ) # Rehydrate the tweets self._rehydrate(node_type="tweet") self._rehydrate(node_type="reply") # Update the IDs of the data that was there pre-hydration self.nodes, self.rels = self._updator.update_all( nodes=self.nodes, rels=self.rels ) # Save dataset self._dump_dataset() # Set the rehydrated flag to True self.rehydrated = True # Extract data from the rehydrated tweets self.nodes, self.rels = self._extractor.extract_all( nodes=self.nodes, rels=self.rels ) # Filter the data self._filter_node_features() self._filter_relations() # Set datatypes self._set_datatypes() # Remove unnecessary bits self._remove_auxilliaries() self._remove_islands() # Save dataset if not self.compiled: self._dump_dataset() # Mark dataset as compiled self.compiled = True return self
Compiles the dataset.
This entails downloading the dataset, rehydrating the Twitter data and downloading the relevant associated data, such as articles and images.
Args
- overwrite (bool, optional): Whether the dataset directory should be overwritten, in case it already exists. Defaults to False.
Raises
- RuntimeError: If the dataset needs to be compiled and a Twitter bearer token has not been provided.
#  
def
add_embeddings(
self,
nodes_to_embed: List[str] = ['tweet', 'reply', 'user', 'claim', 'article', 'image']
):
View Source
def add_embeddings( self, nodes_to_embed: List[str] = [ "tweet", "reply", "user", "claim", "article", "image", ], ): """Computes, stores and dumps embeddings of node features. Args: nodes_to_embed (list of str): The node types which needs to be embedded. If a node type does not exist in the graph it will be ignored. Defaults to ['tweet', 'reply', 'user', 'claim', 'article', 'image']. """ # Compute the embeddings self.nodes, embeddings_added = self._embedder.embed_all( nodes=self.nodes, nodes_to_embed=nodes_to_embed ) # Store dataset if any embeddings were added if embeddings_added: self._dump_dataset() return self
Computes, stores and dumps embeddings of node features.
Args
- nodes_to_embed (list of str): The node types which needs to be embedded. If a node type does not exist in the graph it will be ignored. Defaults to ['tweet', 'reply', 'user', 'claim', 'article', 'image'].
View Source
def to_dgl(self): """Convert the dataset to a DGL dataset. Returns: DGLHeteroGraph: The graph in DGL format. """ logger.info("Outputting to DGL") return build_dgl_dataset(nodes=self.nodes, relations=self.rels)
Convert the dataset to a DGL dataset.
Returns
DGLHeteroGraph: The graph in DGL format.