|
- """Heterograph NN modules"""
- import torch as th
- import torch.nn as nn
-
- __all__ = ['HeteroGraphConv']
-
- class HeteroGraphConv(nn.Module):
- r"""A generic module for computing convolution on heterogeneous graphs.
-
- The heterograph convolution applies sub-modules on their associating
- relation graphs, which reads the features from source nodes and writes the
- updated ones to destination nodes. If multiple relations have the same
- destination node types, their results are aggregated by the specified method.
-
- If the relation graph has no edge, the corresponding module will not be called.
-
-
-
- Parameters
- ----------
- mods : dict[str, nn.Module]
- Modules associated with every edge types. The forward function of each
- module must have a `DGLHeteroGraph` object as the first argument, and
- its second argument is either a tensor object representing the node
- features or a pair of tensor object representing the source and destination
- node features.
- aggregate : str, callable, optional
- Method for aggregating node features generated by different relations.
- Allowed string values are 'sum', 'max', 'min', 'mean', 'stack'.
- The 'stack' aggregation is performed along the second dimension, whose order
- is deterministic.
- User can also customize the aggregator by providing a callable instance.
- For example, aggregation by summation is equivalent to the follows:
-
- .. code::
-
- def my_agg_func(tensors, dsttype):
- # tensors: is a list of tensors to aggregate
- # dsttype: string name of the destination node type for which the
- # aggregation is performed
- stacked = torch.stack(tensors, dim=0)
- return torch.sum(stacked, dim=0)
-
- Attributes
- ----------
- mods : dict[str, nn.Module]
- Modules associated with every edge types.
- """
- def __init__(self, mods):
- super(HeteroGraphConv, self).__init__()
- self.mods = nn.ModuleDict(mods)
- # Do not break if graph has 0-in-degree nodes.
- # Because there is no general rule to add self-loop for heterograph.
- for _, v in self.mods.items():
- set_allow_zero_in_degree_fn = getattr(v, 'set_allow_zero_in_degree', None)
- if callable(set_allow_zero_in_degree_fn):
- set_allow_zero_in_degree_fn(True)
-
- def forward(self, g, inputs, mod_args=None, mod_kwargs=None):
- """Forward computation
-
- Invoke the forward function with each module and aggregate their results.
-
- Parameters
- ----------
- g : DGLHeteroGraph
- Graph data.
- inputs : dict[str, Tensor] or pair of dict[str, Tensor]
- Input node features.
- mod_args : dict[str, tuple[any]], optional
- Extra positional arguments for the sub-modules.
- mod_kwargs : dict[str, dict[str, any]], optional
- Extra key-word arguments for the sub-modules.
-
- Returns
- -------
- dict[str, Tensor]
- Output representations for every types of nodes.
- """
- if mod_args is None:
- mod_args = {}
- if mod_kwargs is None:
- mod_kwargs = {}
- outputs = {nty : [] for nty in g.dsttypes}
- if isinstance(inputs, tuple) or g.is_block:
- if isinstance(inputs, tuple):
- src_inputs, dst_inputs = inputs
- else:
- src_inputs = inputs
- dst_inputs = {k: v[:g.number_of_dst_nodes(k)] for k, v in inputs.items()}
-
- for stype, etype, dtype in g.canonical_etypes:
- rel_graph = g[stype, etype, dtype]
- if rel_graph.number_of_edges() == 0:
- continue
- if stype not in src_inputs or dtype not in dst_inputs:
- continue
- dstdata = self.mods[etype](
- rel_graph,
- (src_inputs[stype], dst_inputs[dtype]),
- *mod_args.get(etype, ()),
- **mod_kwargs.get(etype, {}))
- outputs[dtype].append(dstdata)
- else:
- for stype, etype, dtype in g.canonical_etypes:
- rel_graph = g[stype, etype, dtype]
- if rel_graph.number_of_edges() == 0:
- continue
- if stype not in inputs:
- continue
- dstdata = self.mods[etype](
- rel_graph,
- (inputs[stype], inputs[dtype]),
- *mod_args.get(etype, ()),
- **mod_kwargs.get(etype, {}))
- outputs[dtype].append(dstdata)
-
- return outputs
|