Merge tag 'upstream/0.17.0' into debian
Upstream version 0.17.0
Erik Johnston
7 years ago
0 | Changes in synapse v0.17.0 (2016-08-08) | |
1 | ======================================= | |
2 | ||
3 | This release contains significant security bug fixes regarding authenticating | |
4 | events received over federation. PLEASE UPGRADE. | |
5 | ||
6 | This release changes the LDAP configuration format in a backwards incompatible | |
7 | way, see PR #843 for details. | |
8 | ||
9 | ||
10 | Changes: | |
11 | ||
12 | * Add federation /version API (PR #990) | |
13 | * Make psutil dependency optional (PR #992) | |
14 | ||
15 | ||
16 | Bug fixes: | |
17 | ||
18 | * Fix URL preview API to exclude HTML comments in description (PR #988) | |
19 | * Fix error handling of remote joins (PR #991) | |
20 | ||
21 | ||
22 | Changes in synapse v0.17.0-rc4 (2016-08-05) | |
23 | =========================================== | |
24 | ||
25 | Changes: | |
26 | ||
27 | * Change the way we summarize URLs when previewing (PR #973) | |
28 | * Add new ``/state_ids/`` federation API (PR #979) | |
29 | * Speed up processing of ``/state/`` response (PR #986) | |
30 | ||
31 | Bug fixes: | |
32 | ||
33 | * Fix event persistence when event has already been partially persisted | |
34 | (PR #975, #983, #985) | |
35 | * Fix port script to also copy across backfilled events (PR #982) | |
36 | ||
37 | ||
38 | Changes in synapse v0.17.0-rc3 (2016-08-02) | |
39 | =========================================== | |
40 | ||
41 | Changes: | |
42 | ||
43 | * Forbid non-ASes from registering users whose names begin with '_' (PR #958) | |
44 | * Add some basic admin API docs (PR #963) | |
45 | ||
46 | ||
47 | Bug fixes: | |
48 | ||
49 | * Send the correct host header when fetching keys (PR #941) | |
50 | * Fix joining a room that has missing auth events (PR #964) | |
51 | * Fix various push bugs (PR #966, #970) | |
52 | * Fix adding emails on registration (PR #968) | |
53 | ||
54 | ||
55 | Changes in synapse v0.17.0-rc2 (2016-08-02) | |
56 | =========================================== | |
57 | ||
58 | (This release did not include the changes advertised and was identical to RC1) | |
59 | ||
60 | ||
61 | Changes in synapse v0.17.0-rc1 (2016-07-28) | |
62 | =========================================== | |
63 | ||
64 | This release changes the LDAP configuration format in a backwards incompatible | |
65 | way, see PR #843 for details. | |
66 | ||
67 | ||
68 | Features: | |
69 | ||
70 | * Add purge_media_cache admin API (PR #902) | |
71 | * Add deactivate account admin API (PR #903) | |
72 | * Add optional pepper to password hashing (PR #907, #910 by KentShikama) | |
73 | * Add an admin option to shared secret registration (breaks backwards compat) | |
74 | (PR #909) | |
75 | * Add purge local room history API (PR #911, #923, #924) | |
76 | * Add requestToken endpoints (PR #915) | |
77 | * Add an /account/deactivate endpoint (PR #921) | |
78 | * Add filter param to /messages. Add 'contains_url' to filter. (PR #922) | |
79 | * Add device_id support to /login (PR #929) | |
80 | * Add device_id support to /v2/register flow. (PR #937, #942) | |
81 | * Add GET /devices endpoint (PR #939, #944) | |
82 | * Add GET /device/{deviceId} (PR #943) | |
83 | * Add update and delete APIs for devices (PR #949) | |
84 | ||
85 | ||
86 | Changes: | |
87 | ||
88 | * Rewrite LDAP Authentication against ldap3 (PR #843 by mweinelt) | |
89 | * Linearize some federation endpoints based on (origin, room_id) (PR #879) | |
90 | * Remove the legacy v0 content upload API. (PR #888) | |
91 | * Use similar naming we use in email notifs for push (PR #894) | |
92 | * Optionally include password hash in createUser endpoint (PR #905 by | |
93 | KentShikama) | |
94 | * Use a query that postgresql optimises better for get_events_around (PR #906) | |
95 | * Fall back to 'username' if 'user' is not given for appservice registration. | |
96 | (PR #927 by Half-Shot) | |
97 | * Add metrics for psutil derived memory usage (PR #936) | |
98 | * Record device_id in client_ips (PR #938) | |
99 | * Send the correct host header when fetching keys (PR #941) | |
100 | * Log the hostname the reCAPTCHA was completed on (PR #946) | |
101 | * Make the device id on e2e key upload optional (PR #956) | |
102 | * Add r0.2.0 to the "supported versions" list (PR #960) | |
103 | * Don't include name of room for invites in push (PR #961) | |
104 | ||
105 | ||
106 | Bug fixes: | |
107 | ||
108 | * Fix substitution failure in mail template (PR #887) | |
109 | * Put most recent 20 messages in email notif (PR #892) | |
110 | * Ensure that the guest user is in the database when upgrading accounts | |
111 | (PR #914) | |
112 | * Fix various edge cases in auth handling (PR #919) | |
113 | * Fix 500 ISE when sending alias event without a state_key (PR #925) | |
114 | * Fix bug where we stored rejections in the state_group, persist all | |
115 | rejections (PR #948) | |
116 | * Fix lack of check of if the user is banned when handling 3pid invites | |
117 | (PR #952) | |
118 | * Fix a couple of bugs in the transaction and keyring code (PR #954, #955) | |
119 | ||
120 | ||
121 | ||
0 | 122 | Changes in synapse v0.16.1-r1 (2016-07-08) |
1 | 123 | ========================================== |
2 | 124 |
13 | 13 | recursive-include res * |
14 | 14 | recursive-include scripts * |
15 | 15 | recursive-include scripts-dev * |
16 | recursive-include synapse *.pyi | |
16 | 17 | recursive-include tests *.py |
17 | 18 | |
18 | 19 | recursive-include synapse/static *.css |
22 | 23 | |
23 | 24 | exclude jenkins.sh |
24 | 25 | exclude jenkins*.sh |
26 | exclude jenkins* | |
27 | recursive-exclude jenkins *.sh | |
25 | 28 | |
26 | 29 | prune demo/etc |
10 | 10 | like ``#matrix:matrix.org`` or ``#test:localhost:8448``. |
11 | 11 | |
12 | 12 | - Matrix user IDs look like ``@matthew:matrix.org`` (although in the future |
13 | you will normally refer to yourself and others using a 3PID: email | |
14 | address, phone number, etc rather than manipulating Matrix user IDs) | |
13 | you will normally refer to yourself and others using a third party identifier | |
14 | (3PID): email address, phone number, etc rather than manipulating Matrix user IDs) | |
15 | 15 | |
16 | 16 | The overall architecture is:: |
17 | 17 | |
444 | 444 | IDs: |
445 | 445 | |
446 | 446 | 1) Use the machine's own hostname as available on public DNS in the form of |
447 | its A or AAAA records. This is easier to set up initially, perhaps for | |
447 | its A records. This is easier to set up initially, perhaps for | |
448 | 448 | testing, but lacks the flexibility of SRV. |
449 | 449 | |
450 | 450 | 2) Set up a SRV record for your domain name. This requires you create a SRV |
26 | 26 | # Pull the latest version of the master branch. |
27 | 27 | git pull |
28 | 28 | # Update the versions of synapse's python dependencies. |
29 | python synapse/python_dependencies.py | xargs -n1 pip install | |
29 | python synapse/python_dependencies.py | xargs -n1 pip install --upgrade | |
30 | 30 | |
31 | 31 | |
32 | 32 | Upgrading to v0.15.0 |
0 | Admin APIs | |
1 | ========== | |
2 | ||
3 | This directory includes documentation for the various synapse specific admin | |
4 | APIs available. | |
5 | ||
6 | Only users that are server admins can use these APIs. A user can be marked as a | |
7 | server admin by updating the database directly, e.g.: | |
8 | ||
9 | ``UPDATE users SET admin = 1 WHERE name = '@foo:bar.com'`` | |
10 | ||
11 | Restarting may be required for the changes to register. |
0 | Purge History API | |
1 | ================= | |
2 | ||
3 | The purge history API allows server admins to purge historic events from their | |
4 | database, reclaiming disk space. | |
5 | ||
6 | Depending on the amount of history being purged a call to the API may take | |
7 | several minutes or longer. During this period users will not be able to | |
8 | paginate further back in the room from the point being purged from. | |
9 | ||
10 | The API is simply: | |
11 | ||
12 | ``POST /_matrix/client/r0/admin/purge_history/<room_id>/<event_id>`` | |
13 | ||
14 | including an ``access_token`` of a server admin. |
0 | Purge Remote Media API | |
1 | ====================== | |
2 | ||
3 | The purge remote media API allows server admins to purge old cached remote | |
4 | media. | |
5 | ||
6 | The API is:: | |
7 | ||
8 | POST /_matrix/client/r0/admin/purge_media_cache | |
9 | ||
10 | { | |
11 | "before_ts": <unix_timestamp_in_ms> | |
12 | } | |
13 | ||
14 | Which will remove all cached media that was last accessed before | |
15 | ``<unix_timestamp_in_ms>``. | |
16 | ||
17 | If the user re-requests purged remote media, synapse will re-request the media | |
18 | from the originating server. |
42 | 42 | together, or want to deliberately extend or preserve vertical/horizontal |
43 | 43 | space) |
44 | 44 | |
45 | Comments should follow the google code style. This is so that we can generate | |
46 | documentation with sphinx (http://sphinxcontrib-napoleon.readthedocs.org/en/latest/) | |
45 | Comments should follow the `google code style <http://google.github.io/styleguide/pyguide.html?showone=Comments#Comments>`_. | |
46 | This is so that we can generate documentation with | |
47 | `sphinx <http://sphinxcontrib-napoleon.readthedocs.org/en/latest/>`_. See the | |
48 | `examples <http://sphinxcontrib-napoleon.readthedocs.io/en/latest/example_google.html>`_ | |
49 | in the sphinx documentation. | |
47 | 50 | |
48 | 51 | Code should pass pep8 --max-line-length=100 without any warnings. |
8 | 8 | server through the use of a secret shared between the Home Server and the |
9 | 9 | TURN server. |
10 | 10 | |
11 | This document described how to install coturn | |
12 | (https://code.google.com/p/coturn/) which also supports the TURN REST API, | |
11 | This document describes how to install coturn | |
12 | (https://github.com/coturn/coturn) which also supports the TURN REST API, | |
13 | 13 | and integrate it with synapse. |
14 | 14 | |
15 | 15 | coturn Setup |
16 | 16 | ============ |
17 | 17 | |
18 | You may be able to setup coturn via your package manager, or set it up manually using the usual ``configure, make, make install`` process. | |
19 | ||
18 | 20 | 1. Check out coturn:: |
19 | svn checkout http://coturn.googlecode.com/svn/trunk/ coturn | |
21 | ||
22 | git clone https://github.com/coturn/coturn.git coturn | |
20 | 23 | cd coturn |
21 | 24 | |
22 | 25 | 2. Configure it:: |
26 | ||
23 | 27 | ./configure |
24 | 28 | |
25 | You may need to install libevent2: if so, you should do so | |
29 | You may need to install ``libevent2``: if so, you should do so | |
26 | 30 | in the way recommended by your operating system. |
27 | 31 | You can ignore warnings about lack of database support: a |
28 | 32 | database is unnecessary for this purpose. |
29 | 33 | |
30 | 34 | 3. Build and install it:: |
35 | ||
31 | 36 | make |
32 | 37 | make install |
33 | 38 | |
34 | 4. Make a config file in /etc/turnserver.conf. You can customise | |
35 | a config file from turnserver.conf.default. The relevant | |
39 | 4. Create or edit the config file in ``/etc/turnserver.conf``. The relevant | |
36 | 40 | lines, with example values, are:: |
37 | 41 | |
38 | 42 | lt-cred-mech |
40 | 44 | static-auth-secret=[your secret key here] |
41 | 45 | realm=turn.myserver.org |
42 | 46 | |
43 | See turnserver.conf.default for explanations of the options. | |
47 | See turnserver.conf for explanations of the options. | |
44 | 48 | One way to generate the static-auth-secret is with pwgen:: |
45 | 49 | |
46 | 50 | pwgen -s 64 1 |
53 | 57 | import your private key and certificate. |
54 | 58 | |
55 | 59 | 7. Start the turn server:: |
60 | ||
56 | 61 | bin/turnserver -o |
57 | 62 | |
58 | 63 |
0 | #! /bin/bash | |
1 | ||
2 | # This clones a project from github into a named subdirectory | |
3 | # If the project has a branch with the same name as this branch | |
4 | # then it will checkout that branch after cloning. | |
5 | # Otherwise it will checkout "origin/develop." | |
6 | # The first argument is the name of the directory to checkout | |
7 | # the branch into. | |
8 | # The second argument is the URL of the remote repository to checkout. | |
9 | # Usually something like https://github.com/matrix-org/sytest.git | |
10 | ||
11 | set -eux | |
12 | ||
13 | NAME=$1 | |
14 | PROJECT=$2 | |
15 | BASE=".$NAME-base" | |
16 | ||
17 | # Update our mirror. | |
18 | if [ ! -d ".$NAME-base" ]; then | |
19 | # Create a local mirror of the source repository. | |
20 | # This saves us from having to download the entire repository | |
21 | # when this script is next run. | |
22 | git clone "$PROJECT" "$BASE" --mirror | |
23 | else | |
24 | # Fetch any updates from the source repository. | |
25 | (cd "$BASE"; git fetch -p) | |
26 | fi | |
27 | ||
28 | # Remove the existing repository so that we have a clean copy | |
29 | rm -rf "$NAME" | |
30 | # Cloning with --shared means that we will share portions of the | |
31 | # .git directory with our local mirror. | |
32 | git clone "$BASE" "$NAME" --shared | |
33 | ||
34 | # Jenkins may have supplied us with the name of the branch in the | |
35 | # environment. Otherwise we will have to guess based on the current | |
36 | # commit. | |
37 | : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} | |
38 | cd "$NAME" | |
39 | # check out the relevant branch | |
40 | git checkout "${GIT_BRANCH}" || ( | |
41 | echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" | |
42 | git checkout "origin/develop" | |
43 | ) |
0 | #! /bin/bash | |
1 | ||
2 | cd "`dirname $0`/.." | |
3 | ||
4 | TOX_DIR=$WORKSPACE/.tox | |
5 | ||
6 | mkdir -p $TOX_DIR | |
7 | ||
8 | if ! [ $TOX_DIR -ef .tox ]; then | |
9 | ln -s "$TOX_DIR" .tox | |
10 | fi | |
11 | ||
12 | # set up the virtualenv | |
13 | tox -e py27 --notest -v | |
14 | ||
15 | TOX_BIN=$TOX_DIR/py27/bin | |
16 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install | |
17 | $TOX_BIN/pip install lxml | |
18 | $TOX_BIN/pip install psycopg2 |
3 | 3 | |
4 | 4 | : ${WORKSPACE:="$(pwd)"} |
5 | 5 | |
6 | export WORKSPACE | |
6 | 7 | export PYTHONDONTWRITEBYTECODE=yep |
7 | 8 | export SYNAPSE_CACHE_FACTOR=1 |
8 | 9 | |
9 | # Output test results as junit xml | |
10 | export TRIAL_FLAGS="--reporter=subunit" | |
11 | export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml" | |
12 | # Write coverage reports to a separate file for each process | |
13 | export COVERAGE_OPTS="-p" | |
14 | export DUMP_COVERAGE_COMMAND="coverage help" | |
10 | ./jenkins/prepare_synapse.sh | |
11 | ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git | |
12 | ./jenkins/clone.sh dendron https://github.com/matrix-org/dendron.git | |
13 | ./dendron/jenkins/build_dendron.sh | |
14 | ./sytest/jenkins/prep_sytest_for_postgres.sh | |
15 | 15 | |
16 | # Output flake8 violations to violations.flake8.log | |
17 | # Don't exit with non-0 status code on Jenkins, | |
18 | # so that the build steps continue and a later step can decided whether to | |
19 | # UNSTABLE or FAILURE this build. | |
20 | export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?" | |
21 | ||
22 | rm .coverage* || echo "No coverage files to remove" | |
23 | ||
24 | tox --notest -e py27 | |
25 | ||
26 | TOX_BIN=$WORKSPACE/.tox/py27/bin | |
27 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install | |
28 | $TOX_BIN/pip install psycopg2 | |
29 | $TOX_BIN/pip install lxml | |
30 | ||
31 | : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} | |
32 | ||
33 | if [[ ! -e .dendron-base ]]; then | |
34 | git clone https://github.com/matrix-org/dendron.git .dendron-base --mirror | |
35 | else | |
36 | (cd .dendron-base; git fetch -p) | |
37 | fi | |
38 | ||
39 | rm -rf dendron | |
40 | git clone .dendron-base dendron --shared | |
41 | cd dendron | |
42 | ||
43 | : ${GOPATH:=${WORKSPACE}/.gopath} | |
44 | if [[ "${GOPATH}" != *:* ]]; then | |
45 | mkdir -p "${GOPATH}" | |
46 | export PATH="${GOPATH}/bin:${PATH}" | |
47 | fi | |
48 | export GOPATH | |
49 | ||
50 | git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) | |
51 | ||
52 | go get github.com/constabulary/gb/... | |
53 | gb generate | |
54 | gb build | |
55 | ||
56 | cd .. | |
57 | ||
58 | ||
59 | if [[ ! -e .sytest-base ]]; then | |
60 | git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror | |
61 | else | |
62 | (cd .sytest-base; git fetch -p) | |
63 | fi | |
64 | ||
65 | rm -rf sytest | |
66 | git clone .sytest-base sytest --shared | |
67 | cd sytest | |
68 | ||
69 | git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) | |
70 | ||
71 | : ${PORT_BASE:=8000} | |
72 | ||
73 | ./jenkins/prep_sytest_for_postgres.sh | |
74 | ||
75 | mkdir -p var | |
76 | ||
77 | echo >&2 "Running sytest with PostgreSQL"; | |
78 | ./jenkins/install_and_run.sh --python $TOX_BIN/python \ | |
79 | --synapse-directory $WORKSPACE \ | |
80 | --dendron $WORKSPACE/dendron/bin/dendron \ | |
81 | --pusher \ | |
82 | --synchrotron \ | |
83 | --port-base $PORT_BASE | |
84 | ||
85 | cd .. | |
16 | ./sytest/jenkins/install_and_run.sh \ | |
17 | --synapse-directory $WORKSPACE \ | |
18 | --dendron $WORKSPACE/dendron/bin/dendron \ | |
19 | --pusher \ | |
20 | --synchrotron \ | |
21 | --federation-reader \ |
3 | 3 | |
4 | 4 | : ${WORKSPACE:="$(pwd)"} |
5 | 5 | |
6 | export WORKSPACE | |
6 | 7 | export PYTHONDONTWRITEBYTECODE=yep |
7 | 8 | export SYNAPSE_CACHE_FACTOR=1 |
8 | 9 | |
9 | # Output test results as junit xml | |
10 | export TRIAL_FLAGS="--reporter=subunit" | |
11 | export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml" | |
12 | # Write coverage reports to a separate file for each process | |
13 | export COVERAGE_OPTS="-p" | |
14 | export DUMP_COVERAGE_COMMAND="coverage help" | |
10 | ./jenkins/prepare_synapse.sh | |
11 | ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git | |
15 | 12 | |
16 | # Output flake8 violations to violations.flake8.log | |
17 | # Don't exit with non-0 status code on Jenkins, | |
18 | # so that the build steps continue and a later step can decided whether to | |
19 | # UNSTABLE or FAILURE this build. | |
20 | export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?" | |
13 | ./sytest/jenkins/prep_sytest_for_postgres.sh | |
21 | 14 | |
22 | rm .coverage* || echo "No coverage files to remove" | |
23 | ||
24 | tox --notest -e py27 | |
25 | ||
26 | TOX_BIN=$WORKSPACE/.tox/py27/bin | |
27 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install | |
28 | $TOX_BIN/pip install psycopg2 | |
29 | $TOX_BIN/pip install lxml | |
30 | ||
31 | : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} | |
32 | ||
33 | if [[ ! -e .sytest-base ]]; then | |
34 | git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror | |
35 | else | |
36 | (cd .sytest-base; git fetch -p) | |
37 | fi | |
38 | ||
39 | rm -rf sytest | |
40 | git clone .sytest-base sytest --shared | |
41 | cd sytest | |
42 | ||
43 | git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) | |
44 | ||
45 | : ${PORT_BASE:=8000} | |
46 | ||
47 | ./jenkins/prep_sytest_for_postgres.sh | |
48 | ||
49 | echo >&2 "Running sytest with PostgreSQL"; | |
50 | ./jenkins/install_and_run.sh --coverage \ | |
51 | --python $TOX_BIN/python \ | |
52 | --synapse-directory $WORKSPACE \ | |
53 | --port-base $PORT_BASE | |
54 | ||
55 | cd .. | |
56 | cp sytest/.coverage.* . | |
57 | ||
58 | # Combine the coverage reports | |
59 | echo "Combining:" .coverage.* | |
60 | $TOX_BIN/python -m coverage combine | |
61 | # Output coverage to coverage.xml | |
62 | $TOX_BIN/coverage xml -o coverage.xml | |
15 | ./sytest/jenkins/install_and_run.sh \ | |
16 | --synapse-directory $WORKSPACE \ |
3 | 3 | |
4 | 4 | : ${WORKSPACE:="$(pwd)"} |
5 | 5 | |
6 | export WORKSPACE | |
6 | 7 | export PYTHONDONTWRITEBYTECODE=yep |
7 | 8 | export SYNAPSE_CACHE_FACTOR=1 |
8 | 9 | |
9 | # Output test results as junit xml | |
10 | export TRIAL_FLAGS="--reporter=subunit" | |
11 | export TOXSUFFIX="| subunit-1to2 | subunit2junitxml --no-passthrough --output-to=results.xml" | |
12 | # Write coverage reports to a separate file for each process | |
13 | export COVERAGE_OPTS="-p" | |
14 | export DUMP_COVERAGE_COMMAND="coverage help" | |
10 | ./jenkins/prepare_synapse.sh | |
11 | ./jenkins/clone.sh sytest https://github.com/matrix-org/sytest.git | |
15 | 12 | |
16 | # Output flake8 violations to violations.flake8.log | |
17 | # Don't exit with non-0 status code on Jenkins, | |
18 | # so that the build steps continue and a later step can decided whether to | |
19 | # UNSTABLE or FAILURE this build. | |
20 | export PEP8SUFFIX="--output-file=violations.flake8.log || echo flake8 finished with status code \$?" | |
21 | ||
22 | rm .coverage* || echo "No coverage files to remove" | |
23 | ||
24 | tox --notest -e py27 | |
25 | TOX_BIN=$WORKSPACE/.tox/py27/bin | |
26 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install | |
27 | $TOX_BIN/pip install lxml | |
28 | ||
29 | : ${GIT_BRANCH:="origin/$(git rev-parse --abbrev-ref HEAD)"} | |
30 | ||
31 | if [[ ! -e .sytest-base ]]; then | |
32 | git clone https://github.com/matrix-org/sytest.git .sytest-base --mirror | |
33 | else | |
34 | (cd .sytest-base; git fetch -p) | |
35 | fi | |
36 | ||
37 | rm -rf sytest | |
38 | git clone .sytest-base sytest --shared | |
39 | cd sytest | |
40 | ||
41 | git checkout "${GIT_BRANCH}" || (echo >&2 "No ref ${GIT_BRANCH} found, falling back to develop" ; git checkout develop) | |
42 | ||
43 | : ${PORT_BASE:=8500} | |
44 | ./jenkins/install_and_run.sh --coverage \ | |
45 | --python $TOX_BIN/python \ | |
46 | --synapse-directory $WORKSPACE \ | |
47 | --port-base $PORT_BASE | |
48 | ||
49 | cd .. | |
50 | cp sytest/.coverage.* . | |
51 | ||
52 | # Combine the coverage reports | |
53 | echo "Combining:" .coverage.* | |
54 | $TOX_BIN/python -m coverage combine | |
55 | # Output coverage to coverage.xml | |
56 | $TOX_BIN/coverage xml -o coverage.xml | |
13 | ./sytest/jenkins/install_and_run.sh \ | |
14 | --synapse-directory $WORKSPACE \ |
21 | 21 | |
22 | 22 | rm .coverage* || echo "No coverage files to remove" |
23 | 23 | |
24 | tox --notest -e py27 | |
25 | TOX_BIN=$WORKSPACE/.tox/py27/bin | |
26 | python synapse/python_dependencies.py | xargs -n1 $TOX_BIN/pip install | |
27 | ||
24 | 28 | tox -e py27 |
35 | 35 | <div class="debug"> |
36 | 36 | Sending email at {{ reason.now|format_ts("%c") }} due to activity in room {{ reason.room_name }} because |
37 | 37 | an event was received at {{ reason.received_at|format_ts("%c") }} |
38 | which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} (delay_before_mail_ms) mins ago, | |
38 | which is more than {{ "%.1f"|format(reason.delay_before_mail_ms / (60*1000)) }} ({{ reason.delay_before_mail_ms }}) mins ago, | |
39 | 39 | {% if reason.last_sent_ts %} |
40 | 40 | and the last time we sent a mail for this room was {{ reason.last_sent_ts|format_ts("%c") }}, |
41 | 41 | which is more than {{ "%.1f"|format(reason.throttle_ms / (60*1000)) }} (current throttle_ms) mins ago. |
0 | 0 | #!/usr/bin/env python |
1 | 1 | |
2 | 2 | import argparse |
3 | ||
4 | import sys | |
5 | ||
3 | 6 | import bcrypt |
4 | 7 | import getpass |
5 | 8 | |
9 | import yaml | |
10 | ||
6 | 11 | bcrypt_rounds=12 |
12 | password_pepper = "" | |
7 | 13 | |
8 | 14 | def prompt_for_pass(): |
9 | 15 | password = getpass.getpass("Password: ") |
27 | 33 | default=None, |
28 | 34 | help="New password for user. Will prompt if omitted.", |
29 | 35 | ) |
36 | parser.add_argument( | |
37 | "-c", "--config", | |
38 | type=argparse.FileType('r'), | |
39 | help="Path to server config file. Used to read in bcrypt_rounds and password_pepper.", | |
40 | ) | |
30 | 41 | |
31 | 42 | args = parser.parse_args() |
43 | if "config" in args and args.config: | |
44 | config = yaml.safe_load(args.config) | |
45 | bcrypt_rounds = config.get("bcrypt_rounds", bcrypt_rounds) | |
46 | password_config = config.get("password_config", {}) | |
47 | password_pepper = password_config.get("pepper", password_pepper) | |
32 | 48 | password = args.password |
33 | 49 | |
34 | 50 | if not password: |
35 | 51 | password = prompt_for_pass() |
36 | 52 | |
37 | print bcrypt.hashpw(password, bcrypt.gensalt(bcrypt_rounds)) | |
53 | print bcrypt.hashpw(password + password_pepper, bcrypt.gensalt(bcrypt_rounds)) | |
38 | 54 |
24 | 24 | import yaml |
25 | 25 | |
26 | 26 | |
27 | def request_registration(user, password, server_location, shared_secret): | |
27 | def request_registration(user, password, server_location, shared_secret, admin=False): | |
28 | 28 | mac = hmac.new( |
29 | 29 | key=shared_secret, |
30 | msg=user, | |
31 | 30 | digestmod=hashlib.sha1, |
32 | ).hexdigest() | |
31 | ) | |
32 | ||
33 | mac.update(user) | |
34 | mac.update("\x00") | |
35 | mac.update(password) | |
36 | mac.update("\x00") | |
37 | mac.update("admin" if admin else "notadmin") | |
38 | ||
39 | mac = mac.hexdigest() | |
33 | 40 | |
34 | 41 | data = { |
35 | 42 | "user": user, |
36 | 43 | "password": password, |
37 | 44 | "mac": mac, |
38 | 45 | "type": "org.matrix.login.shared_secret", |
46 | "admin": admin, | |
39 | 47 | } |
40 | 48 | |
41 | 49 | server_location = server_location.rstrip("/") |
67 | 75 | sys.exit(1) |
68 | 76 | |
69 | 77 | |
70 | def register_new_user(user, password, server_location, shared_secret): | |
78 | def register_new_user(user, password, server_location, shared_secret, admin): | |
71 | 79 | if not user: |
72 | 80 | try: |
73 | 81 | default_user = getpass.getuser() |
98 | 106 | print "Passwords do not match" |
99 | 107 | sys.exit(1) |
100 | 108 | |
101 | request_registration(user, password, server_location, shared_secret) | |
109 | if not admin: | |
110 | admin = raw_input("Make admin [no]: ") | |
111 | if admin in ("y", "yes", "true"): | |
112 | admin = True | |
113 | else: | |
114 | admin = False | |
115 | ||
116 | request_registration(user, password, server_location, shared_secret, bool(admin)) | |
102 | 117 | |
103 | 118 | |
104 | 119 | if __name__ == "__main__": |
117 | 132 | "-p", "--password", |
118 | 133 | default=None, |
119 | 134 | help="New password for user. Will prompt if omitted.", |
135 | ) | |
136 | parser.add_argument( | |
137 | "-a", "--admin", | |
138 | action="store_true", | |
139 | help="Register new user as an admin. Will prompt if omitted.", | |
120 | 140 | ) |
121 | 141 | |
122 | 142 | group = parser.add_mutually_exclusive_group(required=True) |
150 | 170 | else: |
151 | 171 | secret = args.shared_secret |
152 | 172 | |
153 | register_new_user(args.user, args.password, args.server_url, secret) | |
173 | register_new_user(args.user, args.password, args.server_url, secret, args.admin) |
33 | 33 | |
34 | 34 | |
35 | 35 | BOOLEAN_COLUMNS = { |
36 | "events": ["processed", "outlier"], | |
36 | "events": ["processed", "outlier", "contains_url"], | |
37 | 37 | "rooms": ["is_public"], |
38 | 38 | "event_edges": ["is_state"], |
39 | 39 | "presence_list": ["accepted"], |
91 | 91 | |
92 | 92 | _simple_select_onecol_txn = SQLBaseStore.__dict__["_simple_select_onecol_txn"] |
93 | 93 | _simple_select_onecol = SQLBaseStore.__dict__["_simple_select_onecol"] |
94 | _simple_select_one = SQLBaseStore.__dict__["_simple_select_one"] | |
95 | _simple_select_one_txn = SQLBaseStore.__dict__["_simple_select_one_txn"] | |
94 | 96 | _simple_select_one_onecol = SQLBaseStore.__dict__["_simple_select_one_onecol"] |
95 | _simple_select_one_onecol_txn = SQLBaseStore.__dict__["_simple_select_one_onecol_txn"] | |
97 | _simple_select_one_onecol_txn = SQLBaseStore.__dict__[ | |
98 | "_simple_select_one_onecol_txn" | |
99 | ] | |
96 | 100 | |
97 | 101 | _simple_update_one = SQLBaseStore.__dict__["_simple_update_one"] |
98 | 102 | _simple_update_one_txn = SQLBaseStore.__dict__["_simple_update_one_txn"] |
157 | 161 | def setup_table(self, table): |
158 | 162 | if table in APPEND_ONLY_TABLES: |
159 | 163 | # It's safe to just carry on inserting. |
160 | next_chunk = yield self.postgres_store._simple_select_one_onecol( | |
164 | row = yield self.postgres_store._simple_select_one( | |
161 | 165 | table="port_from_sqlite3", |
162 | 166 | keyvalues={"table_name": table}, |
163 | retcol="rowid", | |
167 | retcols=("forward_rowid", "backward_rowid"), | |
164 | 168 | allow_none=True, |
165 | 169 | ) |
166 | 170 | |
167 | 171 | total_to_port = None |
168 | if next_chunk is None: | |
172 | if row is None: | |
169 | 173 | if table == "sent_transactions": |
170 | next_chunk, already_ported, total_to_port = ( | |
174 | forward_chunk, already_ported, total_to_port = ( | |
171 | 175 | yield self._setup_sent_transactions() |
172 | 176 | ) |
177 | backward_chunk = 0 | |
173 | 178 | else: |
174 | 179 | yield self.postgres_store._simple_insert( |
175 | 180 | table="port_from_sqlite3", |
176 | values={"table_name": table, "rowid": 1} | |
181 | values={ | |
182 | "table_name": table, | |
183 | "forward_rowid": 1, | |
184 | "backward_rowid": 0, | |
185 | } | |
177 | 186 | ) |
178 | 187 | |
179 | next_chunk = 1 | |
188 | forward_chunk = 1 | |
189 | backward_chunk = 0 | |
180 | 190 | already_ported = 0 |
191 | else: | |
192 | forward_chunk = row["forward_rowid"] | |
193 | backward_chunk = row["backward_rowid"] | |
181 | 194 | |
182 | 195 | if total_to_port is None: |
183 | 196 | already_ported, total_to_port = yield self._get_total_count_to_port( |
184 | table, next_chunk | |
197 | table, forward_chunk, backward_chunk | |
185 | 198 | ) |
186 | 199 | else: |
187 | 200 | def delete_all(txn): |
195 | 208 | |
196 | 209 | yield self.postgres_store._simple_insert( |
197 | 210 | table="port_from_sqlite3", |
198 | values={"table_name": table, "rowid": 0} | |
199 | ) | |
200 | ||
201 | next_chunk = 1 | |
211 | values={ | |
212 | "table_name": table, | |
213 | "forward_rowid": 1, | |
214 | "backward_rowid": 0, | |
215 | } | |
216 | ) | |
217 | ||
218 | forward_chunk = 1 | |
219 | backward_chunk = 0 | |
202 | 220 | |
203 | 221 | already_ported, total_to_port = yield self._get_total_count_to_port( |
204 | table, next_chunk | |
205 | ) | |
206 | ||
207 | defer.returnValue((table, already_ported, total_to_port, next_chunk)) | |
222 | table, forward_chunk, backward_chunk | |
223 | ) | |
224 | ||
225 | defer.returnValue( | |
226 | (table, already_ported, total_to_port, forward_chunk, backward_chunk) | |
227 | ) | |
208 | 228 | |
209 | 229 | @defer.inlineCallbacks |
210 | def handle_table(self, table, postgres_size, table_size, next_chunk): | |
230 | def handle_table(self, table, postgres_size, table_size, forward_chunk, | |
231 | backward_chunk): | |
211 | 232 | if not table_size: |
212 | 233 | return |
213 | 234 | |
214 | 235 | self.progress.add_table(table, postgres_size, table_size) |
215 | 236 | |
216 | 237 | if table == "event_search": |
217 | yield self.handle_search_table(postgres_size, table_size, next_chunk) | |
238 | yield self.handle_search_table( | |
239 | postgres_size, table_size, forward_chunk, backward_chunk | |
240 | ) | |
218 | 241 | return |
219 | 242 | |
220 | select = ( | |
243 | forward_select = ( | |
221 | 244 | "SELECT rowid, * FROM %s WHERE rowid >= ? ORDER BY rowid LIMIT ?" |
222 | 245 | % (table,) |
223 | 246 | ) |
224 | 247 | |
248 | backward_select = ( | |
249 | "SELECT rowid, * FROM %s WHERE rowid <= ? ORDER BY rowid LIMIT ?" | |
250 | % (table,) | |
251 | ) | |
252 | ||
253 | do_forward = [True] | |
254 | do_backward = [True] | |
255 | ||
225 | 256 | while True: |
226 | 257 | def r(txn): |
227 | txn.execute(select, (next_chunk, self.batch_size,)) | |
228 | rows = txn.fetchall() | |
229 | headers = [column[0] for column in txn.description] | |
230 | ||
231 | return headers, rows | |
232 | ||
233 | headers, rows = yield self.sqlite_store.runInteraction("select", r) | |
234 | ||
235 | if rows: | |
236 | next_chunk = rows[-1][0] + 1 | |
237 | ||
258 | forward_rows = [] | |
259 | backward_rows = [] | |
260 | if do_forward[0]: | |
261 | txn.execute(forward_select, (forward_chunk, self.batch_size,)) | |
262 | forward_rows = txn.fetchall() | |
263 | if not forward_rows: | |
264 | do_forward[0] = False | |
265 | ||
266 | if do_backward[0]: | |
267 | txn.execute(backward_select, (backward_chunk, self.batch_size,)) | |
268 | backward_rows = txn.fetchall() | |
269 | if not backward_rows: | |
270 | do_backward[0] = False | |
271 | ||
272 | if forward_rows or backward_rows: | |
273 | headers = [column[0] for column in txn.description] | |
274 | else: | |
275 | headers = None | |
276 | ||
277 | return headers, forward_rows, backward_rows | |
278 | ||
279 | headers, frows, brows = yield self.sqlite_store.runInteraction( | |
280 | "select", r | |
281 | ) | |
282 | ||
283 | if frows or brows: | |
284 | if frows: | |
285 | forward_chunk = max(row[0] for row in frows) + 1 | |
286 | if brows: | |
287 | backward_chunk = min(row[0] for row in brows) - 1 | |
288 | ||
289 | rows = frows + brows | |
238 | 290 | self._convert_rows(table, headers, rows) |
239 | 291 | |
240 | 292 | def insert(txn): |
246 | 298 | txn, |
247 | 299 | table="port_from_sqlite3", |
248 | 300 | keyvalues={"table_name": table}, |
249 | updatevalues={"rowid": next_chunk}, | |
301 | updatevalues={ | |
302 | "forward_rowid": forward_chunk, | |
303 | "backward_rowid": backward_chunk, | |
304 | }, | |
250 | 305 | ) |
251 | 306 | |
252 | 307 | yield self.postgres_store.execute(insert) |
258 | 313 | return |
259 | 314 | |
260 | 315 | @defer.inlineCallbacks |
261 | def handle_search_table(self, postgres_size, table_size, next_chunk): | |
316 | def handle_search_table(self, postgres_size, table_size, forward_chunk, | |
317 | backward_chunk): | |
262 | 318 | select = ( |
263 | 319 | "SELECT es.rowid, es.*, e.origin_server_ts, e.stream_ordering" |
264 | 320 | " FROM event_search as es" |
269 | 325 | |
270 | 326 | while True: |
271 | 327 | def r(txn): |
272 | txn.execute(select, (next_chunk, self.batch_size,)) | |
328 | txn.execute(select, (forward_chunk, self.batch_size,)) | |
273 | 329 | rows = txn.fetchall() |
274 | 330 | headers = [column[0] for column in txn.description] |
275 | 331 | |
278 | 334 | headers, rows = yield self.sqlite_store.runInteraction("select", r) |
279 | 335 | |
280 | 336 | if rows: |
281 | next_chunk = rows[-1][0] + 1 | |
337 | forward_chunk = rows[-1][0] + 1 | |
282 | 338 | |
283 | 339 | # We have to treat event_search differently since it has a |
284 | 340 | # different structure in the two different databases. |
311 | 367 | txn, |
312 | 368 | table="port_from_sqlite3", |
313 | 369 | keyvalues={"table_name": "event_search"}, |
314 | updatevalues={"rowid": next_chunk}, | |
370 | updatevalues={ | |
371 | "forward_rowid": forward_chunk, | |
372 | "backward_rowid": backward_chunk, | |
373 | }, | |
315 | 374 | ) |
316 | 375 | |
317 | 376 | yield self.postgres_store.execute(insert) |
322 | 381 | |
323 | 382 | else: |
324 | 383 | return |
325 | ||
326 | 384 | |
327 | 385 | def setup_db(self, db_config, database_engine): |
328 | 386 | db_conn = database_engine.module.connect( |
394 | 452 | txn.execute( |
395 | 453 | "CREATE TABLE port_from_sqlite3 (" |
396 | 454 | " table_name varchar(100) NOT NULL UNIQUE," |
397 | " rowid bigint NOT NULL" | |
455 | " forward_rowid bigint NOT NULL," | |
456 | " backward_rowid bigint NOT NULL" | |
398 | 457 | ")" |
399 | 458 | ) |
459 | ||
460 | # The old port script created a table with just a "rowid" column. | |
461 | # We want people to be able to rerun this script from an old port | |
462 | # so that they can pick up any missing events that were not | |
463 | # ported across. | |
464 | def alter_table(txn): | |
465 | txn.execute( | |
466 | "ALTER TABLE IF EXISTS port_from_sqlite3" | |
467 | " RENAME rowid TO forward_rowid" | |
468 | ) | |
469 | txn.execute( | |
470 | "ALTER TABLE IF EXISTS port_from_sqlite3" | |
471 | " ADD backward_rowid bigint NOT NULL DEFAULT 0" | |
472 | ) | |
473 | ||
474 | try: | |
475 | yield self.postgres_store.runInteraction( | |
476 | "alter_table", alter_table | |
477 | ) | |
478 | except Exception as e: | |
479 | logger.info("Failed to create port table: %s", e) | |
400 | 480 | |
401 | 481 | try: |
402 | 482 | yield self.postgres_store.runInteraction( |
457 | 537 | @defer.inlineCallbacks |
458 | 538 | def _setup_sent_transactions(self): |
459 | 539 | # Only save things from the last day |
460 | yesterday = int(time.time()*1000) - 86400000 | |
540 | yesterday = int(time.time() * 1000) - 86400000 | |
461 | 541 | |
462 | 542 | # And save the max transaction id from each destination |
463 | 543 | select = ( |
513 | 593 | |
514 | 594 | yield self.postgres_store._simple_insert( |
515 | 595 | table="port_from_sqlite3", |
516 | values={"table_name": "sent_transactions", "rowid": next_chunk} | |
596 | values={ | |
597 | "table_name": "sent_transactions", | |
598 | "forward_rowid": next_chunk, | |
599 | "backward_rowid": 0, | |
600 | } | |
517 | 601 | ) |
518 | 602 | |
519 | 603 | def get_sent_table_size(txn): |
534 | 618 | defer.returnValue((next_chunk, inserted_rows, total_count)) |
535 | 619 | |
536 | 620 | @defer.inlineCallbacks |
537 | def _get_remaining_count_to_port(self, table, next_chunk): | |
538 | rows = yield self.sqlite_store.execute_sql( | |
621 | def _get_remaining_count_to_port(self, table, forward_chunk, backward_chunk): | |
622 | frows = yield self.sqlite_store.execute_sql( | |
539 | 623 | "SELECT count(*) FROM %s WHERE rowid >= ?" % (table,), |
540 | next_chunk, | |
541 | ) | |
542 | ||
543 | defer.returnValue(rows[0][0]) | |
624 | forward_chunk, | |
625 | ) | |
626 | ||
627 | brows = yield self.sqlite_store.execute_sql( | |
628 | "SELECT count(*) FROM %s WHERE rowid <= ?" % (table,), | |
629 | backward_chunk, | |
630 | ) | |
631 | ||
632 | defer.returnValue(frows[0][0] + brows[0][0]) | |
544 | 633 | |
545 | 634 | @defer.inlineCallbacks |
546 | 635 | def _get_already_ported_count(self, table): |
551 | 640 | defer.returnValue(rows[0][0]) |
552 | 641 | |
553 | 642 | @defer.inlineCallbacks |
554 | def _get_total_count_to_port(self, table, next_chunk): | |
643 | def _get_total_count_to_port(self, table, forward_chunk, backward_chunk): | |
555 | 644 | remaining, done = yield defer.gatherResults( |
556 | 645 | [ |
557 | self._get_remaining_count_to_port(table, next_chunk), | |
646 | self._get_remaining_count_to_port(table, forward_chunk, backward_chunk), | |
558 | 647 | self._get_already_ported_count(table), |
559 | 648 | ], |
560 | 649 | consumeErrors=True, |
685 | 774 | color = curses.color_pair(2) if perc == 100 else curses.color_pair(1) |
686 | 775 | |
687 | 776 | self.stdscr.addstr( |
688 | i+2, left_margin + max_len - len(table), | |
777 | i + 2, left_margin + max_len - len(table), | |
689 | 778 | table, |
690 | 779 | curses.A_BOLD | color, |
691 | 780 | ) |
693 | 782 | size = 20 |
694 | 783 | |
695 | 784 | progress = "[%s%s]" % ( |
696 | "#" * int(perc*size/100), | |
697 | " " * (size - int(perc*size/100)), | |
785 | "#" * int(perc * size / 100), | |
786 | " " * (size - int(perc * size / 100)), | |
698 | 787 | ) |
699 | 788 | |
700 | 789 | self.stdscr.addstr( |
701 | i+2, left_margin + max_len + middle_space, | |
790 | i + 2, left_margin + max_len + middle_space, | |
702 | 791 | "%s %3d%% (%d/%d)" % (progress, perc, data["num_done"], data["total"]), |
703 | 792 | ) |
704 | 793 | |
705 | 794 | if self.finished: |
706 | 795 | self.stdscr.addstr( |
707 | rows-1, 0, | |
796 | rows - 1, 0, | |
708 | 797 | "Press any key to exit...", |
709 | 798 | ) |
710 | 799 |
115 | 115 | authorization_headers = [] |
116 | 116 | |
117 | 117 | for key, sig in signed_json["signatures"][origin_name].items(): |
118 | authorization_headers.append(bytes( | |
119 | "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( | |
120 | origin_name, key, sig, | |
121 | ) | |
122 | )) | |
118 | header = "X-Matrix origin=%s,key=\"%s\",sig=\"%s\"" % ( | |
119 | origin_name, key, sig, | |
120 | ) | |
121 | authorization_headers.append(bytes(header)) | |
122 | sys.stderr.write(header) | |
123 | sys.stderr.write("\n") | |
123 | 124 | |
124 | 125 | result = requests.get( |
125 | 126 | lookup(destination, path), |
126 | 127 | headers={"Authorization": authorization_headers[0]}, |
127 | 128 | verify=False, |
128 | 129 | ) |
130 | sys.stderr.write("Status Code: %d\n" % (result.status_code,)) | |
129 | 131 | return result.json() |
130 | 132 | |
131 | 133 | |
140 | 142 | ) |
141 | 143 | |
142 | 144 | json.dump(result, sys.stdout) |
145 | print "" | |
143 | 146 | |
144 | 147 | if __name__ == "__main__": |
145 | 148 | main() |
15 | 15 | |
16 | 16 | [flake8] |
17 | 17 | max-line-length = 90 |
18 | ignore = W503 ; W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. | |
19 | ||
20 | [pep8] | |
21 | max-line-length = 90 | |
18 | # W503 requires that binary operators be at the end, not start, of lines. Erik doesn't like it. | |
19 | ignore = W503 |
15 | 15 | """ This is a reference implementation of a Matrix home server. |
16 | 16 | """ |
17 | 17 | |
18 | __version__ = "0.16.1-r1" | |
18 | __version__ = "0.17.0" |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import logging | |
16 | ||
17 | import pymacaroons | |
15 | 18 | from canonicaljson import encode_canonical_json |
16 | 19 | from signedjson.key import decode_verify_key_bytes |
17 | 20 | from signedjson.sign import verify_signed_json, SignatureVerifyException |
18 | ||
19 | 21 | from twisted.internet import defer |
20 | ||
22 | from unpaddedbase64 import decode_base64 | |
23 | ||
24 | import synapse.types | |
21 | 25 | from synapse.api.constants import EventTypes, Membership, JoinRules |
22 | 26 | from synapse.api.errors import AuthError, Codes, SynapseError, EventSizeError |
23 | from synapse.types import Requester, UserID, get_domain_from_id | |
27 | from synapse.types import UserID, get_domain_from_id | |
28 | from synapse.util.logcontext import preserve_context_over_fn | |
24 | 29 | from synapse.util.logutils import log_function |
25 | from synapse.util.logcontext import preserve_context_over_fn | |
26 | 30 | from synapse.util.metrics import Measure |
27 | from unpaddedbase64 import decode_base64 | |
28 | ||
29 | import logging | |
30 | import pymacaroons | |
31 | 31 | |
32 | 32 | logger = logging.getLogger(__name__) |
33 | 33 | |
62 | 62 | "user_id = ", |
63 | 63 | ]) |
64 | 64 | |
65 | def check(self, event, auth_events): | |
65 | def check(self, event, auth_events, do_sig_check=True): | |
66 | 66 | """ Checks if this event is correctly authed. |
67 | 67 | |
68 | 68 | Args: |
78 | 78 | |
79 | 79 | if not hasattr(event, "room_id"): |
80 | 80 | raise AuthError(500, "Event has no room_id: %s" % event) |
81 | ||
82 | sender_domain = get_domain_from_id(event.sender) | |
83 | ||
84 | # Check the sender's domain has signed the event | |
85 | if do_sig_check and not event.signatures.get(sender_domain): | |
86 | raise AuthError(403, "Event not signed by sending server") | |
87 | ||
81 | 88 | if auth_events is None: |
82 | 89 | # Oh, we don't know what the state of the room was, so we |
83 | 90 | # are trusting that this is allowed (at least for now) |
85 | 92 | return True |
86 | 93 | |
87 | 94 | if event.type == EventTypes.Create: |
95 | room_id_domain = get_domain_from_id(event.room_id) | |
96 | if room_id_domain != sender_domain: | |
97 | raise AuthError( | |
98 | 403, | |
99 | "Creation event's room_id domain does not match sender's" | |
100 | ) | |
88 | 101 | # FIXME |
89 | 102 | return True |
90 | 103 | |
107 | 120 | |
108 | 121 | # FIXME: Temp hack |
109 | 122 | if event.type == EventTypes.Aliases: |
123 | if not event.is_state(): | |
124 | raise AuthError( | |
125 | 403, | |
126 | "Alias event must be a state event", | |
127 | ) | |
128 | if not event.state_key: | |
129 | raise AuthError( | |
130 | 403, | |
131 | "Alias event must have non-empty state_key" | |
132 | ) | |
133 | sender_domain = get_domain_from_id(event.sender) | |
134 | if event.state_key != sender_domain: | |
135 | raise AuthError( | |
136 | 403, | |
137 | "Alias event's state_key does not match sender's domain" | |
138 | ) | |
110 | 139 | return True |
111 | 140 | |
112 | 141 | logger.debug( |
346 | 375 | if Membership.INVITE == membership and "third_party_invite" in event.content: |
347 | 376 | if not self._verify_third_party_invite(event, auth_events): |
348 | 377 | raise AuthError(403, "You are not invited to this room.") |
378 | if target_banned: | |
379 | raise AuthError( | |
380 | 403, "%s is banned from the room" % (target_user_id,) | |
381 | ) | |
349 | 382 | return True |
350 | 383 | |
351 | 384 | if Membership.JOIN != membership: |
536 | 569 | Args: |
537 | 570 | request - An HTTP request with an access_token query parameter. |
538 | 571 | Returns: |
539 | tuple of: | |
540 | UserID (str) | |
541 | Access token ID (str) | |
572 | defer.Deferred: resolves to a ``synapse.types.Requester`` object | |
542 | 573 | Raises: |
543 | 574 | AuthError if no user by that token exists or the token is invalid. |
544 | 575 | """ |
547 | 578 | user_id = yield self._get_appservice_user_id(request.args) |
548 | 579 | if user_id: |
549 | 580 | request.authenticated_entity = user_id |
550 | defer.returnValue( | |
551 | Requester(UserID.from_string(user_id), "", False) | |
552 | ) | |
581 | defer.returnValue(synapse.types.create_requester(user_id)) | |
553 | 582 | |
554 | 583 | access_token = request.args["access_token"][0] |
555 | 584 | user_info = yield self.get_user_by_access_token(access_token, rights) |
556 | 585 | user = user_info["user"] |
557 | 586 | token_id = user_info["token_id"] |
558 | 587 | is_guest = user_info["is_guest"] |
588 | ||
589 | # device_id may not be present if get_user_by_access_token has been | |
590 | # stubbed out. | |
591 | device_id = user_info.get("device_id") | |
559 | 592 | |
560 | 593 | ip_addr = self.hs.get_ip_from_request(request) |
561 | 594 | user_agent = request.requestHeaders.getRawHeaders( |
568 | 601 | user=user, |
569 | 602 | access_token=access_token, |
570 | 603 | ip=ip_addr, |
571 | user_agent=user_agent | |
604 | user_agent=user_agent, | |
605 | device_id=device_id, | |
572 | 606 | ) |
573 | 607 | |
574 | 608 | if is_guest and not allow_guest: |
578 | 612 | |
579 | 613 | request.authenticated_entity = user.to_string() |
580 | 614 | |
581 | defer.returnValue(Requester(user, token_id, is_guest)) | |
615 | defer.returnValue(synapse.types.create_requester( | |
616 | user, token_id, is_guest, device_id)) | |
582 | 617 | except KeyError: |
583 | 618 | raise AuthError( |
584 | 619 | self.TOKEN_NOT_FOUND_HTTP_STATUS, "Missing access token.", |
628 | 663 | except AuthError: |
629 | 664 | # TODO(daniel): Remove this fallback when all existing access tokens |
630 | 665 | # have been re-issued as macaroons. |
666 | if self.hs.config.expire_access_token: | |
667 | raise | |
631 | 668 | ret = yield self._look_up_user_by_access_token(token) |
669 | ||
632 | 670 | defer.returnValue(ret) |
633 | 671 | |
634 | 672 | @defer.inlineCallbacks |
663 | 701 | "user": user, |
664 | 702 | "is_guest": True, |
665 | 703 | "token_id": None, |
704 | "device_id": None, | |
666 | 705 | } |
667 | 706 | elif rights == "delete_pusher": |
668 | 707 | # We don't store these tokens in the database |
670 | 709 | "user": user, |
671 | 710 | "is_guest": False, |
672 | 711 | "token_id": None, |
712 | "device_id": None, | |
673 | 713 | } |
674 | 714 | else: |
675 | # This codepath exists so that we can actually return a | |
676 | # token ID, because we use token IDs in place of device | |
677 | # identifiers throughout the codebase. | |
678 | # TODO(daniel): Remove this fallback when device IDs are | |
679 | # properly implemented. | |
715 | # This codepath exists for several reasons: | |
716 | # * so that we can actually return a token ID, which is used | |
717 | # in some parts of the schema (where we probably ought to | |
718 | # use device IDs instead) | |
719 | # * the only way we currently have to invalidate an | |
720 | # access_token is by removing it from the database, so we | |
721 | # have to check here that it is still in the db | |
722 | # * some attributes (notably device_id) aren't stored in the | |
723 | # macaroon. They probably should be. | |
724 | # TODO: build the dictionary from the macaroon once the | |
725 | # above are fixed | |
680 | 726 | ret = yield self._look_up_user_by_access_token(macaroon_str) |
681 | 727 | if ret["user"] != user: |
682 | 728 | logger.error( |
750 | 796 | self.TOKEN_NOT_FOUND_HTTP_STATUS, "Unrecognised access token.", |
751 | 797 | errcode=Codes.UNKNOWN_TOKEN |
752 | 798 | ) |
799 | # we use ret.get() below because *lots* of unit tests stub out | |
800 | # get_user_by_access_token in a way where it only returns a couple of | |
801 | # the fields. | |
753 | 802 | user_info = { |
754 | 803 | "user": UserID.from_string(ret.get("name")), |
755 | 804 | "token_id": ret.get("token_id", None), |
756 | 805 | "is_guest": False, |
806 | "device_id": ret.get("device_id"), | |
757 | 807 | } |
758 | 808 | defer.returnValue(user_info) |
759 | 809 |
41 | 41 | TOO_LARGE = "M_TOO_LARGE" |
42 | 42 | EXCLUSIVE = "M_EXCLUSIVE" |
43 | 43 | THREEPID_AUTH_FAILED = "M_THREEPID_AUTH_FAILED" |
44 | THREEPID_IN_USE = "THREEPID_IN_USE" | |
44 | THREEPID_IN_USE = "M_THREEPID_IN_USE" | |
45 | THREEPID_NOT_FOUND = "M_THREEPID_NOT_FOUND" | |
45 | 46 | INVALID_USERNAME = "M_INVALID_USERNAME" |
47 | SERVER_NOT_TRUSTED = "M_SERVER_NOT_TRUSTED" | |
46 | 48 | |
47 | 49 | |
48 | 50 | class CodeMessageException(RuntimeError): |
190 | 190 | def __init__(self, filter_json): |
191 | 191 | self.filter_json = filter_json |
192 | 192 | |
193 | self.types = self.filter_json.get("types", None) | |
194 | self.not_types = self.filter_json.get("not_types", []) | |
195 | ||
196 | self.rooms = self.filter_json.get("rooms", None) | |
197 | self.not_rooms = self.filter_json.get("not_rooms", []) | |
198 | ||
199 | self.senders = self.filter_json.get("senders", None) | |
200 | self.not_senders = self.filter_json.get("not_senders", []) | |
201 | ||
202 | self.contains_url = self.filter_json.get("contains_url", None) | |
203 | ||
193 | 204 | def check(self, event): |
194 | 205 | """Checks whether the filter matches the given event. |
195 | 206 | |
208 | 219 | event.get("room_id", None), |
209 | 220 | sender, |
210 | 221 | event.get("type", None), |
222 | "url" in event.get("content", {}) | |
211 | 223 | ) |
212 | 224 | |
213 | def check_fields(self, room_id, sender, event_type): | |
225 | def check_fields(self, room_id, sender, event_type, contains_url): | |
214 | 226 | """Checks whether the filter matches the given event fields. |
215 | 227 | |
216 | 228 | Returns: |
224 | 236 | |
225 | 237 | for name, match_func in literal_keys.items(): |
226 | 238 | not_name = "not_%s" % (name,) |
227 | disallowed_values = self.filter_json.get(not_name, []) | |
239 | disallowed_values = getattr(self, not_name) | |
228 | 240 | if any(map(match_func, disallowed_values)): |
229 | 241 | return False |
230 | 242 | |
231 | allowed_values = self.filter_json.get(name, None) | |
243 | allowed_values = getattr(self, name) | |
232 | 244 | if allowed_values is not None: |
233 | 245 | if not any(map(match_func, allowed_values)): |
234 | 246 | return False |
247 | ||
248 | contains_url_filter = self.filter_json.get("contains_url") | |
249 | if contains_url_filter is not None: | |
250 | if contains_url_filter != contains_url: | |
251 | return False | |
235 | 252 | |
236 | 253 | return True |
237 | 254 |
15 | 15 | import sys |
16 | 16 | sys.dont_write_bytecode = True |
17 | 17 | |
18 | from synapse.python_dependencies import ( | |
19 | check_requirements, MissingRequirementError | |
20 | ) # NOQA | |
18 | from synapse import python_dependencies # noqa: E402 | |
21 | 19 | |
22 | 20 | try: |
23 | check_requirements() | |
24 | except MissingRequirementError as e: | |
21 | python_dependencies.check_requirements() | |
22 | except python_dependencies.MissingRequirementError as e: | |
25 | 23 | message = "\n".join([ |
26 | 24 | "Missing Requirement: %s" % (e.message,), |
27 | 25 | "To install run:", |
0 | #!/usr/bin/env python | |
1 | # -*- coding: utf-8 -*- | |
2 | # Copyright 2016 OpenMarket Ltd | |
3 | # | |
4 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
5 | # you may not use this file except in compliance with the License. | |
6 | # You may obtain a copy of the License at | |
7 | # | |
8 | # http://www.apache.org/licenses/LICENSE-2.0 | |
9 | # | |
10 | # Unless required by applicable law or agreed to in writing, software | |
11 | # distributed under the License is distributed on an "AS IS" BASIS, | |
12 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
13 | # See the License for the specific language governing permissions and | |
14 | # limitations under the License. | |
15 | ||
16 | import synapse | |
17 | ||
18 | from synapse.config._base import ConfigError | |
19 | from synapse.config.homeserver import HomeServerConfig | |
20 | from synapse.config.logger import setup_logging | |
21 | from synapse.http.site import SynapseSite | |
22 | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX | |
23 | from synapse.replication.slave.storage._base import BaseSlavedStore | |
24 | from synapse.replication.slave.storage.events import SlavedEventStore | |
25 | from synapse.replication.slave.storage.keys import SlavedKeyStore | |
26 | from synapse.replication.slave.storage.room import RoomStore | |
27 | from synapse.replication.slave.storage.transactions import TransactionStore | |
28 | from synapse.replication.slave.storage.directory import DirectoryStore | |
29 | from synapse.server import HomeServer | |
30 | from synapse.storage.engines import create_engine | |
31 | from synapse.util.async import sleep | |
32 | from synapse.util.httpresourcetree import create_resource_tree | |
33 | from synapse.util.logcontext import LoggingContext | |
34 | from synapse.util.manhole import manhole | |
35 | from synapse.util.rlimit import change_resource_limit | |
36 | from synapse.util.versionstring import get_version_string | |
37 | from synapse.api.urls import FEDERATION_PREFIX | |
38 | from synapse.federation.transport.server import TransportLayerServer | |
39 | from synapse.crypto import context_factory | |
40 | ||
41 | ||
42 | from twisted.internet import reactor, defer | |
43 | from twisted.web.resource import Resource | |
44 | ||
45 | from daemonize import Daemonize | |
46 | ||
47 | import sys | |
48 | import logging | |
49 | import gc | |
50 | ||
51 | logger = logging.getLogger("synapse.app.federation_reader") | |
52 | ||
53 | ||
54 | class FederationReaderSlavedStore( | |
55 | SlavedEventStore, | |
56 | SlavedKeyStore, | |
57 | RoomStore, | |
58 | DirectoryStore, | |
59 | TransactionStore, | |
60 | BaseSlavedStore, | |
61 | ): | |
62 | pass | |
63 | ||
64 | ||
65 | class FederationReaderServer(HomeServer): | |
66 | def get_db_conn(self, run_new_connection=True): | |
67 | # Any param beginning with cp_ is a parameter for adbapi, and should | |
68 | # not be passed to the database engine. | |
69 | db_params = { | |
70 | k: v for k, v in self.db_config.get("args", {}).items() | |
71 | if not k.startswith("cp_") | |
72 | } | |
73 | db_conn = self.database_engine.module.connect(**db_params) | |
74 | ||
75 | if run_new_connection: | |
76 | self.database_engine.on_new_connection(db_conn) | |
77 | return db_conn | |
78 | ||
79 | def setup(self): | |
80 | logger.info("Setting up.") | |
81 | self.datastore = FederationReaderSlavedStore(self.get_db_conn(), self) | |
82 | logger.info("Finished setting up.") | |
83 | ||
84 | def _listen_http(self, listener_config): | |
85 | port = listener_config["port"] | |
86 | bind_address = listener_config.get("bind_address", "") | |
87 | site_tag = listener_config.get("tag", port) | |
88 | resources = {} | |
89 | for res in listener_config["resources"]: | |
90 | for name in res["names"]: | |
91 | if name == "metrics": | |
92 | resources[METRICS_PREFIX] = MetricsResource(self) | |
93 | elif name == "federation": | |
94 | resources.update({ | |
95 | FEDERATION_PREFIX: TransportLayerServer(self), | |
96 | }) | |
97 | ||
98 | root_resource = create_resource_tree(resources, Resource()) | |
99 | reactor.listenTCP( | |
100 | port, | |
101 | SynapseSite( | |
102 | "synapse.access.http.%s" % (site_tag,), | |
103 | site_tag, | |
104 | listener_config, | |
105 | root_resource, | |
106 | ), | |
107 | interface=bind_address | |
108 | ) | |
109 | logger.info("Synapse federation reader now listening on port %d", port) | |
110 | ||
111 | def start_listening(self, listeners): | |
112 | for listener in listeners: | |
113 | if listener["type"] == "http": | |
114 | self._listen_http(listener) | |
115 | elif listener["type"] == "manhole": | |
116 | reactor.listenTCP( | |
117 | listener["port"], | |
118 | manhole( | |
119 | username="matrix", | |
120 | password="rabbithole", | |
121 | globals={"hs": self}, | |
122 | ), | |
123 | interface=listener.get("bind_address", '127.0.0.1') | |
124 | ) | |
125 | else: | |
126 | logger.warn("Unrecognized listener type: %s", listener["type"]) | |
127 | ||
128 | @defer.inlineCallbacks | |
129 | def replicate(self): | |
130 | http_client = self.get_simple_http_client() | |
131 | store = self.get_datastore() | |
132 | replication_url = self.config.worker_replication_url | |
133 | ||
134 | while True: | |
135 | try: | |
136 | args = store.stream_positions() | |
137 | args["timeout"] = 30000 | |
138 | result = yield http_client.get_json(replication_url, args=args) | |
139 | yield store.process_replication(result) | |
140 | except: | |
141 | logger.exception("Error replicating from %r", replication_url) | |
142 | yield sleep(5) | |
143 | ||
144 | ||
145 | def start(config_options): | |
146 | try: | |
147 | config = HomeServerConfig.load_config( | |
148 | "Synapse federation reader", config_options | |
149 | ) | |
150 | except ConfigError as e: | |
151 | sys.stderr.write("\n" + e.message + "\n") | |
152 | sys.exit(1) | |
153 | ||
154 | assert config.worker_app == "synapse.app.federation_reader" | |
155 | ||
156 | setup_logging(config.worker_log_config, config.worker_log_file) | |
157 | ||
158 | database_engine = create_engine(config.database_config) | |
159 | ||
160 | tls_server_context_factory = context_factory.ServerContextFactory(config) | |
161 | ||
162 | ss = FederationReaderServer( | |
163 | config.server_name, | |
164 | db_config=config.database_config, | |
165 | tls_server_context_factory=tls_server_context_factory, | |
166 | config=config, | |
167 | version_string="Synapse/" + get_version_string(synapse), | |
168 | database_engine=database_engine, | |
169 | ) | |
170 | ||
171 | ss.setup() | |
172 | ss.get_handlers() | |
173 | ss.start_listening(config.worker_listeners) | |
174 | ||
175 | def run(): | |
176 | with LoggingContext("run"): | |
177 | logger.info("Running") | |
178 | change_resource_limit(config.soft_file_limit) | |
179 | if config.gc_thresholds: | |
180 | gc.set_threshold(*config.gc_thresholds) | |
181 | reactor.run() | |
182 | ||
183 | def start(): | |
184 | ss.get_datastore().start_profiling() | |
185 | ss.replicate() | |
186 | ||
187 | reactor.callWhenRunning(start) | |
188 | ||
189 | if config.worker_daemonize: | |
190 | daemon = Daemonize( | |
191 | app="synapse-federation-reader", | |
192 | pid=config.worker_pid_file, | |
193 | action=run, | |
194 | auto_close_fds=False, | |
195 | verbose=True, | |
196 | logger=logger, | |
197 | ) | |
198 | daemon.start() | |
199 | else: | |
200 | run() | |
201 | ||
202 | ||
203 | if __name__ == '__main__': | |
204 | with LoggingContext("main"): | |
205 | start(sys.argv[1:]) |
50 | 50 | from synapse.config.homeserver import HomeServerConfig |
51 | 51 | from synapse.crypto import context_factory |
52 | 52 | from synapse.util.logcontext import LoggingContext |
53 | from synapse.metrics import register_memory_metrics | |
53 | 54 | from synapse.metrics.resource import MetricsResource, METRICS_PREFIX |
54 | 55 | from synapse.replication.resource import ReplicationResource, REPLICATION_PREFIX |
55 | 56 | from synapse.federation.transport.server import TransportLayerServer |
146 | 147 | MEDIA_PREFIX: media_repo, |
147 | 148 | LEGACY_MEDIA_PREFIX: media_repo, |
148 | 149 | CONTENT_REPO_PREFIX: ContentRepoResource( |
149 | self, self.config.uploads_path, self.auth, self.content_addr | |
150 | self, self.config.uploads_path | |
150 | 151 | ), |
151 | 152 | }) |
152 | 153 | |
283 | 284 | # check any extra requirements we have now we have a config |
284 | 285 | check_requirements(config) |
285 | 286 | |
286 | version_string = get_version_string("Synapse", synapse) | |
287 | version_string = "Synapse/" + get_version_string(synapse) | |
287 | 288 | |
288 | 289 | logger.info("Server hostname: %s", config.server_name) |
289 | 290 | logger.info("Server version: %s", version_string) |
300 | 301 | db_config=config.database_config, |
301 | 302 | tls_server_context_factory=tls_server_context_factory, |
302 | 303 | config=config, |
303 | content_addr=config.content_addr, | |
304 | 304 | version_string=version_string, |
305 | 305 | database_engine=database_engine, |
306 | 306 | ) |
334 | 334 | hs.get_datastore().start_profiling() |
335 | 335 | hs.get_datastore().start_doing_background_updates() |
336 | 336 | hs.get_replication_layer().start_get_pdu_cache() |
337 | ||
338 | register_memory_metrics(hs) | |
337 | 339 | |
338 | 340 | reactor.callWhenRunning(start) |
339 | 341 |
272 | 272 | config.server_name, |
273 | 273 | db_config=config.database_config, |
274 | 274 | config=config, |
275 | version_string=get_version_string("Synapse", synapse), | |
275 | version_string="Synapse/" + get_version_string(synapse), | |
276 | 276 | database_engine=database_engine, |
277 | 277 | ) |
278 | 278 |
423 | 423 | config.server_name, |
424 | 424 | db_config=config.database_config, |
425 | 425 | config=config, |
426 | version_string=get_version_string("Synapse", synapse), | |
426 | version_string="Synapse/" + get_version_string(synapse), | |
427 | 427 | database_engine=database_engine, |
428 | 428 | application_service_handler=SynchrotronApplicationService(), |
429 | 429 | ) |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | from ._base import Config | |
15 | from ._base import Config, ConfigError | |
16 | ||
17 | ||
18 | MISSING_LDAP3 = ( | |
19 | "Missing ldap3 library. This is required for LDAP Authentication." | |
20 | ) | |
21 | ||
22 | ||
23 | class LDAPMode(object): | |
24 | SIMPLE = "simple", | |
25 | SEARCH = "search", | |
26 | ||
27 | LIST = (SIMPLE, SEARCH) | |
16 | 28 | |
17 | 29 | |
18 | 30 | class LDAPConfig(Config): |
19 | 31 | def read_config(self, config): |
20 | ldap_config = config.get("ldap_config", None) | |
21 | if ldap_config: | |
22 | self.ldap_enabled = ldap_config.get("enabled", False) | |
23 | self.ldap_server = ldap_config["server"] | |
24 | self.ldap_port = ldap_config["port"] | |
25 | self.ldap_tls = ldap_config.get("tls", False) | |
26 | self.ldap_search_base = ldap_config["search_base"] | |
27 | self.ldap_search_property = ldap_config["search_property"] | |
28 | self.ldap_email_property = ldap_config["email_property"] | |
29 | self.ldap_full_name_property = ldap_config["full_name_property"] | |
30 | else: | |
31 | self.ldap_enabled = False | |
32 | self.ldap_server = None | |
33 | self.ldap_port = None | |
34 | self.ldap_tls = False | |
35 | self.ldap_search_base = None | |
36 | self.ldap_search_property = None | |
37 | self.ldap_email_property = None | |
38 | self.ldap_full_name_property = None | |
32 | ldap_config = config.get("ldap_config", {}) | |
33 | ||
34 | self.ldap_enabled = ldap_config.get("enabled", False) | |
35 | ||
36 | if self.ldap_enabled: | |
37 | # verify dependencies are available | |
38 | try: | |
39 | import ldap3 | |
40 | ldap3 # to stop unused lint | |
41 | except ImportError: | |
42 | raise ConfigError(MISSING_LDAP3) | |
43 | ||
44 | self.ldap_mode = LDAPMode.SIMPLE | |
45 | ||
46 | # verify config sanity | |
47 | self.require_keys(ldap_config, [ | |
48 | "uri", | |
49 | "base", | |
50 | "attributes", | |
51 | ]) | |
52 | ||
53 | self.ldap_uri = ldap_config["uri"] | |
54 | self.ldap_start_tls = ldap_config.get("start_tls", False) | |
55 | self.ldap_base = ldap_config["base"] | |
56 | self.ldap_attributes = ldap_config["attributes"] | |
57 | ||
58 | if "bind_dn" in ldap_config: | |
59 | self.ldap_mode = LDAPMode.SEARCH | |
60 | self.require_keys(ldap_config, [ | |
61 | "bind_dn", | |
62 | "bind_password", | |
63 | ]) | |
64 | ||
65 | self.ldap_bind_dn = ldap_config["bind_dn"] | |
66 | self.ldap_bind_password = ldap_config["bind_password"] | |
67 | self.ldap_filter = ldap_config.get("filter", None) | |
68 | ||
69 | # verify attribute lookup | |
70 | self.require_keys(ldap_config['attributes'], [ | |
71 | "uid", | |
72 | "name", | |
73 | "mail", | |
74 | ]) | |
75 | ||
76 | def require_keys(self, config, required): | |
77 | missing = [key for key in required if key not in config] | |
78 | if missing: | |
79 | raise ConfigError( | |
80 | "LDAP enabled but missing required config values: {}".format( | |
81 | ", ".join(missing) | |
82 | ) | |
83 | ) | |
39 | 84 | |
40 | 85 | def default_config(self, **kwargs): |
41 | 86 | return """\ |
42 | 87 | # ldap_config: |
43 | 88 | # enabled: true |
44 | # server: "ldap://localhost" | |
45 | # port: 389 | |
46 | # tls: false | |
47 | # search_base: "ou=Users,dc=example,dc=com" | |
48 | # search_property: "cn" | |
49 | # email_property: "email" | |
50 | # full_name_property: "givenName" | |
89 | # uri: "ldap://ldap.example.com:389" | |
90 | # start_tls: true | |
91 | # base: "ou=users,dc=example,dc=com" | |
92 | # attributes: | |
93 | # uid: "cn" | |
94 | # mail: "email" | |
95 | # name: "givenName" | |
96 | # #bind_dn: | |
97 | # #bind_password: | |
98 | # #filter: "(objectClass=posixAccount)" | |
51 | 99 | """ |
22 | 22 | def read_config(self, config): |
23 | 23 | password_config = config.get("password_config", {}) |
24 | 24 | self.password_enabled = password_config.get("enabled", True) |
25 | self.password_pepper = password_config.get("pepper", "") | |
25 | 26 | |
26 | 27 | def default_config(self, config_dir_path, server_name, **kwargs): |
27 | 28 | return """ |
28 | 29 | # Enable password for login. |
29 | 30 | password_config: |
30 | 31 | enabled: true |
32 | # Uncomment and change to a secret random string for extra security. | |
33 | # DO NOT CHANGE THIS AFTER INITIAL SETUP! | |
34 | #pepper: "" | |
31 | 35 | """ |
106 | 106 | ] |
107 | 107 | }) |
108 | 108 | |
109 | # Attempt to guess the content_addr for the v0 content repostitory | |
110 | content_addr = config.get("content_addr") | |
111 | if not content_addr: | |
112 | for listener in self.listeners: | |
113 | if listener["type"] == "http" and not listener.get("tls", False): | |
114 | unsecure_port = listener["port"] | |
115 | break | |
116 | else: | |
117 | raise RuntimeError("Could not determine 'content_addr'") | |
118 | ||
119 | host = self.server_name | |
120 | if ':' not in host: | |
121 | host = "%s:%d" % (host, unsecure_port) | |
122 | else: | |
123 | host = host.split(':')[0] | |
124 | host = "%s:%d" % (host, unsecure_port) | |
125 | content_addr = "http://%s" % (host,) | |
126 | ||
127 | self.content_addr = content_addr | |
128 | ||
129 | 109 | def default_config(self, server_name, **kwargs): |
130 | 110 | if ":" in server_name: |
131 | 111 | bind_port = int(server_name.split(":")[1]) |
168 | 148 | # room directory. |
169 | 149 | # secondary_directory_servers: |
170 | 150 | # - matrix.org |
171 | # - vector.im | |
172 | 151 | |
173 | 152 | # List of ports that Synapse should listen on, their purpose and their |
174 | 153 | # configuration. |
76 | 76 | def __init__(self): |
77 | 77 | self.remote_key = defer.Deferred() |
78 | 78 | self.host = None |
79 | self._peer = None | |
79 | 80 | |
80 | 81 | def connectionMade(self): |
81 | self.host = self.transport.getHost() | |
82 | logger.debug("Connected to %s", self.host) | |
82 | self._peer = self.transport.getPeer() | |
83 | logger.debug("Connected to %s", self._peer) | |
84 | ||
83 | 85 | self.sendCommand(b"GET", self.path) |
84 | 86 | if self.host: |
85 | 87 | self.sendHeader(b"Host", self.host) |
123 | 125 | self.timer.cancel() |
124 | 126 | |
125 | 127 | def on_timeout(self): |
126 | logger.debug("Timeout waiting for response from %s", self.host) | |
128 | logger.debug( | |
129 | "Timeout waiting for response from %s: %s", | |
130 | self.host, self._peer, | |
131 | ) | |
127 | 132 | self.errback(IOError("Timeout waiting for response")) |
128 | 133 | self.transport.abortConnection() |
129 | 134 | |
132 | 137 | def protocol(self): |
133 | 138 | protocol = SynapseKeyClientProtocol() |
134 | 139 | protocol.path = self.path |
140 | protocol.host = self.host | |
135 | 141 | return protocol |
43 | 43 | logger = logging.getLogger(__name__) |
44 | 44 | |
45 | 45 | |
46 | KeyGroup = namedtuple("KeyGroup", ("server_name", "group_id", "key_ids")) | |
46 | VerifyKeyRequest = namedtuple("VerifyRequest", ( | |
47 | "server_name", "key_ids", "json_object", "deferred" | |
48 | )) | |
49 | """ | |
50 | A request for a verify key to verify a JSON object. | |
51 | ||
52 | Attributes: | |
53 | server_name(str): The name of the server to verify against. | |
54 | key_ids(set(str)): The set of key_ids to that could be used to verify the | |
55 | JSON object | |
56 | json_object(dict): The JSON object to verify. | |
57 | deferred(twisted.internet.defer.Deferred): | |
58 | A deferred (server_name, key_id, verify_key) tuple that resolves when | |
59 | a verify key has been fetched | |
60 | """ | |
47 | 61 | |
48 | 62 | |
49 | 63 | class Keyring(object): |
73 | 87 | list of deferreds indicating success or failure to verify each |
74 | 88 | json object's signature for the given server_name. |
75 | 89 | """ |
76 | group_id_to_json = {} | |
77 | group_id_to_group = {} | |
78 | group_ids = [] | |
79 | ||
80 | next_group_id = 0 | |
81 | deferreds = {} | |
90 | verify_requests = [] | |
82 | 91 | |
83 | 92 | for server_name, json_object in server_and_json: |
84 | 93 | logger.debug("Verifying for %s", server_name) |
85 | group_id = next_group_id | |
86 | next_group_id += 1 | |
87 | group_ids.append(group_id) | |
88 | 94 | |
89 | 95 | key_ids = signature_ids(json_object, server_name) |
90 | 96 | if not key_ids: |
91 | deferreds[group_id] = defer.fail(SynapseError( | |
97 | deferred = defer.fail(SynapseError( | |
92 | 98 | 400, |
93 | 99 | "Not signed with a supported algorithm", |
94 | 100 | Codes.UNAUTHORIZED, |
95 | 101 | )) |
96 | 102 | else: |
97 | deferreds[group_id] = defer.Deferred() | |
98 | ||
99 | group = KeyGroup(server_name, group_id, key_ids) | |
100 | ||
101 | group_id_to_group[group_id] = group | |
102 | group_id_to_json[group_id] = json_object | |
103 | deferred = defer.Deferred() | |
104 | ||
105 | verify_request = VerifyKeyRequest( | |
106 | server_name, key_ids, json_object, deferred | |
107 | ) | |
108 | ||
109 | verify_requests.append(verify_request) | |
103 | 110 | |
104 | 111 | @defer.inlineCallbacks |
105 | def handle_key_deferred(group, deferred): | |
106 | server_name = group.server_name | |
112 | def handle_key_deferred(verify_request): | |
113 | server_name = verify_request.server_name | |
107 | 114 | try: |
108 | _, _, key_id, verify_key = yield deferred | |
115 | _, key_id, verify_key = yield verify_request.deferred | |
109 | 116 | except IOError as e: |
110 | 117 | logger.warn( |
111 | 118 | "Got IOError when downloading keys for %s: %s %s", |
127 | 134 | Codes.UNAUTHORIZED, |
128 | 135 | ) |
129 | 136 | |
130 | json_object = group_id_to_json[group.group_id] | |
137 | json_object = verify_request.json_object | |
131 | 138 | |
132 | 139 | try: |
133 | 140 | verify_signed_json(json_object, server_name, verify_key) |
156 | 163 | |
157 | 164 | # Actually start fetching keys. |
158 | 165 | wait_on_deferred.addBoth( |
159 | lambda _: self.get_server_verify_keys(group_id_to_group, deferreds) | |
166 | lambda _: self.get_server_verify_keys(verify_requests) | |
160 | 167 | ) |
161 | 168 | |
162 | 169 | # When we've finished fetching all the keys for a given server_name, |
163 | 170 | # resolve the deferred passed to `wait_for_previous_lookups` so that |
164 | 171 | # any lookups waiting will proceed. |
165 | server_to_gids = {} | |
166 | ||
167 | def remove_deferreds(res, server_name, group_id): | |
168 | server_to_gids[server_name].discard(group_id) | |
169 | if not server_to_gids[server_name]: | |
172 | server_to_request_ids = {} | |
173 | ||
174 | def remove_deferreds(res, server_name, verify_request): | |
175 | request_id = id(verify_request) | |
176 | server_to_request_ids[server_name].discard(request_id) | |
177 | if not server_to_request_ids[server_name]: | |
170 | 178 | d = server_to_deferred.pop(server_name, None) |
171 | 179 | if d: |
172 | 180 | d.callback(None) |
173 | 181 | return res |
174 | 182 | |
175 | for g_id, deferred in deferreds.items(): | |
176 | server_name = group_id_to_group[g_id].server_name | |
177 | server_to_gids.setdefault(server_name, set()).add(g_id) | |
178 | deferred.addBoth(remove_deferreds, server_name, g_id) | |
183 | for verify_request in verify_requests: | |
184 | server_name = verify_request.server_name | |
185 | request_id = id(verify_request) | |
186 | server_to_request_ids.setdefault(server_name, set()).add(request_id) | |
187 | deferred.addBoth(remove_deferreds, server_name, verify_request) | |
179 | 188 | |
180 | 189 | # Pass those keys to handle_key_deferred so that the json object |
181 | 190 | # signatures can be verified |
182 | 191 | return [ |
183 | preserve_context_over_fn( | |
184 | handle_key_deferred, | |
185 | group_id_to_group[g_id], | |
186 | deferreds[g_id], | |
187 | ) | |
188 | for g_id in group_ids | |
192 | preserve_context_over_fn(handle_key_deferred, verify_request) | |
193 | for verify_request in verify_requests | |
189 | 194 | ] |
190 | 195 | |
191 | 196 | @defer.inlineCallbacks |
219 | 224 | |
220 | 225 | d.addBoth(rm, server_name) |
221 | 226 | |
222 | def get_server_verify_keys(self, group_id_to_group, group_id_to_deferred): | |
227 | def get_server_verify_keys(self, verify_requests): | |
223 | 228 | """Takes a dict of KeyGroups and tries to find at least one key for |
224 | 229 | each group. |
225 | 230 | """ |
236 | 241 | merged_results = {} |
237 | 242 | |
238 | 243 | missing_keys = {} |
239 | for group in group_id_to_group.values(): | |
240 | missing_keys.setdefault(group.server_name, set()).update( | |
241 | group.key_ids | |
244 | for verify_request in verify_requests: | |
245 | missing_keys.setdefault(verify_request.server_name, set()).update( | |
246 | verify_request.key_ids | |
242 | 247 | ) |
243 | 248 | |
244 | 249 | for fn in key_fetch_fns: |
245 | 250 | results = yield fn(missing_keys.items()) |
246 | 251 | merged_results.update(results) |
247 | 252 | |
248 | # We now need to figure out which groups we have keys for | |
249 | # and which we don't | |
250 | missing_groups = {} | |
251 | for group in group_id_to_group.values(): | |
252 | for key_id in group.key_ids: | |
253 | if key_id in merged_results[group.server_name]: | |
253 | # We now need to figure out which verify requests we have keys | |
254 | # for and which we don't | |
255 | missing_keys = {} | |
256 | requests_missing_keys = [] | |
257 | for verify_request in verify_requests: | |
258 | server_name = verify_request.server_name | |
259 | result_keys = merged_results[server_name] | |
260 | ||
261 | if verify_request.deferred.called: | |
262 | # We've already called this deferred, which probably | |
263 | # means that we've already found a key for it. | |
264 | continue | |
265 | ||
266 | for key_id in verify_request.key_ids: | |
267 | if key_id in result_keys: | |
254 | 268 | with PreserveLoggingContext(): |
255 | group_id_to_deferred[group.group_id].callback(( | |
256 | group.group_id, | |
257 | group.server_name, | |
269 | verify_request.deferred.callback(( | |
270 | server_name, | |
258 | 271 | key_id, |
259 | merged_results[group.server_name][key_id], | |
272 | result_keys[key_id], | |
260 | 273 | )) |
261 | 274 | break |
262 | 275 | else: |
263 | missing_groups.setdefault( | |
264 | group.server_name, [] | |
265 | ).append(group) | |
266 | ||
267 | if not missing_groups: | |
276 | # The else block is only reached if the loop above | |
277 | # doesn't break. | |
278 | missing_keys.setdefault(server_name, set()).update( | |
279 | verify_request.key_ids | |
280 | ) | |
281 | requests_missing_keys.append(verify_request) | |
282 | ||
283 | if not missing_keys: | |
268 | 284 | break |
269 | 285 | |
270 | missing_keys = { | |
271 | server_name: set( | |
272 | key_id for group in groups for key_id in group.key_ids | |
273 | ) | |
274 | for server_name, groups in missing_groups.items() | |
275 | } | |
276 | ||
277 | for group in missing_groups.values(): | |
278 | group_id_to_deferred[group.group_id].errback(SynapseError( | |
286 | for verify_request in requests_missing_keys.values(): | |
287 | verify_request.deferred.errback(SynapseError( | |
279 | 288 | 401, |
280 | 289 | "No key for %s with id %s" % ( |
281 | group.server_name, group.key_ids, | |
290 | verify_request.server_name, verify_request.key_ids, | |
282 | 291 | ), |
283 | 292 | Codes.UNAUTHORIZED, |
284 | 293 | )) |
285 | 294 | |
286 | 295 | def on_err(err): |
287 | for deferred in group_id_to_deferred.values(): | |
288 | if not deferred.called: | |
289 | deferred.errback(err) | |
296 | for verify_request in verify_requests: | |
297 | if not verify_request.deferred.called: | |
298 | verify_request.deferred.errback(err) | |
290 | 299 | |
291 | 300 | do_iterations().addErrback(on_err) |
292 | ||
293 | return group_id_to_deferred | |
294 | 301 | |
295 | 302 | @defer.inlineCallbacks |
296 | 303 | def get_keys_from_store(self, server_name_and_key_ids): |
446 | 453 | ) |
447 | 454 | |
448 | 455 | processed_response = yield self.process_v2_response( |
449 | perspective_name, response | |
456 | perspective_name, response, only_from_server=False | |
450 | 457 | ) |
451 | 458 | |
452 | 459 | for server_name, response_keys in processed_response.items(): |
526 | 533 | |
527 | 534 | @defer.inlineCallbacks |
528 | 535 | def process_v2_response(self, from_server, response_json, |
529 | requested_ids=[]): | |
536 | requested_ids=[], only_from_server=True): | |
530 | 537 | time_now_ms = self.clock.time_msec() |
531 | 538 | response_keys = {} |
532 | 539 | verify_keys = {} |
550 | 557 | |
551 | 558 | results = {} |
552 | 559 | server_name = response_json["server_name"] |
560 | if only_from_server: | |
561 | if server_name != from_server: | |
562 | raise ValueError( | |
563 | "Expected a response for server %r not %r" % ( | |
564 | from_server, server_name | |
565 | ) | |
566 | ) | |
553 | 567 | for key_id in response_json["signatures"].get(server_name, {}): |
554 | 568 | if key_id not in response_json["verify_keys"]: |
555 | 569 | raise ValueError( |
235 | 235 | # TODO: Rate limit the number of times we try and get the same event. |
236 | 236 | |
237 | 237 | if self._get_pdu_cache: |
238 | e = self._get_pdu_cache.get(event_id) | |
239 | if e: | |
240 | defer.returnValue(e) | |
238 | ev = self._get_pdu_cache.get(event_id) | |
239 | if ev: | |
240 | defer.returnValue(ev) | |
241 | 241 | |
242 | 242 | pdu = None |
243 | 243 | for destination in destinations: |
268 | 268 | |
269 | 269 | break |
270 | 270 | |
271 | except SynapseError: | |
271 | except SynapseError as e: | |
272 | 272 | logger.info( |
273 | 273 | "Failed to get PDU %s from %s because %s", |
274 | 274 | event_id, destination, e, |
313 | 313 | Deferred: Results in a list of PDUs. |
314 | 314 | """ |
315 | 315 | |
316 | try: | |
317 | # First we try and ask for just the IDs, as thats far quicker if | |
318 | # we have most of the state and auth_chain already. | |
319 | # However, this may 404 if the other side has an old synapse. | |
320 | result = yield self.transport_layer.get_room_state_ids( | |
321 | destination, room_id, event_id=event_id, | |
322 | ) | |
323 | ||
324 | state_event_ids = result["pdu_ids"] | |
325 | auth_event_ids = result.get("auth_chain_ids", []) | |
326 | ||
327 | fetched_events, failed_to_fetch = yield self.get_events( | |
328 | [destination], room_id, set(state_event_ids + auth_event_ids) | |
329 | ) | |
330 | ||
331 | if failed_to_fetch: | |
332 | logger.warn("Failed to get %r", failed_to_fetch) | |
333 | ||
334 | event_map = { | |
335 | ev.event_id: ev for ev in fetched_events | |
336 | } | |
337 | ||
338 | pdus = [event_map[e_id] for e_id in state_event_ids if e_id in event_map] | |
339 | auth_chain = [ | |
340 | event_map[e_id] for e_id in auth_event_ids if e_id in event_map | |
341 | ] | |
342 | ||
343 | auth_chain.sort(key=lambda e: e.depth) | |
344 | ||
345 | defer.returnValue((pdus, auth_chain)) | |
346 | except HttpResponseException as e: | |
347 | if e.code == 400 or e.code == 404: | |
348 | logger.info("Failed to use get_room_state_ids API, falling back") | |
349 | else: | |
350 | raise e | |
351 | ||
316 | 352 | result = yield self.transport_layer.get_room_state( |
317 | 353 | destination, room_id, event_id=event_id, |
318 | 354 | ) |
326 | 362 | for p in result.get("auth_chain", []) |
327 | 363 | ] |
328 | 364 | |
365 | seen_events = yield self.store.get_events([ | |
366 | ev.event_id for ev in itertools.chain(pdus, auth_chain) | |
367 | ]) | |
368 | ||
329 | 369 | signed_pdus = yield self._check_sigs_and_hash_and_fetch( |
330 | destination, pdus, outlier=True | |
370 | destination, | |
371 | [p for p in pdus if p.event_id not in seen_events], | |
372 | outlier=True | |
373 | ) | |
374 | signed_pdus.extend( | |
375 | seen_events[p.event_id] for p in pdus if p.event_id in seen_events | |
331 | 376 | ) |
332 | 377 | |
333 | 378 | signed_auth = yield self._check_sigs_and_hash_and_fetch( |
334 | destination, auth_chain, outlier=True | |
379 | destination, | |
380 | [p for p in auth_chain if p.event_id not in seen_events], | |
381 | outlier=True | |
382 | ) | |
383 | signed_auth.extend( | |
384 | seen_events[p.event_id] for p in auth_chain if p.event_id in seen_events | |
335 | 385 | ) |
336 | 386 | |
337 | 387 | signed_auth.sort(key=lambda e: e.depth) |
338 | 388 | |
339 | 389 | defer.returnValue((signed_pdus, signed_auth)) |
390 | ||
391 | @defer.inlineCallbacks | |
392 | def get_events(self, destinations, room_id, event_ids, return_local=True): | |
393 | """Fetch events from some remote destinations, checking if we already | |
394 | have them. | |
395 | ||
396 | Args: | |
397 | destinations (list) | |
398 | room_id (str) | |
399 | event_ids (list) | |
400 | return_local (bool): Whether to include events we already have in | |
401 | the DB in the returned list of events | |
402 | ||
403 | Returns: | |
404 | Deferred: A deferred resolving to a 2-tuple where the first is a list of | |
405 | events and the second is a list of event ids that we failed to fetch. | |
406 | """ | |
407 | if return_local: | |
408 | seen_events = yield self.store.get_events(event_ids) | |
409 | signed_events = seen_events.values() | |
410 | else: | |
411 | seen_events = yield self.store.have_events(event_ids) | |
412 | signed_events = [] | |
413 | ||
414 | failed_to_fetch = set() | |
415 | ||
416 | missing_events = set(event_ids) | |
417 | for k in seen_events: | |
418 | missing_events.discard(k) | |
419 | ||
420 | if not missing_events: | |
421 | defer.returnValue((signed_events, failed_to_fetch)) | |
422 | ||
423 | def random_server_list(): | |
424 | srvs = list(destinations) | |
425 | random.shuffle(srvs) | |
426 | return srvs | |
427 | ||
428 | batch_size = 20 | |
429 | missing_events = list(missing_events) | |
430 | for i in xrange(0, len(missing_events), batch_size): | |
431 | batch = set(missing_events[i:i + batch_size]) | |
432 | ||
433 | deferreds = [ | |
434 | self.get_pdu( | |
435 | destinations=random_server_list(), | |
436 | event_id=e_id, | |
437 | ) | |
438 | for e_id in batch | |
439 | ] | |
440 | ||
441 | res = yield defer.DeferredList(deferreds, consumeErrors=True) | |
442 | for success, result in res: | |
443 | if success: | |
444 | signed_events.append(result) | |
445 | batch.discard(result.event_id) | |
446 | ||
447 | # We removed all events we successfully fetched from `batch` | |
448 | failed_to_fetch.update(batch) | |
449 | ||
450 | defer.returnValue((signed_events, failed_to_fetch)) | |
340 | 451 | |
341 | 452 | @defer.inlineCallbacks |
342 | 453 | @log_function |
413 | 524 | (destination, self.event_from_pdu_json(pdu_dict)) |
414 | 525 | ) |
415 | 526 | break |
416 | except CodeMessageException: | |
417 | raise | |
527 | except CodeMessageException as e: | |
528 | if not 500 <= e.code < 600: | |
529 | raise | |
530 | else: | |
531 | logger.warn( | |
532 | "Failed to make_%s via %s: %s", | |
533 | membership, destination, e.message | |
534 | ) | |
418 | 535 | except Exception as e: |
419 | 536 | logger.warn( |
420 | 537 | "Failed to make_%s via %s: %s", |
421 | 538 | membership, destination, e.message |
422 | 539 | ) |
423 | raise | |
424 | 540 | |
425 | 541 | raise RuntimeError("Failed to send to any server.") |
426 | 542 | |
492 | 608 | "auth_chain": signed_auth, |
493 | 609 | "origin": destination, |
494 | 610 | }) |
495 | except CodeMessageException: | |
496 | raise | |
611 | except CodeMessageException as e: | |
612 | if not 500 <= e.code < 600: | |
613 | raise | |
614 | else: | |
615 | logger.exception( | |
616 | "Failed to send_join via %s: %s", | |
617 | destination, e.message | |
618 | ) | |
497 | 619 | except Exception as e: |
498 | 620 | logger.exception( |
499 | 621 | "Failed to send_join via %s: %s", |
20 | 20 | |
21 | 21 | from synapse.util.async import Linearizer |
22 | 22 | from synapse.util.logutils import log_function |
23 | from synapse.util.caches.response_cache import ResponseCache | |
23 | 24 | from synapse.events import FrozenEvent |
24 | 25 | import synapse.metrics |
25 | 26 | |
26 | from synapse.api.errors import FederationError, SynapseError | |
27 | from synapse.api.errors import AuthError, FederationError, SynapseError | |
27 | 28 | |
28 | 29 | from synapse.crypto.event_signing import compute_event_signature |
29 | 30 | |
47 | 48 | def __init__(self, hs): |
48 | 49 | super(FederationServer, self).__init__(hs) |
49 | 50 | |
51 | self.auth = hs.get_auth() | |
52 | ||
50 | 53 | self._room_pdu_linearizer = Linearizer() |
54 | self._server_linearizer = Linearizer() | |
55 | ||
56 | # We cache responses to state queries, as they take a while and often | |
57 | # come in waves. | |
58 | self._state_resp_cache = ResponseCache(hs, timeout_ms=30000) | |
51 | 59 | |
52 | 60 | def set_handler(self, handler): |
53 | 61 | """Sets the handler that the replication layer will use to communicate |
88 | 96 | @defer.inlineCallbacks |
89 | 97 | @log_function |
90 | 98 | def on_backfill_request(self, origin, room_id, versions, limit): |
91 | pdus = yield self.handler.on_backfill_request( | |
92 | origin, room_id, versions, limit | |
93 | ) | |
94 | ||
95 | defer.returnValue((200, self._transaction_from_pdus(pdus).get_dict())) | |
99 | with (yield self._server_linearizer.queue((origin, room_id))): | |
100 | pdus = yield self.handler.on_backfill_request( | |
101 | origin, room_id, versions, limit | |
102 | ) | |
103 | ||
104 | res = self._transaction_from_pdus(pdus).get_dict() | |
105 | ||
106 | defer.returnValue((200, res)) | |
96 | 107 | |
97 | 108 | @defer.inlineCallbacks |
98 | 109 | @log_function |
183 | 194 | @defer.inlineCallbacks |
184 | 195 | @log_function |
185 | 196 | def on_context_state_request(self, origin, room_id, event_id): |
186 | if event_id: | |
187 | pdus = yield self.handler.get_state_for_pdu( | |
188 | origin, room_id, event_id, | |
189 | ) | |
190 | auth_chain = yield self.store.get_auth_chain( | |
191 | [pdu.event_id for pdu in pdus] | |
192 | ) | |
193 | ||
194 | for event in auth_chain: | |
195 | # We sign these again because there was a bug where we | |
196 | # incorrectly signed things the first time round | |
197 | if self.hs.is_mine_id(event.event_id): | |
198 | event.signatures.update( | |
199 | compute_event_signature( | |
200 | event, | |
201 | self.hs.hostname, | |
202 | self.hs.config.signing_key[0] | |
203 | ) | |
197 | if not event_id: | |
198 | raise NotImplementedError("Specify an event") | |
199 | ||
200 | in_room = yield self.auth.check_host_in_room(room_id, origin) | |
201 | if not in_room: | |
202 | raise AuthError(403, "Host not in room.") | |
203 | ||
204 | result = self._state_resp_cache.get((room_id, event_id)) | |
205 | if not result: | |
206 | with (yield self._server_linearizer.queue((origin, room_id))): | |
207 | resp = yield self._state_resp_cache.set( | |
208 | (room_id, event_id), | |
209 | self._on_context_state_request_compute(room_id, event_id) | |
210 | ) | |
211 | else: | |
212 | resp = yield result | |
213 | ||
214 | defer.returnValue((200, resp)) | |
215 | ||
216 | @defer.inlineCallbacks | |
217 | def on_state_ids_request(self, origin, room_id, event_id): | |
218 | if not event_id: | |
219 | raise NotImplementedError("Specify an event") | |
220 | ||
221 | in_room = yield self.auth.check_host_in_room(room_id, origin) | |
222 | if not in_room: | |
223 | raise AuthError(403, "Host not in room.") | |
224 | ||
225 | pdus = yield self.handler.get_state_for_pdu( | |
226 | room_id, event_id, | |
227 | ) | |
228 | auth_chain = yield self.store.get_auth_chain( | |
229 | [pdu.event_id for pdu in pdus] | |
230 | ) | |
231 | ||
232 | defer.returnValue((200, { | |
233 | "pdu_ids": [pdu.event_id for pdu in pdus], | |
234 | "auth_chain_ids": [pdu.event_id for pdu in auth_chain], | |
235 | })) | |
236 | ||
237 | @defer.inlineCallbacks | |
238 | def _on_context_state_request_compute(self, room_id, event_id): | |
239 | pdus = yield self.handler.get_state_for_pdu( | |
240 | room_id, event_id, | |
241 | ) | |
242 | auth_chain = yield self.store.get_auth_chain( | |
243 | [pdu.event_id for pdu in pdus] | |
244 | ) | |
245 | ||
246 | for event in auth_chain: | |
247 | # We sign these again because there was a bug where we | |
248 | # incorrectly signed things the first time round | |
249 | if self.hs.is_mine_id(event.event_id): | |
250 | event.signatures.update( | |
251 | compute_event_signature( | |
252 | event, | |
253 | self.hs.hostname, | |
254 | self.hs.config.signing_key[0] | |
204 | 255 | ) |
205 | else: | |
206 | raise NotImplementedError("Specify an event") | |
207 | ||
208 | defer.returnValue((200, { | |
256 | ) | |
257 | ||
258 | defer.returnValue({ | |
209 | 259 | "pdus": [pdu.get_pdu_json() for pdu in pdus], |
210 | 260 | "auth_chain": [pdu.get_pdu_json() for pdu in auth_chain], |
211 | })) | |
261 | }) | |
212 | 262 | |
213 | 263 | @defer.inlineCallbacks |
214 | 264 | @log_function |
282 | 332 | |
283 | 333 | @defer.inlineCallbacks |
284 | 334 | def on_event_auth(self, origin, room_id, event_id): |
285 | time_now = self._clock.time_msec() | |
286 | auth_pdus = yield self.handler.on_event_auth(event_id) | |
287 | defer.returnValue((200, { | |
288 | "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], | |
289 | })) | |
290 | ||
291 | @defer.inlineCallbacks | |
292 | def on_query_auth_request(self, origin, content, event_id): | |
335 | with (yield self._server_linearizer.queue((origin, room_id))): | |
336 | time_now = self._clock.time_msec() | |
337 | auth_pdus = yield self.handler.on_event_auth(event_id) | |
338 | res = { | |
339 | "auth_chain": [a.get_pdu_json(time_now) for a in auth_pdus], | |
340 | } | |
341 | defer.returnValue((200, res)) | |
342 | ||
343 | @defer.inlineCallbacks | |
344 | def on_query_auth_request(self, origin, content, room_id, event_id): | |
293 | 345 | """ |
294 | 346 | Content is a dict with keys:: |
295 | 347 | auth_chain (list): A list of events that give the auth chain. |
308 | 360 | Returns: |
309 | 361 | Deferred: Results in `dict` with the same format as `content` |
310 | 362 | """ |
311 | auth_chain = [ | |
312 | self.event_from_pdu_json(e) | |
313 | for e in content["auth_chain"] | |
314 | ] | |
315 | ||
316 | signed_auth = yield self._check_sigs_and_hash_and_fetch( | |
317 | origin, auth_chain, outlier=True | |
318 | ) | |
319 | ||
320 | ret = yield self.handler.on_query_auth( | |
321 | origin, | |
322 | event_id, | |
323 | signed_auth, | |
324 | content.get("rejects", []), | |
325 | content.get("missing", []), | |
326 | ) | |
327 | ||
328 | time_now = self._clock.time_msec() | |
329 | send_content = { | |
330 | "auth_chain": [ | |
331 | e.get_pdu_json(time_now) | |
332 | for e in ret["auth_chain"] | |
333 | ], | |
334 | "rejects": ret.get("rejects", []), | |
335 | "missing": ret.get("missing", []), | |
336 | } | |
363 | with (yield self._server_linearizer.queue((origin, room_id))): | |
364 | auth_chain = [ | |
365 | self.event_from_pdu_json(e) | |
366 | for e in content["auth_chain"] | |
367 | ] | |
368 | ||
369 | signed_auth = yield self._check_sigs_and_hash_and_fetch( | |
370 | origin, auth_chain, outlier=True | |
371 | ) | |
372 | ||
373 | ret = yield self.handler.on_query_auth( | |
374 | origin, | |
375 | event_id, | |
376 | signed_auth, | |
377 | content.get("rejects", []), | |
378 | content.get("missing", []), | |
379 | ) | |
380 | ||
381 | time_now = self._clock.time_msec() | |
382 | send_content = { | |
383 | "auth_chain": [ | |
384 | e.get_pdu_json(time_now) | |
385 | for e in ret["auth_chain"] | |
386 | ], | |
387 | "rejects": ret.get("rejects", []), | |
388 | "missing": ret.get("missing", []), | |
389 | } | |
337 | 390 | |
338 | 391 | defer.returnValue( |
339 | 392 | (200, send_content) |
340 | 393 | ) |
341 | 394 | |
342 | @defer.inlineCallbacks | |
343 | 395 | @log_function |
344 | 396 | def on_query_client_keys(self, origin, content): |
345 | query = [] | |
346 | for user_id, device_ids in content.get("device_keys", {}).items(): | |
347 | if not device_ids: | |
348 | query.append((user_id, None)) | |
349 | else: | |
350 | for device_id in device_ids: | |
351 | query.append((user_id, device_id)) | |
352 | ||
353 | results = yield self.store.get_e2e_device_keys(query) | |
354 | ||
355 | json_result = {} | |
356 | for user_id, device_keys in results.items(): | |
357 | for device_id, json_bytes in device_keys.items(): | |
358 | json_result.setdefault(user_id, {})[device_id] = json.loads( | |
359 | json_bytes | |
360 | ) | |
361 | ||
362 | defer.returnValue({"device_keys": json_result}) | |
397 | return self.on_query_request("client_keys", content) | |
363 | 398 | |
364 | 399 | @defer.inlineCallbacks |
365 | 400 | @log_function |
385 | 420 | @log_function |
386 | 421 | def on_get_missing_events(self, origin, room_id, earliest_events, |
387 | 422 | latest_events, limit, min_depth): |
388 | logger.info( | |
389 | "on_get_missing_events: earliest_events: %r, latest_events: %r," | |
390 | " limit: %d, min_depth: %d", | |
391 | earliest_events, latest_events, limit, min_depth | |
392 | ) | |
393 | missing_events = yield self.handler.on_get_missing_events( | |
394 | origin, room_id, earliest_events, latest_events, limit, min_depth | |
395 | ) | |
396 | ||
397 | if len(missing_events) < 5: | |
398 | logger.info("Returning %d events: %r", len(missing_events), missing_events) | |
399 | else: | |
400 | logger.info("Returning %d events", len(missing_events)) | |
401 | ||
402 | time_now = self._clock.time_msec() | |
423 | with (yield self._server_linearizer.queue((origin, room_id))): | |
424 | logger.info( | |
425 | "on_get_missing_events: earliest_events: %r, latest_events: %r," | |
426 | " limit: %d, min_depth: %d", | |
427 | earliest_events, latest_events, limit, min_depth | |
428 | ) | |
429 | missing_events = yield self.handler.on_get_missing_events( | |
430 | origin, room_id, earliest_events, latest_events, limit, min_depth | |
431 | ) | |
432 | ||
433 | if len(missing_events) < 5: | |
434 | logger.info( | |
435 | "Returning %d events: %r", len(missing_events), missing_events | |
436 | ) | |
437 | else: | |
438 | logger.info("Returning %d events", len(missing_events)) | |
439 | ||
440 | time_now = self._clock.time_msec() | |
403 | 441 | |
404 | 442 | defer.returnValue({ |
405 | 443 | "events": [ev.get_pdu_json(time_now) for ev in missing_events], |
566 | 604 | origin, pdu.room_id, pdu.event_id, |
567 | 605 | ) |
568 | 606 | except: |
569 | logger.warn("Failed to get state for event: %s", pdu.event_id) | |
607 | logger.exception("Failed to get state for event: %s", pdu.event_id) | |
570 | 608 | |
571 | 609 | yield self.handler.on_receive_pdu( |
572 | 610 | origin, |
54 | 54 | ) |
55 | 55 | |
56 | 56 | @log_function |
57 | def get_room_state_ids(self, destination, room_id, event_id): | |
58 | """ Requests all state for a given room from the given server at the | |
59 | given event. Returns the state's event_id's | |
60 | ||
61 | Args: | |
62 | destination (str): The host name of the remote home server we want | |
63 | to get the state from. | |
64 | context (str): The name of the context we want the state of | |
65 | event_id (str): The event we want the context at. | |
66 | ||
67 | Returns: | |
68 | Deferred: Results in a dict received from the remote homeserver. | |
69 | """ | |
70 | logger.debug("get_room_state_ids dest=%s, room=%s", | |
71 | destination, room_id) | |
72 | ||
73 | path = PREFIX + "/state_ids/%s/" % room_id | |
74 | return self.client.get_json( | |
75 | destination, path=path, args={"event_id": event_id}, | |
76 | ) | |
77 | ||
78 | @log_function | |
57 | 79 | def get_event(self, destination, event_id, timeout=None): |
58 | 80 | """ Requests the pdu with give id and origin from the given server. |
59 | 81 |
17 | 17 | from synapse.api.urls import FEDERATION_PREFIX as PREFIX |
18 | 18 | from synapse.api.errors import Codes, SynapseError |
19 | 19 | from synapse.http.server import JsonResource |
20 | from synapse.http.servlet import parse_json_object_from_request, parse_string | |
20 | from synapse.http.servlet import parse_json_object_from_request | |
21 | 21 | from synapse.util.ratelimitutils import FederationRateLimiter |
22 | from synapse.util.versionstring import get_version_string | |
22 | 23 | |
23 | 24 | import functools |
24 | 25 | import logging |
25 | import simplejson as json | |
26 | 26 | import re |
27 | import synapse | |
27 | 28 | |
28 | 29 | |
29 | 30 | logger = logging.getLogger(__name__) |
59 | 60 | ) |
60 | 61 | |
61 | 62 | |
63 | class AuthenticationError(SynapseError): | |
64 | """There was a problem authenticating the request""" | |
65 | pass | |
66 | ||
67 | ||
68 | class NoAuthenticationError(AuthenticationError): | |
69 | """The request had no authentication information""" | |
70 | pass | |
71 | ||
72 | ||
62 | 73 | class Authenticator(object): |
63 | 74 | def __init__(self, hs): |
64 | 75 | self.keyring = hs.get_keyring() |
66 | 77 | |
67 | 78 | # A method just so we can pass 'self' as the authenticator to the Servlets |
68 | 79 | @defer.inlineCallbacks |
69 | def authenticate_request(self, request): | |
80 | def authenticate_request(self, request, content): | |
70 | 81 | json_request = { |
71 | 82 | "method": request.method, |
72 | 83 | "uri": request.uri, |
74 | 85 | "signatures": {}, |
75 | 86 | } |
76 | 87 | |
77 | content = None | |
88 | if content is not None: | |
89 | json_request["content"] = content | |
90 | ||
78 | 91 | origin = None |
79 | ||
80 | if request.method in ["PUT", "POST"]: | |
81 | # TODO: Handle other method types? other content types? | |
82 | try: | |
83 | content_bytes = request.content.read() | |
84 | content = json.loads(content_bytes) | |
85 | json_request["content"] = content | |
86 | except: | |
87 | raise SynapseError(400, "Unable to parse JSON", Codes.BAD_JSON) | |
88 | 92 | |
89 | 93 | def parse_auth_header(header_str): |
90 | 94 | try: |
102 | 106 | sig = strip_quotes(param_dict["sig"]) |
103 | 107 | return (origin, key, sig) |
104 | 108 | except: |
105 | raise SynapseError( | |
109 | raise AuthenticationError( | |
106 | 110 | 400, "Malformed Authorization header", Codes.UNAUTHORIZED |
107 | 111 | ) |
108 | 112 | |
109 | 113 | auth_headers = request.requestHeaders.getRawHeaders(b"Authorization") |
110 | 114 | |
111 | 115 | if not auth_headers: |
112 | raise SynapseError( | |
116 | raise NoAuthenticationError( | |
113 | 117 | 401, "Missing Authorization headers", Codes.UNAUTHORIZED, |
114 | 118 | ) |
115 | 119 | |
120 | 124 | json_request["signatures"].setdefault(origin, {})[key] = sig |
121 | 125 | |
122 | 126 | if not json_request["signatures"]: |
123 | raise SynapseError( | |
127 | raise NoAuthenticationError( | |
124 | 128 | 401, "Missing Authorization headers", Codes.UNAUTHORIZED, |
125 | 129 | ) |
126 | 130 | |
129 | 133 | logger.info("Request from %s", origin) |
130 | 134 | request.authenticated_entity = origin |
131 | 135 | |
132 | defer.returnValue((origin, content)) | |
136 | defer.returnValue(origin) | |
133 | 137 | |
134 | 138 | |
135 | 139 | class BaseFederationServlet(object): |
140 | REQUIRE_AUTH = True | |
141 | ||
136 | 142 | def __init__(self, handler, authenticator, ratelimiter, server_name, |
137 | 143 | room_list_handler): |
138 | 144 | self.handler = handler |
140 | 146 | self.ratelimiter = ratelimiter |
141 | 147 | self.room_list_handler = room_list_handler |
142 | 148 | |
143 | def _wrap(self, code): | |
149 | def _wrap(self, func): | |
144 | 150 | authenticator = self.authenticator |
145 | 151 | ratelimiter = self.ratelimiter |
146 | 152 | |
147 | 153 | @defer.inlineCallbacks |
148 | @functools.wraps(code) | |
149 | def new_code(request, *args, **kwargs): | |
154 | @functools.wraps(func) | |
155 | def new_func(request, *args, **kwargs): | |
156 | content = None | |
157 | if request.method in ["PUT", "POST"]: | |
158 | # TODO: Handle other method types? other content types? | |
159 | content = parse_json_object_from_request(request) | |
160 | ||
150 | 161 | try: |
151 | (origin, content) = yield authenticator.authenticate_request(request) | |
152 | with ratelimiter.ratelimit(origin) as d: | |
153 | yield d | |
154 | response = yield code( | |
155 | origin, content, request.args, *args, **kwargs | |
156 | ) | |
162 | origin = yield authenticator.authenticate_request(request, content) | |
163 | except NoAuthenticationError: | |
164 | origin = None | |
165 | if self.REQUIRE_AUTH: | |
166 | logger.exception("authenticate_request failed") | |
167 | raise | |
157 | 168 | except: |
158 | 169 | logger.exception("authenticate_request failed") |
159 | 170 | raise |
171 | ||
172 | if origin: | |
173 | with ratelimiter.ratelimit(origin) as d: | |
174 | yield d | |
175 | response = yield func( | |
176 | origin, content, request.args, *args, **kwargs | |
177 | ) | |
178 | else: | |
179 | response = yield func( | |
180 | origin, content, request.args, *args, **kwargs | |
181 | ) | |
182 | ||
160 | 183 | defer.returnValue(response) |
161 | 184 | |
162 | 185 | # Extra logic that functools.wraps() doesn't finish |
163 | new_code.__self__ = code.__self__ | |
164 | ||
165 | return new_code | |
186 | new_func.__self__ = func.__self__ | |
187 | ||
188 | return new_func | |
166 | 189 | |
167 | 190 | def register(self, server): |
168 | 191 | pattern = re.compile("^" + PREFIX + self.PATH + "$") |
270 | 293 | ) |
271 | 294 | |
272 | 295 | |
296 | class FederationStateIdsServlet(BaseFederationServlet): | |
297 | PATH = "/state_ids/(?P<room_id>[^/]*)/" | |
298 | ||
299 | def on_GET(self, origin, content, query, room_id): | |
300 | return self.handler.on_state_ids_request( | |
301 | origin, | |
302 | room_id, | |
303 | query.get("event_id", [None])[0], | |
304 | ) | |
305 | ||
306 | ||
273 | 307 | class FederationBackfillServlet(BaseFederationServlet): |
274 | 308 | PATH = "/backfill/(?P<context>[^/]*)/" |
275 | 309 | |
366 | 400 | class FederationClientKeysQueryServlet(BaseFederationServlet): |
367 | 401 | PATH = "/user/keys/query" |
368 | 402 | |
369 | @defer.inlineCallbacks | |
370 | 403 | def on_POST(self, origin, content, query): |
371 | response = yield self.handler.on_query_client_keys(origin, content) | |
372 | defer.returnValue((200, response)) | |
404 | return self.handler.on_query_client_keys(origin, content) | |
373 | 405 | |
374 | 406 | |
375 | 407 | class FederationClientKeysClaimServlet(BaseFederationServlet): |
387 | 419 | @defer.inlineCallbacks |
388 | 420 | def on_POST(self, origin, content, query, context, event_id): |
389 | 421 | new_content = yield self.handler.on_query_auth_request( |
390 | origin, content, event_id | |
422 | origin, content, context, event_id | |
391 | 423 | ) |
392 | 424 | |
393 | 425 | defer.returnValue((200, new_content)) |
419 | 451 | class On3pidBindServlet(BaseFederationServlet): |
420 | 452 | PATH = "/3pid/onbind" |
421 | 453 | |
422 | @defer.inlineCallbacks | |
423 | def on_POST(self, request): | |
424 | content = parse_json_object_from_request(request) | |
454 | REQUIRE_AUTH = False | |
455 | ||
456 | @defer.inlineCallbacks | |
457 | def on_POST(self, origin, content, query): | |
425 | 458 | if "invites" in content: |
426 | 459 | last_exception = None |
427 | 460 | for invite in content["invites"]: |
443 | 476 | raise last_exception |
444 | 477 | defer.returnValue((200, {})) |
445 | 478 | |
446 | # Avoid doing remote HS authorization checks which are done by default by | |
447 | # BaseFederationServlet. | |
448 | def _wrap(self, code): | |
449 | return code | |
450 | ||
451 | 479 | |
452 | 480 | class OpenIdUserInfo(BaseFederationServlet): |
453 | 481 | """ |
468 | 496 | |
469 | 497 | PATH = "/openid/userinfo" |
470 | 498 | |
471 | @defer.inlineCallbacks | |
472 | def on_GET(self, request): | |
473 | token = parse_string(request, "access_token") | |
499 | REQUIRE_AUTH = False | |
500 | ||
501 | @defer.inlineCallbacks | |
502 | def on_GET(self, origin, content, query): | |
503 | token = query.get("access_token", [None])[0] | |
474 | 504 | if token is None: |
475 | 505 | defer.returnValue((401, { |
476 | 506 | "errcode": "M_MISSING_TOKEN", "error": "Access Token required" |
486 | 516 | })) |
487 | 517 | |
488 | 518 | defer.returnValue((200, {"sub": user_id})) |
489 | ||
490 | # Avoid doing remote HS authorization checks which are done by default by | |
491 | # BaseFederationServlet. | |
492 | def _wrap(self, code): | |
493 | return code | |
494 | 519 | |
495 | 520 | |
496 | 521 | class PublicRoomList(BaseFederationServlet): |
532 | 557 | defer.returnValue((200, data)) |
533 | 558 | |
534 | 559 | |
560 | class FederationVersionServlet(BaseFederationServlet): | |
561 | PATH = "/version" | |
562 | ||
563 | REQUIRE_AUTH = False | |
564 | ||
565 | def on_GET(self, origin, content, query): | |
566 | return defer.succeed((200, { | |
567 | "server": { | |
568 | "name": "Synapse", | |
569 | "version": get_version_string(synapse) | |
570 | }, | |
571 | })) | |
572 | ||
573 | ||
535 | 574 | SERVLET_CLASSES = ( |
536 | 575 | FederationSendServlet, |
537 | 576 | FederationPullServlet, |
538 | 577 | FederationEventServlet, |
539 | 578 | FederationStateServlet, |
579 | FederationStateIdsServlet, | |
540 | 580 | FederationBackfillServlet, |
541 | 581 | FederationQueryServlet, |
542 | 582 | FederationMakeJoinServlet, |
554 | 594 | On3pidBindServlet, |
555 | 595 | OpenIdUserInfo, |
556 | 596 | PublicRoomList, |
597 | FederationVersionServlet, | |
557 | 598 | ) |
558 | 599 | |
559 | 600 |
30 | 30 | |
31 | 31 | class Handlers(object): |
32 | 32 | |
33 | """ A collection of all the event handlers. | |
33 | """ Deprecated. A collection of handlers. | |
34 | 34 | |
35 | There's no need to lazily create these; we'll just make them all eagerly | |
36 | at construction time. | |
35 | At some point most of the classes whose name ended "Handler" were | |
36 | accessed through this class. | |
37 | ||
38 | However this makes it painful to unit test the handlers and to run cut | |
39 | down versions of synapse that only use specific handlers because using a | |
40 | single handler required creating all of the handlers. So some of the | |
41 | handlers have been lifted out of the Handlers object and are now accessed | |
42 | directly through the homeserver object itself. | |
43 | ||
44 | Any new handlers should follow the new pattern of being accessed through | |
45 | the homeserver object and should not be added to the Handlers object. | |
46 | ||
47 | The remaining handlers should be moved out of the handlers object. | |
37 | 48 | """ |
38 | 49 | |
39 | 50 | def __init__(self, hs): |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import logging | |
16 | ||
15 | 17 | from twisted.internet import defer |
16 | 18 | |
19 | import synapse.types | |
20 | from synapse.api.constants import Membership, EventTypes | |
17 | 21 | from synapse.api.errors import LimitExceededError |
18 | from synapse.api.constants import Membership, EventTypes | |
19 | from synapse.types import UserID, Requester | |
20 | ||
21 | ||
22 | import logging | |
22 | from synapse.types import UserID | |
23 | 23 | |
24 | 24 | |
25 | 25 | logger = logging.getLogger(__name__) |
30 | 30 | Common base class for the event handlers. |
31 | 31 | |
32 | 32 | Attributes: |
33 | store (synapse.storage.events.StateStore): | |
33 | store (synapse.storage.DataStore): | |
34 | 34 | state_handler (synapse.state.StateHandler): |
35 | 35 | """ |
36 | 36 | |
37 | 37 | def __init__(self, hs): |
38 | """ | |
39 | Args: | |
40 | hs (synapse.server.HomeServer): | |
41 | """ | |
38 | 42 | self.store = hs.get_datastore() |
39 | 43 | self.auth = hs.get_auth() |
40 | 44 | self.notifier = hs.get_notifier() |
119 | 123 | # and having homeservers have their own users leave keeps more |
120 | 124 | # of that decision-making and control local to the guest-having |
121 | 125 | # homeserver. |
122 | requester = Requester(target_user, "", True) | |
126 | requester = synapse.types.create_requester( | |
127 | target_user, is_guest=True) | |
123 | 128 | handler = self.hs.get_handlers().room_member_handler |
124 | 129 | yield handler.update_membership( |
125 | 130 | requester, |
19 | 19 | from synapse.types import UserID |
20 | 20 | from synapse.api.errors import AuthError, LoginError, Codes, StoreError, SynapseError |
21 | 21 | from synapse.util.async import run_on_reactor |
22 | from synapse.config.ldap import LDAPMode | |
22 | 23 | |
23 | 24 | from twisted.web.client import PartialDownloadError |
24 | 25 | |
27 | 28 | import pymacaroons |
28 | 29 | import simplejson |
29 | 30 | |
31 | try: | |
32 | import ldap3 | |
33 | except ImportError: | |
34 | ldap3 = None | |
35 | pass | |
36 | ||
30 | 37 | import synapse.util.stringutils as stringutils |
31 | 38 | |
32 | 39 | |
37 | 44 | SESSION_EXPIRE_MS = 48 * 60 * 60 * 1000 |
38 | 45 | |
39 | 46 | def __init__(self, hs): |
47 | """ | |
48 | Args: | |
49 | hs (synapse.server.HomeServer): | |
50 | """ | |
40 | 51 | super(AuthHandler, self).__init__(hs) |
41 | 52 | self.checkers = { |
42 | 53 | LoginType.PASSWORD: self._check_password_auth, |
49 | 60 | self.INVALID_TOKEN_HTTP_STATUS = 401 |
50 | 61 | |
51 | 62 | self.ldap_enabled = hs.config.ldap_enabled |
52 | self.ldap_server = hs.config.ldap_server | |
53 | self.ldap_port = hs.config.ldap_port | |
54 | self.ldap_tls = hs.config.ldap_tls | |
55 | self.ldap_search_base = hs.config.ldap_search_base | |
56 | self.ldap_search_property = hs.config.ldap_search_property | |
57 | self.ldap_email_property = hs.config.ldap_email_property | |
58 | self.ldap_full_name_property = hs.config.ldap_full_name_property | |
59 | ||
60 | if self.ldap_enabled is True: | |
61 | import ldap | |
62 | logger.info("Import ldap version: %s", ldap.__version__) | |
63 | if self.ldap_enabled: | |
64 | if not ldap3: | |
65 | raise RuntimeError( | |
66 | 'Missing ldap3 library. This is required for LDAP Authentication.' | |
67 | ) | |
68 | self.ldap_mode = hs.config.ldap_mode | |
69 | self.ldap_uri = hs.config.ldap_uri | |
70 | self.ldap_start_tls = hs.config.ldap_start_tls | |
71 | self.ldap_base = hs.config.ldap_base | |
72 | self.ldap_filter = hs.config.ldap_filter | |
73 | self.ldap_attributes = hs.config.ldap_attributes | |
74 | if self.ldap_mode == LDAPMode.SEARCH: | |
75 | self.ldap_bind_dn = hs.config.ldap_bind_dn | |
76 | self.ldap_bind_password = hs.config.ldap_bind_password | |
63 | 77 | |
64 | 78 | self.hs = hs # FIXME better possibility to access registrationHandler later? |
79 | self.device_handler = hs.get_device_handler() | |
65 | 80 | |
66 | 81 | @defer.inlineCallbacks |
67 | 82 | def check_auth(self, flows, clientdict, clientip): |
219 | 234 | sess = self._get_session_info(session_id) |
220 | 235 | return sess.setdefault('serverdict', {}).get(key, default) |
221 | 236 | |
222 | @defer.inlineCallbacks | |
223 | 237 | def _check_password_auth(self, authdict, _): |
224 | 238 | if "user" not in authdict or "password" not in authdict: |
225 | 239 | raise LoginError(400, "", Codes.MISSING_PARAM) |
229 | 243 | if not user_id.startswith('@'): |
230 | 244 | user_id = UserID.create(user_id, self.hs.hostname).to_string() |
231 | 245 | |
232 | if not (yield self._check_password(user_id, password)): | |
233 | logger.warn("Failed password login for user %s", user_id) | |
234 | raise LoginError(403, "", errcode=Codes.FORBIDDEN) | |
235 | ||
236 | defer.returnValue(user_id) | |
246 | return self._check_password(user_id, password) | |
237 | 247 | |
238 | 248 | @defer.inlineCallbacks |
239 | 249 | def _check_recaptcha(self, authdict, clientip): |
269 | 279 | data = pde.response |
270 | 280 | resp_body = simplejson.loads(data) |
271 | 281 | |
272 | if 'success' in resp_body and resp_body['success']: | |
273 | defer.returnValue(True) | |
282 | if 'success' in resp_body: | |
283 | # Note that we do NOT check the hostname here: we explicitly | |
284 | # intend the CAPTCHA to be presented by whatever client the | |
285 | # user is using, we just care that they have completed a CAPTCHA. | |
286 | logger.info( | |
287 | "%s reCAPTCHA from hostname %s", | |
288 | "Successful" if resp_body['success'] else "Failed", | |
289 | resp_body.get('hostname') | |
290 | ) | |
291 | if resp_body['success']: | |
292 | defer.returnValue(True) | |
274 | 293 | raise LoginError(401, "", errcode=Codes.UNAUTHORIZED) |
275 | 294 | |
276 | 295 | @defer.inlineCallbacks |
337 | 356 | |
338 | 357 | return self.sessions[session_id] |
339 | 358 | |
340 | @defer.inlineCallbacks | |
341 | def login_with_password(self, user_id, password): | |
359 | def validate_password_login(self, user_id, password): | |
342 | 360 | """ |
343 | 361 | Authenticates the user with their username and password. |
344 | 362 | |
345 | 363 | Used only by the v1 login API. |
346 | 364 | |
347 | 365 | Args: |
348 | user_id (str): User ID | |
366 | user_id (str): complete @user:id | |
349 | 367 | password (str): Password |
350 | 368 | Returns: |
369 | defer.Deferred: (str) canonical user id | |
370 | Raises: | |
371 | StoreError if there was a problem accessing the database | |
372 | LoginError if there was an authentication problem. | |
373 | """ | |
374 | return self._check_password(user_id, password) | |
375 | ||
376 | @defer.inlineCallbacks | |
377 | def get_login_tuple_for_user_id(self, user_id, device_id=None, | |
378 | initial_display_name=None): | |
379 | """ | |
380 | Gets login tuple for the user with the given user ID. | |
381 | ||
382 | Creates a new access/refresh token for the user. | |
383 | ||
384 | The user is assumed to have been authenticated by some other | |
385 | machanism (e.g. CAS), and the user_id converted to the canonical case. | |
386 | ||
387 | The device will be recorded in the table if it is not there already. | |
388 | ||
389 | Args: | |
390 | user_id (str): canonical User ID | |
391 | device_id (str|None): the device ID to associate with the tokens. | |
392 | None to leave the tokens unassociated with a device (deprecated: | |
393 | we should always have a device ID) | |
394 | initial_display_name (str): display name to associate with the | |
395 | device if it needs re-registering | |
396 | Returns: | |
351 | 397 | A tuple of: |
352 | The user's ID. | |
353 | 398 | The access token for the user's session. |
354 | 399 | The refresh token for the user's session. |
355 | 400 | Raises: |
356 | 401 | StoreError if there was a problem storing the token. |
357 | 402 | LoginError if there was an authentication problem. |
358 | 403 | """ |
359 | ||
360 | if not (yield self._check_password(user_id, password)): | |
361 | logger.warn("Failed password login for user %s", user_id) | |
362 | raise LoginError(403, "", errcode=Codes.FORBIDDEN) | |
363 | ||
364 | logger.info("Logging in user %s", user_id) | |
365 | access_token = yield self.issue_access_token(user_id) | |
366 | refresh_token = yield self.issue_refresh_token(user_id) | |
367 | defer.returnValue((user_id, access_token, refresh_token)) | |
368 | ||
369 | @defer.inlineCallbacks | |
370 | def get_login_tuple_for_user_id(self, user_id): | |
371 | """ | |
372 | Gets login tuple for the user with the given user ID. | |
373 | The user is assumed to have been authenticated by some other | |
374 | machanism (e.g. CAS) | |
375 | ||
376 | Args: | |
377 | user_id (str): User ID | |
378 | Returns: | |
379 | A tuple of: | |
380 | The user's ID. | |
381 | The access token for the user's session. | |
382 | The refresh token for the user's session. | |
383 | Raises: | |
384 | StoreError if there was a problem storing the token. | |
385 | LoginError if there was an authentication problem. | |
386 | """ | |
387 | user_id, ignored = yield self._find_user_id_and_pwd_hash(user_id) | |
388 | ||
389 | logger.info("Logging in user %s", user_id) | |
390 | access_token = yield self.issue_access_token(user_id) | |
391 | refresh_token = yield self.issue_refresh_token(user_id) | |
392 | defer.returnValue((user_id, access_token, refresh_token)) | |
393 | ||
394 | @defer.inlineCallbacks | |
395 | def does_user_exist(self, user_id): | |
404 | logger.info("Logging in user %s on device %s", user_id, device_id) | |
405 | access_token = yield self.issue_access_token(user_id, device_id) | |
406 | refresh_token = yield self.issue_refresh_token(user_id, device_id) | |
407 | ||
408 | # the device *should* have been registered before we got here; however, | |
409 | # it's possible we raced against a DELETE operation. The thing we | |
410 | # really don't want is active access_tokens without a record of the | |
411 | # device, so we double-check it here. | |
412 | if device_id is not None: | |
413 | yield self.device_handler.check_device_registered( | |
414 | user_id, device_id, initial_display_name | |
415 | ) | |
416 | ||
417 | defer.returnValue((access_token, refresh_token)) | |
418 | ||
419 | @defer.inlineCallbacks | |
420 | def check_user_exists(self, user_id): | |
421 | """ | |
422 | Checks to see if a user with the given id exists. Will check case | |
423 | insensitively, but return None if there are multiple inexact matches. | |
424 | ||
425 | Args: | |
426 | (str) user_id: complete @user:id | |
427 | ||
428 | Returns: | |
429 | defer.Deferred: (str) canonical_user_id, or None if zero or | |
430 | multiple matches | |
431 | """ | |
396 | 432 | try: |
397 | yield self._find_user_id_and_pwd_hash(user_id) | |
398 | defer.returnValue(True) | |
433 | res = yield self._find_user_id_and_pwd_hash(user_id) | |
434 | defer.returnValue(res[0]) | |
399 | 435 | except LoginError: |
400 | defer.returnValue(False) | |
436 | defer.returnValue(None) | |
401 | 437 | |
402 | 438 | @defer.inlineCallbacks |
403 | 439 | def _find_user_id_and_pwd_hash(self, user_id): |
427 | 463 | |
428 | 464 | @defer.inlineCallbacks |
429 | 465 | def _check_password(self, user_id, password): |
430 | """ | |
431 | Returns: | |
432 | True if the user_id successfully authenticated | |
466 | """Authenticate a user against the LDAP and local databases. | |
467 | ||
468 | user_id is checked case insensitively against the local database, but | |
469 | will throw if there are multiple inexact matches. | |
470 | ||
471 | Args: | |
472 | user_id (str): complete @user:id | |
473 | Returns: | |
474 | (str) the canonical_user_id | |
475 | Raises: | |
476 | LoginError if the password was incorrect | |
433 | 477 | """ |
434 | 478 | valid_ldap = yield self._check_ldap_password(user_id, password) |
435 | 479 | if valid_ldap: |
480 | defer.returnValue(user_id) | |
481 | ||
482 | result = yield self._check_local_password(user_id, password) | |
483 | defer.returnValue(result) | |
484 | ||
485 | @defer.inlineCallbacks | |
486 | def _check_local_password(self, user_id, password): | |
487 | """Authenticate a user against the local password database. | |
488 | ||
489 | user_id is checked case insensitively, but will throw if there are | |
490 | multiple inexact matches. | |
491 | ||
492 | Args: | |
493 | user_id (str): complete @user:id | |
494 | Returns: | |
495 | (str) the canonical_user_id | |
496 | Raises: | |
497 | LoginError if the password was incorrect | |
498 | """ | |
499 | user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) | |
500 | result = self.validate_hash(password, password_hash) | |
501 | if not result: | |
502 | logger.warn("Failed password login for user %s", user_id) | |
503 | raise LoginError(403, "", errcode=Codes.FORBIDDEN) | |
504 | defer.returnValue(user_id) | |
505 | ||
506 | @defer.inlineCallbacks | |
507 | def _check_ldap_password(self, user_id, password): | |
508 | """ Attempt to authenticate a user against an LDAP Server | |
509 | and register an account if none exists. | |
510 | ||
511 | Returns: | |
512 | True if authentication against LDAP was successful | |
513 | """ | |
514 | ||
515 | if not ldap3 or not self.ldap_enabled: | |
516 | defer.returnValue(False) | |
517 | ||
518 | if self.ldap_mode not in LDAPMode.LIST: | |
519 | raise RuntimeError( | |
520 | 'Invalid ldap mode specified: {mode}'.format( | |
521 | mode=self.ldap_mode | |
522 | ) | |
523 | ) | |
524 | ||
525 | try: | |
526 | server = ldap3.Server(self.ldap_uri) | |
527 | logger.debug( | |
528 | "Attempting ldap connection with %s", | |
529 | self.ldap_uri | |
530 | ) | |
531 | ||
532 | localpart = UserID.from_string(user_id).localpart | |
533 | if self.ldap_mode == LDAPMode.SIMPLE: | |
534 | # bind with the the local users ldap credentials | |
535 | bind_dn = "{prop}={value},{base}".format( | |
536 | prop=self.ldap_attributes['uid'], | |
537 | value=localpart, | |
538 | base=self.ldap_base | |
539 | ) | |
540 | conn = ldap3.Connection(server, bind_dn, password) | |
541 | logger.debug( | |
542 | "Established ldap connection in simple mode: %s", | |
543 | conn | |
544 | ) | |
545 | ||
546 | if self.ldap_start_tls: | |
547 | conn.start_tls() | |
548 | logger.debug( | |
549 | "Upgraded ldap connection in simple mode through StartTLS: %s", | |
550 | conn | |
551 | ) | |
552 | ||
553 | conn.bind() | |
554 | ||
555 | elif self.ldap_mode == LDAPMode.SEARCH: | |
556 | # connect with preconfigured credentials and search for local user | |
557 | conn = ldap3.Connection( | |
558 | server, | |
559 | self.ldap_bind_dn, | |
560 | self.ldap_bind_password | |
561 | ) | |
562 | logger.debug( | |
563 | "Established ldap connection in search mode: %s", | |
564 | conn | |
565 | ) | |
566 | ||
567 | if self.ldap_start_tls: | |
568 | conn.start_tls() | |
569 | logger.debug( | |
570 | "Upgraded ldap connection in search mode through StartTLS: %s", | |
571 | conn | |
572 | ) | |
573 | ||
574 | conn.bind() | |
575 | ||
576 | # find matching dn | |
577 | query = "({prop}={value})".format( | |
578 | prop=self.ldap_attributes['uid'], | |
579 | value=localpart | |
580 | ) | |
581 | if self.ldap_filter: | |
582 | query = "(&{query}{filter})".format( | |
583 | query=query, | |
584 | filter=self.ldap_filter | |
585 | ) | |
586 | logger.debug("ldap search filter: %s", query) | |
587 | result = conn.search(self.ldap_base, query) | |
588 | ||
589 | if result and len(conn.response) == 1: | |
590 | # found exactly one result | |
591 | user_dn = conn.response[0]['dn'] | |
592 | logger.debug('ldap search found dn: %s', user_dn) | |
593 | ||
594 | # unbind and reconnect, rebind with found dn | |
595 | conn.unbind() | |
596 | conn = ldap3.Connection( | |
597 | server, | |
598 | user_dn, | |
599 | password, | |
600 | auto_bind=True | |
601 | ) | |
602 | else: | |
603 | # found 0 or > 1 results, abort! | |
604 | logger.warn( | |
605 | "ldap search returned unexpected (%d!=1) amount of results", | |
606 | len(conn.response) | |
607 | ) | |
608 | defer.returnValue(False) | |
609 | ||
610 | logger.info( | |
611 | "User authenticated against ldap server: %s", | |
612 | conn | |
613 | ) | |
614 | ||
615 | # check for existing account, if none exists, create one | |
616 | if not (yield self.check_user_exists(user_id)): | |
617 | # query user metadata for account creation | |
618 | query = "({prop}={value})".format( | |
619 | prop=self.ldap_attributes['uid'], | |
620 | value=localpart | |
621 | ) | |
622 | ||
623 | if self.ldap_mode == LDAPMode.SEARCH and self.ldap_filter: | |
624 | query = "(&{filter}{user_filter})".format( | |
625 | filter=query, | |
626 | user_filter=self.ldap_filter | |
627 | ) | |
628 | logger.debug("ldap registration filter: %s", query) | |
629 | ||
630 | result = conn.search( | |
631 | search_base=self.ldap_base, | |
632 | search_filter=query, | |
633 | attributes=[ | |
634 | self.ldap_attributes['name'], | |
635 | self.ldap_attributes['mail'] | |
636 | ] | |
637 | ) | |
638 | ||
639 | if len(conn.response) == 1: | |
640 | attrs = conn.response[0]['attributes'] | |
641 | mail = attrs[self.ldap_attributes['mail']][0] | |
642 | name = attrs[self.ldap_attributes['name']][0] | |
643 | ||
644 | # create account | |
645 | registration_handler = self.hs.get_handlers().registration_handler | |
646 | user_id, access_token = ( | |
647 | yield registration_handler.register(localpart=localpart) | |
648 | ) | |
649 | ||
650 | # TODO: bind email, set displayname with data from ldap directory | |
651 | ||
652 | logger.info( | |
653 | "ldap registration successful: %d: %s (%s, %)", | |
654 | user_id, | |
655 | localpart, | |
656 | name, | |
657 | ||
658 | ) | |
659 | else: | |
660 | logger.warn( | |
661 | "ldap registration failed: unexpected (%d!=1) amount of results", | |
662 | len(result) | |
663 | ) | |
664 | defer.returnValue(False) | |
665 | ||
436 | 666 | defer.returnValue(True) |
437 | ||
438 | valid_local_password = yield self._check_local_password(user_id, password) | |
439 | if valid_local_password: | |
440 | defer.returnValue(True) | |
441 | ||
442 | defer.returnValue(False) | |
443 | ||
444 | @defer.inlineCallbacks | |
445 | def _check_local_password(self, user_id, password): | |
446 | try: | |
447 | user_id, password_hash = yield self._find_user_id_and_pwd_hash(user_id) | |
448 | defer.returnValue(self.validate_hash(password, password_hash)) | |
449 | except LoginError: | |
667 | except ldap3.core.exceptions.LDAPException as e: | |
668 | logger.warn("Error during ldap authentication: %s", e) | |
450 | 669 | defer.returnValue(False) |
451 | 670 | |
452 | 671 | @defer.inlineCallbacks |
453 | def _check_ldap_password(self, user_id, password): | |
454 | if not self.ldap_enabled: | |
455 | logger.debug("LDAP not configured") | |
456 | defer.returnValue(False) | |
457 | ||
458 | import ldap | |
459 | ||
460 | logger.info("Authenticating %s with LDAP" % user_id) | |
461 | try: | |
462 | ldap_url = "%s:%s" % (self.ldap_server, self.ldap_port) | |
463 | logger.debug("Connecting LDAP server at %s" % ldap_url) | |
464 | l = ldap.initialize(ldap_url) | |
465 | if self.ldap_tls: | |
466 | logger.debug("Initiating TLS") | |
467 | self._connection.start_tls_s() | |
468 | ||
469 | local_name = UserID.from_string(user_id).localpart | |
470 | ||
471 | dn = "%s=%s, %s" % ( | |
472 | self.ldap_search_property, | |
473 | local_name, | |
474 | self.ldap_search_base) | |
475 | logger.debug("DN for LDAP authentication: %s" % dn) | |
476 | ||
477 | l.simple_bind_s(dn.encode('utf-8'), password.encode('utf-8')) | |
478 | ||
479 | if not (yield self.does_user_exist(user_id)): | |
480 | handler = self.hs.get_handlers().registration_handler | |
481 | user_id, access_token = ( | |
482 | yield handler.register(localpart=local_name) | |
483 | ) | |
484 | ||
485 | defer.returnValue(True) | |
486 | except ldap.LDAPError, e: | |
487 | logger.warn("LDAP error: %s", e) | |
488 | defer.returnValue(False) | |
489 | ||
490 | @defer.inlineCallbacks | |
491 | def issue_access_token(self, user_id): | |
672 | def issue_access_token(self, user_id, device_id=None): | |
492 | 673 | access_token = self.generate_access_token(user_id) |
493 | yield self.store.add_access_token_to_user(user_id, access_token) | |
674 | yield self.store.add_access_token_to_user(user_id, access_token, | |
675 | device_id) | |
494 | 676 | defer.returnValue(access_token) |
495 | 677 | |
496 | 678 | @defer.inlineCallbacks |
497 | def issue_refresh_token(self, user_id): | |
679 | def issue_refresh_token(self, user_id, device_id=None): | |
498 | 680 | refresh_token = self.generate_refresh_token(user_id) |
499 | yield self.store.add_refresh_token_to_user(user_id, refresh_token) | |
681 | yield self.store.add_refresh_token_to_user(user_id, refresh_token, | |
682 | device_id) | |
500 | 683 | defer.returnValue(refresh_token) |
501 | 684 | |
502 | def generate_access_token(self, user_id, extra_caveats=None): | |
685 | def generate_access_token(self, user_id, extra_caveats=None, | |
686 | duration_in_ms=(60 * 60 * 1000)): | |
503 | 687 | extra_caveats = extra_caveats or [] |
504 | 688 | macaroon = self._generate_base_macaroon(user_id) |
505 | 689 | macaroon.add_first_party_caveat("type = access") |
506 | 690 | now = self.hs.get_clock().time_msec() |
507 | expiry = now + (60 * 60 * 1000) | |
691 | expiry = now + duration_in_ms | |
508 | 692 | macaroon.add_first_party_caveat("time < %d" % (expiry,)) |
509 | 693 | for caveat in extra_caveats: |
510 | 694 | macaroon.add_first_party_caveat(caveat) |
612 | 796 | Returns: |
613 | 797 | Hashed password (str). |
614 | 798 | """ |
615 | return bcrypt.hashpw(password, bcrypt.gensalt(self.bcrypt_rounds)) | |
799 | return bcrypt.hashpw(password + self.hs.config.password_pepper, | |
800 | bcrypt.gensalt(self.bcrypt_rounds)) | |
616 | 801 | |
617 | 802 | def validate_hash(self, password, stored_hash): |
618 | 803 | """Validates that self.hash(password) == stored_hash. |
625 | 810 | Whether self.hash(password) == stored_hash (bool). |
626 | 811 | """ |
627 | 812 | if stored_hash: |
628 | return bcrypt.hashpw(password, stored_hash.encode('utf-8')) == stored_hash | |
813 | return bcrypt.hashpw(password + self.hs.config.password_pepper, | |
814 | stored_hash.encode('utf-8')) == stored_hash | |
629 | 815 | else: |
630 | 816 | return False |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from synapse.api import errors | |
16 | from synapse.util import stringutils | |
17 | from twisted.internet import defer | |
18 | from ._base import BaseHandler | |
19 | ||
20 | import logging | |
21 | ||
22 | logger = logging.getLogger(__name__) | |
23 | ||
24 | ||
25 | class DeviceHandler(BaseHandler): | |
26 | def __init__(self, hs): | |
27 | super(DeviceHandler, self).__init__(hs) | |
28 | ||
29 | @defer.inlineCallbacks | |
30 | def check_device_registered(self, user_id, device_id, | |
31 | initial_device_display_name=None): | |
32 | """ | |
33 | If the given device has not been registered, register it with the | |
34 | supplied display name. | |
35 | ||
36 | If no device_id is supplied, we make one up. | |
37 | ||
38 | Args: | |
39 | user_id (str): @user:id | |
40 | device_id (str | None): device id supplied by client | |
41 | initial_device_display_name (str | None): device display name from | |
42 | client | |
43 | Returns: | |
44 | str: device id (generated if none was supplied) | |
45 | """ | |
46 | if device_id is not None: | |
47 | yield self.store.store_device( | |
48 | user_id=user_id, | |
49 | device_id=device_id, | |
50 | initial_device_display_name=initial_device_display_name, | |
51 | ignore_if_known=True, | |
52 | ) | |
53 | defer.returnValue(device_id) | |
54 | ||
55 | # if the device id is not specified, we'll autogen one, but loop a few | |
56 | # times in case of a clash. | |
57 | attempts = 0 | |
58 | while attempts < 5: | |
59 | try: | |
60 | device_id = stringutils.random_string_with_symbols(16) | |
61 | yield self.store.store_device( | |
62 | user_id=user_id, | |
63 | device_id=device_id, | |
64 | initial_device_display_name=initial_device_display_name, | |
65 | ignore_if_known=False, | |
66 | ) | |
67 | defer.returnValue(device_id) | |
68 | except errors.StoreError: | |
69 | attempts += 1 | |
70 | ||
71 | raise errors.StoreError(500, "Couldn't generate a device ID.") | |
72 | ||
73 | @defer.inlineCallbacks | |
74 | def get_devices_by_user(self, user_id): | |
75 | """ | |
76 | Retrieve the given user's devices | |
77 | ||
78 | Args: | |
79 | user_id (str): | |
80 | Returns: | |
81 | defer.Deferred: list[dict[str, X]]: info on each device | |
82 | """ | |
83 | ||
84 | device_map = yield self.store.get_devices_by_user(user_id) | |
85 | ||
86 | ips = yield self.store.get_last_client_ip_by_device( | |
87 | devices=((user_id, device_id) for device_id in device_map.keys()) | |
88 | ) | |
89 | ||
90 | devices = device_map.values() | |
91 | for device in devices: | |
92 | _update_device_from_client_ips(device, ips) | |
93 | ||
94 | defer.returnValue(devices) | |
95 | ||
96 | @defer.inlineCallbacks | |
97 | def get_device(self, user_id, device_id): | |
98 | """ Retrieve the given device | |
99 | ||
100 | Args: | |
101 | user_id (str): | |
102 | device_id (str): | |
103 | ||
104 | Returns: | |
105 | defer.Deferred: dict[str, X]: info on the device | |
106 | Raises: | |
107 | errors.NotFoundError: if the device was not found | |
108 | """ | |
109 | try: | |
110 | device = yield self.store.get_device(user_id, device_id) | |
111 | except errors.StoreError: | |
112 | raise errors.NotFoundError | |
113 | ips = yield self.store.get_last_client_ip_by_device( | |
114 | devices=((user_id, device_id),) | |
115 | ) | |
116 | _update_device_from_client_ips(device, ips) | |
117 | defer.returnValue(device) | |
118 | ||
119 | @defer.inlineCallbacks | |
120 | def delete_device(self, user_id, device_id): | |
121 | """ Delete the given device | |
122 | ||
123 | Args: | |
124 | user_id (str): | |
125 | device_id (str): | |
126 | ||
127 | Returns: | |
128 | defer.Deferred: | |
129 | """ | |
130 | ||
131 | try: | |
132 | yield self.store.delete_device(user_id, device_id) | |
133 | except errors.StoreError, e: | |
134 | if e.code == 404: | |
135 | # no match | |
136 | pass | |
137 | else: | |
138 | raise | |
139 | ||
140 | yield self.store.user_delete_access_tokens( | |
141 | user_id, device_id=device_id, | |
142 | delete_refresh_tokens=True, | |
143 | ) | |
144 | ||
145 | yield self.store.delete_e2e_keys_by_device( | |
146 | user_id=user_id, device_id=device_id | |
147 | ) | |
148 | ||
149 | @defer.inlineCallbacks | |
150 | def update_device(self, user_id, device_id, content): | |
151 | """ Update the given device | |
152 | ||
153 | Args: | |
154 | user_id (str): | |
155 | device_id (str): | |
156 | content (dict): body of update request | |
157 | ||
158 | Returns: | |
159 | defer.Deferred: | |
160 | """ | |
161 | ||
162 | try: | |
163 | yield self.store.update_device( | |
164 | user_id, | |
165 | device_id, | |
166 | new_display_name=content.get("display_name") | |
167 | ) | |
168 | except errors.StoreError, e: | |
169 | if e.code == 404: | |
170 | raise errors.NotFoundError() | |
171 | else: | |
172 | raise | |
173 | ||
174 | ||
175 | def _update_device_from_client_ips(device, client_ips): | |
176 | ip = client_ips.get((device["user_id"], device["device_id"]), {}) | |
177 | device.update({ | |
178 | "last_seen_ts": ip.get("last_seen"), | |
179 | "last_seen_ip": ip.get("ip"), | |
180 | }) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import collections | |
16 | import json | |
17 | import logging | |
18 | ||
19 | from twisted.internet import defer | |
20 | ||
21 | from synapse.api import errors | |
22 | import synapse.types | |
23 | ||
24 | logger = logging.getLogger(__name__) | |
25 | ||
26 | ||
27 | class E2eKeysHandler(object): | |
28 | def __init__(self, hs): | |
29 | self.store = hs.get_datastore() | |
30 | self.federation = hs.get_replication_layer() | |
31 | self.is_mine_id = hs.is_mine_id | |
32 | self.server_name = hs.hostname | |
33 | ||
34 | # doesn't really work as part of the generic query API, because the | |
35 | # query request requires an object POST, but we abuse the | |
36 | # "query handler" interface. | |
37 | self.federation.register_query_handler( | |
38 | "client_keys", self.on_federation_query_client_keys | |
39 | ) | |
40 | ||
41 | @defer.inlineCallbacks | |
42 | def query_devices(self, query_body): | |
43 | """ Handle a device key query from a client | |
44 | ||
45 | { | |
46 | "device_keys": { | |
47 | "<user_id>": ["<device_id>"] | |
48 | } | |
49 | } | |
50 | -> | |
51 | { | |
52 | "device_keys": { | |
53 | "<user_id>": { | |
54 | "<device_id>": { | |
55 | ... | |
56 | } | |
57 | } | |
58 | } | |
59 | } | |
60 | """ | |
61 | device_keys_query = query_body.get("device_keys", {}) | |
62 | ||
63 | # separate users by domain. | |
64 | # make a map from domain to user_id to device_ids | |
65 | queries_by_domain = collections.defaultdict(dict) | |
66 | for user_id, device_ids in device_keys_query.items(): | |
67 | user = synapse.types.UserID.from_string(user_id) | |
68 | queries_by_domain[user.domain][user_id] = device_ids | |
69 | ||
70 | # do the queries | |
71 | # TODO: do these in parallel | |
72 | results = {} | |
73 | for destination, destination_query in queries_by_domain.items(): | |
74 | if destination == self.server_name: | |
75 | res = yield self.query_local_devices(destination_query) | |
76 | else: | |
77 | res = yield self.federation.query_client_keys( | |
78 | destination, {"device_keys": destination_query} | |
79 | ) | |
80 | res = res["device_keys"] | |
81 | for user_id, keys in res.items(): | |
82 | if user_id in destination_query: | |
83 | results[user_id] = keys | |
84 | ||
85 | defer.returnValue((200, {"device_keys": results})) | |
86 | ||
87 | @defer.inlineCallbacks | |
88 | def query_local_devices(self, query): | |
89 | """Get E2E device keys for local users | |
90 | ||
91 | Args: | |
92 | query (dict[string, list[string]|None): map from user_id to a list | |
93 | of devices to query (None for all devices) | |
94 | ||
95 | Returns: | |
96 | defer.Deferred: (resolves to dict[string, dict[string, dict]]): | |
97 | map from user_id -> device_id -> device details | |
98 | """ | |
99 | local_query = [] | |
100 | ||
101 | result_dict = {} | |
102 | for user_id, device_ids in query.items(): | |
103 | if not self.is_mine_id(user_id): | |
104 | logger.warning("Request for keys for non-local user %s", | |
105 | user_id) | |
106 | raise errors.SynapseError(400, "Not a user here") | |
107 | ||
108 | if not device_ids: | |
109 | local_query.append((user_id, None)) | |
110 | else: | |
111 | for device_id in device_ids: | |
112 | local_query.append((user_id, device_id)) | |
113 | ||
114 | # make sure that each queried user appears in the result dict | |
115 | result_dict[user_id] = {} | |
116 | ||
117 | results = yield self.store.get_e2e_device_keys(local_query) | |
118 | ||
119 | # Build the result structure, un-jsonify the results, and add the | |
120 | # "unsigned" section | |
121 | for user_id, device_keys in results.items(): | |
122 | for device_id, device_info in device_keys.items(): | |
123 | r = json.loads(device_info["key_json"]) | |
124 | r["unsigned"] = {} | |
125 | display_name = device_info["device_display_name"] | |
126 | if display_name is not None: | |
127 | r["unsigned"]["device_display_name"] = display_name | |
128 | result_dict[user_id][device_id] = r | |
129 | ||
130 | defer.returnValue(result_dict) | |
131 | ||
132 | @defer.inlineCallbacks | |
133 | def on_federation_query_client_keys(self, query_body): | |
134 | """ Handle a device key query from a federated server | |
135 | """ | |
136 | device_keys_query = query_body.get("device_keys", {}) | |
137 | res = yield self.query_local_devices(device_keys_query) | |
138 | defer.returnValue({"device_keys": res}) |
123 | 123 | |
124 | 124 | try: |
125 | 125 | event_stream_id, max_stream_id = yield self._persist_auth_tree( |
126 | auth_chain, state, event | |
126 | origin, auth_chain, state, event | |
127 | 127 | ) |
128 | 128 | except AuthError as e: |
129 | 129 | raise FederationError( |
334 | 334 | state_events.update({s.event_id: s for s in state}) |
335 | 335 | events_to_state[e_id] = state |
336 | 336 | |
337 | required_auth = set( | |
338 | a_id | |
339 | for event in events + state_events.values() + auth_events.values() | |
340 | for a_id, _ in event.auth_events | |
341 | ) | |
342 | auth_events.update({ | |
343 | e_id: event_map[e_id] for e_id in required_auth if e_id in event_map | |
344 | }) | |
345 | missing_auth = required_auth - set(auth_events) | |
346 | failed_to_fetch = set() | |
347 | ||
348 | # Try and fetch any missing auth events from both DB and remote servers. | |
349 | # We repeatedly do this until we stop finding new auth events. | |
350 | while missing_auth - failed_to_fetch: | |
351 | logger.info("Missing auth for backfill: %r", missing_auth) | |
352 | ret_events = yield self.store.get_events(missing_auth - failed_to_fetch) | |
353 | auth_events.update(ret_events) | |
354 | ||
355 | required_auth.update( | |
356 | a_id for event in ret_events.values() for a_id, _ in event.auth_events | |
357 | ) | |
358 | missing_auth = required_auth - set(auth_events) | |
359 | ||
360 | if missing_auth - failed_to_fetch: | |
361 | logger.info( | |
362 | "Fetching missing auth for backfill: %r", | |
363 | missing_auth - failed_to_fetch | |
364 | ) | |
365 | ||
366 | results = yield defer.gatherResults( | |
367 | [ | |
368 | self.replication_layer.get_pdu( | |
369 | [dest], | |
370 | event_id, | |
371 | outlier=True, | |
372 | timeout=10000, | |
373 | ) | |
374 | for event_id in missing_auth - failed_to_fetch | |
375 | ], | |
376 | consumeErrors=True | |
377 | ).addErrback(unwrapFirstError) | |
378 | auth_events.update({a.event_id: a for a in results}) | |
379 | required_auth.update( | |
380 | a_id for event in results for a_id, _ in event.auth_events | |
381 | ) | |
382 | missing_auth = required_auth - set(auth_events) | |
383 | ||
384 | failed_to_fetch = missing_auth - set(auth_events) | |
385 | ||
337 | 386 | seen_events = yield self.store.have_events( |
338 | 387 | set(auth_events.keys()) | set(state_events.keys()) |
339 | 388 | ) |
340 | ||
341 | all_events = events + state_events.values() + auth_events.values() | |
342 | required_auth = set( | |
343 | a_id for event in all_events for a_id, _ in event.auth_events | |
344 | ) | |
345 | ||
346 | missing_auth = required_auth - set(auth_events) | |
347 | if missing_auth: | |
348 | logger.info("Missing auth for backfill: %r", missing_auth) | |
349 | results = yield defer.gatherResults( | |
350 | [ | |
351 | self.replication_layer.get_pdu( | |
352 | [dest], | |
353 | event_id, | |
354 | outlier=True, | |
355 | timeout=10000, | |
356 | ) | |
357 | for event_id in missing_auth | |
358 | ], | |
359 | consumeErrors=True | |
360 | ).addErrback(unwrapFirstError) | |
361 | auth_events.update({a.event_id: a for a in results}) | |
362 | 389 | |
363 | 390 | ev_infos = [] |
364 | 391 | for a in auth_events.values(): |
371 | 398 | (auth_events[a_id].type, auth_events[a_id].state_key): |
372 | 399 | auth_events[a_id] |
373 | 400 | for a_id, _ in a.auth_events |
401 | if a_id in auth_events | |
374 | 402 | } |
375 | 403 | }) |
376 | 404 | |
382 | 410 | (auth_events[a_id].type, auth_events[a_id].state_key): |
383 | 411 | auth_events[a_id] |
384 | 412 | for a_id, _ in event_map[e_id].auth_events |
413 | if a_id in auth_events | |
385 | 414 | } |
386 | 415 | }) |
387 | 416 | |
636 | 665 | pass |
637 | 666 | |
638 | 667 | event_stream_id, max_stream_id = yield self._persist_auth_tree( |
639 | auth_chain, state, event | |
668 | origin, auth_chain, state, event | |
640 | 669 | ) |
641 | 670 | |
642 | 671 | with PreserveLoggingContext(): |
687 | 716 | logger.warn("Failed to create join %r because %s", event, e) |
688 | 717 | raise e |
689 | 718 | |
690 | self.auth.check(event, auth_events=context.current_state) | |
719 | # The remote hasn't signed it yet, obviously. We'll do the full checks | |
720 | # when we get the event back in `on_send_join_request` | |
721 | self.auth.check(event, auth_events=context.current_state, do_sig_check=False) | |
691 | 722 | |
692 | 723 | defer.returnValue(event) |
693 | 724 | |
917 | 948 | ) |
918 | 949 | |
919 | 950 | try: |
920 | self.auth.check(event, auth_events=context.current_state) | |
951 | # The remote hasn't signed it yet, obviously. We'll do the full checks | |
952 | # when we get the event back in `on_send_leave_request` | |
953 | self.auth.check(event, auth_events=context.current_state, do_sig_check=False) | |
921 | 954 | except AuthError as e: |
922 | 955 | logger.warn("Failed to create new leave %r because %s", event, e) |
923 | 956 | raise e |
986 | 1019 | defer.returnValue(None) |
987 | 1020 | |
988 | 1021 | @defer.inlineCallbacks |
989 | def get_state_for_pdu(self, origin, room_id, event_id, do_auth=True): | |
1022 | def get_state_for_pdu(self, room_id, event_id): | |
990 | 1023 | yield run_on_reactor() |
991 | ||
992 | if do_auth: | |
993 | in_room = yield self.auth.check_host_in_room(room_id, origin) | |
994 | if not in_room: | |
995 | raise AuthError(403, "Host not in room.") | |
996 | 1024 | |
997 | 1025 | state_groups = yield self.store.get_state_groups( |
998 | 1026 | room_id, [event_id] |
1113 | 1141 | backfilled=backfilled, |
1114 | 1142 | ) |
1115 | 1143 | |
1116 | # this intentionally does not yield: we don't care about the result | |
1117 | # and don't need to wait for it. | |
1118 | preserve_fn(self.hs.get_pusherpool().on_new_notifications)( | |
1119 | event_stream_id, max_stream_id | |
1120 | ) | |
1144 | if not backfilled: | |
1145 | # this intentionally does not yield: we don't care about the result | |
1146 | # and don't need to wait for it. | |
1147 | preserve_fn(self.hs.get_pusherpool().on_new_notifications)( | |
1148 | event_stream_id, max_stream_id | |
1149 | ) | |
1121 | 1150 | |
1122 | 1151 | defer.returnValue((context, event_stream_id, max_stream_id)) |
1123 | 1152 | |
1149 | 1178 | ) |
1150 | 1179 | |
1151 | 1180 | @defer.inlineCallbacks |
1152 | def _persist_auth_tree(self, auth_events, state, event): | |
1181 | def _persist_auth_tree(self, origin, auth_events, state, event): | |
1153 | 1182 | """Checks the auth chain is valid (and passes auth checks) for the |
1154 | 1183 | state and event. Then persists the auth chain and state atomically. |
1155 | 1184 | Persists the event seperately. |
1185 | ||
1186 | Will attempt to fetch missing auth events. | |
1187 | ||
1188 | Args: | |
1189 | origin (str): Where the events came from | |
1190 | auth_events (list) | |
1191 | state (list) | |
1192 | event (Event) | |
1156 | 1193 | |
1157 | 1194 | Returns: |
1158 | 1195 | 2-tuple of (event_stream_id, max_stream_id) from the persist_event |
1166 | 1203 | |
1167 | 1204 | event_map = { |
1168 | 1205 | e.event_id: e |
1169 | for e in auth_events | |
1206 | for e in itertools.chain(auth_events, state, [event]) | |
1170 | 1207 | } |
1171 | 1208 | |
1172 | 1209 | create_event = None |
1175 | 1212 | create_event = e |
1176 | 1213 | break |
1177 | 1214 | |
1215 | missing_auth_events = set() | |
1216 | for e in itertools.chain(auth_events, state, [event]): | |
1217 | for e_id, _ in e.auth_events: | |
1218 | if e_id not in event_map: | |
1219 | missing_auth_events.add(e_id) | |
1220 | ||
1221 | for e_id in missing_auth_events: | |
1222 | m_ev = yield self.replication_layer.get_pdu( | |
1223 | [origin], | |
1224 | e_id, | |
1225 | outlier=True, | |
1226 | timeout=10000, | |
1227 | ) | |
1228 | if m_ev and m_ev.event_id == e_id: | |
1229 | event_map[e_id] = m_ev | |
1230 | else: | |
1231 | logger.info("Failed to find auth event %r", e_id) | |
1232 | ||
1178 | 1233 | for e in itertools.chain(auth_events, state, [event]): |
1179 | 1234 | auth_for_e = { |
1180 | 1235 | (event_map[e_id].type, event_map[e_id].state_key): event_map[e_id] |
1181 | 1236 | for e_id, _ in e.auth_events |
1237 | if e_id in event_map | |
1182 | 1238 | } |
1183 | 1239 | if create_event: |
1184 | 1240 | auth_for_e[(EventTypes.Create, "")] = create_event |
1412 | 1468 | local_view = dict(auth_events) |
1413 | 1469 | remote_view = dict(auth_events) |
1414 | 1470 | remote_view.update({ |
1415 | (d.type, d.state_key): d for d in different_events | |
1471 | (d.type, d.state_key): d for d in different_events if d | |
1416 | 1472 | }) |
1417 | 1473 | |
1418 | 1474 | new_state, prev_state = self.state_handler.resolve_events( |
20 | 20 | ) |
21 | 21 | from ._base import BaseHandler |
22 | 22 | from synapse.util.async import run_on_reactor |
23 | from synapse.api.errors import SynapseError | |
23 | from synapse.api.errors import SynapseError, Codes | |
24 | 24 | |
25 | 25 | import json |
26 | 26 | import logging |
40 | 40 | hs.config.use_insecure_ssl_client_just_for_testing_do_not_use |
41 | 41 | ) |
42 | 42 | |
43 | def _should_trust_id_server(self, id_server): | |
44 | if id_server not in self.trusted_id_servers: | |
45 | if self.trust_any_id_server_just_for_testing_do_not_use: | |
46 | logger.warn( | |
47 | "Trusting untrustworthy ID server %r even though it isn't" | |
48 | " in the trusted id list for testing because" | |
49 | " 'use_insecure_ssl_client_just_for_testing_do_not_use'" | |
50 | " is set in the config", | |
51 | id_server, | |
52 | ) | |
53 | else: | |
54 | return False | |
55 | return True | |
56 | ||
43 | 57 | @defer.inlineCallbacks |
44 | 58 | def threepid_from_creds(self, creds): |
45 | 59 | yield run_on_reactor() |
58 | 72 | else: |
59 | 73 | raise SynapseError(400, "No client_secret in creds") |
60 | 74 | |
61 | if id_server not in self.trusted_id_servers: | |
62 | if self.trust_any_id_server_just_for_testing_do_not_use: | |
63 | logger.warn( | |
64 | "Trusting untrustworthy ID server %r even though it isn't" | |
65 | " in the trusted id list for testing because" | |
66 | " 'use_insecure_ssl_client_just_for_testing_do_not_use'" | |
67 | " is set in the config", | |
68 | id_server, | |
69 | ) | |
70 | else: | |
71 | logger.warn('%s is not a trusted ID server: rejecting 3pid ' + | |
72 | 'credentials', id_server) | |
73 | defer.returnValue(None) | |
75 | if not self._should_trust_id_server(id_server): | |
76 | logger.warn( | |
77 | '%s is not a trusted ID server: rejecting 3pid ' + | |
78 | 'credentials', id_server | |
79 | ) | |
80 | defer.returnValue(None) | |
74 | 81 | |
75 | 82 | data = {} |
76 | 83 | try: |
128 | 135 | def requestEmailToken(self, id_server, email, client_secret, send_attempt, **kwargs): |
129 | 136 | yield run_on_reactor() |
130 | 137 | |
138 | if not self._should_trust_id_server(id_server): | |
139 | raise SynapseError( | |
140 | 400, "Untrusted ID server '%s'" % id_server, | |
141 | Codes.SERVER_NOT_TRUSTED | |
142 | ) | |
143 | ||
131 | 144 | params = { |
132 | 145 | 'email': email, |
133 | 146 | 'client_secret': client_secret, |
25 | 25 | UserID, RoomAlias, RoomStreamToken, StreamToken, get_domain_from_id |
26 | 26 | ) |
27 | 27 | from synapse.util import unwrapFirstError |
28 | from synapse.util.async import concurrently_execute, run_on_reactor | |
28 | from synapse.util.async import concurrently_execute, run_on_reactor, ReadWriteLock | |
29 | 29 | from synapse.util.caches.snapshot_cache import SnapshotCache |
30 | 30 | from synapse.util.logcontext import preserve_fn |
31 | 31 | from synapse.visibility import filter_events_for_client |
49 | 49 | self.validator = EventValidator() |
50 | 50 | self.snapshot_cache = SnapshotCache() |
51 | 51 | |
52 | self.pagination_lock = ReadWriteLock() | |
53 | ||
54 | @defer.inlineCallbacks | |
55 | def purge_history(self, room_id, event_id): | |
56 | event = yield self.store.get_event(event_id) | |
57 | ||
58 | if event.room_id != room_id: | |
59 | raise SynapseError(400, "Event is for wrong room.") | |
60 | ||
61 | depth = event.depth | |
62 | ||
63 | with (yield self.pagination_lock.write(room_id)): | |
64 | yield self.store.delete_old_state(room_id, depth) | |
65 | ||
52 | 66 | @defer.inlineCallbacks |
53 | 67 | def get_messages(self, requester, room_id=None, pagin_config=None, |
54 | as_client_event=True): | |
68 | as_client_event=True, event_filter=None): | |
55 | 69 | """Get messages in a room. |
56 | 70 | |
57 | 71 | Args: |
60 | 74 | pagin_config (synapse.api.streams.PaginationConfig): The pagination |
61 | 75 | config rules to apply, if any. |
62 | 76 | as_client_event (bool): True to get events in client-server format. |
77 | event_filter (Filter): Filter to apply to results or None | |
63 | 78 | Returns: |
64 | 79 | dict: Pagination API results |
65 | 80 | """ |
66 | 81 | user_id = requester.user.to_string() |
67 | data_source = self.hs.get_event_sources().sources["room"] | |
68 | 82 | |
69 | 83 | if pagin_config.from_token: |
70 | 84 | room_token = pagin_config.from_token.room_key |
84 | 98 | |
85 | 99 | source_config = pagin_config.get_source_config("room") |
86 | 100 | |
87 | membership, member_event_id = yield self._check_in_room_or_world_readable( | |
88 | room_id, user_id | |
89 | ) | |
90 | ||
91 | if source_config.direction == 'b': | |
92 | # if we're going backwards, we might need to backfill. This | |
93 | # requires that we have a topo token. | |
94 | if room_token.topological: | |
95 | max_topo = room_token.topological | |
96 | else: | |
97 | max_topo = yield self.store.get_max_topological_token_for_stream_and_room( | |
98 | room_id, room_token.stream | |
101 | with (yield self.pagination_lock.read(room_id)): | |
102 | membership, member_event_id = yield self._check_in_room_or_world_readable( | |
103 | room_id, user_id | |
104 | ) | |
105 | ||
106 | if source_config.direction == 'b': | |
107 | # if we're going backwards, we might need to backfill. This | |
108 | # requires that we have a topo token. | |
109 | if room_token.topological: | |
110 | max_topo = room_token.topological | |
111 | else: | |
112 | max_topo = yield self.store.get_max_topological_token( | |
113 | room_id, room_token.stream | |
114 | ) | |
115 | ||
116 | if membership == Membership.LEAVE: | |
117 | # If they have left the room then clamp the token to be before | |
118 | # they left the room, to save the effort of loading from the | |
119 | # database. | |
120 | leave_token = yield self.store.get_topological_token_for_event( | |
121 | member_event_id | |
122 | ) | |
123 | leave_token = RoomStreamToken.parse(leave_token) | |
124 | if leave_token.topological < max_topo: | |
125 | source_config.from_key = str(leave_token) | |
126 | ||
127 | yield self.hs.get_handlers().federation_handler.maybe_backfill( | |
128 | room_id, max_topo | |
99 | 129 | ) |
100 | 130 | |
101 | if membership == Membership.LEAVE: | |
102 | # If they have left the room then clamp the token to be before | |
103 | # they left the room, to save the effort of loading from the | |
104 | # database. | |
105 | leave_token = yield self.store.get_topological_token_for_event( | |
106 | member_event_id | |
107 | ) | |
108 | leave_token = RoomStreamToken.parse(leave_token) | |
109 | if leave_token.topological < max_topo: | |
110 | source_config.from_key = str(leave_token) | |
111 | ||
112 | yield self.hs.get_handlers().federation_handler.maybe_backfill( | |
113 | room_id, max_topo | |
114 | ) | |
115 | ||
116 | events, next_key = yield data_source.get_pagination_rows( | |
117 | requester.user, source_config, room_id | |
118 | ) | |
119 | ||
120 | next_token = pagin_config.from_token.copy_and_replace( | |
121 | "room_key", next_key | |
122 | ) | |
131 | events, next_key = yield self.store.paginate_room_events( | |
132 | room_id=room_id, | |
133 | from_key=source_config.from_key, | |
134 | to_key=source_config.to_key, | |
135 | direction=source_config.direction, | |
136 | limit=source_config.limit, | |
137 | event_filter=event_filter, | |
138 | ) | |
139 | ||
140 | next_token = pagin_config.from_token.copy_and_replace( | |
141 | "room_key", next_key | |
142 | ) | |
123 | 143 | |
124 | 144 | if not events: |
125 | 145 | defer.returnValue({ |
127 | 147 | "start": pagin_config.from_token.to_string(), |
128 | 148 | "end": next_token.to_string(), |
129 | 149 | }) |
150 | ||
151 | if event_filter: | |
152 | events = event_filter.filter(events) | |
130 | 153 | |
131 | 154 | events = yield filter_events_for_client( |
132 | 155 | self.store, |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import logging | |
16 | ||
15 | 17 | from twisted.internet import defer |
16 | 18 | |
19 | import synapse.types | |
17 | 20 | from synapse.api.errors import SynapseError, AuthError, CodeMessageException |
18 | from synapse.types import UserID, Requester | |
19 | ||
21 | from synapse.types import UserID | |
20 | 22 | from ._base import BaseHandler |
21 | ||
22 | import logging | |
23 | 23 | |
24 | 24 | |
25 | 25 | logger = logging.getLogger(__name__) |
164 | 164 | try: |
165 | 165 | # Assume the user isn't a guest because we don't let guests set |
166 | 166 | # profile or avatar data. |
167 | requester = Requester(user, "", False) | |
167 | # XXX why are we recreating `requester` here for each room? | |
168 | # what was wrong with the `requester` we were passed? | |
169 | requester = synapse.types.create_requester(user) | |
168 | 170 | yield handler.update_membership( |
169 | 171 | requester, |
170 | 172 | user, |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | """Contains functions for registering clients.""" |
16 | import logging | |
17 | import urllib | |
18 | ||
16 | 19 | from twisted.internet import defer |
17 | 20 | |
18 | from synapse.types import UserID, Requester | |
21 | import synapse.types | |
19 | 22 | from synapse.api.errors import ( |
20 | 23 | AuthError, Codes, SynapseError, RegistrationError, InvalidCaptchaError |
21 | 24 | ) |
25 | from synapse.http.client import CaptchaServerHttpClient | |
26 | from synapse.types import UserID | |
27 | from synapse.util.async import run_on_reactor | |
22 | 28 | from ._base import BaseHandler |
23 | from synapse.util.async import run_on_reactor | |
24 | from synapse.http.client import CaptchaServerHttpClient | |
25 | ||
26 | import logging | |
27 | import urllib | |
28 | 29 | |
29 | 30 | logger = logging.getLogger(__name__) |
30 | 31 | |
48 | 49 | raise SynapseError( |
49 | 50 | 400, |
50 | 51 | "User ID can only contain characters a-z, 0-9, or '_-./'", |
52 | Codes.INVALID_USERNAME | |
53 | ) | |
54 | ||
55 | if localpart[0] == '_': | |
56 | raise SynapseError( | |
57 | 400, | |
58 | "User ID may not begin with _", | |
51 | 59 | Codes.INVALID_USERNAME |
52 | 60 | ) |
53 | 61 | |
89 | 97 | password=None, |
90 | 98 | generate_token=True, |
91 | 99 | guest_access_token=None, |
92 | make_guest=False | |
100 | make_guest=False, | |
101 | admin=False, | |
93 | 102 | ): |
94 | 103 | """Registers a new client on the server. |
95 | 104 | |
97 | 106 | localpart : The local part of the user ID to register. If None, |
98 | 107 | one will be generated. |
99 | 108 | password (str) : The password to assign to this user so they can |
100 | login again. This can be None which means they cannot login again | |
101 | via a password (e.g. the user is an application service user). | |
109 | login again. This can be None which means they cannot login again | |
110 | via a password (e.g. the user is an application service user). | |
111 | generate_token (bool): Whether a new access token should be | |
112 | generated. Having this be True should be considered deprecated, | |
113 | since it offers no means of associating a device_id with the | |
114 | access_token. Instead you should call auth_handler.issue_access_token | |
115 | after registration. | |
102 | 116 | Returns: |
103 | 117 | A tuple of (user_id, access_token). |
104 | 118 | Raises: |
140 | 154 | # If the user was a guest then they already have a profile |
141 | 155 | None if was_guest else user.localpart |
142 | 156 | ), |
157 | admin=admin, | |
143 | 158 | ) |
144 | 159 | else: |
145 | 160 | # autogen a sequential user ID |
193 | 208 | user_id, allowed_appservice=service |
194 | 209 | ) |
195 | 210 | |
196 | token = self.auth_handler().generate_access_token(user_id) | |
197 | 211 | yield self.store.register( |
198 | 212 | user_id=user_id, |
199 | token=token, | |
200 | 213 | password_hash="", |
201 | 214 | appservice_id=service_id, |
202 | 215 | create_profile_with_localpart=user.localpart, |
203 | 216 | ) |
204 | defer.returnValue((user_id, token)) | |
217 | defer.returnValue(user_id) | |
205 | 218 | |
206 | 219 | @defer.inlineCallbacks |
207 | 220 | def check_recaptcha(self, ip, private_key, challenge, response): |
357 | 370 | defer.returnValue(data) |
358 | 371 | |
359 | 372 | @defer.inlineCallbacks |
360 | def get_or_create_user(self, localpart, displayname, duration_seconds): | |
373 | def get_or_create_user(self, localpart, displayname, duration_in_ms, | |
374 | password_hash=None): | |
361 | 375 | """Creates a new user if the user does not exist, |
362 | 376 | else revokes all previous access tokens and generates a new one. |
363 | 377 | |
386 | 400 | |
387 | 401 | user = UserID(localpart, self.hs.hostname) |
388 | 402 | user_id = user.to_string() |
389 | token = self.auth_handler().generate_short_term_login_token( | |
390 | user_id, duration_seconds) | |
403 | token = self.auth_handler().generate_access_token( | |
404 | user_id, None, duration_in_ms) | |
391 | 405 | |
392 | 406 | if need_register: |
393 | 407 | yield self.store.register( |
394 | 408 | user_id=user_id, |
395 | 409 | token=token, |
396 | password_hash=None, | |
410 | password_hash=password_hash, | |
397 | 411 | create_profile_with_localpart=user.localpart, |
398 | 412 | ) |
399 | 413 | else: |
403 | 417 | if displayname is not None: |
404 | 418 | logger.info("setting user display name: %s -> %s", user_id, displayname) |
405 | 419 | profile_handler = self.hs.get_handlers().profile_handler |
420 | requester = synapse.types.create_requester(user) | |
406 | 421 | yield profile_handler.set_displayname( |
407 | user, Requester(user, token, False), displayname | |
422 | user, requester, displayname | |
408 | 423 | ) |
409 | 424 | |
410 | 425 | defer.returnValue((user_id, token)) |
344 | 344 | class RoomListHandler(BaseHandler): |
345 | 345 | def __init__(self, hs): |
346 | 346 | super(RoomListHandler, self).__init__(hs) |
347 | self.response_cache = ResponseCache() | |
348 | self.remote_list_request_cache = ResponseCache() | |
347 | self.response_cache = ResponseCache(hs) | |
348 | self.remote_list_request_cache = ResponseCache(hs) | |
349 | 349 | self.remote_list_cache = {} |
350 | 350 | self.fetch_looping_call = hs.get_clock().looping_call( |
351 | 351 | self.fetch_all_remote_lists, REMOTE_ROOM_LIST_POLL_INTERVAL |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | |
16 | import logging | |
17 | ||
18 | from signedjson.key import decode_verify_key_bytes | |
19 | from signedjson.sign import verify_signed_json | |
16 | 20 | from twisted.internet import defer |
17 | ||
18 | from ._base import BaseHandler | |
19 | ||
20 | from synapse.types import UserID, RoomID, Requester | |
21 | from unpaddedbase64 import decode_base64 | |
22 | ||
23 | import synapse.types | |
21 | 24 | from synapse.api.constants import ( |
22 | 25 | EventTypes, Membership, |
23 | 26 | ) |
24 | 27 | from synapse.api.errors import AuthError, SynapseError, Codes |
28 | from synapse.types import UserID, RoomID | |
25 | 29 | from synapse.util.async import Linearizer |
26 | 30 | from synapse.util.distributor import user_left_room, user_joined_room |
27 | ||
28 | from signedjson.sign import verify_signed_json | |
29 | from signedjson.key import decode_verify_key_bytes | |
30 | ||
31 | from unpaddedbase64 import decode_base64 | |
32 | ||
33 | import logging | |
31 | from ._base import BaseHandler | |
34 | 32 | |
35 | 33 | logger = logging.getLogger(__name__) |
36 | 34 | |
314 | 312 | ) |
315 | 313 | assert self.hs.is_mine(sender), "Sender must be our own: %s" % (sender,) |
316 | 314 | else: |
317 | requester = Requester(target_user, None, False) | |
315 | requester = synapse.types.create_requester(target_user) | |
318 | 316 | |
319 | 317 | message_handler = self.hs.get_handlers().message_handler |
320 | 318 | prev_event = message_handler.deduplicate_state_event(event, context) |
137 | 137 | self.presence_handler = hs.get_presence_handler() |
138 | 138 | self.event_sources = hs.get_event_sources() |
139 | 139 | self.clock = hs.get_clock() |
140 | self.response_cache = ResponseCache() | |
140 | self.response_cache = ResponseCache(hs) | |
141 | 141 | |
142 | 142 | def wait_for_sync_for_user(self, sync_config, since_token=None, timeout=0, |
143 | 143 | full_state=False): |
204 | 204 | |
205 | 205 | def register_paths(self, method, path_patterns, callback): |
206 | 206 | for path_pattern in path_patterns: |
207 | logger.debug("Registering for %s %s", method, path_pattern.pattern) | |
207 | 208 | self.path_regexs.setdefault(method, []).append( |
208 | 209 | self._PathEntry(path_pattern, callback) |
209 | 210 | ) |
26 | 26 | from twisted.internet import reactor |
27 | 27 | |
28 | 28 | from .metric import ( |
29 | CounterMetric, CallbackMetric, DistributionMetric, CacheMetric | |
29 | CounterMetric, CallbackMetric, DistributionMetric, CacheMetric, | |
30 | MemoryUsageMetric, | |
30 | 31 | ) |
31 | 32 | |
32 | 33 | |
63 | 64 | |
64 | 65 | def register_cache(self, *args, **kwargs): |
65 | 66 | return self._register(CacheMetric, *args, **kwargs) |
67 | ||
68 | ||
69 | def register_memory_metrics(hs): | |
70 | try: | |
71 | import psutil | |
72 | process = psutil.Process() | |
73 | process.memory_info().rss | |
74 | except (ImportError, AttributeError): | |
75 | logger.warn( | |
76 | "psutil is not installed or incorrect version." | |
77 | " Disabling memory metrics." | |
78 | ) | |
79 | return | |
80 | metric = MemoryUsageMetric(hs, psutil) | |
81 | all_metrics.append(metric) | |
66 | 82 | |
67 | 83 | |
68 | 84 | def get_metrics_for(pkg_name): |
152 | 152 | """%s:total{name="%s"} %d""" % (self.name, self.cache_name, total), |
153 | 153 | """%s:size{name="%s"} %d""" % (self.name, self.cache_name, size), |
154 | 154 | ] |
155 | ||
156 | ||
157 | class MemoryUsageMetric(object): | |
158 | """Keeps track of the current memory usage, using psutil. | |
159 | ||
160 | The class will keep the current min/max/sum/counts of rss over the last | |
161 | WINDOW_SIZE_SEC, by polling UPDATE_HZ times per second | |
162 | """ | |
163 | ||
164 | UPDATE_HZ = 2 # number of times to get memory per second | |
165 | WINDOW_SIZE_SEC = 30 # the size of the window in seconds | |
166 | ||
167 | def __init__(self, hs, psutil): | |
168 | clock = hs.get_clock() | |
169 | self.memory_snapshots = [] | |
170 | ||
171 | self.process = psutil.Process() | |
172 | ||
173 | clock.looping_call(self._update_curr_values, 1000 / self.UPDATE_HZ) | |
174 | ||
175 | def _update_curr_values(self): | |
176 | max_size = self.UPDATE_HZ * self.WINDOW_SIZE_SEC | |
177 | self.memory_snapshots.append(self.process.memory_info().rss) | |
178 | self.memory_snapshots[:] = self.memory_snapshots[-max_size:] | |
179 | ||
180 | def render(self): | |
181 | if not self.memory_snapshots: | |
182 | return [] | |
183 | ||
184 | max_rss = max(self.memory_snapshots) | |
185 | min_rss = min(self.memory_snapshots) | |
186 | sum_rss = sum(self.memory_snapshots) | |
187 | len_rss = len(self.memory_snapshots) | |
188 | ||
189 | return [ | |
190 | "process_psutil_rss:max %d" % max_rss, | |
191 | "process_psutil_rss:min %d" % min_rss, | |
192 | "process_psutil_rss:total %d" % sum_rss, | |
193 | "process_psutil_rss:count %d" % len_rss, | |
194 | ] |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from twisted.internet import defer, reactor |
16 | from twisted.internet.error import AlreadyCalled, AlreadyCancelled | |
16 | 17 | |
17 | 18 | import logging |
18 | 19 | |
91 | 92 | |
92 | 93 | def on_stop(self): |
93 | 94 | if self.timed_call: |
94 | self.timed_call.cancel() | |
95 | try: | |
96 | self.timed_call.cancel() | |
97 | except (AlreadyCalled, AlreadyCancelled): | |
98 | pass | |
99 | self.timed_call = None | |
95 | 100 | |
96 | 101 | @defer.inlineCallbacks |
97 | 102 | def on_new_notifications(self, min_stream_ordering, max_stream_ordering): |
139 | 144 | being run. |
140 | 145 | """ |
141 | 146 | start = 0 if INCLUDE_ALL_UNREAD_NOTIFS else self.last_stream_ordering |
142 | unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( | |
143 | self.user_id, start, self.max_stream_ordering | |
144 | ) | |
147 | fn = self.store.get_unread_push_actions_for_user_in_range_for_email | |
148 | unprocessed = yield fn(self.user_id, start, self.max_stream_ordering) | |
145 | 149 | |
146 | 150 | soonest_due_at = None |
147 | 151 | |
189 | 193 | soonest_due_at = should_notify_at |
190 | 194 | |
191 | 195 | if self.timed_call is not None: |
192 | self.timed_call.cancel() | |
196 | try: | |
197 | self.timed_call.cancel() | |
198 | except (AlreadyCalled, AlreadyCancelled): | |
199 | pass | |
193 | 200 | self.timed_call = None |
194 | 201 | |
195 | 202 | if soonest_due_at is not None: |
15 | 15 | from synapse.push import PusherConfigException |
16 | 16 | |
17 | 17 | from twisted.internet import defer, reactor |
18 | from twisted.internet.error import AlreadyCalled, AlreadyCancelled | |
18 | 19 | |
19 | 20 | import logging |
20 | 21 | import push_rule_evaluator |
37 | 38 | self.hs = hs |
38 | 39 | self.store = self.hs.get_datastore() |
39 | 40 | self.clock = self.hs.get_clock() |
41 | self.state_handler = self.hs.get_state_handler() | |
40 | 42 | self.user_id = pusherdict['user_name'] |
41 | 43 | self.app_id = pusherdict['app_id'] |
42 | 44 | self.app_display_name = pusherdict['app_display_name'] |
107 | 109 | |
108 | 110 | def on_stop(self): |
109 | 111 | if self.timed_call: |
110 | self.timed_call.cancel() | |
112 | try: | |
113 | self.timed_call.cancel() | |
114 | except (AlreadyCalled, AlreadyCancelled): | |
115 | pass | |
116 | self.timed_call = None | |
111 | 117 | |
112 | 118 | @defer.inlineCallbacks |
113 | 119 | def _process(self): |
139 | 145 | run once per pusher. |
140 | 146 | """ |
141 | 147 | |
142 | unprocessed = yield self.store.get_unread_push_actions_for_user_in_range( | |
148 | fn = self.store.get_unread_push_actions_for_user_in_range_for_http | |
149 | unprocessed = yield fn( | |
143 | 150 | self.user_id, self.last_stream_ordering, self.max_stream_ordering |
144 | 151 | ) |
145 | 152 | |
236 | 243 | |
237 | 244 | @defer.inlineCallbacks |
238 | 245 | def _build_notification_dict(self, event, tweaks, badge): |
239 | ctx = yield push_tools.get_context_for_event(self.hs.get_datastore(), event) | |
246 | ctx = yield push_tools.get_context_for_event( | |
247 | self.state_handler, event, self.user_id | |
248 | ) | |
240 | 249 | |
241 | 250 | d = { |
242 | 251 | 'notification': { |
268 | 277 | if 'content' in event: |
269 | 278 | d['notification']['content'] = event.content |
270 | 279 | |
271 | if len(ctx['aliases']): | |
272 | d['notification']['room_alias'] = ctx['aliases'][0] | |
280 | # We no longer send aliases separately, instead, we send the human | |
281 | # readable name of the room, which may be an alias. | |
273 | 282 | if 'sender_display_name' in ctx and len(ctx['sender_display_name']) > 0: |
274 | 283 | d['notification']['sender_display_name'] = ctx['sender_display_name'] |
275 | 284 | if 'name' in ctx and len(ctx['name']) > 0: |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from twisted.internet import defer |
16 | from synapse.util.presentable_names import ( | |
17 | calculate_room_name, name_from_member_event | |
18 | ) | |
16 | 19 | |
17 | 20 | |
18 | 21 | @defer.inlineCallbacks |
44 | 47 | |
45 | 48 | |
46 | 49 | @defer.inlineCallbacks |
47 | def get_context_for_event(store, ev): | |
48 | name_aliases = yield store.get_room_name_and_aliases( | |
49 | ev.room_id | |
50 | def get_context_for_event(state_handler, ev, user_id): | |
51 | ctx = {} | |
52 | ||
53 | room_state = yield state_handler.get_current_state(ev.room_id) | |
54 | ||
55 | # we no longer bother setting room_alias, and make room_name the | |
56 | # human-readable name instead, be that m.room.name, an alias or | |
57 | # a list of people in the room | |
58 | name = calculate_room_name( | |
59 | room_state, user_id, fallback_to_single_member=False | |
50 | 60 | ) |
61 | if name: | |
62 | ctx['name'] = name | |
51 | 63 | |
52 | ctx = {'aliases': name_aliases[1]} | |
53 | if name_aliases[0] is not None: | |
54 | ctx['name'] = name_aliases[0] | |
55 | ||
56 | their_member_events_for_room = yield store.get_current_state( | |
57 | room_id=ev.room_id, | |
58 | event_type='m.room.member', | |
59 | state_key=ev.user_id | |
60 | ) | |
61 | for mev in their_member_events_for_room: | |
62 | if mev.content['membership'] == 'join' and 'displayname' in mev.content: | |
63 | dn = mev.content['displayname'] | |
64 | if dn is not None: | |
65 | ctx['sender_display_name'] = dn | |
64 | sender_state_event = room_state[("m.room.member", ev.sender)] | |
65 | ctx['sender_display_name'] = name_from_member_event(sender_state_event) | |
66 | 66 | |
67 | 67 | defer.returnValue(ctx) |
46 | 46 | "email.enable_notifs": { |
47 | 47 | "Jinja2>=2.8": ["Jinja2>=2.8"], |
48 | 48 | "bleach>=1.4.2": ["bleach>=1.4.2"], |
49 | }, | |
50 | "ldap": { | |
51 | "ldap3>=1.0": ["ldap3>=1.0"], | |
52 | }, | |
53 | "psutil": { | |
54 | "psutil>=2.0.0": ["psutil>=2.0.0"], | |
49 | 55 | }, |
50 | 56 | } |
51 | 57 |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from ._base import BaseSlavedStore | |
16 | from synapse.storage.directory import DirectoryStore | |
17 | ||
18 | ||
19 | class DirectoryStore(BaseSlavedStore): | |
20 | get_aliases_for_room = DirectoryStore.__dict__[ | |
21 | "get_aliases_for_room" | |
22 | ].orig |
17 | 17 | from synapse.api.constants import EventTypes |
18 | 18 | from synapse.events import FrozenEvent |
19 | 19 | from synapse.storage import DataStore |
20 | from synapse.storage.room import RoomStore | |
21 | 20 | from synapse.storage.roommember import RoomMemberStore |
22 | 21 | from synapse.storage.event_federation import EventFederationStore |
23 | 22 | from synapse.storage.event_push_actions import EventPushActionsStore |
63 | 62 | |
64 | 63 | # Cached functions can't be accessed through a class instance so we need |
65 | 64 | # to reach inside the __dict__ to extract them. |
66 | get_room_name_and_aliases = RoomStore.__dict__["get_room_name_and_aliases"] | |
67 | 65 | get_rooms_for_user = RoomMemberStore.__dict__["get_rooms_for_user"] |
68 | 66 | get_users_in_room = RoomMemberStore.__dict__["get_users_in_room"] |
69 | 67 | get_latest_event_ids_in_room = EventFederationStore.__dict__[ |
94 | 92 | StreamStore.__dict__["get_recent_event_ids_for_room"] |
95 | 93 | ) |
96 | 94 | |
97 | get_unread_push_actions_for_user_in_range = ( | |
98 | DataStore.get_unread_push_actions_for_user_in_range.__func__ | |
95 | get_unread_push_actions_for_user_in_range_for_http = ( | |
96 | DataStore.get_unread_push_actions_for_user_in_range_for_http.__func__ | |
97 | ) | |
98 | get_unread_push_actions_for_user_in_range_for_email = ( | |
99 | DataStore.get_unread_push_actions_for_user_in_range_for_email.__func__ | |
99 | 100 | ) |
100 | 101 | get_push_action_users_in_range = ( |
101 | 102 | DataStore.get_push_action_users_in_range.__func__ |
143 | 144 | _get_events_around_txn = DataStore._get_events_around_txn.__func__ |
144 | 145 | _get_some_state_from_cache = DataStore._get_some_state_from_cache.__func__ |
145 | 146 | |
147 | get_backfill_events = DataStore.get_backfill_events.__func__ | |
148 | _get_backfill_events = DataStore._get_backfill_events.__func__ | |
149 | get_missing_events = DataStore.get_missing_events.__func__ | |
150 | _get_missing_events = DataStore._get_missing_events.__func__ | |
151 | ||
152 | get_auth_chain = DataStore.get_auth_chain.__func__ | |
153 | get_auth_chain_ids = DataStore.get_auth_chain_ids.__func__ | |
154 | _get_auth_chain_ids_txn = DataStore._get_auth_chain_ids_txn.__func__ | |
155 | ||
146 | 156 | def stream_positions(self): |
147 | 157 | result = super(SlavedEventStore, self).stream_positions() |
148 | 158 | result["events"] = self._stream_id_gen.get_current_token() |
201 | 211 | self.get_rooms_for_user.invalidate_all() |
202 | 212 | self.get_users_in_room.invalidate((event.room_id,)) |
203 | 213 | # self.get_joined_hosts_for_room.invalidate((event.room_id,)) |
204 | self.get_room_name_and_aliases.invalidate((event.room_id,)) | |
205 | 214 | |
206 | 215 | self._invalidate_get_event_cache(event.event_id) |
207 | 216 | |
245 | 254 | self._get_current_state_for_key.invalidate(( |
246 | 255 | event.room_id, event.type, event.state_key |
247 | 256 | )) |
248 | ||
249 | if event.type in [EventTypes.Name, EventTypes.Aliases]: | |
250 | self.get_room_name_and_aliases.invalidate( | |
251 | (event.room_id,) | |
252 | ) | |
253 | pass |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from ._base import BaseSlavedStore | |
16 | from synapse.storage import DataStore | |
17 | from synapse.storage.keys import KeyStore | |
18 | ||
19 | ||
20 | class SlavedKeyStore(BaseSlavedStore): | |
21 | _get_server_verify_key = KeyStore.__dict__[ | |
22 | "_get_server_verify_key" | |
23 | ] | |
24 | ||
25 | get_server_verify_keys = DataStore.get_server_verify_keys.__func__ | |
26 | store_server_verify_key = DataStore.store_server_verify_key.__func__ | |
27 | ||
28 | get_server_certificate = DataStore.get_server_certificate.__func__ | |
29 | store_server_certificate = DataStore.store_server_certificate.__func__ | |
30 | ||
31 | get_server_keys_json = DataStore.get_server_keys_json.__func__ | |
32 | store_server_keys_json = DataStore.store_server_keys_json.__func__ |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from ._base import BaseSlavedStore | |
16 | from synapse.storage import DataStore | |
17 | ||
18 | ||
19 | class RoomStore(BaseSlavedStore): | |
20 | get_public_room_ids = DataStore.get_public_room_ids.__func__ |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | from ._base import BaseSlavedStore | |
17 | from synapse.storage import DataStore | |
18 | from synapse.storage.transactions import TransactionStore | |
19 | ||
20 | ||
21 | class TransactionStore(BaseSlavedStore): | |
22 | get_destination_retry_timings = TransactionStore.__dict__[ | |
23 | "get_destination_retry_timings" | |
24 | ].orig | |
25 | _get_destination_retry_timings = DataStore._get_destination_retry_timings.__func__ | |
26 | ||
27 | # For now, don't record the destination rety timings | |
28 | def set_destination_retry_timings(*args, **kwargs): | |
29 | return defer.succeed(None) |
45 | 45 | account_data, |
46 | 46 | report_event, |
47 | 47 | openid, |
48 | devices, | |
48 | 49 | ) |
49 | 50 | |
50 | 51 | from synapse.http.server import JsonResource |
89 | 90 | account_data.register_servlets(hs, client_resource) |
90 | 91 | report_event.register_servlets(hs, client_resource) |
91 | 92 | openid.register_servlets(hs, client_resource) |
93 | devices.register_servlets(hs, client_resource) |
45 | 45 | defer.returnValue((200, ret)) |
46 | 46 | |
47 | 47 | |
48 | class PurgeMediaCacheRestServlet(ClientV1RestServlet): | |
49 | PATTERNS = client_path_patterns("/admin/purge_media_cache") | |
50 | ||
51 | def __init__(self, hs): | |
52 | self.media_repository = hs.get_media_repository() | |
53 | super(PurgeMediaCacheRestServlet, self).__init__(hs) | |
54 | ||
55 | @defer.inlineCallbacks | |
56 | def on_POST(self, request): | |
57 | requester = yield self.auth.get_user_by_req(request) | |
58 | is_admin = yield self.auth.is_server_admin(requester.user) | |
59 | ||
60 | if not is_admin: | |
61 | raise AuthError(403, "You are not a server admin") | |
62 | ||
63 | before_ts = request.args.get("before_ts", None) | |
64 | if not before_ts: | |
65 | raise SynapseError(400, "Missing 'before_ts' arg") | |
66 | ||
67 | logger.info("before_ts: %r", before_ts[0]) | |
68 | ||
69 | try: | |
70 | before_ts = int(before_ts[0]) | |
71 | except Exception: | |
72 | raise SynapseError(400, "Invalid 'before_ts' arg") | |
73 | ||
74 | ret = yield self.media_repository.delete_old_remote_media(before_ts) | |
75 | ||
76 | defer.returnValue((200, ret)) | |
77 | ||
78 | ||
79 | class PurgeHistoryRestServlet(ClientV1RestServlet): | |
80 | PATTERNS = client_path_patterns( | |
81 | "/admin/purge_history/(?P<room_id>[^/]*)/(?P<event_id>[^/]*)" | |
82 | ) | |
83 | ||
84 | @defer.inlineCallbacks | |
85 | def on_POST(self, request, room_id, event_id): | |
86 | requester = yield self.auth.get_user_by_req(request) | |
87 | is_admin = yield self.auth.is_server_admin(requester.user) | |
88 | ||
89 | if not is_admin: | |
90 | raise AuthError(403, "You are not a server admin") | |
91 | ||
92 | yield self.handlers.message_handler.purge_history(room_id, event_id) | |
93 | ||
94 | defer.returnValue((200, {})) | |
95 | ||
96 | ||
97 | class DeactivateAccountRestServlet(ClientV1RestServlet): | |
98 | PATTERNS = client_path_patterns("/admin/deactivate/(?P<target_user_id>[^/]*)") | |
99 | ||
100 | def __init__(self, hs): | |
101 | self.store = hs.get_datastore() | |
102 | super(DeactivateAccountRestServlet, self).__init__(hs) | |
103 | ||
104 | @defer.inlineCallbacks | |
105 | def on_POST(self, request, target_user_id): | |
106 | UserID.from_string(target_user_id) | |
107 | requester = yield self.auth.get_user_by_req(request) | |
108 | is_admin = yield self.auth.is_server_admin(requester.user) | |
109 | ||
110 | if not is_admin: | |
111 | raise AuthError(403, "You are not a server admin") | |
112 | ||
113 | # FIXME: Theoretically there is a race here wherein user resets password | |
114 | # using threepid. | |
115 | yield self.store.user_delete_access_tokens(target_user_id) | |
116 | yield self.store.user_delete_threepids(target_user_id) | |
117 | yield self.store.user_set_password_hash(target_user_id, None) | |
118 | ||
119 | defer.returnValue((200, {})) | |
120 | ||
121 | ||
48 | 122 | def register_servlets(hs, http_server): |
49 | 123 | WhoisRestServlet(hs).register(http_server) |
124 | PurgeMediaCacheRestServlet(hs).register(http_server) | |
125 | DeactivateAccountRestServlet(hs).register(http_server) | |
126 | PurgeHistoryRestServlet(hs).register(http_server) |
51 | 51 | """ |
52 | 52 | |
53 | 53 | def __init__(self, hs): |
54 | """ | |
55 | Args: | |
56 | hs (synapse.server.HomeServer): | |
57 | """ | |
54 | 58 | self.hs = hs |
55 | 59 | self.handlers = hs.get_handlers() |
56 | 60 | self.builder_factory = hs.get_event_builder_factory() |
58 | 58 | self.servername = hs.config.server_name |
59 | 59 | self.http_client = hs.get_simple_http_client() |
60 | 60 | self.auth_handler = self.hs.get_auth_handler() |
61 | self.device_handler = self.hs.get_device_handler() | |
61 | 62 | |
62 | 63 | def on_GET(self, request): |
63 | 64 | flows = [] |
144 | 145 | ).to_string() |
145 | 146 | |
146 | 147 | auth_handler = self.auth_handler |
147 | user_id, access_token, refresh_token = yield auth_handler.login_with_password( | |
148 | user_id = yield auth_handler.validate_password_login( | |
148 | 149 | user_id=user_id, |
149 | password=login_submission["password"]) | |
150 | ||
150 | password=login_submission["password"], | |
151 | ) | |
152 | device_id = yield self._register_device(user_id, login_submission) | |
153 | access_token, refresh_token = ( | |
154 | yield auth_handler.get_login_tuple_for_user_id( | |
155 | user_id, device_id, | |
156 | login_submission.get("initial_device_display_name") | |
157 | ) | |
158 | ) | |
151 | 159 | result = { |
152 | 160 | "user_id": user_id, # may have changed |
153 | 161 | "access_token": access_token, |
154 | 162 | "refresh_token": refresh_token, |
155 | 163 | "home_server": self.hs.hostname, |
164 | "device_id": device_id, | |
156 | 165 | } |
157 | 166 | |
158 | 167 | defer.returnValue((200, result)) |
164 | 173 | user_id = ( |
165 | 174 | yield auth_handler.validate_short_term_login_token_and_get_user_id(token) |
166 | 175 | ) |
167 | user_id, access_token, refresh_token = ( | |
168 | yield auth_handler.get_login_tuple_for_user_id(user_id) | |
176 | device_id = yield self._register_device(user_id, login_submission) | |
177 | access_token, refresh_token = ( | |
178 | yield auth_handler.get_login_tuple_for_user_id( | |
179 | user_id, device_id, | |
180 | login_submission.get("initial_device_display_name") | |
181 | ) | |
169 | 182 | ) |
170 | 183 | result = { |
171 | 184 | "user_id": user_id, # may have changed |
172 | 185 | "access_token": access_token, |
173 | 186 | "refresh_token": refresh_token, |
174 | 187 | "home_server": self.hs.hostname, |
188 | "device_id": device_id, | |
175 | 189 | } |
176 | 190 | |
177 | 191 | defer.returnValue((200, result)) |
195 | 209 | |
196 | 210 | user_id = UserID.create(user, self.hs.hostname).to_string() |
197 | 211 | auth_handler = self.auth_handler |
198 | user_exists = yield auth_handler.does_user_exist(user_id) | |
199 | if user_exists: | |
200 | user_id, access_token, refresh_token = ( | |
201 | yield auth_handler.get_login_tuple_for_user_id(user_id) | |
212 | registered_user_id = yield auth_handler.check_user_exists(user_id) | |
213 | if registered_user_id: | |
214 | access_token, refresh_token = ( | |
215 | yield auth_handler.get_login_tuple_for_user_id( | |
216 | registered_user_id | |
217 | ) | |
202 | 218 | ) |
203 | 219 | result = { |
204 | "user_id": user_id, # may have changed | |
220 | "user_id": registered_user_id, # may have changed | |
205 | 221 | "access_token": access_token, |
206 | 222 | "refresh_token": refresh_token, |
207 | 223 | "home_server": self.hs.hostname, |
244 | 260 | |
245 | 261 | user_id = UserID.create(user, self.hs.hostname).to_string() |
246 | 262 | auth_handler = self.auth_handler |
247 | user_exists = yield auth_handler.does_user_exist(user_id) | |
248 | if user_exists: | |
249 | user_id, access_token, refresh_token = ( | |
250 | yield auth_handler.get_login_tuple_for_user_id(user_id) | |
263 | registered_user_id = yield auth_handler.check_user_exists(user_id) | |
264 | if registered_user_id: | |
265 | device_id = yield self._register_device( | |
266 | registered_user_id, login_submission | |
267 | ) | |
268 | access_token, refresh_token = ( | |
269 | yield auth_handler.get_login_tuple_for_user_id( | |
270 | registered_user_id, device_id, | |
271 | login_submission.get("initial_device_display_name") | |
272 | ) | |
251 | 273 | ) |
252 | 274 | result = { |
253 | "user_id": user_id, # may have changed | |
275 | "user_id": registered_user_id, | |
254 | 276 | "access_token": access_token, |
255 | 277 | "refresh_token": refresh_token, |
256 | 278 | "home_server": self.hs.hostname, |
257 | 279 | } |
258 | 280 | else: |
281 | # TODO: we should probably check that the register isn't going | |
282 | # to fonx/change our user_id before registering the device | |
283 | device_id = yield self._register_device(user_id, login_submission) | |
259 | 284 | user_id, access_token = ( |
260 | 285 | yield self.handlers.registration_handler.register(localpart=user) |
261 | 286 | ) |
293 | 318 | raise LoginError(401, "Invalid CAS response", errcode=Codes.UNAUTHORIZED) |
294 | 319 | |
295 | 320 | return (user, attributes) |
321 | ||
322 | def _register_device(self, user_id, login_submission): | |
323 | """Register a device for a user. | |
324 | ||
325 | This is called after the user's credentials have been validated, but | |
326 | before the access token has been issued. | |
327 | ||
328 | Args: | |
329 | (str) user_id: full canonical @user:id | |
330 | (object) login_submission: dictionary supplied to /login call, from | |
331 | which we pull device_id and initial_device_name | |
332 | Returns: | |
333 | defer.Deferred: (str) device_id | |
334 | """ | |
335 | device_id = login_submission.get("device_id") | |
336 | initial_display_name = login_submission.get( | |
337 | "initial_device_display_name") | |
338 | return self.device_handler.check_device_registered( | |
339 | user_id, device_id, initial_display_name | |
340 | ) | |
296 | 341 | |
297 | 342 | |
298 | 343 | class SAML2RestServlet(ClientV1RestServlet): |
413 | 458 | |
414 | 459 | user_id = UserID.create(user, self.hs.hostname).to_string() |
415 | 460 | auth_handler = self.auth_handler |
416 | user_exists = yield auth_handler.does_user_exist(user_id) | |
417 | if not user_exists: | |
418 | user_id, _ = ( | |
461 | registered_user_id = yield auth_handler.check_user_exists(user_id) | |
462 | if not registered_user_id: | |
463 | registered_user_id, _ = ( | |
419 | 464 | yield self.handlers.registration_handler.register(localpart=user) |
420 | 465 | ) |
421 | 466 | |
422 | login_token = auth_handler.generate_short_term_login_token(user_id) | |
467 | login_token = auth_handler.generate_short_term_login_token(registered_user_id) | |
423 | 468 | redirect_url = self.add_login_token_to_redirect_url(client_redirect_url, |
424 | 469 | login_token) |
425 | 470 | request.redirect(redirect_url) |
51 | 51 | PATTERNS = client_path_patterns("/register$", releases=(), include_in_unstable=False) |
52 | 52 | |
53 | 53 | def __init__(self, hs): |
54 | """ | |
55 | Args: | |
56 | hs (synapse.server.HomeServer): server | |
57 | """ | |
54 | 58 | super(RegisterRestServlet, self).__init__(hs) |
55 | 59 | # sessions are stored as: |
56 | 60 | # self.sessions = { |
59 | 63 | # TODO: persistent storage |
60 | 64 | self.sessions = {} |
61 | 65 | self.enable_registration = hs.config.enable_registration |
66 | self.auth_handler = hs.get_auth_handler() | |
62 | 67 | |
63 | 68 | def on_GET(self, request): |
64 | 69 | if self.hs.config.enable_registration_captcha: |
298 | 303 | user_localpart = register_json["user"].encode("utf-8") |
299 | 304 | |
300 | 305 | handler = self.handlers.registration_handler |
301 | (user_id, token) = yield handler.appservice_register( | |
306 | user_id = yield handler.appservice_register( | |
302 | 307 | user_localpart, as_token |
303 | 308 | ) |
309 | token = yield self.auth_handler.issue_access_token(user_id) | |
304 | 310 | self._remove_session(session) |
305 | 311 | defer.returnValue({ |
306 | 312 | "user_id": user_id, |
323 | 329 | raise SynapseError(400, "Shared secret registration is not enabled") |
324 | 330 | |
325 | 331 | user = register_json["user"].encode("utf-8") |
332 | password = register_json["password"].encode("utf-8") | |
333 | admin = register_json.get("admin", None) | |
334 | ||
335 | # Its important to check as we use null bytes as HMAC field separators | |
336 | if "\x00" in user: | |
337 | raise SynapseError(400, "Invalid user") | |
338 | if "\x00" in password: | |
339 | raise SynapseError(400, "Invalid password") | |
326 | 340 | |
327 | 341 | # str() because otherwise hmac complains that 'unicode' does not |
328 | 342 | # have the buffer interface |
330 | 344 | |
331 | 345 | want_mac = hmac.new( |
332 | 346 | key=self.hs.config.registration_shared_secret, |
333 | msg=user, | |
334 | 347 | digestmod=sha1, |
335 | ).hexdigest() | |
336 | ||
337 | password = register_json["password"].encode("utf-8") | |
348 | ) | |
349 | want_mac.update(user) | |
350 | want_mac.update("\x00") | |
351 | want_mac.update(password) | |
352 | want_mac.update("\x00") | |
353 | want_mac.update("admin" if admin else "notadmin") | |
354 | want_mac = want_mac.hexdigest() | |
338 | 355 | |
339 | 356 | if compare_digest(want_mac, got_mac): |
340 | 357 | handler = self.handlers.registration_handler |
341 | 358 | user_id, token = yield handler.register( |
342 | 359 | localpart=user, |
343 | 360 | password=password, |
361 | admin=bool(admin), | |
344 | 362 | ) |
345 | 363 | self._remove_session(session) |
346 | 364 | defer.returnValue({ |
409 | 427 | raise SynapseError(400, "Failed to parse 'duration_seconds'") |
410 | 428 | if duration_seconds > self.direct_user_creation_max_duration: |
411 | 429 | duration_seconds = self.direct_user_creation_max_duration |
430 | password_hash = user_json["password_hash"].encode("utf-8") \ | |
431 | if user_json.get("password_hash") else None | |
412 | 432 | |
413 | 433 | handler = self.handlers.registration_handler |
414 | 434 | user_id, token = yield handler.get_or_create_user( |
415 | 435 | localpart=localpart, |
416 | 436 | displayname=displayname, |
417 | duration_seconds=duration_seconds | |
437 | duration_in_ms=(duration_seconds * 1000), | |
438 | password_hash=password_hash | |
418 | 439 | ) |
419 | 440 | |
420 | 441 | defer.returnValue({ |
19 | 19 | from synapse.api.errors import SynapseError, Codes, AuthError |
20 | 20 | from synapse.streams.config import PaginationConfig |
21 | 21 | from synapse.api.constants import EventTypes, Membership |
22 | from synapse.api.filtering import Filter | |
22 | 23 | from synapse.types import UserID, RoomID, RoomAlias |
23 | 24 | from synapse.events.utils import serialize_event |
24 | 25 | from synapse.http.servlet import parse_json_object_from_request |
25 | 26 | |
26 | 27 | import logging |
27 | 28 | import urllib |
29 | import ujson as json | |
28 | 30 | |
29 | 31 | logger = logging.getLogger(__name__) |
30 | 32 | |
326 | 328 | request, default_limit=10, |
327 | 329 | ) |
328 | 330 | as_client_event = "raw" not in request.args |
331 | filter_bytes = request.args.get("filter", None) | |
332 | if filter_bytes: | |
333 | filter_json = urllib.unquote(filter_bytes[-1]).decode("UTF-8") | |
334 | event_filter = Filter(json.loads(filter_json)) | |
335 | else: | |
336 | event_filter = None | |
329 | 337 | handler = self.handlers.message_handler |
330 | 338 | msgs = yield handler.get_messages( |
331 | 339 | room_id=room_id, |
332 | 340 | requester=requester, |
333 | 341 | pagin_config=pagination_config, |
334 | as_client_event=as_client_event | |
342 | as_client_event=as_client_event, | |
343 | event_filter=event_filter, | |
335 | 344 | ) |
336 | 345 | |
337 | 346 | defer.returnValue((200, msgs)) |
24 | 24 | logger = logging.getLogger(__name__) |
25 | 25 | |
26 | 26 | |
27 | def client_v2_patterns(path_regex, releases=(0,)): | |
27 | def client_v2_patterns(path_regex, releases=(0,), | |
28 | v2_alpha=True, | |
29 | unstable=True): | |
28 | 30 | """Creates a regex compiled client path with the correct client path |
29 | 31 | prefix. |
30 | 32 | |
34 | 36 | Returns: |
35 | 37 | SRE_Pattern |
36 | 38 | """ |
37 | patterns = [re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)] | |
38 | unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") | |
39 | patterns.append(re.compile("^" + unstable_prefix + path_regex)) | |
39 | patterns = [] | |
40 | if v2_alpha: | |
41 | patterns.append(re.compile("^" + CLIENT_V2_ALPHA_PREFIX + path_regex)) | |
42 | if unstable: | |
43 | unstable_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/unstable") | |
44 | patterns.append(re.compile("^" + unstable_prefix + path_regex)) | |
40 | 45 | for release in releases: |
41 | 46 | new_prefix = CLIENT_V2_ALPHA_PREFIX.replace("/v2_alpha", "/r%d" % release) |
42 | 47 | patterns.append(re.compile("^" + new_prefix + path_regex)) |
27 | 27 | logger = logging.getLogger(__name__) |
28 | 28 | |
29 | 29 | |
30 | class PasswordRequestTokenRestServlet(RestServlet): | |
31 | PATTERNS = client_v2_patterns("/account/password/email/requestToken$") | |
32 | ||
33 | def __init__(self, hs): | |
34 | super(PasswordRequestTokenRestServlet, self).__init__() | |
35 | self.hs = hs | |
36 | self.identity_handler = hs.get_handlers().identity_handler | |
37 | ||
38 | @defer.inlineCallbacks | |
39 | def on_POST(self, request): | |
40 | body = parse_json_object_from_request(request) | |
41 | ||
42 | required = ['id_server', 'client_secret', 'email', 'send_attempt'] | |
43 | absent = [] | |
44 | for k in required: | |
45 | if k not in body: | |
46 | absent.append(k) | |
47 | ||
48 | if absent: | |
49 | raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | |
50 | ||
51 | existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | |
52 | 'email', body['email'] | |
53 | ) | |
54 | ||
55 | if existingUid is None: | |
56 | raise SynapseError(400, "Email not found", Codes.THREEPID_NOT_FOUND) | |
57 | ||
58 | ret = yield self.identity_handler.requestEmailToken(**body) | |
59 | defer.returnValue((200, ret)) | |
60 | ||
61 | ||
30 | 62 | class PasswordRestServlet(RestServlet): |
31 | PATTERNS = client_v2_patterns("/account/password") | |
63 | PATTERNS = client_v2_patterns("/account/password$") | |
32 | 64 | |
33 | 65 | def __init__(self, hs): |
34 | 66 | super(PasswordRestServlet, self).__init__() |
88 | 120 | return 200, {} |
89 | 121 | |
90 | 122 | |
123 | class DeactivateAccountRestServlet(RestServlet): | |
124 | PATTERNS = client_v2_patterns("/account/deactivate$") | |
125 | ||
126 | def __init__(self, hs): | |
127 | self.hs = hs | |
128 | self.store = hs.get_datastore() | |
129 | self.auth = hs.get_auth() | |
130 | self.auth_handler = hs.get_auth_handler() | |
131 | super(DeactivateAccountRestServlet, self).__init__() | |
132 | ||
133 | @defer.inlineCallbacks | |
134 | def on_POST(self, request): | |
135 | body = parse_json_object_from_request(request) | |
136 | ||
137 | authed, result, params, _ = yield self.auth_handler.check_auth([ | |
138 | [LoginType.PASSWORD], | |
139 | ], body, self.hs.get_ip_from_request(request)) | |
140 | ||
141 | if not authed: | |
142 | defer.returnValue((401, result)) | |
143 | ||
144 | user_id = None | |
145 | requester = None | |
146 | ||
147 | if LoginType.PASSWORD in result: | |
148 | # if using password, they should also be logged in | |
149 | requester = yield self.auth.get_user_by_req(request) | |
150 | user_id = requester.user.to_string() | |
151 | if user_id != result[LoginType.PASSWORD]: | |
152 | raise LoginError(400, "", Codes.UNKNOWN) | |
153 | else: | |
154 | logger.error("Auth succeeded but no known type!", result.keys()) | |
155 | raise SynapseError(500, "", Codes.UNKNOWN) | |
156 | ||
157 | # FIXME: Theoretically there is a race here wherein user resets password | |
158 | # using threepid. | |
159 | yield self.store.user_delete_access_tokens(user_id) | |
160 | yield self.store.user_delete_threepids(user_id) | |
161 | yield self.store.user_set_password_hash(user_id, None) | |
162 | ||
163 | defer.returnValue((200, {})) | |
164 | ||
165 | ||
166 | class ThreepidRequestTokenRestServlet(RestServlet): | |
167 | PATTERNS = client_v2_patterns("/account/3pid/email/requestToken$") | |
168 | ||
169 | def __init__(self, hs): | |
170 | self.hs = hs | |
171 | super(ThreepidRequestTokenRestServlet, self).__init__() | |
172 | self.identity_handler = hs.get_handlers().identity_handler | |
173 | ||
174 | @defer.inlineCallbacks | |
175 | def on_POST(self, request): | |
176 | body = parse_json_object_from_request(request) | |
177 | ||
178 | required = ['id_server', 'client_secret', 'email', 'send_attempt'] | |
179 | absent = [] | |
180 | for k in required: | |
181 | if k not in body: | |
182 | absent.append(k) | |
183 | ||
184 | if absent: | |
185 | raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | |
186 | ||
187 | existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | |
188 | 'email', body['email'] | |
189 | ) | |
190 | ||
191 | if existingUid is not None: | |
192 | raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) | |
193 | ||
194 | ret = yield self.identity_handler.requestEmailToken(**body) | |
195 | defer.returnValue((200, ret)) | |
196 | ||
197 | ||
91 | 198 | class ThreepidRestServlet(RestServlet): |
92 | PATTERNS = client_v2_patterns("/account/3pid") | |
199 | PATTERNS = client_v2_patterns("/account/3pid$") | |
93 | 200 | |
94 | 201 | def __init__(self, hs): |
95 | 202 | super(ThreepidRestServlet, self).__init__() |
156 | 263 | |
157 | 264 | |
158 | 265 | def register_servlets(hs, http_server): |
266 | PasswordRequestTokenRestServlet(hs).register(http_server) | |
159 | 267 | PasswordRestServlet(hs).register(http_server) |
268 | DeactivateAccountRestServlet(hs).register(http_server) | |
269 | ThreepidRequestTokenRestServlet(hs).register(http_server) | |
160 | 270 | ThreepidRestServlet(hs).register(http_server) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2015, 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import logging | |
16 | ||
17 | from twisted.internet import defer | |
18 | ||
19 | from synapse.http import servlet | |
20 | from ._base import client_v2_patterns | |
21 | ||
22 | logger = logging.getLogger(__name__) | |
23 | ||
24 | ||
25 | class DevicesRestServlet(servlet.RestServlet): | |
26 | PATTERNS = client_v2_patterns("/devices$", releases=[], v2_alpha=False) | |
27 | ||
28 | def __init__(self, hs): | |
29 | """ | |
30 | Args: | |
31 | hs (synapse.server.HomeServer): server | |
32 | """ | |
33 | super(DevicesRestServlet, self).__init__() | |
34 | self.hs = hs | |
35 | self.auth = hs.get_auth() | |
36 | self.device_handler = hs.get_device_handler() | |
37 | ||
38 | @defer.inlineCallbacks | |
39 | def on_GET(self, request): | |
40 | requester = yield self.auth.get_user_by_req(request) | |
41 | devices = yield self.device_handler.get_devices_by_user( | |
42 | requester.user.to_string() | |
43 | ) | |
44 | defer.returnValue((200, {"devices": devices})) | |
45 | ||
46 | ||
47 | class DeviceRestServlet(servlet.RestServlet): | |
48 | PATTERNS = client_v2_patterns("/devices/(?P<device_id>[^/]*)$", | |
49 | releases=[], v2_alpha=False) | |
50 | ||
51 | def __init__(self, hs): | |
52 | """ | |
53 | Args: | |
54 | hs (synapse.server.HomeServer): server | |
55 | """ | |
56 | super(DeviceRestServlet, self).__init__() | |
57 | self.hs = hs | |
58 | self.auth = hs.get_auth() | |
59 | self.device_handler = hs.get_device_handler() | |
60 | ||
61 | @defer.inlineCallbacks | |
62 | def on_GET(self, request, device_id): | |
63 | requester = yield self.auth.get_user_by_req(request) | |
64 | device = yield self.device_handler.get_device( | |
65 | requester.user.to_string(), | |
66 | device_id, | |
67 | ) | |
68 | defer.returnValue((200, device)) | |
69 | ||
70 | @defer.inlineCallbacks | |
71 | def on_DELETE(self, request, device_id): | |
72 | # XXX: it's not completely obvious we want to expose this endpoint. | |
73 | # It allows the client to delete access tokens, which feels like a | |
74 | # thing which merits extra auth. But if we want to do the interactive- | |
75 | # auth dance, we should really make it possible to delete more than one | |
76 | # device at a time. | |
77 | requester = yield self.auth.get_user_by_req(request) | |
78 | yield self.device_handler.delete_device( | |
79 | requester.user.to_string(), | |
80 | device_id, | |
81 | ) | |
82 | defer.returnValue((200, {})) | |
83 | ||
84 | @defer.inlineCallbacks | |
85 | def on_PUT(self, request, device_id): | |
86 | requester = yield self.auth.get_user_by_req(request) | |
87 | ||
88 | body = servlet.parse_json_object_from_request(request) | |
89 | yield self.device_handler.update_device( | |
90 | requester.user.to_string(), | |
91 | device_id, | |
92 | body | |
93 | ) | |
94 | defer.returnValue((200, {})) | |
95 | ||
96 | ||
97 | def register_servlets(hs, http_server): | |
98 | DevicesRestServlet(hs).register(http_server) | |
99 | DeviceRestServlet(hs).register(http_server) |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import logging | |
16 | ||
17 | import simplejson as json | |
18 | from canonicaljson import encode_canonical_json | |
15 | 19 | from twisted.internet import defer |
16 | 20 | |
21 | import synapse.api.errors | |
22 | import synapse.server | |
23 | import synapse.types | |
17 | 24 | from synapse.http.servlet import RestServlet, parse_json_object_from_request |
18 | 25 | from synapse.types import UserID |
19 | ||
20 | from canonicaljson import encode_canonical_json | |
21 | ||
22 | 26 | from ._base import client_v2_patterns |
23 | 27 | |
24 | import logging | |
25 | import simplejson as json | |
26 | ||
27 | 28 | logger = logging.getLogger(__name__) |
28 | 29 | |
29 | 30 | |
30 | 31 | class KeyUploadServlet(RestServlet): |
31 | 32 | """ |
32 | POST /keys/upload/<device_id> HTTP/1.1 | |
33 | POST /keys/upload HTTP/1.1 | |
33 | 34 | Content-Type: application/json |
34 | 35 | |
35 | 36 | { |
52 | 53 | }, |
53 | 54 | } |
54 | 55 | """ |
55 | PATTERNS = client_v2_patterns("/keys/upload/(?P<device_id>[^/]*)", releases=()) | |
56 | PATTERNS = client_v2_patterns("/keys/upload(/(?P<device_id>[^/]+))?$", | |
57 | releases=()) | |
56 | 58 | |
57 | 59 | def __init__(self, hs): |
60 | """ | |
61 | Args: | |
62 | hs (synapse.server.HomeServer): server | |
63 | """ | |
58 | 64 | super(KeyUploadServlet, self).__init__() |
59 | 65 | self.store = hs.get_datastore() |
60 | 66 | self.clock = hs.get_clock() |
61 | 67 | self.auth = hs.get_auth() |
68 | self.device_handler = hs.get_device_handler() | |
62 | 69 | |
63 | 70 | @defer.inlineCallbacks |
64 | 71 | def on_POST(self, request, device_id): |
65 | 72 | requester = yield self.auth.get_user_by_req(request) |
73 | ||
66 | 74 | user_id = requester.user.to_string() |
67 | # TODO: Check that the device_id matches that in the authentication | |
68 | # or derive the device_id from the authentication instead. | |
69 | 75 | |
70 | 76 | body = parse_json_object_from_request(request) |
77 | ||
78 | if device_id is not None: | |
79 | # passing the device_id here is deprecated; however, we allow it | |
80 | # for now for compatibility with older clients. | |
81 | if (requester.device_id is not None and | |
82 | device_id != requester.device_id): | |
83 | logger.warning("Client uploading keys for a different device " | |
84 | "(logged in as %s, uploading for %s)", | |
85 | requester.device_id, device_id) | |
86 | else: | |
87 | device_id = requester.device_id | |
88 | ||
89 | if device_id is None: | |
90 | raise synapse.api.errors.SynapseError( | |
91 | 400, | |
92 | "To upload keys, you must pass device_id when authenticating" | |
93 | ) | |
71 | 94 | |
72 | 95 | time_now = self.clock.time_msec() |
73 | 96 | |
101 | 124 | user_id, device_id, time_now, key_list |
102 | 125 | ) |
103 | 126 | |
104 | result = yield self.store.count_e2e_one_time_keys(user_id, device_id) | |
105 | defer.returnValue((200, {"one_time_key_counts": result})) | |
106 | ||
107 | @defer.inlineCallbacks | |
108 | def on_GET(self, request, device_id): | |
109 | requester = yield self.auth.get_user_by_req(request) | |
110 | user_id = requester.user.to_string() | |
127 | # the device should have been registered already, but it may have been | |
128 | # deleted due to a race with a DELETE request. Or we may be using an | |
129 | # old access_token without an associated device_id. Either way, we | |
130 | # need to double-check the device is registered to avoid ending up with | |
131 | # keys without a corresponding device. | |
132 | self.device_handler.check_device_registered(user_id, device_id) | |
111 | 133 | |
112 | 134 | result = yield self.store.count_e2e_one_time_keys(user_id, device_id) |
113 | 135 | defer.returnValue((200, {"one_time_key_counts": result})) |
161 | 183 | ) |
162 | 184 | |
163 | 185 | def __init__(self, hs): |
186 | """ | |
187 | Args: | |
188 | hs (synapse.server.HomeServer): | |
189 | """ | |
164 | 190 | super(KeyQueryServlet, self).__init__() |
165 | self.store = hs.get_datastore() | |
166 | 191 | self.auth = hs.get_auth() |
167 | self.federation = hs.get_replication_layer() | |
168 | self.is_mine = hs.is_mine | |
192 | self.e2e_keys_handler = hs.get_e2e_keys_handler() | |
169 | 193 | |
170 | 194 | @defer.inlineCallbacks |
171 | 195 | def on_POST(self, request, user_id, device_id): |
172 | 196 | yield self.auth.get_user_by_req(request) |
173 | 197 | body = parse_json_object_from_request(request) |
174 | result = yield self.handle_request(body) | |
198 | result = yield self.e2e_keys_handler.query_devices(body) | |
175 | 199 | defer.returnValue(result) |
176 | 200 | |
177 | 201 | @defer.inlineCallbacks |
180 | 204 | auth_user_id = requester.user.to_string() |
181 | 205 | user_id = user_id if user_id else auth_user_id |
182 | 206 | device_ids = [device_id] if device_id else [] |
183 | result = yield self.handle_request( | |
207 | result = yield self.e2e_keys_handler.query_devices( | |
184 | 208 | {"device_keys": {user_id: device_ids}} |
185 | 209 | ) |
186 | 210 | defer.returnValue(result) |
187 | ||
188 | @defer.inlineCallbacks | |
189 | def handle_request(self, body): | |
190 | local_query = [] | |
191 | remote_queries = {} | |
192 | for user_id, device_ids in body.get("device_keys", {}).items(): | |
193 | user = UserID.from_string(user_id) | |
194 | if self.is_mine(user): | |
195 | if not device_ids: | |
196 | local_query.append((user_id, None)) | |
197 | else: | |
198 | for device_id in device_ids: | |
199 | local_query.append((user_id, device_id)) | |
200 | else: | |
201 | remote_queries.setdefault(user.domain, {})[user_id] = list( | |
202 | device_ids | |
203 | ) | |
204 | results = yield self.store.get_e2e_device_keys(local_query) | |
205 | ||
206 | json_result = {} | |
207 | for user_id, device_keys in results.items(): | |
208 | for device_id, json_bytes in device_keys.items(): | |
209 | json_result.setdefault(user_id, {})[device_id] = json.loads( | |
210 | json_bytes | |
211 | ) | |
212 | ||
213 | for destination, device_keys in remote_queries.items(): | |
214 | remote_result = yield self.federation.query_client_keys( | |
215 | destination, {"device_keys": device_keys} | |
216 | ) | |
217 | for user_id, keys in remote_result["device_keys"].items(): | |
218 | if user_id in device_keys: | |
219 | json_result[user_id] = keys | |
220 | defer.returnValue((200, {"device_keys": json_result})) | |
221 | 211 | |
222 | 212 | |
223 | 213 | class OneTimeKeyServlet(RestServlet): |
40 | 40 | logger = logging.getLogger(__name__) |
41 | 41 | |
42 | 42 | |
43 | class RegisterRequestTokenRestServlet(RestServlet): | |
44 | PATTERNS = client_v2_patterns("/register/email/requestToken$") | |
45 | ||
46 | def __init__(self, hs): | |
47 | """ | |
48 | Args: | |
49 | hs (synapse.server.HomeServer): server | |
50 | """ | |
51 | super(RegisterRequestTokenRestServlet, self).__init__() | |
52 | self.hs = hs | |
53 | self.identity_handler = hs.get_handlers().identity_handler | |
54 | ||
55 | @defer.inlineCallbacks | |
56 | def on_POST(self, request): | |
57 | body = parse_json_object_from_request(request) | |
58 | ||
59 | required = ['id_server', 'client_secret', 'email', 'send_attempt'] | |
60 | absent = [] | |
61 | for k in required: | |
62 | if k not in body: | |
63 | absent.append(k) | |
64 | ||
65 | if len(absent) > 0: | |
66 | raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | |
67 | ||
68 | existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | |
69 | 'email', body['email'] | |
70 | ) | |
71 | ||
72 | if existingUid is not None: | |
73 | raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) | |
74 | ||
75 | ret = yield self.identity_handler.requestEmailToken(**body) | |
76 | defer.returnValue((200, ret)) | |
77 | ||
78 | ||
43 | 79 | class RegisterRestServlet(RestServlet): |
44 | PATTERNS = client_v2_patterns("/register") | |
80 | PATTERNS = client_v2_patterns("/register$") | |
45 | 81 | |
46 | 82 | def __init__(self, hs): |
83 | """ | |
84 | Args: | |
85 | hs (synapse.server.HomeServer): server | |
86 | """ | |
47 | 87 | super(RegisterRestServlet, self).__init__() |
88 | ||
48 | 89 | self.hs = hs |
49 | 90 | self.auth = hs.get_auth() |
50 | 91 | self.store = hs.get_datastore() |
51 | 92 | self.auth_handler = hs.get_auth_handler() |
52 | 93 | self.registration_handler = hs.get_handlers().registration_handler |
53 | 94 | self.identity_handler = hs.get_handlers().identity_handler |
95 | self.device_handler = hs.get_device_handler() | |
54 | 96 | |
55 | 97 | @defer.inlineCallbacks |
56 | 98 | def on_POST(self, request): |
68 | 110 | raise UnrecognizedRequestError( |
69 | 111 | "Do not understand membership kind: %s" % (kind,) |
70 | 112 | ) |
71 | ||
72 | if '/register/email/requestToken' in request.path: | |
73 | ret = yield self.onEmailTokenRequest(request) | |
74 | defer.returnValue(ret) | |
75 | 113 | |
76 | 114 | body = parse_json_object_from_request(request) |
77 | 115 | |
103 | 141 | # Set the desired user according to the AS API (which uses the |
104 | 142 | # 'user' key not 'username'). Since this is a new addition, we'll |
105 | 143 | # fallback to 'username' if they gave one. |
106 | if isinstance(body.get("user"), basestring): | |
107 | desired_username = body["user"] | |
108 | result = yield self._do_appservice_registration( | |
109 | desired_username, request.args["access_token"][0] | |
110 | ) | |
144 | desired_username = body.get("user", desired_username) | |
145 | ||
146 | if isinstance(desired_username, basestring): | |
147 | result = yield self._do_appservice_registration( | |
148 | desired_username, request.args["access_token"][0], body | |
149 | ) | |
111 | 150 | defer.returnValue((200, result)) # we throw for non 200 responses |
112 | 151 | return |
113 | 152 | |
116 | 155 | # FIXME: Should we really be determining if this is shared secret |
117 | 156 | # auth based purely on the 'mac' key? |
118 | 157 | result = yield self._do_shared_secret_registration( |
119 | desired_username, desired_password, body["mac"] | |
158 | desired_username, desired_password, body | |
120 | 159 | ) |
121 | 160 | defer.returnValue((200, result)) # we throw for non 200 responses |
122 | 161 | return |
156 | 195 | [LoginType.EMAIL_IDENTITY] |
157 | 196 | ] |
158 | 197 | |
159 | authed, result, params, session_id = yield self.auth_handler.check_auth( | |
198 | authed, auth_result, params, session_id = yield self.auth_handler.check_auth( | |
160 | 199 | flows, body, self.hs.get_ip_from_request(request) |
161 | 200 | ) |
162 | 201 | |
163 | 202 | if not authed: |
164 | defer.returnValue((401, result)) | |
203 | defer.returnValue((401, auth_result)) | |
165 | 204 | return |
166 | 205 | |
167 | 206 | if registered_user_id is not None: |
169 | 208 | "Already registered user ID %r for this session", |
170 | 209 | registered_user_id |
171 | 210 | ) |
172 | access_token = yield self.auth_handler.issue_access_token(registered_user_id) | |
173 | refresh_token = yield self.auth_handler.issue_refresh_token( | |
174 | registered_user_id | |
175 | ) | |
176 | defer.returnValue((200, { | |
177 | "user_id": registered_user_id, | |
178 | "access_token": access_token, | |
179 | "home_server": self.hs.hostname, | |
180 | "refresh_token": refresh_token, | |
181 | })) | |
182 | ||
183 | # NB: This may be from the auth handler and NOT from the POST | |
184 | if 'password' not in params: | |
185 | raise SynapseError(400, "Missing password.", Codes.MISSING_PARAM) | |
186 | ||
187 | desired_username = params.get("username", None) | |
188 | new_password = params.get("password", None) | |
189 | guest_access_token = params.get("guest_access_token", None) | |
190 | ||
191 | (user_id, token) = yield self.registration_handler.register( | |
192 | localpart=desired_username, | |
193 | password=new_password, | |
194 | guest_access_token=guest_access_token, | |
195 | ) | |
196 | ||
197 | # remember that we've now registered that user account, and with what | |
198 | # user ID (since the user may not have specified) | |
199 | self.auth_handler.set_session_data( | |
200 | session_id, "registered_user_id", user_id | |
201 | ) | |
202 | ||
203 | if result and LoginType.EMAIL_IDENTITY in result: | |
204 | threepid = result[LoginType.EMAIL_IDENTITY] | |
205 | ||
206 | for reqd in ['medium', 'address', 'validated_at']: | |
207 | if reqd not in threepid: | |
208 | logger.info("Can't add incomplete 3pid") | |
209 | else: | |
210 | yield self.auth_handler.add_threepid( | |
211 | user_id, | |
212 | threepid['medium'], | |
213 | threepid['address'], | |
214 | threepid['validated_at'], | |
215 | ) | |
216 | ||
217 | # And we add an email pusher for them by default, but only | |
218 | # if email notifications are enabled (so people don't start | |
219 | # getting mail spam where they weren't before if email | |
220 | # notifs are set up on a home server) | |
221 | if ( | |
222 | self.hs.config.email_enable_notifs and | |
223 | self.hs.config.email_notif_for_new_users | |
224 | ): | |
225 | # Pull the ID of the access token back out of the db | |
226 | # It would really make more sense for this to be passed | |
227 | # up when the access token is saved, but that's quite an | |
228 | # invasive change I'd rather do separately. | |
229 | user_tuple = yield self.store.get_user_by_access_token( | |
230 | token | |
231 | ) | |
232 | ||
233 | yield self.hs.get_pusherpool().add_pusher( | |
234 | user_id=user_id, | |
235 | access_token=user_tuple["token_id"], | |
236 | kind="email", | |
237 | app_id="m.email", | |
238 | app_display_name="Email Notifications", | |
239 | device_display_name=threepid["address"], | |
240 | pushkey=threepid["address"], | |
241 | lang=None, # We don't know a user's language here | |
242 | data={}, | |
243 | ) | |
244 | ||
245 | if 'bind_email' in params and params['bind_email']: | |
246 | logger.info("bind_email specified: binding") | |
247 | ||
248 | emailThreepid = result[LoginType.EMAIL_IDENTITY] | |
249 | threepid_creds = emailThreepid['threepid_creds'] | |
250 | logger.debug("Binding emails %s to %s" % ( | |
251 | emailThreepid, user_id | |
252 | )) | |
253 | yield self.identity_handler.bind_threepid(threepid_creds, user_id) | |
254 | else: | |
255 | logger.info("bind_email not specified: not binding email") | |
256 | ||
257 | result = yield self._create_registration_details(user_id, token) | |
258 | defer.returnValue((200, result)) | |
211 | # don't re-register the email address | |
212 | add_email = False | |
213 | else: | |
214 | # NB: This may be from the auth handler and NOT from the POST | |
215 | if 'password' not in params: | |
216 | raise SynapseError(400, "Missing password.", | |
217 | Codes.MISSING_PARAM) | |
218 | ||
219 | desired_username = params.get("username", None) | |
220 | new_password = params.get("password", None) | |
221 | guest_access_token = params.get("guest_access_token", None) | |
222 | ||
223 | (registered_user_id, _) = yield self.registration_handler.register( | |
224 | localpart=desired_username, | |
225 | password=new_password, | |
226 | guest_access_token=guest_access_token, | |
227 | generate_token=False, | |
228 | ) | |
229 | ||
230 | # remember that we've now registered that user account, and with | |
231 | # what user ID (since the user may not have specified) | |
232 | self.auth_handler.set_session_data( | |
233 | session_id, "registered_user_id", registered_user_id | |
234 | ) | |
235 | ||
236 | add_email = True | |
237 | ||
238 | return_dict = yield self._create_registration_details( | |
239 | registered_user_id, params | |
240 | ) | |
241 | ||
242 | if add_email and auth_result and LoginType.EMAIL_IDENTITY in auth_result: | |
243 | threepid = auth_result[LoginType.EMAIL_IDENTITY] | |
244 | yield self._register_email_threepid( | |
245 | registered_user_id, threepid, return_dict["access_token"], | |
246 | params.get("bind_email") | |
247 | ) | |
248 | ||
249 | defer.returnValue((200, return_dict)) | |
259 | 250 | |
260 | 251 | def on_OPTIONS(self, _): |
261 | 252 | return 200, {} |
262 | 253 | |
263 | 254 | @defer.inlineCallbacks |
264 | def _do_appservice_registration(self, username, as_token): | |
265 | (user_id, token) = yield self.registration_handler.appservice_register( | |
255 | def _do_appservice_registration(self, username, as_token, body): | |
256 | user_id = yield self.registration_handler.appservice_register( | |
266 | 257 | username, as_token |
267 | 258 | ) |
268 | defer.returnValue((yield self._create_registration_details(user_id, token))) | |
269 | ||
270 | @defer.inlineCallbacks | |
271 | def _do_shared_secret_registration(self, username, password, mac): | |
259 | defer.returnValue((yield self._create_registration_details(user_id, body))) | |
260 | ||
261 | @defer.inlineCallbacks | |
262 | def _do_shared_secret_registration(self, username, password, body): | |
272 | 263 | if not self.hs.config.registration_shared_secret: |
273 | 264 | raise SynapseError(400, "Shared secret registration is not enabled") |
274 | 265 | |
276 | 267 | |
277 | 268 | # str() because otherwise hmac complains that 'unicode' does not |
278 | 269 | # have the buffer interface |
279 | got_mac = str(mac) | |
270 | got_mac = str(body["mac"]) | |
280 | 271 | |
281 | 272 | want_mac = hmac.new( |
282 | 273 | key=self.hs.config.registration_shared_secret, |
289 | 280 | 403, "HMAC incorrect", |
290 | 281 | ) |
291 | 282 | |
292 | (user_id, token) = yield self.registration_handler.register( | |
293 | localpart=username, password=password | |
294 | ) | |
295 | defer.returnValue((yield self._create_registration_details(user_id, token))) | |
296 | ||
297 | @defer.inlineCallbacks | |
298 | def _create_registration_details(self, user_id, token): | |
299 | refresh_token = yield self.auth_handler.issue_refresh_token(user_id) | |
283 | (user_id, _) = yield self.registration_handler.register( | |
284 | localpart=username, password=password, generate_token=False, | |
285 | ) | |
286 | ||
287 | result = yield self._create_registration_details(user_id, body) | |
288 | defer.returnValue(result) | |
289 | ||
290 | @defer.inlineCallbacks | |
291 | def _register_email_threepid(self, user_id, threepid, token, bind_email): | |
292 | """Add an email address as a 3pid identifier | |
293 | ||
294 | Also adds an email pusher for the email address, if configured in the | |
295 | HS config | |
296 | ||
297 | Also optionally binds emails to the given user_id on the identity server | |
298 | ||
299 | Args: | |
300 | user_id (str): id of user | |
301 | threepid (object): m.login.email.identity auth response | |
302 | token (str): access_token for the user | |
303 | bind_email (bool): true if the client requested the email to be | |
304 | bound at the identity server | |
305 | Returns: | |
306 | defer.Deferred: | |
307 | """ | |
308 | reqd = ('medium', 'address', 'validated_at') | |
309 | if any(x not in threepid for x in reqd): | |
310 | logger.info("Can't add incomplete 3pid") | |
311 | defer.returnValue() | |
312 | ||
313 | yield self.auth_handler.add_threepid( | |
314 | user_id, | |
315 | threepid['medium'], | |
316 | threepid['address'], | |
317 | threepid['validated_at'], | |
318 | ) | |
319 | ||
320 | # And we add an email pusher for them by default, but only | |
321 | # if email notifications are enabled (so people don't start | |
322 | # getting mail spam where they weren't before if email | |
323 | # notifs are set up on a home server) | |
324 | if (self.hs.config.email_enable_notifs and | |
325 | self.hs.config.email_notif_for_new_users): | |
326 | # Pull the ID of the access token back out of the db | |
327 | # It would really make more sense for this to be passed | |
328 | # up when the access token is saved, but that's quite an | |
329 | # invasive change I'd rather do separately. | |
330 | user_tuple = yield self.store.get_user_by_access_token( | |
331 | token | |
332 | ) | |
333 | token_id = user_tuple["token_id"] | |
334 | ||
335 | yield self.hs.get_pusherpool().add_pusher( | |
336 | user_id=user_id, | |
337 | access_token=token_id, | |
338 | kind="email", | |
339 | app_id="m.email", | |
340 | app_display_name="Email Notifications", | |
341 | device_display_name=threepid["address"], | |
342 | pushkey=threepid["address"], | |
343 | lang=None, # We don't know a user's language here | |
344 | data={}, | |
345 | ) | |
346 | ||
347 | if bind_email: | |
348 | logger.info("bind_email specified: binding") | |
349 | logger.debug("Binding emails %s to %s" % ( | |
350 | threepid, user_id | |
351 | )) | |
352 | yield self.identity_handler.bind_threepid( | |
353 | threepid['threepid_creds'], user_id | |
354 | ) | |
355 | else: | |
356 | logger.info("bind_email not specified: not binding email") | |
357 | ||
358 | @defer.inlineCallbacks | |
359 | def _create_registration_details(self, user_id, params): | |
360 | """Complete registration of newly-registered user | |
361 | ||
362 | Allocates device_id if one was not given; also creates access_token | |
363 | and refresh_token. | |
364 | ||
365 | Args: | |
366 | (str) user_id: full canonical @user:id | |
367 | (object) params: registration parameters, from which we pull | |
368 | device_id and initial_device_name | |
369 | Returns: | |
370 | defer.Deferred: (object) dictionary for response from /register | |
371 | """ | |
372 | device_id = yield self._register_device(user_id, params) | |
373 | ||
374 | access_token, refresh_token = ( | |
375 | yield self.auth_handler.get_login_tuple_for_user_id( | |
376 | user_id, device_id=device_id, | |
377 | initial_display_name=params.get("initial_device_display_name") | |
378 | ) | |
379 | ) | |
380 | ||
300 | 381 | defer.returnValue({ |
301 | 382 | "user_id": user_id, |
302 | "access_token": token, | |
383 | "access_token": access_token, | |
303 | 384 | "home_server": self.hs.hostname, |
304 | 385 | "refresh_token": refresh_token, |
386 | "device_id": device_id, | |
305 | 387 | }) |
306 | 388 | |
307 | @defer.inlineCallbacks | |
308 | def onEmailTokenRequest(self, request): | |
309 | body = parse_json_object_from_request(request) | |
310 | ||
311 | required = ['id_server', 'client_secret', 'email', 'send_attempt'] | |
312 | absent = [] | |
313 | for k in required: | |
314 | if k not in body: | |
315 | absent.append(k) | |
316 | ||
317 | if len(absent) > 0: | |
318 | raise SynapseError(400, "Missing params: %r" % absent, Codes.MISSING_PARAM) | |
319 | ||
320 | existingUid = yield self.hs.get_datastore().get_user_id_by_threepid( | |
321 | 'email', body['email'] | |
322 | ) | |
323 | ||
324 | if existingUid is not None: | |
325 | raise SynapseError(400, "Email is already in use", Codes.THREEPID_IN_USE) | |
326 | ||
327 | ret = yield self.identity_handler.requestEmailToken(**body) | |
328 | defer.returnValue((200, ret)) | |
389 | def _register_device(self, user_id, params): | |
390 | """Register a device for a user. | |
391 | ||
392 | This is called after the user's credentials have been validated, but | |
393 | before the access token has been issued. | |
394 | ||
395 | Args: | |
396 | (str) user_id: full canonical @user:id | |
397 | (object) params: registration parameters, from which we pull | |
398 | device_id and initial_device_name | |
399 | Returns: | |
400 | defer.Deferred: (str) device_id | |
401 | """ | |
402 | # register the user's device | |
403 | device_id = params.get("device_id") | |
404 | initial_display_name = params.get("initial_device_display_name") | |
405 | device_id = self.device_handler.check_device_registered( | |
406 | user_id, device_id, initial_display_name | |
407 | ) | |
408 | return device_id | |
329 | 409 | |
330 | 410 | @defer.inlineCallbacks |
331 | 411 | def _do_guest_registration(self): |
335 | 415 | generate_token=False, |
336 | 416 | make_guest=True |
337 | 417 | ) |
338 | access_token = self.auth_handler.generate_access_token(user_id, ["guest = true"]) | |
418 | access_token = self.auth_handler.generate_access_token( | |
419 | user_id, ["guest = true"] | |
420 | ) | |
421 | # XXX the "guest" caveat is not copied by /tokenrefresh. That's ok | |
422 | # so long as we don't return a refresh_token here. | |
339 | 423 | defer.returnValue((200, { |
340 | 424 | "user_id": user_id, |
341 | 425 | "access_token": access_token, |
344 | 428 | |
345 | 429 | |
346 | 430 | def register_servlets(hs, http_server): |
431 | RegisterRequestTokenRestServlet(hs).register(http_server) | |
347 | 432 | RegisterRestServlet(hs).register(http_server) |
38 | 38 | try: |
39 | 39 | old_refresh_token = body["refresh_token"] |
40 | 40 | auth_handler = self.hs.get_auth_handler() |
41 | (user_id, new_refresh_token) = yield self.store.exchange_refresh_token( | |
42 | old_refresh_token, auth_handler.generate_refresh_token) | |
43 | new_access_token = yield auth_handler.issue_access_token(user_id) | |
41 | refresh_result = yield self.store.exchange_refresh_token( | |
42 | old_refresh_token, auth_handler.generate_refresh_token | |
43 | ) | |
44 | (user_id, new_refresh_token, device_id) = refresh_result | |
45 | new_access_token = yield auth_handler.issue_access_token( | |
46 | user_id, device_id | |
47 | ) | |
44 | 48 | defer.returnValue((200, { |
45 | 49 | "access_token": new_access_token, |
46 | 50 | "refresh_token": new_refresh_token, |
25 | 25 | |
26 | 26 | def on_GET(self, request): |
27 | 27 | return (200, { |
28 | "versions": ["r0.0.1"] | |
28 | "versions": [ | |
29 | "r0.0.1", | |
30 | "r0.1.0", | |
31 | "r0.2.0", | |
32 | ] | |
29 | 33 | }) |
30 | 34 | |
31 | 35 |
14 | 14 | |
15 | 15 | from synapse.http.server import respond_with_json_bytes, finish_request |
16 | 16 | |
17 | from synapse.util.stringutils import random_string | |
18 | 17 | from synapse.api.errors import ( |
19 | cs_exception, SynapseError, CodeMessageException, Codes, cs_error | |
18 | Codes, cs_error | |
20 | 19 | ) |
21 | 20 | |
22 | 21 | from twisted.protocols.basic import FileSender |
23 | 22 | from twisted.web import server, resource |
24 | from twisted.internet import defer | |
25 | 23 | |
26 | 24 | import base64 |
27 | 25 | import simplejson as json |
49 | 47 | """ |
50 | 48 | isLeaf = True |
51 | 49 | |
52 | def __init__(self, hs, directory, auth, external_addr): | |
50 | def __init__(self, hs, directory): | |
53 | 51 | resource.Resource.__init__(self) |
54 | 52 | self.hs = hs |
55 | 53 | self.directory = directory |
56 | self.auth = auth | |
57 | self.external_addr = external_addr.rstrip('/') | |
58 | self.max_upload_size = hs.config.max_upload_size | |
59 | ||
60 | if not os.path.isdir(self.directory): | |
61 | os.mkdir(self.directory) | |
62 | logger.info("ContentRepoResource : Created %s directory.", | |
63 | self.directory) | |
64 | ||
65 | @defer.inlineCallbacks | |
66 | def map_request_to_name(self, request): | |
67 | # auth the user | |
68 | requester = yield self.auth.get_user_by_req(request) | |
69 | ||
70 | # namespace all file uploads on the user | |
71 | prefix = base64.urlsafe_b64encode( | |
72 | requester.user.to_string() | |
73 | ).replace('=', '') | |
74 | ||
75 | # use a random string for the main portion | |
76 | main_part = random_string(24) | |
77 | ||
78 | # suffix with a file extension if we can make one. This is nice to | |
79 | # provide a hint to clients on the file information. We will also reuse | |
80 | # this info to spit back the content type to the client. | |
81 | suffix = "" | |
82 | if request.requestHeaders.hasHeader("Content-Type"): | |
83 | content_type = request.requestHeaders.getRawHeaders( | |
84 | "Content-Type")[0] | |
85 | suffix = "." + base64.urlsafe_b64encode(content_type) | |
86 | if (content_type.split("/")[0].lower() in | |
87 | ["image", "video", "audio"]): | |
88 | file_ext = content_type.split("/")[-1] | |
89 | # be a little paranoid and only allow a-z | |
90 | file_ext = re.sub("[^a-z]", "", file_ext) | |
91 | suffix += "." + file_ext | |
92 | ||
93 | file_name = prefix + main_part + suffix | |
94 | file_path = os.path.join(self.directory, file_name) | |
95 | logger.info("User %s is uploading a file to path %s", | |
96 | request.user.user_id.to_string(), | |
97 | file_path) | |
98 | ||
99 | # keep trying to make a non-clashing file, with a sensible max attempts | |
100 | attempts = 0 | |
101 | while os.path.exists(file_path): | |
102 | main_part = random_string(24) | |
103 | file_name = prefix + main_part + suffix | |
104 | file_path = os.path.join(self.directory, file_name) | |
105 | attempts += 1 | |
106 | if attempts > 25: # really? Really? | |
107 | raise SynapseError(500, "Unable to create file.") | |
108 | ||
109 | defer.returnValue(file_path) | |
110 | 54 | |
111 | 55 | def render_GET(self, request): |
112 | 56 | # no auth here on purpose, to allow anyone to view, even across home |
154 | 98 | |
155 | 99 | return server.NOT_DONE_YET |
156 | 100 | |
157 | def render_POST(self, request): | |
158 | self._async_render(request) | |
159 | return server.NOT_DONE_YET | |
160 | ||
161 | 101 | def render_OPTIONS(self, request): |
162 | 102 | respond_with_json_bytes(request, 200, {}, send_cors=True) |
163 | 103 | return server.NOT_DONE_YET |
164 | ||
165 | @defer.inlineCallbacks | |
166 | def _async_render(self, request): | |
167 | try: | |
168 | # TODO: The checks here are a bit late. The content will have | |
169 | # already been uploaded to a tmp file at this point | |
170 | content_length = request.getHeader("Content-Length") | |
171 | if content_length is None: | |
172 | raise SynapseError( | |
173 | msg="Request must specify a Content-Length", code=400 | |
174 | ) | |
175 | if int(content_length) > self.max_upload_size: | |
176 | raise SynapseError( | |
177 | msg="Upload request body is too large", | |
178 | code=413, | |
179 | ) | |
180 | ||
181 | fname = yield self.map_request_to_name(request) | |
182 | ||
183 | # TODO I have a suspicious feeling this is just going to block | |
184 | with open(fname, "wb") as f: | |
185 | f.write(request.content.read()) | |
186 | ||
187 | # FIXME (erikj): These should use constants. | |
188 | file_name = os.path.basename(fname) | |
189 | # FIXME: we can't assume what the repo's public mounted path is | |
190 | # ...plus self-signed SSL won't work to remote clients anyway | |
191 | # ...and we can't assume that it's SSL anyway, as we might want to | |
192 | # serve it via the non-SSL listener... | |
193 | url = "%s/_matrix/content/%s" % ( | |
194 | self.external_addr, file_name | |
195 | ) | |
196 | ||
197 | respond_with_json_bytes(request, 200, | |
198 | json.dumps({"content_token": url}), | |
199 | send_cors=True) | |
200 | ||
201 | except CodeMessageException as e: | |
202 | logger.exception(e) | |
203 | respond_with_json_bytes(request, e.code, | |
204 | json.dumps(cs_exception(e))) | |
205 | except Exception as e: | |
206 | logger.error("Failed to store file: %s" % e) | |
207 | respond_with_json_bytes( | |
208 | request, | |
209 | 500, | |
210 | json.dumps({"error": "Internal server error"}), | |
211 | send_cors=True) |
64 | 64 | file_id[0:2], file_id[2:4], file_id[4:], |
65 | 65 | file_name |
66 | 66 | ) |
67 | ||
68 | def remote_media_thumbnail_dir(self, server_name, file_id): | |
69 | return os.path.join( | |
70 | self.base_path, "remote_thumbnail", server_name, | |
71 | file_id[0:2], file_id[2:4], file_id[4:], | |
72 | ) |
29 | 29 | |
30 | 30 | from twisted.internet import defer, threads |
31 | 31 | |
32 | from synapse.util.async import ObservableDeferred | |
32 | from synapse.util.async import Linearizer | |
33 | 33 | from synapse.util.stringutils import is_ascii |
34 | 34 | from synapse.util.logcontext import preserve_context_over_fn |
35 | 35 | |
36 | 36 | import os |
37 | import errno | |
38 | import shutil | |
37 | 39 | |
38 | 40 | import cgi |
39 | 41 | import logging |
42 | 44 | logger = logging.getLogger(__name__) |
43 | 45 | |
44 | 46 | |
47 | UPDATE_RECENTLY_ACCESSED_REMOTES_TS = 60 * 1000 | |
48 | ||
49 | ||
45 | 50 | class MediaRepository(object): |
46 | def __init__(self, hs, filepaths): | |
51 | def __init__(self, hs): | |
47 | 52 | self.auth = hs.get_auth() |
48 | 53 | self.client = MatrixFederationHttpClient(hs) |
49 | 54 | self.clock = hs.get_clock() |
51 | 56 | self.store = hs.get_datastore() |
52 | 57 | self.max_upload_size = hs.config.max_upload_size |
53 | 58 | self.max_image_pixels = hs.config.max_image_pixels |
54 | self.filepaths = filepaths | |
55 | self.downloads = {} | |
59 | self.filepaths = MediaFilePaths(hs.config.media_store_path) | |
56 | 60 | self.dynamic_thumbnails = hs.config.dynamic_thumbnails |
57 | 61 | self.thumbnail_requirements = hs.config.thumbnail_requirements |
62 | ||
63 | self.remote_media_linearizer = Linearizer() | |
64 | ||
65 | self.recently_accessed_remotes = set() | |
66 | ||
67 | self.clock.looping_call( | |
68 | self._update_recently_accessed_remotes, | |
69 | UPDATE_RECENTLY_ACCESSED_REMOTES_TS | |
70 | ) | |
71 | ||
72 | @defer.inlineCallbacks | |
73 | def _update_recently_accessed_remotes(self): | |
74 | media = self.recently_accessed_remotes | |
75 | self.recently_accessed_remotes = set() | |
76 | ||
77 | yield self.store.update_cached_last_access_time( | |
78 | media, self.clock.time_msec() | |
79 | ) | |
58 | 80 | |
59 | 81 | @staticmethod |
60 | 82 | def _makedirs(filepath): |
92 | 114 | |
93 | 115 | defer.returnValue("mxc://%s/%s" % (self.server_name, media_id)) |
94 | 116 | |
117 | @defer.inlineCallbacks | |
95 | 118 | def get_remote_media(self, server_name, media_id): |
96 | 119 | key = (server_name, media_id) |
97 | download = self.downloads.get(key) | |
98 | if download is None: | |
99 | download = self._get_remote_media_impl(server_name, media_id) | |
100 | download = ObservableDeferred( | |
101 | download, | |
102 | consumeErrors=True | |
103 | ) | |
104 | self.downloads[key] = download | |
105 | ||
106 | @download.addBoth | |
107 | def callback(media_info): | |
108 | del self.downloads[key] | |
109 | return media_info | |
110 | return download.observe() | |
120 | with (yield self.remote_media_linearizer.queue(key)): | |
121 | media_info = yield self._get_remote_media_impl(server_name, media_id) | |
122 | defer.returnValue(media_info) | |
111 | 123 | |
112 | 124 | @defer.inlineCallbacks |
113 | 125 | def _get_remote_media_impl(self, server_name, media_id): |
117 | 129 | if not media_info: |
118 | 130 | media_info = yield self._download_remote_file( |
119 | 131 | server_name, media_id |
132 | ) | |
133 | else: | |
134 | self.recently_accessed_remotes.add((server_name, media_id)) | |
135 | yield self.store.update_cached_last_access_time( | |
136 | [(server_name, media_id)], self.clock.time_msec() | |
120 | 137 | ) |
121 | 138 | defer.returnValue(media_info) |
122 | 139 | |
415 | 432 | "height": m_height, |
416 | 433 | }) |
417 | 434 | |
435 | @defer.inlineCallbacks | |
436 | def delete_old_remote_media(self, before_ts): | |
437 | old_media = yield self.store.get_remote_media_before(before_ts) | |
438 | ||
439 | deleted = 0 | |
440 | ||
441 | for media in old_media: | |
442 | origin = media["media_origin"] | |
443 | media_id = media["media_id"] | |
444 | file_id = media["filesystem_id"] | |
445 | key = (origin, media_id) | |
446 | ||
447 | logger.info("Deleting: %r", key) | |
448 | ||
449 | with (yield self.remote_media_linearizer.queue(key)): | |
450 | full_path = self.filepaths.remote_media_filepath(origin, file_id) | |
451 | try: | |
452 | os.remove(full_path) | |
453 | except OSError as e: | |
454 | logger.warn("Failed to remove file: %r", full_path) | |
455 | if e.errno == errno.ENOENT: | |
456 | pass | |
457 | else: | |
458 | continue | |
459 | ||
460 | thumbnail_dir = self.filepaths.remote_media_thumbnail_dir( | |
461 | origin, file_id | |
462 | ) | |
463 | shutil.rmtree(thumbnail_dir, ignore_errors=True) | |
464 | ||
465 | yield self.store.delete_remote_media(origin, media_id) | |
466 | deleted += 1 | |
467 | ||
468 | defer.returnValue({"deleted": deleted}) | |
469 | ||
418 | 470 | |
419 | 471 | class MediaRepositoryResource(Resource): |
420 | 472 | """File uploading and downloading. |
463 | 515 | |
464 | 516 | def __init__(self, hs): |
465 | 517 | Resource.__init__(self) |
466 | filepaths = MediaFilePaths(hs.config.media_store_path) | |
467 | ||
468 | media_repo = MediaRepository(hs, filepaths) | |
518 | ||
519 | media_repo = hs.get_media_repository() | |
469 | 520 | |
470 | 521 | self.putChild("upload", UploadResource(hs, media_repo)) |
471 | 522 | self.putChild("download", DownloadResource(hs, media_repo)) |
27 | 27 | ) |
28 | 28 | from synapse.util.async import ObservableDeferred |
29 | 29 | from synapse.util.stringutils import is_ascii |
30 | ||
31 | from copy import deepcopy | |
30 | 32 | |
31 | 33 | import os |
32 | 34 | import re |
328 | 330 | # ...or if they are within a <script/> or <style/> tag. |
329 | 331 | # This is a very very very coarse approximation to a plain text |
330 | 332 | # render of the page. |
331 | text_nodes = tree.xpath("//text()[not(ancestor::header | ancestor::nav | " | |
332 | "ancestor::aside | ancestor::footer | " | |
333 | "ancestor::script | ancestor::style)]" + | |
334 | "[ancestor::body]") | |
335 | text = '' | |
336 | for text_node in text_nodes: | |
337 | if len(text) < 500: | |
338 | text += text_node + ' ' | |
339 | else: | |
340 | break | |
341 | text = re.sub(r'[\t ]+', ' ', text) | |
342 | text = re.sub(r'[\t \r\n]*[\r\n]+', '\n', text) | |
343 | text = text.strip()[:500] | |
344 | og['og:description'] = text if text else None | |
333 | ||
334 | # We don't just use XPATH here as that is slow on some machines. | |
335 | ||
336 | # We clone `tree` as we modify it. | |
337 | cloned_tree = deepcopy(tree.find("body")) | |
338 | ||
339 | TAGS_TO_REMOVE = ("header", "nav", "aside", "footer", "script", "style",) | |
340 | for el in cloned_tree.iter(TAGS_TO_REMOVE): | |
341 | el.getparent().remove(el) | |
342 | ||
343 | # Split all the text nodes into paragraphs (by splitting on new | |
344 | # lines) | |
345 | text_nodes = ( | |
346 | re.sub(r'\s+', '\n', el.text).strip() | |
347 | for el in cloned_tree.iter() | |
348 | if el.text and isinstance(el.tag, basestring) # Removes comments | |
349 | ) | |
350 | og['og:description'] = summarize_paragraphs(text_nodes) | |
345 | 351 | |
346 | 352 | # TODO: delete the url downloads to stop diskfilling, |
347 | 353 | # as we only ever cared about its OG |
449 | 455 | content_type.startswith("application/xhtml") |
450 | 456 | ): |
451 | 457 | return True |
458 | ||
459 | ||
460 | def summarize_paragraphs(text_nodes, min_size=200, max_size=500): | |
461 | # Try to get a summary of between 200 and 500 words, respecting | |
462 | # first paragraph and then word boundaries. | |
463 | # TODO: Respect sentences? | |
464 | ||
465 | description = '' | |
466 | ||
467 | # Keep adding paragraphs until we get to the MIN_SIZE. | |
468 | for text_node in text_nodes: | |
469 | if len(description) < min_size: | |
470 | text_node = re.sub(r'[\t \r\n]+', ' ', text_node) | |
471 | description += text_node + '\n\n' | |
472 | else: | |
473 | break | |
474 | ||
475 | description = description.strip() | |
476 | description = re.sub(r'[\t ]+', ' ', description) | |
477 | description = re.sub(r'[\t \r\n]*[\r\n]+', '\n\n', description) | |
478 | ||
479 | # If the concatenation of paragraphs to get above MIN_SIZE | |
480 | # took us over MAX_SIZE, then we need to truncate mid paragraph | |
481 | if len(description) > max_size: | |
482 | new_desc = "" | |
483 | ||
484 | # This splits the paragraph into words, but keeping the | |
485 | # (preceeding) whitespace intact so we can easily concat | |
486 | # words back together. | |
487 | for match in re.finditer("\s*\S+", description): | |
488 | word = match.group() | |
489 | ||
490 | # Keep adding words while the total length is less than | |
491 | # MAX_SIZE. | |
492 | if len(word) + len(new_desc) < max_size: | |
493 | new_desc += word | |
494 | else: | |
495 | # At this point the next word *will* take us over | |
496 | # MAX_SIZE, but we also want to ensure that its not | |
497 | # a huge word. If it is add it anyway and we'll | |
498 | # truncate later. | |
499 | if len(new_desc) < min_size: | |
500 | new_desc += word | |
501 | break | |
502 | ||
503 | # Double check that we're not over the limit | |
504 | if len(new_desc) > max_size: | |
505 | new_desc = new_desc[:max_size] | |
506 | ||
507 | # We always add an ellipsis because at the very least | |
508 | # we chopped mid paragraph. | |
509 | description = new_desc.strip() + "…" | |
510 | return description if description else None |
18 | 18 | # partial one for unit test mocking. |
19 | 19 | |
20 | 20 | # Imports required for the default HomeServer() implementation |
21 | import logging | |
22 | ||
23 | from twisted.enterprise import adbapi | |
21 | 24 | from twisted.web.client import BrowserLikePolicyForHTTPS |
22 | from twisted.enterprise import adbapi | |
23 | ||
25 | ||
26 | from synapse.api.auth import Auth | |
27 | from synapse.api.filtering import Filtering | |
28 | from synapse.api.ratelimiting import Ratelimiter | |
29 | from synapse.appservice.api import ApplicationServiceApi | |
24 | 30 | from synapse.appservice.scheduler import ApplicationServiceScheduler |
25 | from synapse.appservice.api import ApplicationServiceApi | |
31 | from synapse.crypto.keyring import Keyring | |
32 | from synapse.events.builder import EventBuilderFactory | |
26 | 33 | from synapse.federation import initialize_http_replication |
27 | from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory | |
28 | from synapse.notifier import Notifier | |
29 | from synapse.api.auth import Auth | |
30 | 34 | from synapse.handlers import Handlers |
35 | from synapse.handlers.appservice import ApplicationServicesHandler | |
36 | from synapse.handlers.auth import AuthHandler | |
37 | from synapse.handlers.device import DeviceHandler | |
38 | from synapse.handlers.e2e_keys import E2eKeysHandler | |
31 | 39 | from synapse.handlers.presence import PresenceHandler |
40 | from synapse.handlers.room import RoomListHandler | |
32 | 41 | from synapse.handlers.sync import SyncHandler |
33 | 42 | from synapse.handlers.typing import TypingHandler |
34 | from synapse.handlers.room import RoomListHandler | |
35 | from synapse.handlers.auth import AuthHandler | |
36 | from synapse.handlers.appservice import ApplicationServicesHandler | |
43 | from synapse.http.client import SimpleHttpClient, InsecureInterceptableContextFactory | |
44 | from synapse.http.matrixfederationclient import MatrixFederationHttpClient | |
45 | from synapse.notifier import Notifier | |
46 | from synapse.push.pusherpool import PusherPool | |
47 | from synapse.rest.media.v1.media_repository import MediaRepository | |
37 | 48 | from synapse.state import StateHandler |
38 | 49 | from synapse.storage import DataStore |
50 | from synapse.streams.events import EventSources | |
39 | 51 | from synapse.util import Clock |
40 | 52 | from synapse.util.distributor import Distributor |
41 | from synapse.streams.events import EventSources | |
42 | from synapse.api.ratelimiting import Ratelimiter | |
43 | from synapse.crypto.keyring import Keyring | |
44 | from synapse.push.pusherpool import PusherPool | |
45 | from synapse.events.builder import EventBuilderFactory | |
46 | from synapse.api.filtering import Filtering | |
47 | ||
48 | from synapse.http.matrixfederationclient import MatrixFederationHttpClient | |
49 | ||
50 | import logging | |
51 | ||
52 | 53 | |
53 | 54 | logger = logging.getLogger(__name__) |
54 | 55 | |
90 | 91 | 'typing_handler', |
91 | 92 | 'room_list_handler', |
92 | 93 | 'auth_handler', |
94 | 'device_handler', | |
95 | 'e2e_keys_handler', | |
93 | 96 | 'application_service_api', |
94 | 97 | 'application_service_scheduler', |
95 | 98 | 'application_service_handler', |
112 | 115 | 'filtering', |
113 | 116 | 'http_client_context_factory', |
114 | 117 | 'simple_http_client', |
118 | 'media_repository', | |
115 | 119 | ] |
116 | 120 | |
117 | 121 | def __init__(self, hostname, **kwargs): |
194 | 198 | def build_auth_handler(self): |
195 | 199 | return AuthHandler(self) |
196 | 200 | |
201 | def build_device_handler(self): | |
202 | return DeviceHandler(self) | |
203 | ||
204 | def build_e2e_keys_handler(self): | |
205 | return E2eKeysHandler(self) | |
206 | ||
197 | 207 | def build_application_service_api(self): |
198 | 208 | return ApplicationServiceApi(self) |
199 | 209 | |
231 | 241 | name, |
232 | 242 | **self.db_config.get("args", {}) |
233 | 243 | ) |
244 | ||
245 | def build_media_repository(self): | |
246 | return MediaRepository(self) | |
234 | 247 | |
235 | 248 | def remove_pusher(self, app_id, push_key, user_id): |
236 | 249 | return self.get_pusherpool().remove_pusher(app_id, push_key, user_id) |
0 | import synapse.handlers | |
1 | import synapse.handlers.auth | |
2 | import synapse.handlers.device | |
3 | import synapse.handlers.e2e_keys | |
4 | import synapse.storage | |
5 | import synapse.state | |
6 | ||
7 | class HomeServer(object): | |
8 | def get_auth_handler(self) -> synapse.handlers.auth.AuthHandler: | |
9 | pass | |
10 | ||
11 | def get_datastore(self) -> synapse.storage.DataStore: | |
12 | pass | |
13 | ||
14 | def get_device_handler(self) -> synapse.handlers.device.DeviceHandler: | |
15 | pass | |
16 | ||
17 | def get_e2e_keys_handler(self) -> synapse.handlers.e2e_keys.E2eKeysHandler: | |
18 | pass | |
19 | ||
20 | def get_handlers(self) -> synapse.handlers.Handlers: | |
21 | pass | |
22 | ||
23 | def get_state_handler(self) -> synapse.state.StateHandler: | |
24 | pass |
378 | 378 | try: |
379 | 379 | # FIXME: hs.get_auth() is bad style, but we need to do it to |
380 | 380 | # get around circular deps. |
381 | self.hs.get_auth().check(event, auth_events) | |
381 | # The signatures have already been checked at this point | |
382 | self.hs.get_auth().check(event, auth_events, do_sig_check=False) | |
382 | 383 | prev_event = event |
383 | 384 | except AuthError: |
384 | 385 | return prev_event |
390 | 391 | try: |
391 | 392 | # FIXME: hs.get_auth() is bad style, but we need to do it to |
392 | 393 | # get around circular deps. |
393 | self.hs.get_auth().check(event, auth_events) | |
394 | # The signatures have already been checked at this point | |
395 | self.hs.get_auth().check(event, auth_events, do_sig_check=False) | |
394 | 396 | return event |
395 | 397 | except AuthError: |
396 | 398 | pass |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from twisted.internet import defer |
16 | ||
17 | from synapse.storage.devices import DeviceStore | |
16 | 18 | from .appservice import ( |
17 | 19 | ApplicationServiceStore, ApplicationServiceTransactionStore |
18 | 20 | ) |
79 | 81 | EventPushActionsStore, |
80 | 82 | OpenIdStore, |
81 | 83 | ClientIpStore, |
84 | DeviceStore, | |
82 | 85 | ): |
83 | 86 | |
84 | 87 | def __init__(self, db_conn, hs): |
91 | 94 | extra_tables=[("local_invites", "stream_id")] |
92 | 95 | ) |
93 | 96 | self._backfill_id_gen = StreamIdGenerator( |
94 | db_conn, "events", "stream_ordering", step=-1 | |
97 | db_conn, "events", "stream_ordering", step=-1, | |
98 | extra_tables=[("ex_outlier_stream", "event_stream_ordering")] | |
95 | 99 | ) |
96 | 100 | self._receipts_id_gen = StreamIdGenerator( |
97 | 101 | db_conn, "receipts_linearized", "stream_id" |
596 | 596 | more rows, returning the result as a list of dicts. |
597 | 597 | |
598 | 598 | Args: |
599 | table : string giving the table name | |
600 | keyvalues : dict of column names and values to select the rows with, | |
601 | or None to not apply a WHERE clause. | |
602 | retcols : list of strings giving the names of the columns to return | |
599 | table (str): the table name | |
600 | keyvalues (dict[str, Any] | None): | |
601 | column names and values to select the rows with, or None to not | |
602 | apply a WHERE clause. | |
603 | retcols (iterable[str]): the names of the columns to return | |
604 | Returns: | |
605 | defer.Deferred: resolves to list[dict[str, Any]] | |
603 | 606 | """ |
604 | 607 | return self.runInteraction( |
605 | 608 | desc, |
614 | 617 | |
615 | 618 | Args: |
616 | 619 | txn : Transaction object |
617 | table : string giving the table name | |
618 | keyvalues : dict of column names and values to select the rows with | |
619 | retcols : list of strings giving the names of the columns to return | |
620 | table (str): the table name | |
621 | keyvalues (dict[str, T] | None): | |
622 | column names and values to select the rows with, or None to not | |
623 | apply a WHERE clause. | |
624 | retcols (iterable[str]): the names of the columns to return | |
620 | 625 | """ |
621 | 626 | if keyvalues: |
622 | 627 | sql = "SELECT %s FROM %s WHERE %s" % ( |
805 | 810 | raise StoreError(404, "No row found") |
806 | 811 | if txn.rowcount > 1: |
807 | 812 | raise StoreError(500, "more than one row matched") |
813 | ||
814 | def _simple_delete(self, table, keyvalues, desc): | |
815 | return self.runInteraction( | |
816 | desc, self._simple_delete_txn, table, keyvalues | |
817 | ) | |
808 | 818 | |
809 | 819 | @staticmethod |
810 | 820 | def _simple_delete_txn(txn, table, keyvalues): |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | from ._base import SQLBaseStore |
16 | from . import engines | |
16 | 17 | |
17 | 18 | from twisted.internet import defer |
18 | 19 | |
86 | 87 | |
87 | 88 | @defer.inlineCallbacks |
88 | 89 | def start_doing_background_updates(self): |
90 | assert self._background_update_timer is None, \ | |
91 | "background updates already running" | |
92 | ||
93 | logger.info("Starting background schema updates") | |
94 | ||
89 | 95 | while True: |
90 | if self._background_update_timer is not None: | |
91 | return | |
92 | ||
93 | 96 | sleep = defer.Deferred() |
94 | 97 | self._background_update_timer = self._clock.call_later( |
95 | 98 | self.BACKGROUND_UPDATE_INTERVAL_MS / 1000., sleep.callback, None |
100 | 103 | self._background_update_timer = None |
101 | 104 | |
102 | 105 | try: |
103 | result = yield self.do_background_update( | |
106 | result = yield self.do_next_background_update( | |
104 | 107 | self.BACKGROUND_UPDATE_DURATION_MS |
105 | 108 | ) |
106 | 109 | except: |
107 | 110 | logger.exception("Error doing update") |
108 | ||
109 | if result is None: | |
110 | logger.info( | |
111 | "No more background updates to do." | |
112 | " Unscheduling background update task." | |
113 | ) | |
114 | return | |
111 | else: | |
112 | if result is None: | |
113 | logger.info( | |
114 | "No more background updates to do." | |
115 | " Unscheduling background update task." | |
116 | ) | |
117 | defer.returnValue(None) | |
115 | 118 | |
116 | 119 | @defer.inlineCallbacks |
117 | def do_background_update(self, desired_duration_ms): | |
118 | """Does some amount of work on a background update | |
120 | def do_next_background_update(self, desired_duration_ms): | |
121 | """Does some amount of work on the next queued background update | |
122 | ||
119 | 123 | Args: |
120 | 124 | desired_duration_ms(float): How long we want to spend |
121 | 125 | updating. |
134 | 138 | self._background_update_queue.append(update['update_name']) |
135 | 139 | |
136 | 140 | if not self._background_update_queue: |
141 | # no work left to do | |
137 | 142 | defer.returnValue(None) |
138 | 143 | |
144 | # pop from the front, and add back to the back | |
139 | 145 | update_name = self._background_update_queue.pop(0) |
140 | 146 | self._background_update_queue.append(update_name) |
147 | ||
148 | res = yield self._do_background_update(update_name, desired_duration_ms) | |
149 | defer.returnValue(res) | |
150 | ||
151 | @defer.inlineCallbacks | |
152 | def _do_background_update(self, update_name, desired_duration_ms): | |
153 | logger.info("Starting update batch on background update '%s'", | |
154 | update_name) | |
141 | 155 | |
142 | 156 | update_handler = self._background_update_handlers[update_name] |
143 | 157 | |
201 | 215 | """ |
202 | 216 | self._background_update_handlers[update_name] = update_handler |
203 | 217 | |
218 | def register_background_index_update(self, update_name, index_name, | |
219 | table, columns): | |
220 | """Helper for store classes to do a background index addition | |
221 | ||
222 | To use: | |
223 | ||
224 | 1. use a schema delta file to add a background update. Example: | |
225 | INSERT INTO background_updates (update_name, progress_json) VALUES | |
226 | ('my_new_index', '{}'); | |
227 | ||
228 | 2. In the Store constructor, call this method | |
229 | ||
230 | Args: | |
231 | update_name (str): update_name to register for | |
232 | index_name (str): name of index to add | |
233 | table (str): table to add index to | |
234 | columns (list[str]): columns/expressions to include in index | |
235 | """ | |
236 | ||
237 | # if this is postgres, we add the indexes concurrently. Otherwise | |
238 | # we fall back to doing it inline | |
239 | if isinstance(self.database_engine, engines.PostgresEngine): | |
240 | conc = True | |
241 | else: | |
242 | conc = False | |
243 | ||
244 | sql = "CREATE INDEX %(conc)s %(name)s ON %(table)s (%(columns)s)" \ | |
245 | % { | |
246 | "conc": "CONCURRENTLY" if conc else "", | |
247 | "name": index_name, | |
248 | "table": table, | |
249 | "columns": ", ".join(columns), | |
250 | } | |
251 | ||
252 | def create_index_concurrently(conn): | |
253 | conn.rollback() | |
254 | # postgres insists on autocommit for the index | |
255 | conn.set_session(autocommit=True) | |
256 | c = conn.cursor() | |
257 | c.execute(sql) | |
258 | conn.set_session(autocommit=False) | |
259 | ||
260 | def create_index(conn): | |
261 | c = conn.cursor() | |
262 | c.execute(sql) | |
263 | ||
264 | @defer.inlineCallbacks | |
265 | def updater(progress, batch_size): | |
266 | logger.info("Adding index %s to %s", index_name, table) | |
267 | if conc: | |
268 | yield self.runWithConnection(create_index_concurrently) | |
269 | else: | |
270 | yield self.runWithConnection(create_index) | |
271 | yield self._end_background_update(update_name) | |
272 | defer.returnValue(1) | |
273 | ||
274 | self.register_background_update_handler(update_name, updater) | |
275 | ||
204 | 276 | def start_background_update(self, update_name, progress): |
205 | 277 | """Starts a background update running. |
206 | 278 |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | from ._base import SQLBaseStore, Cache | |
15 | import logging | |
16 | 16 | |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | from ._base import Cache | |
20 | from . import background_updates | |
21 | ||
22 | logger = logging.getLogger(__name__) | |
19 | 23 | |
20 | 24 | # Number of msec of granularity to store the user IP 'last seen' time. Smaller |
21 | 25 | # times give more inserts into the database even for readonly API hits |
23 | 27 | LAST_SEEN_GRANULARITY = 120 * 1000 |
24 | 28 | |
25 | 29 | |
26 | class ClientIpStore(SQLBaseStore): | |
27 | ||
30 | class ClientIpStore(background_updates.BackgroundUpdateStore): | |
28 | 31 | def __init__(self, hs): |
29 | 32 | self.client_ip_last_seen = Cache( |
30 | 33 | name="client_ip_last_seen", |
33 | 36 | |
34 | 37 | super(ClientIpStore, self).__init__(hs) |
35 | 38 | |
39 | self.register_background_index_update( | |
40 | "user_ips_device_index", | |
41 | index_name="user_ips_device_id", | |
42 | table="user_ips", | |
43 | columns=["user_id", "device_id", "last_seen"], | |
44 | ) | |
45 | ||
36 | 46 | @defer.inlineCallbacks |
37 | def insert_client_ip(self, user, access_token, ip, user_agent): | |
47 | def insert_client_ip(self, user, access_token, ip, user_agent, device_id): | |
38 | 48 | now = int(self._clock.time_msec()) |
39 | 49 | key = (user.to_string(), access_token, ip) |
40 | 50 | |
58 | 68 | "access_token": access_token, |
59 | 69 | "ip": ip, |
60 | 70 | "user_agent": user_agent, |
71 | "device_id": device_id, | |
61 | 72 | }, |
62 | 73 | values={ |
63 | 74 | "last_seen": now, |
65 | 76 | desc="insert_client_ip", |
66 | 77 | lock=False, |
67 | 78 | ) |
79 | ||
80 | @defer.inlineCallbacks | |
81 | def get_last_client_ip_by_device(self, devices): | |
82 | """For each device_id listed, give the user_ip it was last seen on | |
83 | ||
84 | Args: | |
85 | devices (iterable[(str, str)]): list of (user_id, device_id) pairs | |
86 | ||
87 | Returns: | |
88 | defer.Deferred: resolves to a dict, where the keys | |
89 | are (user_id, device_id) tuples. The values are also dicts, with | |
90 | keys giving the column names | |
91 | """ | |
92 | ||
93 | res = yield self.runInteraction( | |
94 | "get_last_client_ip_by_device", | |
95 | self._get_last_client_ip_by_device_txn, | |
96 | retcols=( | |
97 | "user_id", | |
98 | "access_token", | |
99 | "ip", | |
100 | "user_agent", | |
101 | "device_id", | |
102 | "last_seen", | |
103 | ), | |
104 | devices=devices | |
105 | ) | |
106 | ||
107 | ret = {(d["user_id"], d["device_id"]): d for d in res} | |
108 | defer.returnValue(ret) | |
109 | ||
110 | @classmethod | |
111 | def _get_last_client_ip_by_device_txn(cls, txn, devices, retcols): | |
112 | where_clauses = [] | |
113 | bindings = [] | |
114 | for (user_id, device_id) in devices: | |
115 | if device_id is None: | |
116 | where_clauses.append("(user_id = ? AND device_id IS NULL)") | |
117 | bindings.extend((user_id, )) | |
118 | else: | |
119 | where_clauses.append("(user_id = ? AND device_id = ?)") | |
120 | bindings.extend((user_id, device_id)) | |
121 | ||
122 | inner_select = ( | |
123 | "SELECT MAX(last_seen) mls, user_id, device_id FROM user_ips " | |
124 | "WHERE %(where)s " | |
125 | "GROUP BY user_id, device_id" | |
126 | ) % { | |
127 | "where": " OR ".join(where_clauses), | |
128 | } | |
129 | ||
130 | sql = ( | |
131 | "SELECT %(retcols)s FROM user_ips " | |
132 | "JOIN (%(inner_select)s) ips ON" | |
133 | " user_ips.last_seen = ips.mls AND" | |
134 | " user_ips.user_id = ips.user_id AND" | |
135 | " (user_ips.device_id = ips.device_id OR" | |
136 | " (user_ips.device_id IS NULL AND ips.device_id IS NULL)" | |
137 | " )" | |
138 | ) % { | |
139 | "retcols": ",".join("user_ips." + c for c in retcols), | |
140 | "inner_select": inner_select, | |
141 | } | |
142 | ||
143 | txn.execute(sql, bindings) | |
144 | return cls.cursor_to_dict(txn) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | import logging | |
15 | ||
16 | from twisted.internet import defer | |
17 | ||
18 | from synapse.api.errors import StoreError | |
19 | from ._base import SQLBaseStore | |
20 | ||
21 | logger = logging.getLogger(__name__) | |
22 | ||
23 | ||
24 | class DeviceStore(SQLBaseStore): | |
25 | @defer.inlineCallbacks | |
26 | def store_device(self, user_id, device_id, | |
27 | initial_device_display_name, | |
28 | ignore_if_known=True): | |
29 | """Ensure the given device is known; add it to the store if not | |
30 | ||
31 | Args: | |
32 | user_id (str): id of user associated with the device | |
33 | device_id (str): id of device | |
34 | initial_device_display_name (str): initial displayname of the | |
35 | device | |
36 | ignore_if_known (bool): ignore integrity errors which mean the | |
37 | device is already known | |
38 | Returns: | |
39 | defer.Deferred | |
40 | Raises: | |
41 | StoreError: if ignore_if_known is False and the device was already | |
42 | known | |
43 | """ | |
44 | try: | |
45 | yield self._simple_insert( | |
46 | "devices", | |
47 | values={ | |
48 | "user_id": user_id, | |
49 | "device_id": device_id, | |
50 | "display_name": initial_device_display_name | |
51 | }, | |
52 | desc="store_device", | |
53 | or_ignore=ignore_if_known, | |
54 | ) | |
55 | except Exception as e: | |
56 | logger.error("store_device with device_id=%s failed: %s", | |
57 | device_id, e) | |
58 | raise StoreError(500, "Problem storing device.") | |
59 | ||
60 | def get_device(self, user_id, device_id): | |
61 | """Retrieve a device. | |
62 | ||
63 | Args: | |
64 | user_id (str): The ID of the user which owns the device | |
65 | device_id (str): The ID of the device to retrieve | |
66 | Returns: | |
67 | defer.Deferred for a dict containing the device information | |
68 | Raises: | |
69 | StoreError: if the device is not found | |
70 | """ | |
71 | return self._simple_select_one( | |
72 | table="devices", | |
73 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
74 | retcols=("user_id", "device_id", "display_name"), | |
75 | desc="get_device", | |
76 | ) | |
77 | ||
78 | def delete_device(self, user_id, device_id): | |
79 | """Delete a device. | |
80 | ||
81 | Args: | |
82 | user_id (str): The ID of the user which owns the device | |
83 | device_id (str): The ID of the device to delete | |
84 | Returns: | |
85 | defer.Deferred | |
86 | """ | |
87 | return self._simple_delete_one( | |
88 | table="devices", | |
89 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
90 | desc="delete_device", | |
91 | ) | |
92 | ||
93 | def update_device(self, user_id, device_id, new_display_name=None): | |
94 | """Update a device. | |
95 | ||
96 | Args: | |
97 | user_id (str): The ID of the user which owns the device | |
98 | device_id (str): The ID of the device to update | |
99 | new_display_name (str|None): new displayname for device; None | |
100 | to leave unchanged | |
101 | Raises: | |
102 | StoreError: if the device is not found | |
103 | Returns: | |
104 | defer.Deferred | |
105 | """ | |
106 | updates = {} | |
107 | if new_display_name is not None: | |
108 | updates["display_name"] = new_display_name | |
109 | if not updates: | |
110 | return defer.succeed(None) | |
111 | return self._simple_update_one( | |
112 | table="devices", | |
113 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
114 | updatevalues=updates, | |
115 | desc="update_device", | |
116 | ) | |
117 | ||
118 | @defer.inlineCallbacks | |
119 | def get_devices_by_user(self, user_id): | |
120 | """Retrieve all of a user's registered devices. | |
121 | ||
122 | Args: | |
123 | user_id (str): | |
124 | Returns: | |
125 | defer.Deferred: resolves to a dict from device_id to a dict | |
126 | containing "device_id", "user_id" and "display_name" for each | |
127 | device. | |
128 | """ | |
129 | devices = yield self._simple_select_list( | |
130 | table="devices", | |
131 | keyvalues={"user_id": user_id}, | |
132 | retcols=("user_id", "device_id", "display_name"), | |
133 | desc="get_devices_by_user" | |
134 | ) | |
135 | ||
136 | defer.returnValue({d["device_id"]: d for d in devices}) |
11 | 11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | import collections | |
15 | ||
16 | import twisted.internet.defer | |
14 | 17 | |
15 | 18 | from ._base import SQLBaseStore |
16 | 19 | |
35 | 38 | query_list(list): List of pairs of user_ids and device_ids. |
36 | 39 | Returns: |
37 | 40 | Dict mapping from user-id to dict mapping from device_id to |
38 | key json byte strings. | |
41 | dict containing "key_json", "device_display_name". | |
39 | 42 | """ |
40 | def _get_e2e_device_keys(txn): | |
41 | result = {} | |
42 | for user_id, device_id in query_list: | |
43 | user_result = result.setdefault(user_id, {}) | |
44 | keyvalues = {"user_id": user_id} | |
45 | if device_id: | |
46 | keyvalues["device_id"] = device_id | |
47 | rows = self._simple_select_list_txn( | |
48 | txn, table="e2e_device_keys_json", | |
49 | keyvalues=keyvalues, | |
50 | retcols=["device_id", "key_json"] | |
51 | ) | |
52 | for row in rows: | |
53 | user_result[row["device_id"]] = row["key_json"] | |
54 | return result | |
55 | return self.runInteraction("get_e2e_device_keys", _get_e2e_device_keys) | |
43 | if not query_list: | |
44 | return {} | |
45 | ||
46 | return self.runInteraction( | |
47 | "get_e2e_device_keys", self._get_e2e_device_keys_txn, query_list | |
48 | ) | |
49 | ||
50 | def _get_e2e_device_keys_txn(self, txn, query_list): | |
51 | query_clauses = [] | |
52 | query_params = [] | |
53 | ||
54 | for (user_id, device_id) in query_list: | |
55 | query_clause = "k.user_id = ?" | |
56 | query_params.append(user_id) | |
57 | ||
58 | if device_id: | |
59 | query_clause += " AND k.device_id = ?" | |
60 | query_params.append(device_id) | |
61 | ||
62 | query_clauses.append(query_clause) | |
63 | ||
64 | sql = ( | |
65 | "SELECT k.user_id, k.device_id, " | |
66 | " d.display_name AS device_display_name, " | |
67 | " k.key_json" | |
68 | " FROM e2e_device_keys_json k" | |
69 | " LEFT JOIN devices d ON d.user_id = k.user_id" | |
70 | " AND d.device_id = k.device_id" | |
71 | " WHERE %s" | |
72 | ) % ( | |
73 | " OR ".join("(" + q + ")" for q in query_clauses) | |
74 | ) | |
75 | ||
76 | txn.execute(sql, query_params) | |
77 | rows = self.cursor_to_dict(txn) | |
78 | ||
79 | result = collections.defaultdict(dict) | |
80 | for row in rows: | |
81 | result[row["user_id"]][row["device_id"]] = row | |
82 | ||
83 | return result | |
56 | 84 | |
57 | 85 | def add_e2e_one_time_keys(self, user_id, device_id, time_now, key_list): |
58 | 86 | def _add_e2e_one_time_keys(txn): |
122 | 150 | return self.runInteraction( |
123 | 151 | "claim_e2e_one_time_keys", _claim_e2e_one_time_keys |
124 | 152 | ) |
153 | ||
154 | @twisted.internet.defer.inlineCallbacks | |
155 | def delete_e2e_keys_by_device(self, user_id, device_id): | |
156 | yield self._simple_delete( | |
157 | table="e2e_device_keys_json", | |
158 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
159 | desc="delete_e2e_device_keys_by_device" | |
160 | ) | |
161 | yield self._simple_delete( | |
162 | table="e2e_one_time_keys_json", | |
163 | keyvalues={"user_id": user_id, "device_id": device_id}, | |
164 | desc="delete_e2e_one_time_keys_by_device" | |
165 | ) |
15 | 15 | from ._base import SQLBaseStore |
16 | 16 | from twisted.internet import defer |
17 | 17 | from synapse.util.caches.descriptors import cachedInlineCallbacks |
18 | from synapse.types import RoomStreamToken | |
19 | from .stream import lower_bound | |
18 | 20 | |
19 | 21 | import logging |
20 | 22 | import ujson as json |
72 | 74 | |
73 | 75 | stream_ordering = results[0][0] |
74 | 76 | topological_ordering = results[0][1] |
77 | token = RoomStreamToken( | |
78 | topological_ordering, stream_ordering | |
79 | ) | |
75 | 80 | |
76 | 81 | sql = ( |
77 | 82 | "SELECT sum(notif), sum(highlight)" |
79 | 84 | " WHERE" |
80 | 85 | " user_id = ?" |
81 | 86 | " AND room_id = ?" |
82 | " AND (" | |
83 | " topological_ordering > ?" | |
84 | " OR (topological_ordering = ? AND stream_ordering > ?)" | |
85 | ")" | |
86 | ) | |
87 | txn.execute(sql, ( | |
88 | user_id, room_id, | |
89 | topological_ordering, topological_ordering, stream_ordering | |
90 | )) | |
87 | " AND %s" | |
88 | ) % (lower_bound(token, self.database_engine, inclusive=False),) | |
89 | ||
90 | txn.execute(sql, (user_id, room_id)) | |
91 | 91 | row = txn.fetchone() |
92 | 92 | if row: |
93 | 93 | return { |
116 | 116 | defer.returnValue(ret) |
117 | 117 | |
118 | 118 | @defer.inlineCallbacks |
119 | def get_unread_push_actions_for_user_in_range(self, user_id, | |
120 | min_stream_ordering, | |
121 | max_stream_ordering=None, | |
122 | limit=20): | |
119 | def get_unread_push_actions_for_user_in_range_for_http( | |
120 | self, user_id, min_stream_ordering, max_stream_ordering, limit=20 | |
121 | ): | |
122 | """Get a list of the most recent unread push actions for a given user, | |
123 | within the given stream ordering range. Called by the httppusher. | |
124 | ||
125 | Args: | |
126 | user_id (str): The user to fetch push actions for. | |
127 | min_stream_ordering(int): The exclusive lower bound on the | |
128 | stream ordering of event push actions to fetch. | |
129 | max_stream_ordering(int): The inclusive upper bound on the | |
130 | stream ordering of event push actions to fetch. | |
131 | limit (int): The maximum number of rows to return. | |
132 | Returns: | |
133 | A promise which resolves to a list of dicts with the keys "event_id", | |
134 | "room_id", "stream_ordering", "actions". | |
135 | The list will be ordered by ascending stream_ordering. | |
136 | The list will have between 0~limit entries. | |
137 | """ | |
138 | # find rooms that have a read receipt in them and return the next | |
139 | # push actions | |
123 | 140 | def get_after_receipt(txn): |
124 | sql = ( | |
125 | "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions, " | |
126 | "e.received_ts " | |
127 | "FROM (" | |
128 | " SELECT room_id, user_id, " | |
129 | " max(topological_ordering) as topological_ordering, " | |
130 | " max(stream_ordering) as stream_ordering " | |
131 | " FROM events" | |
132 | " NATURAL JOIN receipts_linearized WHERE receipt_type = 'm.read'" | |
133 | " GROUP BY room_id, user_id" | |
141 | # find rooms that have a read receipt in them and return the next | |
142 | # push actions | |
143 | sql = ( | |
144 | "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions" | |
145 | " FROM (" | |
146 | " SELECT room_id," | |
147 | " MAX(topological_ordering) as topological_ordering," | |
148 | " MAX(stream_ordering) as stream_ordering" | |
149 | " FROM events" | |
150 | " INNER JOIN receipts_linearized USING (room_id, event_id)" | |
151 | " WHERE receipt_type = 'm.read' AND user_id = ?" | |
152 | " GROUP BY room_id" | |
153 | ") AS rl," | |
154 | " event_push_actions AS ep" | |
155 | " WHERE" | |
156 | " ep.room_id = rl.room_id" | |
157 | " AND (" | |
158 | " ep.topological_ordering > rl.topological_ordering" | |
159 | " OR (" | |
160 | " ep.topological_ordering = rl.topological_ordering" | |
161 | " AND ep.stream_ordering > rl.stream_ordering" | |
162 | " )" | |
163 | " )" | |
164 | " AND ep.user_id = ?" | |
165 | " AND ep.stream_ordering > ?" | |
166 | " AND ep.stream_ordering <= ?" | |
167 | " ORDER BY ep.stream_ordering ASC LIMIT ?" | |
168 | ) | |
169 | args = [ | |
170 | user_id, user_id, | |
171 | min_stream_ordering, max_stream_ordering, limit, | |
172 | ] | |
173 | txn.execute(sql, args) | |
174 | return txn.fetchall() | |
175 | after_read_receipt = yield self.runInteraction( | |
176 | "get_unread_push_actions_for_user_in_range_http_arr", get_after_receipt | |
177 | ) | |
178 | ||
179 | # There are rooms with push actions in them but you don't have a read receipt in | |
180 | # them e.g. rooms you've been invited to, so get push actions for rooms which do | |
181 | # not have read receipts in them too. | |
182 | def get_no_receipt(txn): | |
183 | sql = ( | |
184 | "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," | |
185 | " e.received_ts" | |
186 | " FROM event_push_actions AS ep" | |
187 | " INNER JOIN events AS e USING (room_id, event_id)" | |
188 | " WHERE" | |
189 | " ep.room_id NOT IN (" | |
190 | " SELECT room_id FROM receipts_linearized" | |
191 | " WHERE receipt_type = 'm.read' AND user_id = ?" | |
192 | " GROUP BY room_id" | |
193 | " )" | |
194 | " AND ep.user_id = ?" | |
195 | " AND ep.stream_ordering > ?" | |
196 | " AND ep.stream_ordering <= ?" | |
197 | " ORDER BY ep.stream_ordering ASC LIMIT ?" | |
198 | ) | |
199 | args = [ | |
200 | user_id, user_id, | |
201 | min_stream_ordering, max_stream_ordering, limit, | |
202 | ] | |
203 | txn.execute(sql, args) | |
204 | return txn.fetchall() | |
205 | no_read_receipt = yield self.runInteraction( | |
206 | "get_unread_push_actions_for_user_in_range_http_nrr", get_no_receipt | |
207 | ) | |
208 | ||
209 | notifs = [ | |
210 | { | |
211 | "event_id": row[0], | |
212 | "room_id": row[1], | |
213 | "stream_ordering": row[2], | |
214 | "actions": json.loads(row[3]), | |
215 | } for row in after_read_receipt + no_read_receipt | |
216 | ] | |
217 | ||
218 | # Now sort it so it's ordered correctly, since currently it will | |
219 | # contain results from the first query, correctly ordered, followed | |
220 | # by results from the second query, but we want them all ordered | |
221 | # by stream_ordering, oldest first. | |
222 | notifs.sort(key=lambda r: r['stream_ordering']) | |
223 | ||
224 | # Take only up to the limit. We have to stop at the limit because | |
225 | # one of the subqueries may have hit the limit. | |
226 | defer.returnValue(notifs[:limit]) | |
227 | ||
228 | @defer.inlineCallbacks | |
229 | def get_unread_push_actions_for_user_in_range_for_email( | |
230 | self, user_id, min_stream_ordering, max_stream_ordering, limit=20 | |
231 | ): | |
232 | """Get a list of the most recent unread push actions for a given user, | |
233 | within the given stream ordering range. Called by the emailpusher | |
234 | ||
235 | Args: | |
236 | user_id (str): The user to fetch push actions for. | |
237 | min_stream_ordering(int): The exclusive lower bound on the | |
238 | stream ordering of event push actions to fetch. | |
239 | max_stream_ordering(int): The inclusive upper bound on the | |
240 | stream ordering of event push actions to fetch. | |
241 | limit (int): The maximum number of rows to return. | |
242 | Returns: | |
243 | A promise which resolves to a list of dicts with the keys "event_id", | |
244 | "room_id", "stream_ordering", "actions", "received_ts". | |
245 | The list will be ordered by descending received_ts. | |
246 | The list will have between 0~limit entries. | |
247 | """ | |
248 | # find rooms that have a read receipt in them and return the most recent | |
249 | # push actions | |
250 | def get_after_receipt(txn): | |
251 | sql = ( | |
252 | "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," | |
253 | " e.received_ts" | |
254 | " FROM (" | |
255 | " SELECT room_id," | |
256 | " MAX(topological_ordering) as topological_ordering," | |
257 | " MAX(stream_ordering) as stream_ordering" | |
258 | " FROM events" | |
259 | " INNER JOIN receipts_linearized USING (room_id, event_id)" | |
260 | " WHERE receipt_type = 'm.read' AND user_id = ?" | |
261 | " GROUP BY room_id" | |
134 | 262 | ") AS rl," |
135 | 263 | " event_push_actions AS ep" |
136 | 264 | " INNER JOIN events AS e USING (room_id, event_id)" |
143 | 271 | " AND ep.stream_ordering > rl.stream_ordering" |
144 | 272 | " )" |
145 | 273 | " )" |
274 | " AND ep.user_id = ?" | |
146 | 275 | " AND ep.stream_ordering > ?" |
147 | " AND ep.user_id = ?" | |
148 | " AND ep.user_id = rl.user_id" | |
149 | ) | |
150 | args = [min_stream_ordering, user_id] | |
151 | if max_stream_ordering is not None: | |
152 | sql += " AND ep.stream_ordering <= ?" | |
153 | args.append(max_stream_ordering) | |
154 | sql += " ORDER BY ep.stream_ordering ASC LIMIT ?" | |
155 | args.append(limit) | |
276 | " AND ep.stream_ordering <= ?" | |
277 | " ORDER BY ep.stream_ordering DESC LIMIT ?" | |
278 | ) | |
279 | args = [ | |
280 | user_id, user_id, | |
281 | min_stream_ordering, max_stream_ordering, limit, | |
282 | ] | |
156 | 283 | txn.execute(sql, args) |
157 | 284 | return txn.fetchall() |
158 | 285 | after_read_receipt = yield self.runInteraction( |
159 | "get_unread_push_actions_for_user_in_range", get_after_receipt | |
160 | ) | |
161 | ||
286 | "get_unread_push_actions_for_user_in_range_email_arr", get_after_receipt | |
287 | ) | |
288 | ||
289 | # There are rooms with push actions in them but you don't have a read receipt in | |
290 | # them e.g. rooms you've been invited to, so get push actions for rooms which do | |
291 | # not have read receipts in them too. | |
162 | 292 | def get_no_receipt(txn): |
163 | 293 | sql = ( |
164 | 294 | "SELECT ep.event_id, ep.room_id, ep.stream_ordering, ep.actions," |
165 | 295 | " e.received_ts" |
166 | 296 | " FROM event_push_actions AS ep" |
167 | " JOIN events e ON ep.room_id = e.room_id AND ep.event_id = e.event_id" | |
168 | " WHERE ep.room_id not in (" | |
169 | " SELECT room_id FROM events NATURAL JOIN receipts_linearized" | |
170 | " WHERE receipt_type = 'm.read' AND user_id = ?" | |
171 | " GROUP BY room_id" | |
172 | ") AND ep.user_id = ? AND ep.stream_ordering > ?" | |
173 | ) | |
174 | args = [user_id, user_id, min_stream_ordering] | |
175 | if max_stream_ordering is not None: | |
176 | sql += " AND ep.stream_ordering <= ?" | |
177 | args.append(max_stream_ordering) | |
178 | sql += " ORDER BY ep.stream_ordering ASC" | |
297 | " INNER JOIN events AS e USING (room_id, event_id)" | |
298 | " WHERE" | |
299 | " ep.room_id NOT IN (" | |
300 | " SELECT room_id FROM receipts_linearized" | |
301 | " WHERE receipt_type = 'm.read' AND user_id = ?" | |
302 | " GROUP BY room_id" | |
303 | " )" | |
304 | " AND ep.user_id = ?" | |
305 | " AND ep.stream_ordering > ?" | |
306 | " AND ep.stream_ordering <= ?" | |
307 | " ORDER BY ep.stream_ordering DESC LIMIT ?" | |
308 | ) | |
309 | args = [ | |
310 | user_id, user_id, | |
311 | min_stream_ordering, max_stream_ordering, limit, | |
312 | ] | |
179 | 313 | txn.execute(sql, args) |
180 | 314 | return txn.fetchall() |
181 | 315 | no_read_receipt = yield self.runInteraction( |
182 | "get_unread_push_actions_for_user_in_range", get_no_receipt | |
183 | ) | |
184 | ||
185 | defer.returnValue([ | |
316 | "get_unread_push_actions_for_user_in_range_email_nrr", get_no_receipt | |
317 | ) | |
318 | ||
319 | # Make a list of dicts from the two sets of results. | |
320 | notifs = [ | |
186 | 321 | { |
187 | 322 | "event_id": row[0], |
188 | 323 | "room_id": row[1], |
190 | 325 | "actions": json.loads(row[3]), |
191 | 326 | "received_ts": row[4], |
192 | 327 | } for row in after_read_receipt + no_read_receipt |
193 | ]) | |
328 | ] | |
329 | ||
330 | # Now sort it so it's ordered correctly, since currently it will | |
331 | # contain results from the first query, correctly ordered, followed | |
332 | # by results from the second query, but we want them all ordered | |
333 | # by received_ts (most recent first) | |
334 | notifs.sort(key=lambda r: -(r['received_ts'] or 0)) | |
335 | ||
336 | # Now return the first `limit` | |
337 | defer.returnValue(notifs[:limit]) | |
194 | 338 | |
195 | 339 | @defer.inlineCallbacks |
196 | 340 | def get_time_of_last_push_action_before(self, stream_ordering): |
22 | 22 | from synapse.util.logcontext import preserve_fn, PreserveLoggingContext |
23 | 23 | from synapse.util.logutils import log_function |
24 | 24 | from synapse.api.constants import EventTypes |
25 | from synapse.api.errors import SynapseError | |
25 | 26 | |
26 | 27 | from canonicaljson import encode_canonical_json |
27 | from collections import deque, namedtuple | |
28 | from collections import deque, namedtuple, OrderedDict | |
29 | from functools import wraps | |
28 | 30 | |
29 | 31 | import synapse |
30 | 32 | import synapse.metrics |
148 | 150 | _EventCacheEntry = namedtuple("_EventCacheEntry", ("event", "redacted_event")) |
149 | 151 | |
150 | 152 | |
153 | def _retry_on_integrity_error(func): | |
154 | """Wraps a database function so that it gets retried on IntegrityError, | |
155 | with `delete_existing=True` passed in. | |
156 | ||
157 | Args: | |
158 | func: function that returns a Deferred and accepts a `delete_existing` arg | |
159 | """ | |
160 | @wraps(func) | |
161 | @defer.inlineCallbacks | |
162 | def f(self, *args, **kwargs): | |
163 | try: | |
164 | res = yield func(self, *args, **kwargs) | |
165 | except self.database_engine.module.IntegrityError: | |
166 | logger.exception("IntegrityError, retrying.") | |
167 | res = yield func(self, *args, delete_existing=True, **kwargs) | |
168 | defer.returnValue(res) | |
169 | ||
170 | return f | |
171 | ||
172 | ||
151 | 173 | class EventsStore(SQLBaseStore): |
152 | 174 | EVENT_ORIGIN_SERVER_TS_NAME = "event_origin_server_ts" |
175 | EVENT_FIELDS_SENDER_URL_UPDATE_NAME = "event_fields_sender_url" | |
153 | 176 | |
154 | 177 | def __init__(self, hs): |
155 | 178 | super(EventsStore, self).__init__(hs) |
156 | 179 | self._clock = hs.get_clock() |
157 | 180 | self.register_background_update_handler( |
158 | 181 | self.EVENT_ORIGIN_SERVER_TS_NAME, self._background_reindex_origin_server_ts |
182 | ) | |
183 | self.register_background_update_handler( | |
184 | self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, | |
185 | self._background_reindex_fields_sender, | |
159 | 186 | ) |
160 | 187 | |
161 | 188 | self._event_persist_queue = _EventPeristenceQueue() |
222 | 249 | |
223 | 250 | self._event_persist_queue.handle_queue(room_id, persisting_queue) |
224 | 251 | |
252 | @_retry_on_integrity_error | |
225 | 253 | @defer.inlineCallbacks |
226 | def _persist_events(self, events_and_contexts, backfilled=False): | |
254 | def _persist_events(self, events_and_contexts, backfilled=False, | |
255 | delete_existing=False): | |
227 | 256 | if not events_and_contexts: |
228 | 257 | return |
229 | 258 | |
266 | 295 | self._persist_events_txn, |
267 | 296 | events_and_contexts=chunk, |
268 | 297 | backfilled=backfilled, |
298 | delete_existing=delete_existing, | |
269 | 299 | ) |
270 | 300 | persist_event_counter.inc_by(len(chunk)) |
271 | 301 | |
302 | @_retry_on_integrity_error | |
272 | 303 | @defer.inlineCallbacks |
273 | 304 | @log_function |
274 | def _persist_event(self, event, context, current_state=None, backfilled=False): | |
305 | def _persist_event(self, event, context, current_state=None, backfilled=False, | |
306 | delete_existing=False): | |
275 | 307 | try: |
276 | 308 | with self._stream_id_gen.get_next() as stream_ordering: |
277 | 309 | with self._state_groups_id_gen.get_next() as state_group_id: |
284 | 316 | context=context, |
285 | 317 | current_state=current_state, |
286 | 318 | backfilled=backfilled, |
319 | delete_existing=delete_existing, | |
287 | 320 | ) |
288 | 321 | persist_event_counter.inc() |
289 | 322 | except _RollbackButIsFineException: |
316 | 349 | ) |
317 | 350 | |
318 | 351 | if not events and not allow_none: |
319 | raise RuntimeError("Could not find event %s" % (event_id,)) | |
352 | raise SynapseError(404, "Could not find event %s" % (event_id,)) | |
320 | 353 | |
321 | 354 | defer.returnValue(events[0] if events else None) |
322 | 355 | |
346 | 379 | defer.returnValue({e.event_id: e for e in events}) |
347 | 380 | |
348 | 381 | @log_function |
349 | def _persist_event_txn(self, txn, event, context, current_state, backfilled=False): | |
382 | def _persist_event_txn(self, txn, event, context, current_state, backfilled=False, | |
383 | delete_existing=False): | |
350 | 384 | # We purposefully do this first since if we include a `current_state` |
351 | 385 | # key, we *want* to update the `current_state_events` table |
352 | 386 | if current_state: |
354 | 388 | txn.call_after(self.get_rooms_for_user.invalidate_all) |
355 | 389 | txn.call_after(self.get_users_in_room.invalidate, (event.room_id,)) |
356 | 390 | txn.call_after(self.get_joined_hosts_for_room.invalidate, (event.room_id,)) |
357 | txn.call_after(self.get_room_name_and_aliases.invalidate, (event.room_id,)) | |
358 | 391 | |
359 | 392 | # Add an entry to the current_state_resets table to record the point |
360 | 393 | # where we clobbered the current state |
387 | 420 | txn, |
388 | 421 | [(event, context)], |
389 | 422 | backfilled=backfilled, |
423 | delete_existing=delete_existing, | |
390 | 424 | ) |
391 | 425 | |
392 | 426 | @log_function |
393 | def _persist_events_txn(self, txn, events_and_contexts, backfilled): | |
427 | def _persist_events_txn(self, txn, events_and_contexts, backfilled, | |
428 | delete_existing=False): | |
429 | """Insert some number of room events into the necessary database tables. | |
430 | ||
431 | Rejected events are only inserted into the events table, the events_json table, | |
432 | and the rejections table. Things reading from those table will need to check | |
433 | whether the event was rejected. | |
434 | ||
435 | If delete_existing is True then existing events will be purged from the | |
436 | database before insertion. This is useful when retrying due to IntegrityError. | |
437 | """ | |
438 | # Ensure that we don't have the same event twice. | |
439 | # Pick the earliest non-outlier if there is one, else the earliest one. | |
440 | new_events_and_contexts = OrderedDict() | |
441 | for event, context in events_and_contexts: | |
442 | prev_event_context = new_events_and_contexts.get(event.event_id) | |
443 | if prev_event_context: | |
444 | if not event.internal_metadata.is_outlier(): | |
445 | if prev_event_context[0].internal_metadata.is_outlier(): | |
446 | # To ensure correct ordering we pop, as OrderedDict is | |
447 | # ordered by first insertion. | |
448 | new_events_and_contexts.pop(event.event_id, None) | |
449 | new_events_and_contexts[event.event_id] = (event, context) | |
450 | else: | |
451 | new_events_and_contexts[event.event_id] = (event, context) | |
452 | ||
453 | events_and_contexts = new_events_and_contexts.values() | |
454 | ||
394 | 455 | depth_updates = {} |
395 | 456 | for event, context in events_and_contexts: |
396 | 457 | # Remove the any existing cache entries for the event_ids |
401 | 462 | event.room_id, event.internal_metadata.stream_ordering, |
402 | 463 | ) |
403 | 464 | |
404 | if not event.internal_metadata.is_outlier(): | |
465 | if not event.internal_metadata.is_outlier() and not context.rejected: | |
405 | 466 | depth_updates[event.room_id] = max( |
406 | 467 | event.depth, depth_updates.get(event.room_id, event.depth) |
407 | 468 | ) |
408 | ||
409 | if context.push_actions: | |
410 | self._set_push_actions_for_event_and_users_txn( | |
411 | txn, event, context.push_actions | |
412 | ) | |
413 | ||
414 | if event.type == EventTypes.Redaction and event.redacts is not None: | |
415 | self._remove_push_actions_for_event_id_txn( | |
416 | txn, event.room_id, event.redacts | |
417 | ) | |
418 | 469 | |
419 | 470 | for room_id, depth in depth_updates.items(): |
420 | 471 | self._update_min_depth_for_room_txn(txn, room_id, depth) |
425 | 476 | ), |
426 | 477 | [event.event_id for event, _ in events_and_contexts] |
427 | 478 | ) |
479 | ||
428 | 480 | have_persisted = { |
429 | 481 | event_id: outlier |
430 | 482 | for event_id, outlier in txn.fetchall() |
431 | 483 | } |
432 | 484 | |
433 | event_map = {} | |
434 | 485 | to_remove = set() |
435 | 486 | for event, context in events_and_contexts: |
436 | # Handle the case of the list including the same event multiple | |
437 | # times. The tricky thing here is when they differ by whether | |
438 | # they are an outlier. | |
439 | if event.event_id in event_map: | |
440 | other = event_map[event.event_id] | |
441 | ||
442 | if not other.internal_metadata.is_outlier(): | |
487 | if context.rejected: | |
488 | # If the event is rejected then we don't care if the event | |
489 | # was an outlier or not. | |
490 | if event.event_id in have_persisted: | |
491 | # If we have already seen the event then ignore it. | |
443 | 492 | to_remove.add(event) |
444 | continue | |
445 | elif not event.internal_metadata.is_outlier(): | |
446 | to_remove.add(event) | |
447 | continue | |
448 | else: | |
449 | to_remove.add(other) | |
450 | ||
451 | event_map[event.event_id] = event | |
493 | continue | |
452 | 494 | |
453 | 495 | if event.event_id not in have_persisted: |
454 | 496 | continue |
457 | 499 | |
458 | 500 | outlier_persisted = have_persisted[event.event_id] |
459 | 501 | if not event.internal_metadata.is_outlier() and outlier_persisted: |
502 | # We received a copy of an event that we had already stored as | |
503 | # an outlier in the database. We now have some state at that | |
504 | # so we need to update the state_groups table with that state. | |
505 | ||
506 | # insert into the state_group, state_groups_state and | |
507 | # event_to_state_groups tables. | |
460 | 508 | self._store_mult_state_groups_txn(txn, ((event, context),)) |
461 | 509 | |
462 | 510 | metadata_json = encode_json( |
472 | 520 | (metadata_json, event.event_id,) |
473 | 521 | ) |
474 | 522 | |
523 | # Add an entry to the ex_outlier_stream table to replicate the | |
524 | # change in outlier status to our workers. | |
475 | 525 | stream_order = event.internal_metadata.stream_ordering |
476 | 526 | state_group_id = context.state_group or context.new_state_group_id |
477 | 527 | self._simple_insert_txn( |
493 | 543 | (False, event.event_id,) |
494 | 544 | ) |
495 | 545 | |
546 | # Update the event_backward_extremities table now that this | |
547 | # event isn't an outlier any more. | |
496 | 548 | self._update_extremeties(txn, [event]) |
497 | 549 | |
498 | 550 | events_and_contexts = [ |
500 | 552 | ] |
501 | 553 | |
502 | 554 | if not events_and_contexts: |
555 | # Make sure we don't pass an empty list to functions that expect to | |
556 | # be storing at least one element. | |
503 | 557 | return |
504 | 558 | |
505 | self._store_mult_state_groups_txn(txn, events_and_contexts) | |
506 | ||
507 | self._handle_mult_prev_events( | |
508 | txn, | |
509 | events=[event for event, _ in events_and_contexts], | |
510 | ) | |
511 | ||
512 | for event, _ in events_and_contexts: | |
513 | if event.type == EventTypes.Name: | |
514 | self._store_room_name_txn(txn, event) | |
515 | elif event.type == EventTypes.Topic: | |
516 | self._store_room_topic_txn(txn, event) | |
517 | elif event.type == EventTypes.Message: | |
518 | self._store_room_message_txn(txn, event) | |
519 | elif event.type == EventTypes.Redaction: | |
520 | self._store_redaction(txn, event) | |
521 | elif event.type == EventTypes.RoomHistoryVisibility: | |
522 | self._store_history_visibility_txn(txn, event) | |
523 | elif event.type == EventTypes.GuestAccess: | |
524 | self._store_guest_access_txn(txn, event) | |
525 | ||
526 | self._store_room_members_txn( | |
527 | txn, | |
528 | [ | |
529 | event | |
530 | for event, _ in events_and_contexts | |
531 | if event.type == EventTypes.Member | |
532 | ], | |
533 | backfilled=backfilled, | |
534 | ) | |
559 | # From this point onwards the events are only events that we haven't | |
560 | # seen before. | |
535 | 561 | |
536 | 562 | def event_dict(event): |
537 | 563 | return { |
542 | 568 | "redacted_because", |
543 | 569 | ] |
544 | 570 | } |
571 | ||
572 | if delete_existing: | |
573 | # For paranoia reasons, we go and delete all the existing entries | |
574 | # for these events so we can reinsert them. | |
575 | # This gets around any problems with some tables already having | |
576 | # entries. | |
577 | ||
578 | logger.info("Deleting existing") | |
579 | ||
580 | for table in ( | |
581 | "events", | |
582 | "event_auth", | |
583 | "event_json", | |
584 | "event_content_hashes", | |
585 | "event_destinations", | |
586 | "event_edge_hashes", | |
587 | "event_edges", | |
588 | "event_forward_extremities", | |
589 | "event_push_actions", | |
590 | "event_reference_hashes", | |
591 | "event_search", | |
592 | "event_signatures", | |
593 | "event_to_state_groups", | |
594 | "guest_access", | |
595 | "history_visibility", | |
596 | "local_invites", | |
597 | "room_names", | |
598 | "state_events", | |
599 | "rejections", | |
600 | "redactions", | |
601 | "room_memberships", | |
602 | "state_events" | |
603 | ): | |
604 | txn.executemany( | |
605 | "DELETE FROM %s WHERE event_id = ?" % (table,), | |
606 | [(ev.event_id,) for ev, _ in events_and_contexts] | |
607 | ) | |
545 | 608 | |
546 | 609 | self._simple_insert_many_txn( |
547 | 610 | txn, |
575 | 638 | "content": encode_json(event.content).decode("UTF-8"), |
576 | 639 | "origin_server_ts": int(event.origin_server_ts), |
577 | 640 | "received_ts": self._clock.time_msec(), |
641 | "sender": event.sender, | |
642 | "contains_url": ( | |
643 | "url" in event.content | |
644 | and isinstance(event.content["url"], basestring) | |
645 | ), | |
578 | 646 | } |
579 | 647 | for event, _ in events_and_contexts |
580 | 648 | ], |
581 | 649 | ) |
582 | 650 | |
583 | if context.rejected: | |
584 | self._store_rejections_txn( | |
585 | txn, event.event_id, context.rejected | |
586 | ) | |
651 | # Remove the rejected events from the list now that we've added them | |
652 | # to the events table and the events_json table. | |
653 | to_remove = set() | |
654 | for event, context in events_and_contexts: | |
655 | if context.rejected: | |
656 | # Insert the event_id into the rejections table | |
657 | self._store_rejections_txn( | |
658 | txn, event.event_id, context.rejected | |
659 | ) | |
660 | to_remove.add(event) | |
661 | ||
662 | events_and_contexts = [ | |
663 | ec for ec in events_and_contexts if ec[0] not in to_remove | |
664 | ] | |
665 | ||
666 | if not events_and_contexts: | |
667 | # Make sure we don't pass an empty list to functions that expect to | |
668 | # be storing at least one element. | |
669 | return | |
670 | ||
671 | # From this point onwards the events are only ones that weren't rejected. | |
672 | ||
673 | for event, context in events_and_contexts: | |
674 | # Insert all the push actions into the event_push_actions table. | |
675 | if context.push_actions: | |
676 | self._set_push_actions_for_event_and_users_txn( | |
677 | txn, event, context.push_actions | |
678 | ) | |
679 | ||
680 | if event.type == EventTypes.Redaction and event.redacts is not None: | |
681 | # Remove the entries in the event_push_actions table for the | |
682 | # redacted event. | |
683 | self._remove_push_actions_for_event_id_txn( | |
684 | txn, event.room_id, event.redacts | |
685 | ) | |
587 | 686 | |
588 | 687 | self._simple_insert_many_txn( |
589 | 688 | txn, |
599 | 698 | ], |
600 | 699 | ) |
601 | 700 | |
701 | # Insert into the state_groups, state_groups_state, and | |
702 | # event_to_state_groups tables. | |
703 | self._store_mult_state_groups_txn(txn, events_and_contexts) | |
704 | ||
705 | # Update the event_forward_extremities, event_backward_extremities and | |
706 | # event_edges tables. | |
707 | self._handle_mult_prev_events( | |
708 | txn, | |
709 | events=[event for event, _ in events_and_contexts], | |
710 | ) | |
711 | ||
712 | for event, _ in events_and_contexts: | |
713 | if event.type == EventTypes.Name: | |
714 | # Insert into the room_names and event_search tables. | |
715 | self._store_room_name_txn(txn, event) | |
716 | elif event.type == EventTypes.Topic: | |
717 | # Insert into the topics table and event_search table. | |
718 | self._store_room_topic_txn(txn, event) | |
719 | elif event.type == EventTypes.Message: | |
720 | # Insert into the event_search table. | |
721 | self._store_room_message_txn(txn, event) | |
722 | elif event.type == EventTypes.Redaction: | |
723 | # Insert into the redactions table. | |
724 | self._store_redaction(txn, event) | |
725 | elif event.type == EventTypes.RoomHistoryVisibility: | |
726 | # Insert into the event_search table. | |
727 | self._store_history_visibility_txn(txn, event) | |
728 | elif event.type == EventTypes.GuestAccess: | |
729 | # Insert into the event_search table. | |
730 | self._store_guest_access_txn(txn, event) | |
731 | ||
732 | # Insert into the room_memberships table. | |
733 | self._store_room_members_txn( | |
734 | txn, | |
735 | [ | |
736 | event | |
737 | for event, _ in events_and_contexts | |
738 | if event.type == EventTypes.Member | |
739 | ], | |
740 | backfilled=backfilled, | |
741 | ) | |
742 | ||
743 | # Insert event_reference_hashes table. | |
602 | 744 | self._store_event_reference_hashes_txn( |
603 | 745 | txn, [event for event, _ in events_and_contexts] |
604 | 746 | ) |
643 | 785 | ], |
644 | 786 | ) |
645 | 787 | |
788 | # Prefill the event cache | |
646 | 789 | self._add_to_cache(txn, events_and_contexts) |
647 | 790 | |
648 | 791 | if backfilled: |
655 | 798 | # Outlier events shouldn't clobber the current state. |
656 | 799 | continue |
657 | 800 | |
658 | if context.rejected: | |
659 | # If the event failed it's auth checks then it shouldn't | |
660 | # clobbler the current state. | |
661 | continue | |
662 | ||
663 | 801 | txn.call_after( |
664 | 802 | self._get_current_state_for_key.invalidate, |
665 | 803 | (event.room_id, event.type, event.state_key,) |
666 | 804 | ) |
667 | ||
668 | if event.type in [EventTypes.Name, EventTypes.Aliases]: | |
669 | txn.call_after( | |
670 | self.get_room_name_and_aliases.invalidate, | |
671 | (event.room_id,) | |
672 | ) | |
673 | 805 | |
674 | 806 | self._simple_upsert_txn( |
675 | 807 | txn, |
1121 | 1253 | defer.returnValue(ret) |
1122 | 1254 | |
1123 | 1255 | @defer.inlineCallbacks |
1256 | def _background_reindex_fields_sender(self, progress, batch_size): | |
1257 | target_min_stream_id = progress["target_min_stream_id_inclusive"] | |
1258 | max_stream_id = progress["max_stream_id_exclusive"] | |
1259 | rows_inserted = progress.get("rows_inserted", 0) | |
1260 | ||
1261 | INSERT_CLUMP_SIZE = 1000 | |
1262 | ||
1263 | def reindex_txn(txn): | |
1264 | sql = ( | |
1265 | "SELECT stream_ordering, event_id, json FROM events" | |
1266 | " INNER JOIN event_json USING (event_id)" | |
1267 | " WHERE ? <= stream_ordering AND stream_ordering < ?" | |
1268 | " ORDER BY stream_ordering DESC" | |
1269 | " LIMIT ?" | |
1270 | ) | |
1271 | ||
1272 | txn.execute(sql, (target_min_stream_id, max_stream_id, batch_size)) | |
1273 | ||
1274 | rows = txn.fetchall() | |
1275 | if not rows: | |
1276 | return 0 | |
1277 | ||
1278 | min_stream_id = rows[-1][0] | |
1279 | ||
1280 | update_rows = [] | |
1281 | for row in rows: | |
1282 | try: | |
1283 | event_id = row[1] | |
1284 | event_json = json.loads(row[2]) | |
1285 | sender = event_json["sender"] | |
1286 | content = event_json["content"] | |
1287 | ||
1288 | contains_url = "url" in content | |
1289 | if contains_url: | |
1290 | contains_url &= isinstance(content["url"], basestring) | |
1291 | except (KeyError, AttributeError): | |
1292 | # If the event is missing a necessary field then | |
1293 | # skip over it. | |
1294 | continue | |
1295 | ||
1296 | update_rows.append((sender, contains_url, event_id)) | |
1297 | ||
1298 | sql = ( | |
1299 | "UPDATE events SET sender = ?, contains_url = ? WHERE event_id = ?" | |
1300 | ) | |
1301 | ||
1302 | for index in range(0, len(update_rows), INSERT_CLUMP_SIZE): | |
1303 | clump = update_rows[index:index + INSERT_CLUMP_SIZE] | |
1304 | txn.executemany(sql, clump) | |
1305 | ||
1306 | progress = { | |
1307 | "target_min_stream_id_inclusive": target_min_stream_id, | |
1308 | "max_stream_id_exclusive": min_stream_id, | |
1309 | "rows_inserted": rows_inserted + len(rows) | |
1310 | } | |
1311 | ||
1312 | self._background_update_progress_txn( | |
1313 | txn, self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, progress | |
1314 | ) | |
1315 | ||
1316 | return len(rows) | |
1317 | ||
1318 | result = yield self.runInteraction( | |
1319 | self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME, reindex_txn | |
1320 | ) | |
1321 | ||
1322 | if not result: | |
1323 | yield self._end_background_update(self.EVENT_FIELDS_SENDER_URL_UPDATE_NAME) | |
1324 | ||
1325 | defer.returnValue(result) | |
1326 | ||
1327 | @defer.inlineCallbacks | |
1124 | 1328 | def _background_reindex_origin_server_ts(self, progress, batch_size): |
1125 | 1329 | target_min_stream_id = progress["target_min_stream_id_inclusive"] |
1126 | 1330 | max_stream_id = progress["max_stream_id_exclusive"] |
1287 | 1491 | ) |
1288 | 1492 | return self.runInteraction("get_all_new_events", get_all_new_events_txn) |
1289 | 1493 | |
1494 | def delete_old_state(self, room_id, topological_ordering): | |
1495 | return self.runInteraction( | |
1496 | "delete_old_state", | |
1497 | self._delete_old_state_txn, room_id, topological_ordering | |
1498 | ) | |
1499 | ||
1500 | def _delete_old_state_txn(self, txn, room_id, topological_ordering): | |
1501 | """Deletes old room state | |
1502 | """ | |
1503 | ||
1504 | # Tables that should be pruned: | |
1505 | # event_auth | |
1506 | # event_backward_extremities | |
1507 | # event_content_hashes | |
1508 | # event_destinations | |
1509 | # event_edge_hashes | |
1510 | # event_edges | |
1511 | # event_forward_extremities | |
1512 | # event_json | |
1513 | # event_push_actions | |
1514 | # event_reference_hashes | |
1515 | # event_search | |
1516 | # event_signatures | |
1517 | # event_to_state_groups | |
1518 | # events | |
1519 | # rejections | |
1520 | # room_depth | |
1521 | # state_groups | |
1522 | # state_groups_state | |
1523 | ||
1524 | # First ensure that we're not about to delete all the forward extremeties | |
1525 | txn.execute( | |
1526 | "SELECT e.event_id, e.depth FROM events as e " | |
1527 | "INNER JOIN event_forward_extremities as f " | |
1528 | "ON e.event_id = f.event_id " | |
1529 | "AND e.room_id = f.room_id " | |
1530 | "WHERE f.room_id = ?", | |
1531 | (room_id,) | |
1532 | ) | |
1533 | rows = txn.fetchall() | |
1534 | max_depth = max(row[0] for row in rows) | |
1535 | ||
1536 | if max_depth <= topological_ordering: | |
1537 | # We need to ensure we don't delete all the events from the datanase | |
1538 | # otherwise we wouldn't be able to send any events (due to not | |
1539 | # having any backwards extremeties) | |
1540 | raise SynapseError( | |
1541 | 400, "topological_ordering is greater than forward extremeties" | |
1542 | ) | |
1543 | ||
1544 | txn.execute( | |
1545 | "SELECT event_id, state_key FROM events" | |
1546 | " LEFT JOIN state_events USING (room_id, event_id)" | |
1547 | " WHERE room_id = ? AND topological_ordering < ?", | |
1548 | (room_id, topological_ordering,) | |
1549 | ) | |
1550 | event_rows = txn.fetchall() | |
1551 | ||
1552 | # We calculate the new entries for the backward extremeties by finding | |
1553 | # all events that point to events that are to be purged | |
1554 | txn.execute( | |
1555 | "SELECT DISTINCT e.event_id FROM events as e" | |
1556 | " INNER JOIN event_edges as ed ON e.event_id = ed.prev_event_id" | |
1557 | " INNER JOIN events as e2 ON e2.event_id = ed.event_id" | |
1558 | " WHERE e.room_id = ? AND e.topological_ordering < ?" | |
1559 | " AND e2.topological_ordering >= ?", | |
1560 | (room_id, topological_ordering, topological_ordering) | |
1561 | ) | |
1562 | new_backwards_extrems = txn.fetchall() | |
1563 | ||
1564 | txn.execute( | |
1565 | "DELETE FROM event_backward_extremities WHERE room_id = ?", | |
1566 | (room_id,) | |
1567 | ) | |
1568 | ||
1569 | # Update backward extremeties | |
1570 | txn.executemany( | |
1571 | "INSERT INTO event_backward_extremities (room_id, event_id)" | |
1572 | " VALUES (?, ?)", | |
1573 | [ | |
1574 | (room_id, event_id) for event_id, in new_backwards_extrems | |
1575 | ] | |
1576 | ) | |
1577 | ||
1578 | # Get all state groups that are only referenced by events that are | |
1579 | # to be deleted. | |
1580 | txn.execute( | |
1581 | "SELECT state_group FROM event_to_state_groups" | |
1582 | " INNER JOIN events USING (event_id)" | |
1583 | " WHERE state_group IN (" | |
1584 | " SELECT DISTINCT state_group FROM events" | |
1585 | " INNER JOIN event_to_state_groups USING (event_id)" | |
1586 | " WHERE room_id = ? AND topological_ordering < ?" | |
1587 | " )" | |
1588 | " GROUP BY state_group HAVING MAX(topological_ordering) < ?", | |
1589 | (room_id, topological_ordering, topological_ordering) | |
1590 | ) | |
1591 | state_rows = txn.fetchall() | |
1592 | txn.executemany( | |
1593 | "DELETE FROM state_groups_state WHERE state_group = ?", | |
1594 | state_rows | |
1595 | ) | |
1596 | txn.executemany( | |
1597 | "DELETE FROM state_groups WHERE id = ?", | |
1598 | state_rows | |
1599 | ) | |
1600 | # Delete all non-state | |
1601 | txn.executemany( | |
1602 | "DELETE FROM event_to_state_groups WHERE event_id = ?", | |
1603 | [(event_id,) for event_id, _ in event_rows] | |
1604 | ) | |
1605 | ||
1606 | txn.execute( | |
1607 | "UPDATE room_depth SET min_depth = ? WHERE room_id = ?", | |
1608 | (topological_ordering, room_id,) | |
1609 | ) | |
1610 | ||
1611 | # Delete all remote non-state events | |
1612 | to_delete = [ | |
1613 | (event_id,) for event_id, state_key in event_rows | |
1614 | if state_key is None and not self.hs.is_mine_id(event_id) | |
1615 | ] | |
1616 | for table in ( | |
1617 | "events", | |
1618 | "event_json", | |
1619 | "event_auth", | |
1620 | "event_content_hashes", | |
1621 | "event_destinations", | |
1622 | "event_edge_hashes", | |
1623 | "event_edges", | |
1624 | "event_forward_extremities", | |
1625 | "event_push_actions", | |
1626 | "event_reference_hashes", | |
1627 | "event_search", | |
1628 | "event_signatures", | |
1629 | "rejections", | |
1630 | ): | |
1631 | txn.executemany( | |
1632 | "DELETE FROM %s WHERE event_id = ?" % (table,), | |
1633 | to_delete | |
1634 | ) | |
1635 | ||
1636 | txn.executemany( | |
1637 | "DELETE FROM events WHERE event_id = ?", | |
1638 | to_delete | |
1639 | ) | |
1640 | # Mark all state and own events as outliers | |
1641 | txn.executemany( | |
1642 | "UPDATE events SET outlier = ?" | |
1643 | " WHERE event_id = ?", | |
1644 | [ | |
1645 | (True, event_id,) for event_id, state_key in event_rows | |
1646 | if state_key is not None or self.hs.is_mine_id(event_id) | |
1647 | ] | |
1648 | ) | |
1649 | ||
1290 | 1650 | |
1291 | 1651 | AllNewEventsResult = namedtuple("AllNewEventsResult", [ |
1292 | 1652 | "new_forward_events", "new_backfill_events", |
21 | 21 | from signedjson.key import decode_verify_key_bytes |
22 | 22 | import hashlib |
23 | 23 | |
24 | import logging | |
25 | ||
26 | logger = logging.getLogger(__name__) | |
27 | ||
24 | 28 | |
25 | 29 | class KeyStore(SQLBaseStore): |
26 | 30 | """Persistence for signature verification keys and tls X.509 certificates |
73 | 77 | ) |
74 | 78 | |
75 | 79 | @cachedInlineCallbacks() |
76 | def get_all_server_verify_keys(self, server_name): | |
77 | rows = yield self._simple_select_list( | |
80 | def _get_server_verify_key(self, server_name, key_id): | |
81 | verify_key_bytes = yield self._simple_select_one_onecol( | |
78 | 82 | table="server_signature_keys", |
79 | 83 | keyvalues={ |
80 | 84 | "server_name": server_name, |
81 | }, | |
82 | retcols=["key_id", "verify_key"], | |
83 | desc="get_all_server_verify_keys", | |
84 | ) | |
85 | ||
86 | defer.returnValue({ | |
87 | row["key_id"]: decode_verify_key_bytes( | |
88 | row["key_id"], str(row["verify_key"]) | |
89 | ) | |
90 | for row in rows | |
91 | }) | |
85 | "key_id": key_id, | |
86 | }, | |
87 | retcol="verify_key", | |
88 | desc="_get_server_verify_key", | |
89 | allow_none=True, | |
90 | ) | |
91 | ||
92 | if verify_key_bytes: | |
93 | defer.returnValue(decode_verify_key_bytes( | |
94 | key_id, str(verify_key_bytes) | |
95 | )) | |
92 | 96 | |
93 | 97 | @defer.inlineCallbacks |
94 | 98 | def get_server_verify_keys(self, server_name, key_ids): |
100 | 104 | Returns: |
101 | 105 | (list of VerifyKey): The verification keys. |
102 | 106 | """ |
103 | keys = yield self.get_all_server_verify_keys(server_name) | |
104 | defer.returnValue({ | |
105 | k: keys[k] | |
106 | for k in key_ids | |
107 | if k in keys and keys[k] | |
108 | }) | |
107 | keys = {} | |
108 | for key_id in key_ids: | |
109 | key = yield self._get_server_verify_key(server_name, key_id) | |
110 | if key: | |
111 | keys[key_id] = key | |
112 | defer.returnValue(keys) | |
109 | 113 | |
110 | 114 | @defer.inlineCallbacks |
111 | 115 | def store_server_verify_key(self, server_name, from_server, time_now_ms, |
131 | 135 | }, |
132 | 136 | desc="store_server_verify_key", |
133 | 137 | ) |
134 | ||
135 | self.get_all_server_verify_keys.invalidate((server_name,)) | |
136 | 138 | |
137 | 139 | def store_server_keys_json(self, server_name, key_id, from_server, |
138 | 140 | ts_now_ms, ts_expires_ms, key_json_bytes): |
156 | 156 | "created_ts": time_now_ms, |
157 | 157 | "upload_name": upload_name, |
158 | 158 | "filesystem_id": filesystem_id, |
159 | "last_access_ts": time_now_ms, | |
159 | 160 | }, |
160 | 161 | desc="store_cached_remote_media", |
161 | 162 | ) |
163 | ||
164 | def update_cached_last_access_time(self, origin_id_tuples, time_ts): | |
165 | def update_cache_txn(txn): | |
166 | sql = ( | |
167 | "UPDATE remote_media_cache SET last_access_ts = ?" | |
168 | " WHERE media_origin = ? AND media_id = ?" | |
169 | ) | |
170 | ||
171 | txn.executemany(sql, ( | |
172 | (time_ts, media_origin, media_id) | |
173 | for media_origin, media_id in origin_id_tuples | |
174 | )) | |
175 | ||
176 | return self.runInteraction("update_cached_last_access_time", update_cache_txn) | |
162 | 177 | |
163 | 178 | def get_remote_media_thumbnails(self, origin, media_id): |
164 | 179 | return self._simple_select_list( |
189 | 204 | }, |
190 | 205 | desc="store_remote_media_thumbnail", |
191 | 206 | ) |
207 | ||
208 | def get_remote_media_before(self, before_ts): | |
209 | sql = ( | |
210 | "SELECT media_origin, media_id, filesystem_id" | |
211 | " FROM remote_media_cache" | |
212 | " WHERE last_access_ts < ?" | |
213 | ) | |
214 | ||
215 | return self._execute( | |
216 | "get_remote_media_before", self.cursor_to_dict, sql, before_ts | |
217 | ) | |
218 | ||
219 | def delete_remote_media(self, media_origin, media_id): | |
220 | def delete_remote_media_txn(txn): | |
221 | self._simple_delete_txn( | |
222 | txn, | |
223 | "remote_media_cache", | |
224 | keyvalues={ | |
225 | "media_origin": media_origin, "media_id": media_id | |
226 | }, | |
227 | ) | |
228 | self._simple_delete_txn( | |
229 | txn, | |
230 | "remote_media_cache_thumbnails", | |
231 | keyvalues={ | |
232 | "media_origin": media_origin, "media_id": media_id | |
233 | }, | |
234 | ) | |
235 | return self.runInteraction("delete_remote_media", delete_remote_media_txn) |
24 | 24 | |
25 | 25 | # Remember to update this number every time a change is made to database |
26 | 26 | # schema files, so the users will be informed on server restarts. |
27 | SCHEMA_VERSION = 32 | |
27 | SCHEMA_VERSION = 33 | |
28 | 28 | |
29 | 29 | dir_path = os.path.abspath(os.path.dirname(__file__)) |
30 | 30 |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | 19 | from synapse.api.errors import StoreError, Codes |
20 | ||
21 | from ._base import SQLBaseStore | |
20 | from synapse.storage import background_updates | |
22 | 21 | from synapse.util.caches.descriptors import cached, cachedInlineCallbacks |
23 | 22 | |
24 | 23 | |
25 | class RegistrationStore(SQLBaseStore): | |
24 | class RegistrationStore(background_updates.BackgroundUpdateStore): | |
26 | 25 | |
27 | 26 | def __init__(self, hs): |
28 | 27 | super(RegistrationStore, self).__init__(hs) |
29 | 28 | |
30 | 29 | self.clock = hs.get_clock() |
31 | 30 | |
32 | @defer.inlineCallbacks | |
33 | def add_access_token_to_user(self, user_id, token): | |
31 | self.register_background_index_update( | |
32 | "access_tokens_device_index", | |
33 | index_name="access_tokens_device_id", | |
34 | table="access_tokens", | |
35 | columns=["user_id", "device_id"], | |
36 | ) | |
37 | ||
38 | self.register_background_index_update( | |
39 | "refresh_tokens_device_index", | |
40 | index_name="refresh_tokens_device_id", | |
41 | table="refresh_tokens", | |
42 | columns=["user_id", "device_id"], | |
43 | ) | |
44 | ||
45 | @defer.inlineCallbacks | |
46 | def add_access_token_to_user(self, user_id, token, device_id=None): | |
34 | 47 | """Adds an access token for the given user. |
35 | 48 | |
36 | 49 | Args: |
37 | 50 | user_id (str): The user ID. |
38 | 51 | token (str): The new access token to add. |
52 | device_id (str): ID of the device to associate with the access | |
53 | token | |
39 | 54 | Raises: |
40 | 55 | StoreError if there was a problem adding this. |
41 | 56 | """ |
46 | 61 | { |
47 | 62 | "id": next_id, |
48 | 63 | "user_id": user_id, |
49 | "token": token | |
64 | "token": token, | |
65 | "device_id": device_id, | |
50 | 66 | }, |
51 | 67 | desc="add_access_token_to_user", |
52 | 68 | ) |
53 | 69 | |
54 | 70 | @defer.inlineCallbacks |
55 | def add_refresh_token_to_user(self, user_id, token): | |
71 | def add_refresh_token_to_user(self, user_id, token, device_id=None): | |
56 | 72 | """Adds a refresh token for the given user. |
57 | 73 | |
58 | 74 | Args: |
59 | 75 | user_id (str): The user ID. |
60 | 76 | token (str): The new refresh token to add. |
77 | device_id (str): ID of the device to associate with the access | |
78 | token | |
61 | 79 | Raises: |
62 | 80 | StoreError if there was a problem adding this. |
63 | 81 | """ |
68 | 86 | { |
69 | 87 | "id": next_id, |
70 | 88 | "user_id": user_id, |
71 | "token": token | |
89 | "token": token, | |
90 | "device_id": device_id, | |
72 | 91 | }, |
73 | 92 | desc="add_refresh_token_to_user", |
74 | 93 | ) |
75 | 94 | |
76 | 95 | @defer.inlineCallbacks |
77 | def register(self, user_id, token, password_hash, | |
96 | def register(self, user_id, token=None, password_hash=None, | |
78 | 97 | was_guest=False, make_guest=False, appservice_id=None, |
79 | create_profile_with_localpart=None): | |
98 | create_profile_with_localpart=None, admin=False): | |
80 | 99 | """Attempts to register an account. |
81 | 100 | |
82 | 101 | Args: |
83 | 102 | user_id (str): The desired user ID to register. |
84 | token (str): The desired access token to use for this user. | |
103 | token (str): The desired access token to use for this user. If this | |
104 | is not None, the given access token is associated with the user | |
105 | id. | |
85 | 106 | password_hash (str): Optional. The password hash for this user. |
86 | 107 | was_guest (bool): Optional. Whether this is a guest account being |
87 | 108 | upgraded to a non-guest account. |
103 | 124 | make_guest, |
104 | 125 | appservice_id, |
105 | 126 | create_profile_with_localpart, |
127 | admin | |
106 | 128 | ) |
107 | 129 | self.get_user_by_id.invalidate((user_id,)) |
108 | 130 | self.is_guest.invalidate((user_id,)) |
117 | 139 | make_guest, |
118 | 140 | appservice_id, |
119 | 141 | create_profile_with_localpart, |
142 | admin, | |
120 | 143 | ): |
121 | 144 | now = int(self.clock.time()) |
122 | 145 | |
124 | 147 | |
125 | 148 | try: |
126 | 149 | if was_guest: |
127 | txn.execute("UPDATE users SET" | |
128 | " password_hash = ?," | |
129 | " upgrade_ts = ?," | |
130 | " is_guest = ?" | |
131 | " WHERE name = ?", | |
132 | [password_hash, now, 1 if make_guest else 0, user_id]) | |
150 | # Ensure that the guest user actually exists | |
151 | # ``allow_none=False`` makes this raise an exception | |
152 | # if the row isn't in the database. | |
153 | self._simple_select_one_txn( | |
154 | txn, | |
155 | "users", | |
156 | keyvalues={ | |
157 | "name": user_id, | |
158 | "is_guest": 1, | |
159 | }, | |
160 | retcols=("name",), | |
161 | allow_none=False, | |
162 | ) | |
163 | ||
164 | self._simple_update_one_txn( | |
165 | txn, | |
166 | "users", | |
167 | keyvalues={ | |
168 | "name": user_id, | |
169 | "is_guest": 1, | |
170 | }, | |
171 | updatevalues={ | |
172 | "password_hash": password_hash, | |
173 | "upgrade_ts": now, | |
174 | "is_guest": 1 if make_guest else 0, | |
175 | "appservice_id": appservice_id, | |
176 | "admin": 1 if admin else 0, | |
177 | } | |
178 | ) | |
133 | 179 | else: |
134 | txn.execute("INSERT INTO users " | |
135 | "(" | |
136 | " name," | |
137 | " password_hash," | |
138 | " creation_ts," | |
139 | " is_guest," | |
140 | " appservice_id" | |
141 | ") " | |
142 | "VALUES (?,?,?,?,?)", | |
143 | [ | |
144 | user_id, | |
145 | password_hash, | |
146 | now, | |
147 | 1 if make_guest else 0, | |
148 | appservice_id, | |
149 | ]) | |
180 | self._simple_insert_txn( | |
181 | txn, | |
182 | "users", | |
183 | values={ | |
184 | "name": user_id, | |
185 | "password_hash": password_hash, | |
186 | "creation_ts": now, | |
187 | "is_guest": 1 if make_guest else 0, | |
188 | "appservice_id": appservice_id, | |
189 | "admin": 1 if admin else 0, | |
190 | } | |
191 | ) | |
150 | 192 | except self.database_engine.module.IntegrityError: |
151 | 193 | raise StoreError( |
152 | 194 | 400, "User ID already taken.", errcode=Codes.USER_IN_USE |
208 | 250 | self.get_user_by_id.invalidate((user_id,)) |
209 | 251 | |
210 | 252 | @defer.inlineCallbacks |
211 | def user_delete_access_tokens(self, user_id, except_token_ids=[]): | |
212 | def f(txn): | |
213 | sql = "SELECT token FROM access_tokens WHERE user_id = ?" | |
253 | def user_delete_access_tokens(self, user_id, except_token_ids=[], | |
254 | device_id=None, | |
255 | delete_refresh_tokens=False): | |
256 | """ | |
257 | Invalidate access/refresh tokens belonging to a user | |
258 | ||
259 | Args: | |
260 | user_id (str): ID of user the tokens belong to | |
261 | except_token_ids (list[str]): list of access_tokens which should | |
262 | *not* be deleted | |
263 | device_id (str|None): ID of device the tokens are associated with. | |
264 | If None, tokens associated with any device (or no device) will | |
265 | be deleted | |
266 | delete_refresh_tokens (bool): True to delete refresh tokens as | |
267 | well as access tokens. | |
268 | Returns: | |
269 | defer.Deferred: | |
270 | """ | |
271 | def f(txn, table, except_tokens, call_after_delete): | |
272 | sql = "SELECT token FROM %s WHERE user_id = ?" % table | |
214 | 273 | clauses = [user_id] |
215 | 274 | |
216 | if except_token_ids: | |
275 | if device_id is not None: | |
276 | sql += " AND device_id = ?" | |
277 | clauses.append(device_id) | |
278 | ||
279 | if except_tokens: | |
217 | 280 | sql += " AND id NOT IN (%s)" % ( |
218 | ",".join(["?" for _ in except_token_ids]), | |
281 | ",".join(["?" for _ in except_tokens]), | |
219 | 282 | ) |
220 | clauses += except_token_ids | |
283 | clauses += except_tokens | |
221 | 284 | |
222 | 285 | txn.execute(sql, clauses) |
223 | 286 | |
226 | 289 | n = 100 |
227 | 290 | chunks = [rows[i:i + n] for i in xrange(0, len(rows), n)] |
228 | 291 | for chunk in chunks: |
229 | for row in chunk: | |
230 | txn.call_after(self.get_user_by_access_token.invalidate, (row[0],)) | |
292 | if call_after_delete: | |
293 | for row in chunk: | |
294 | txn.call_after(call_after_delete, (row[0],)) | |
231 | 295 | |
232 | 296 | txn.execute( |
233 | "DELETE FROM access_tokens WHERE token in (%s)" % ( | |
297 | "DELETE FROM %s WHERE token in (%s)" % ( | |
298 | table, | |
234 | 299 | ",".join(["?" for _ in chunk]), |
235 | 300 | ), [r[0] for r in chunk] |
236 | 301 | ) |
237 | 302 | |
238 | yield self.runInteraction("user_delete_access_tokens", f) | |
303 | # delete refresh tokens first, to stop new access tokens being | |
304 | # allocated while our backs are turned | |
305 | if delete_refresh_tokens: | |
306 | yield self.runInteraction( | |
307 | "user_delete_access_tokens", f, | |
308 | table="refresh_tokens", | |
309 | except_tokens=[], | |
310 | call_after_delete=None, | |
311 | ) | |
312 | ||
313 | yield self.runInteraction( | |
314 | "user_delete_access_tokens", f, | |
315 | table="access_tokens", | |
316 | except_tokens=except_token_ids, | |
317 | call_after_delete=self.get_user_by_access_token.invalidate, | |
318 | ) | |
239 | 319 | |
240 | 320 | def delete_access_token(self, access_token): |
241 | 321 | def f(txn): |
258 | 338 | Args: |
259 | 339 | token (str): The access token of a user. |
260 | 340 | Returns: |
261 | dict: Including the name (user_id) and the ID of their access token. | |
262 | Raises: | |
263 | StoreError if no user was found. | |
341 | defer.Deferred: None, if the token did not match, otherwise dict | |
342 | including the keys `name`, `is_guest`, `device_id`, `token_id`. | |
264 | 343 | """ |
265 | 344 | return self.runInteraction( |
266 | 345 | "get_user_by_access_token", |
269 | 348 | ) |
270 | 349 | |
271 | 350 | def exchange_refresh_token(self, refresh_token, token_generator): |
272 | """Exchange a refresh token for a new access token and refresh token. | |
351 | """Exchange a refresh token for a new one. | |
273 | 352 | |
274 | 353 | Doing so invalidates the old refresh token - refresh tokens are single |
275 | 354 | use. |
276 | 355 | |
277 | 356 | Args: |
278 | token (str): The refresh token of a user. | |
357 | refresh_token (str): The refresh token of a user. | |
279 | 358 | token_generator (fn: str -> str): Function which, when given a |
280 | 359 | user ID, returns a unique refresh token for that user. This |
281 | 360 | function must never return the same value twice. |
282 | 361 | Returns: |
283 | tuple of (user_id, refresh_token) | |
362 | tuple of (user_id, new_refresh_token, device_id) | |
284 | 363 | Raises: |
285 | 364 | StoreError if no user was found with that refresh token. |
286 | 365 | """ |
292 | 371 | ) |
293 | 372 | |
294 | 373 | def _exchange_refresh_token(self, txn, old_token, token_generator): |
295 | sql = "SELECT user_id FROM refresh_tokens WHERE token = ?" | |
374 | sql = "SELECT user_id, device_id FROM refresh_tokens WHERE token = ?" | |
296 | 375 | txn.execute(sql, (old_token,)) |
297 | 376 | rows = self.cursor_to_dict(txn) |
298 | 377 | if not rows: |
299 | 378 | raise StoreError(403, "Did not recognize refresh token") |
300 | 379 | user_id = rows[0]["user_id"] |
380 | device_id = rows[0]["device_id"] | |
301 | 381 | |
302 | 382 | # TODO(danielwh): Maybe perform a validation on the macaroon that |
303 | 383 | # macaroon.user_id == user_id. |
306 | 386 | sql = "UPDATE refresh_tokens SET token = ? WHERE token = ?" |
307 | 387 | txn.execute(sql, (new_token, old_token,)) |
308 | 388 | |
309 | return user_id, new_token | |
389 | return user_id, new_token, device_id | |
310 | 390 | |
311 | 391 | @defer.inlineCallbacks |
312 | 392 | def is_server_admin(self, user): |
334 | 414 | |
335 | 415 | def _query_for_auth(self, txn, token): |
336 | 416 | sql = ( |
337 | "SELECT users.name, users.is_guest, access_tokens.id as token_id" | |
417 | "SELECT users.name, users.is_guest, access_tokens.id as token_id," | |
418 | " access_tokens.device_id" | |
338 | 419 | " FROM users" |
339 | 420 | " INNER JOIN access_tokens on users.name = access_tokens.user_id" |
340 | 421 | " WHERE token = ?" |
383 | 464 | defer.returnValue(ret['user_id']) |
384 | 465 | defer.returnValue(None) |
385 | 466 | |
467 | def user_delete_threepids(self, user_id): | |
468 | return self._simple_delete( | |
469 | "user_threepids", | |
470 | keyvalues={ | |
471 | "user_id": user_id, | |
472 | }, | |
473 | desc="user_delete_threepids", | |
474 | ) | |
475 | ||
386 | 476 | @defer.inlineCallbacks |
387 | 477 | def count_all_users(self): |
388 | 478 | """Counts all users registered on the homeserver.""" |
17 | 17 | from synapse.api.errors import StoreError |
18 | 18 | |
19 | 19 | from ._base import SQLBaseStore |
20 | from synapse.util.caches.descriptors import cachedInlineCallbacks | |
21 | 20 | from .engines import PostgresEngine, Sqlite3Engine |
22 | 21 | |
23 | 22 | import collections |
191 | 190 | # This should be unreachable. |
192 | 191 | raise Exception("Unrecognized database engine") |
193 | 192 | |
194 | @cachedInlineCallbacks() | |
195 | def get_room_name_and_aliases(self, room_id): | |
196 | def get_room_name(txn): | |
197 | sql = ( | |
198 | "SELECT name FROM room_names" | |
199 | " INNER JOIN current_state_events USING (room_id, event_id)" | |
200 | " WHERE room_id = ?" | |
201 | " LIMIT 1" | |
202 | ) | |
203 | ||
204 | txn.execute(sql, (room_id,)) | |
205 | rows = txn.fetchall() | |
206 | if rows: | |
207 | return rows[0][0] | |
208 | else: | |
209 | return None | |
210 | ||
211 | return [row[0] for row in txn.fetchall()] | |
212 | ||
213 | def get_room_aliases(txn): | |
214 | sql = ( | |
215 | "SELECT content FROM current_state_events" | |
216 | " INNER JOIN events USING (room_id, event_id)" | |
217 | " WHERE room_id = ?" | |
218 | ) | |
219 | txn.execute(sql, (room_id,)) | |
220 | return [row[0] for row in txn.fetchall()] | |
221 | ||
222 | name = yield self.runInteraction("get_room_name", get_room_name) | |
223 | alias_contents = yield self.runInteraction("get_room_aliases", get_room_aliases) | |
224 | ||
225 | aliases = [] | |
226 | ||
227 | for c in alias_contents: | |
228 | try: | |
229 | content = json.loads(c) | |
230 | except: | |
231 | continue | |
232 | ||
233 | aliases.extend(content.get('aliases', [])) | |
234 | ||
235 | defer.returnValue((name, aliases)) | |
236 | ||
237 | 193 | def add_event_report(self, room_id, event_id, user_id, reason, content, |
238 | 194 | received_ts): |
239 | 195 | next_id = self._event_reports_id_gen.get_next() |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | INSERT INTO background_updates (update_name, progress_json) VALUES | |
16 | ('access_tokens_device_index', '{}'); |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | CREATE TABLE devices ( | |
16 | user_id TEXT NOT NULL, | |
17 | device_id TEXT NOT NULL, | |
18 | display_name TEXT, | |
19 | CONSTRAINT device_uniqueness UNIQUE (user_id, device_id) | |
20 | ); |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | -- make sure that we have a device record for each set of E2E keys, so that the | |
16 | -- user can delete them if they like. | |
17 | INSERT INTO devices | |
18 | SELECT user_id, device_id, NULL FROM e2e_device_keys_json; |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | -- a previous version of the "devices_for_e2e_keys" delta set all the device | |
16 | -- names to "unknown device". This wasn't terribly helpful | |
17 | UPDATE devices | |
18 | SET display_name = NULL | |
19 | WHERE display_name = 'unknown device'; |
0 | # Copyright 2016 OpenMarket Ltd | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | from synapse.storage.prepare_database import get_statements | |
15 | ||
16 | import logging | |
17 | import ujson | |
18 | ||
19 | logger = logging.getLogger(__name__) | |
20 | ||
21 | ||
22 | ALTER_TABLE = """ | |
23 | ALTER TABLE events ADD COLUMN sender TEXT; | |
24 | ALTER TABLE events ADD COLUMN contains_url BOOLEAN; | |
25 | """ | |
26 | ||
27 | ||
28 | def run_create(cur, database_engine, *args, **kwargs): | |
29 | for statement in get_statements(ALTER_TABLE.splitlines()): | |
30 | cur.execute(statement) | |
31 | ||
32 | cur.execute("SELECT MIN(stream_ordering) FROM events") | |
33 | rows = cur.fetchall() | |
34 | min_stream_id = rows[0][0] | |
35 | ||
36 | cur.execute("SELECT MAX(stream_ordering) FROM events") | |
37 | rows = cur.fetchall() | |
38 | max_stream_id = rows[0][0] | |
39 | ||
40 | if min_stream_id is not None and max_stream_id is not None: | |
41 | progress = { | |
42 | "target_min_stream_id_inclusive": min_stream_id, | |
43 | "max_stream_id_exclusive": max_stream_id + 1, | |
44 | "rows_inserted": 0, | |
45 | } | |
46 | progress_json = ujson.dumps(progress) | |
47 | ||
48 | sql = ( | |
49 | "INSERT into background_updates (update_name, progress_json)" | |
50 | " VALUES (?, ?)" | |
51 | ) | |
52 | ||
53 | sql = database_engine.convert_param_style(sql) | |
54 | ||
55 | cur.execute(sql, ("event_fields_sender_url", progress_json)) | |
56 | ||
57 | ||
58 | def run_upgrade(cur, database_engine, *args, **kwargs): | |
59 | pass |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | ALTER TABLE refresh_tokens ADD COLUMN device_id TEXT; |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | INSERT INTO background_updates (update_name, progress_json) VALUES | |
16 | ('refresh_tokens_device_index', '{}'); |
0 | # Copyright 2016 OpenMarket Ltd | |
1 | # | |
2 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | # you may not use this file except in compliance with the License. | |
4 | # You may obtain a copy of the License at | |
5 | # | |
6 | # http://www.apache.org/licenses/LICENSE-2.0 | |
7 | # | |
8 | # Unless required by applicable law or agreed to in writing, software | |
9 | # distributed under the License is distributed on an "AS IS" BASIS, | |
10 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | # See the License for the specific language governing permissions and | |
12 | # limitations under the License. | |
13 | ||
14 | import time | |
15 | ||
16 | ||
17 | ALTER_TABLE = "ALTER TABLE remote_media_cache ADD COLUMN last_access_ts BIGINT" | |
18 | ||
19 | ||
20 | def run_create(cur, database_engine, *args, **kwargs): | |
21 | cur.execute(ALTER_TABLE) | |
22 | ||
23 | ||
24 | def run_upgrade(cur, database_engine, *args, **kwargs): | |
25 | cur.execute( | |
26 | database_engine.convert_param_style( | |
27 | "UPDATE remote_media_cache SET last_access_ts = ?" | |
28 | ), | |
29 | (int(time.time() * 1000),) | |
30 | ) |
0 | /* Copyright 2016 OpenMarket Ltd | |
1 | * | |
2 | * Licensed under the Apache License, Version 2.0 (the "License"); | |
3 | * you may not use this file except in compliance with the License. | |
4 | * You may obtain a copy of the License at | |
5 | * | |
6 | * http://www.apache.org/licenses/LICENSE-2.0 | |
7 | * | |
8 | * Unless required by applicable law or agreed to in writing, software | |
9 | * distributed under the License is distributed on an "AS IS" BASIS, | |
10 | * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
11 | * See the License for the specific language governing permissions and | |
12 | * limitations under the License. | |
13 | */ | |
14 | ||
15 | INSERT INTO background_updates (update_name, progress_json) VALUES | |
16 | ('user_ips_device_index', '{}'); |
39 | 39 | from synapse.api.constants import EventTypes |
40 | 40 | from synapse.types import RoomStreamToken |
41 | 41 | from synapse.util.logcontext import preserve_fn |
42 | from synapse.storage.engines import PostgresEngine, Sqlite3Engine | |
42 | 43 | |
43 | 44 | import logging |
44 | 45 | |
53 | 54 | _TOPOLOGICAL_TOKEN = "topological" |
54 | 55 | |
55 | 56 | |
56 | def lower_bound(token): | |
57 | def lower_bound(token, engine, inclusive=False): | |
58 | inclusive = "=" if inclusive else "" | |
57 | 59 | if token.topological is None: |
58 | return "(%d < %s)" % (token.stream, "stream_ordering") | |
60 | return "(%d <%s %s)" % (token.stream, inclusive, "stream_ordering") | |
59 | 61 | else: |
60 | return "(%d < %s OR (%d = %s AND %d < %s))" % ( | |
62 | if isinstance(engine, PostgresEngine): | |
63 | # Postgres doesn't optimise ``(x < a) OR (x=a AND y<b)`` as well | |
64 | # as it optimises ``(x,y) < (a,b)`` on multicolumn indexes. So we | |
65 | # use the later form when running against postgres. | |
66 | return "((%d,%d) <%s (%s,%s))" % ( | |
67 | token.topological, token.stream, inclusive, | |
68 | "topological_ordering", "stream_ordering", | |
69 | ) | |
70 | return "(%d < %s OR (%d = %s AND %d <%s %s))" % ( | |
61 | 71 | token.topological, "topological_ordering", |
62 | 72 | token.topological, "topological_ordering", |
63 | token.stream, "stream_ordering", | |
64 | ) | |
65 | ||
66 | ||
67 | def upper_bound(token): | |
73 | token.stream, inclusive, "stream_ordering", | |
74 | ) | |
75 | ||
76 | ||
77 | def upper_bound(token, engine, inclusive=True): | |
78 | inclusive = "=" if inclusive else "" | |
68 | 79 | if token.topological is None: |
69 | return "(%d >= %s)" % (token.stream, "stream_ordering") | |
80 | return "(%d >%s %s)" % (token.stream, inclusive, "stream_ordering") | |
70 | 81 | else: |
71 | return "(%d > %s OR (%d = %s AND %d >= %s))" % ( | |
82 | if isinstance(engine, PostgresEngine): | |
83 | # Postgres doesn't optimise ``(x > a) OR (x=a AND y>b)`` as well | |
84 | # as it optimises ``(x,y) > (a,b)`` on multicolumn indexes. So we | |
85 | # use the later form when running against postgres. | |
86 | return "((%d,%d) >%s (%s,%s))" % ( | |
87 | token.topological, token.stream, inclusive, | |
88 | "topological_ordering", "stream_ordering", | |
89 | ) | |
90 | return "(%d > %s OR (%d = %s AND %d >%s %s))" % ( | |
72 | 91 | token.topological, "topological_ordering", |
73 | 92 | token.topological, "topological_ordering", |
74 | token.stream, "stream_ordering", | |
75 | ) | |
93 | token.stream, inclusive, "stream_ordering", | |
94 | ) | |
95 | ||
96 | ||
97 | def filter_to_clause(event_filter): | |
98 | # NB: This may create SQL clauses that don't optimise well (and we don't | |
99 | # have indices on all possible clauses). E.g. it may create | |
100 | # "room_id == X AND room_id != X", which postgres doesn't optimise. | |
101 | ||
102 | if not event_filter: | |
103 | return "", [] | |
104 | ||
105 | clauses = [] | |
106 | args = [] | |
107 | ||
108 | if event_filter.types: | |
109 | clauses.append( | |
110 | "(%s)" % " OR ".join("type = ?" for _ in event_filter.types) | |
111 | ) | |
112 | args.extend(event_filter.types) | |
113 | ||
114 | for typ in event_filter.not_types: | |
115 | clauses.append("type != ?") | |
116 | args.append(typ) | |
117 | ||
118 | if event_filter.senders: | |
119 | clauses.append( | |
120 | "(%s)" % " OR ".join("sender = ?" for _ in event_filter.senders) | |
121 | ) | |
122 | args.extend(event_filter.senders) | |
123 | ||
124 | for sender in event_filter.not_senders: | |
125 | clauses.append("sender != ?") | |
126 | args.append(sender) | |
127 | ||
128 | if event_filter.rooms: | |
129 | clauses.append( | |
130 | "(%s)" % " OR ".join("room_id = ?" for _ in event_filter.rooms) | |
131 | ) | |
132 | args.extend(event_filter.rooms) | |
133 | ||
134 | for room_id in event_filter.not_rooms: | |
135 | clauses.append("room_id != ?") | |
136 | args.append(room_id) | |
137 | ||
138 | if event_filter.contains_url: | |
139 | clauses.append("contains_url = ?") | |
140 | args.append(event_filter.contains_url) | |
141 | ||
142 | return " AND ".join(clauses), args | |
76 | 143 | |
77 | 144 | |
78 | 145 | class StreamStore(SQLBaseStore): |
300 | 367 | |
301 | 368 | @defer.inlineCallbacks |
302 | 369 | def paginate_room_events(self, room_id, from_key, to_key=None, |
303 | direction='b', limit=-1): | |
370 | direction='b', limit=-1, event_filter=None): | |
304 | 371 | # Tokens really represent positions between elements, but we use |
305 | 372 | # the convention of pointing to the event before the gap. Hence |
306 | 373 | # we have a bit of asymmetry when it comes to equalities. |
307 | 374 | args = [False, room_id] |
308 | 375 | if direction == 'b': |
309 | 376 | order = "DESC" |
310 | bounds = upper_bound(RoomStreamToken.parse(from_key)) | |
377 | bounds = upper_bound( | |
378 | RoomStreamToken.parse(from_key), self.database_engine | |
379 | ) | |
311 | 380 | if to_key: |
312 | bounds = "%s AND %s" % ( | |
313 | bounds, lower_bound(RoomStreamToken.parse(to_key)) | |
314 | ) | |
381 | bounds = "%s AND %s" % (bounds, lower_bound( | |
382 | RoomStreamToken.parse(to_key), self.database_engine | |
383 | )) | |
315 | 384 | else: |
316 | 385 | order = "ASC" |
317 | bounds = lower_bound(RoomStreamToken.parse(from_key)) | |
386 | bounds = lower_bound( | |
387 | RoomStreamToken.parse(from_key), self.database_engine | |
388 | ) | |
318 | 389 | if to_key: |
319 | bounds = "%s AND %s" % ( | |
320 | bounds, upper_bound(RoomStreamToken.parse(to_key)) | |
321 | ) | |
390 | bounds = "%s AND %s" % (bounds, upper_bound( | |
391 | RoomStreamToken.parse(to_key), self.database_engine | |
392 | )) | |
393 | ||
394 | filter_clause, filter_args = filter_to_clause(event_filter) | |
395 | ||
396 | if filter_clause: | |
397 | bounds += " AND " + filter_clause | |
398 | args.extend(filter_args) | |
322 | 399 | |
323 | 400 | if int(limit) > 0: |
324 | 401 | args.append(int(limit)) |
486 | 563 | row["topological_ordering"], row["stream_ordering"],) |
487 | 564 | ) |
488 | 565 | |
489 | def get_max_topological_token_for_stream_and_room(self, room_id, stream_key): | |
566 | def get_max_topological_token(self, room_id, stream_key): | |
490 | 567 | sql = ( |
491 | 568 | "SELECT max(topological_ordering) FROM events" |
492 | 569 | " WHERE room_id = ? AND stream_ordering < ?" |
493 | 570 | ) |
494 | 571 | return self._execute( |
495 | "get_max_topological_token_for_stream_and_room", None, | |
572 | "get_max_topological_token", None, | |
496 | 573 | sql, room_id, stream_key, |
497 | 574 | ).addCallback( |
498 | 575 | lambda r: r[0][0] if r else 0 |
585 | 662 | retcols=["stream_ordering", "topological_ordering"], |
586 | 663 | ) |
587 | 664 | |
588 | stream_ordering = results["stream_ordering"] | |
589 | topological_ordering = results["topological_ordering"] | |
590 | ||
591 | query_before = ( | |
592 | "SELECT topological_ordering, stream_ordering, event_id FROM events" | |
593 | " WHERE room_id = ? AND (topological_ordering < ?" | |
594 | " OR (topological_ordering = ? AND stream_ordering < ?))" | |
595 | " ORDER BY topological_ordering DESC, stream_ordering DESC" | |
596 | " LIMIT ?" | |
597 | ) | |
598 | ||
599 | query_after = ( | |
600 | "SELECT topological_ordering, stream_ordering, event_id FROM events" | |
601 | " WHERE room_id = ? AND (topological_ordering > ?" | |
602 | " OR (topological_ordering = ? AND stream_ordering > ?))" | |
603 | " ORDER BY topological_ordering ASC, stream_ordering ASC" | |
604 | " LIMIT ?" | |
605 | ) | |
606 | ||
607 | txn.execute( | |
608 | query_before, | |
609 | ( | |
610 | room_id, topological_ordering, topological_ordering, | |
611 | stream_ordering, before_limit, | |
612 | ) | |
613 | ) | |
665 | token = RoomStreamToken( | |
666 | results["topological_ordering"], | |
667 | results["stream_ordering"], | |
668 | ) | |
669 | ||
670 | if isinstance(self.database_engine, Sqlite3Engine): | |
671 | # SQLite3 doesn't optimise ``(x < a) OR (x = a AND y < b)`` | |
672 | # So we give pass it to SQLite3 as the UNION ALL of the two queries. | |
673 | ||
674 | query_before = ( | |
675 | "SELECT topological_ordering, stream_ordering, event_id FROM events" | |
676 | " WHERE room_id = ? AND topological_ordering < ?" | |
677 | " UNION ALL" | |
678 | " SELECT topological_ordering, stream_ordering, event_id FROM events" | |
679 | " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering < ?" | |
680 | " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" | |
681 | ) | |
682 | before_args = ( | |
683 | room_id, token.topological, | |
684 | room_id, token.topological, token.stream, | |
685 | before_limit, | |
686 | ) | |
687 | ||
688 | query_after = ( | |
689 | "SELECT topological_ordering, stream_ordering, event_id FROM events" | |
690 | " WHERE room_id = ? AND topological_ordering > ?" | |
691 | " UNION ALL" | |
692 | " SELECT topological_ordering, stream_ordering, event_id FROM events" | |
693 | " WHERE room_id = ? AND topological_ordering = ? AND stream_ordering > ?" | |
694 | " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" | |
695 | ) | |
696 | after_args = ( | |
697 | room_id, token.topological, | |
698 | room_id, token.topological, token.stream, | |
699 | after_limit, | |
700 | ) | |
701 | else: | |
702 | query_before = ( | |
703 | "SELECT topological_ordering, stream_ordering, event_id FROM events" | |
704 | " WHERE room_id = ? AND %s" | |
705 | " ORDER BY topological_ordering DESC, stream_ordering DESC LIMIT ?" | |
706 | ) % (upper_bound(token, self.database_engine, inclusive=False),) | |
707 | ||
708 | before_args = (room_id, before_limit) | |
709 | ||
710 | query_after = ( | |
711 | "SELECT topological_ordering, stream_ordering, event_id FROM events" | |
712 | " WHERE room_id = ? AND %s" | |
713 | " ORDER BY topological_ordering ASC, stream_ordering ASC LIMIT ?" | |
714 | ) % (lower_bound(token, self.database_engine, inclusive=False),) | |
715 | ||
716 | after_args = (room_id, after_limit) | |
717 | ||
718 | txn.execute(query_before, before_args) | |
614 | 719 | |
615 | 720 | rows = self.cursor_to_dict(txn) |
616 | 721 | events_before = [r["event_id"] for r in rows] |
622 | 727 | )) |
623 | 728 | else: |
624 | 729 | start_token = str(RoomStreamToken( |
625 | topological_ordering, | |
626 | stream_ordering - 1, | |
730 | token.topological, | |
731 | token.stream - 1, | |
627 | 732 | )) |
628 | 733 | |
629 | txn.execute( | |
630 | query_after, | |
631 | ( | |
632 | room_id, topological_ordering, topological_ordering, | |
633 | stream_ordering, after_limit, | |
634 | ) | |
635 | ) | |
734 | txn.execute(query_after, after_args) | |
636 | 735 | |
637 | 736 | rows = self.cursor_to_dict(txn) |
638 | 737 | events_after = [r["event_id"] for r in rows] |
643 | 742 | rows[-1]["stream_ordering"], |
644 | 743 | )) |
645 | 744 | else: |
646 | end_token = str(RoomStreamToken( | |
647 | topological_ordering, | |
648 | stream_ordering, | |
649 | )) | |
745 | end_token = str(token) | |
650 | 746 | |
651 | 747 | return { |
652 | 748 | "before": { |
23 | 23 | |
24 | 24 | import itertools |
25 | 25 | import logging |
26 | import ujson as json | |
26 | 27 | |
27 | 28 | logger = logging.getLogger(__name__) |
28 | 29 | |
100 | 101 | ) |
101 | 102 | |
102 | 103 | if result and result["response_code"]: |
103 | return result["response_code"], result["response_json"] | |
104 | return result["response_code"], json.loads(str(result["response_json"])) | |
104 | 105 | else: |
105 | 106 | return None |
106 | 107 |
17 | 17 | from collections import namedtuple |
18 | 18 | |
19 | 19 | |
20 | Requester = namedtuple("Requester", ["user", "access_token_id", "is_guest"]) | |
20 | Requester = namedtuple("Requester", | |
21 | ["user", "access_token_id", "is_guest", "device_id"]) | |
22 | """ | |
23 | Represents the user making a request | |
24 | ||
25 | Attributes: | |
26 | user (UserID): id of the user making the request | |
27 | access_token_id (int|None): *ID* of the access token used for this | |
28 | request, or None if it came via the appservice API or similar | |
29 | is_guest (bool): True if the user making this request is a guest user | |
30 | device_id (str|None): device_id which was set at authentication time | |
31 | """ | |
32 | ||
33 | ||
34 | def create_requester(user_id, access_token_id=None, is_guest=False, | |
35 | device_id=None): | |
36 | """ | |
37 | Create a new ``Requester`` object | |
38 | ||
39 | Args: | |
40 | user_id (str|UserID): id of the user making the request | |
41 | access_token_id (int|None): *ID* of the access token used for this | |
42 | request, or None if it came via the appservice API or similar | |
43 | is_guest (bool): True if the user making this request is a guest user | |
44 | device_id (str|None): device_id which was set at authentication time | |
45 | ||
46 | Returns: | |
47 | Requester | |
48 | """ | |
49 | if not isinstance(user_id, UserID): | |
50 | user_id = UserID.from_string(user_id) | |
51 | return Requester(user_id, access_token_id, is_guest, device_id) | |
21 | 52 | |
22 | 53 | |
23 | 54 | def get_domain_from_id(string): |
193 | 193 | self.key_to_defer.pop(key, None) |
194 | 194 | |
195 | 195 | defer.returnValue(_ctx_manager()) |
196 | ||
197 | ||
198 | class ReadWriteLock(object): | |
199 | """A deferred style read write lock. | |
200 | ||
201 | Example: | |
202 | ||
203 | with (yield read_write_lock.read("test_key")): | |
204 | # do some work | |
205 | """ | |
206 | ||
207 | # IMPLEMENTATION NOTES | |
208 | # | |
209 | # We track the most recent queued reader and writer deferreds (which get | |
210 | # resolved when they release the lock). | |
211 | # | |
212 | # Read: We know its safe to acquire a read lock when the latest writer has | |
213 | # been resolved. The new reader is appeneded to the list of latest readers. | |
214 | # | |
215 | # Write: We know its safe to acquire the write lock when both the latest | |
216 | # writers and readers have been resolved. The new writer replaces the latest | |
217 | # writer. | |
218 | ||
219 | def __init__(self): | |
220 | # Latest readers queued | |
221 | self.key_to_current_readers = {} | |
222 | ||
223 | # Latest writer queued | |
224 | self.key_to_current_writer = {} | |
225 | ||
226 | @defer.inlineCallbacks | |
227 | def read(self, key): | |
228 | new_defer = defer.Deferred() | |
229 | ||
230 | curr_readers = self.key_to_current_readers.setdefault(key, set()) | |
231 | curr_writer = self.key_to_current_writer.get(key, None) | |
232 | ||
233 | curr_readers.add(new_defer) | |
234 | ||
235 | # We wait for the latest writer to finish writing. We can safely ignore | |
236 | # any existing readers... as they're readers. | |
237 | yield curr_writer | |
238 | ||
239 | @contextmanager | |
240 | def _ctx_manager(): | |
241 | try: | |
242 | yield | |
243 | finally: | |
244 | new_defer.callback(None) | |
245 | self.key_to_current_readers.get(key, set()).discard(new_defer) | |
246 | ||
247 | defer.returnValue(_ctx_manager()) | |
248 | ||
249 | @defer.inlineCallbacks | |
250 | def write(self, key): | |
251 | new_defer = defer.Deferred() | |
252 | ||
253 | curr_readers = self.key_to_current_readers.get(key, set()) | |
254 | curr_writer = self.key_to_current_writer.get(key, None) | |
255 | ||
256 | # We wait on all latest readers and writer. | |
257 | to_wait_on = list(curr_readers) | |
258 | if curr_writer: | |
259 | to_wait_on.append(curr_writer) | |
260 | ||
261 | # We can clear the list of current readers since the new writer waits | |
262 | # for them to finish. | |
263 | curr_readers.clear() | |
264 | self.key_to_current_writer[key] = new_defer | |
265 | ||
266 | yield defer.gatherResults(to_wait_on) | |
267 | ||
268 | @contextmanager | |
269 | def _ctx_manager(): | |
270 | try: | |
271 | yield | |
272 | finally: | |
273 | new_defer.callback(None) | |
274 | if self.key_to_current_writer[key] == new_defer: | |
275 | self.key_to_current_writer.pop(key) | |
276 | ||
277 | defer.returnValue(_ctx_manager()) |
23 | 23 | used rather than trying to compute a new response. |
24 | 24 | """ |
25 | 25 | |
26 | def __init__(self): | |
26 | def __init__(self, hs, timeout_ms=0): | |
27 | 27 | self.pending_result_cache = {} # Requests that haven't finished yet. |
28 | ||
29 | self.clock = hs.get_clock() | |
30 | self.timeout_sec = timeout_ms / 1000. | |
28 | 31 | |
29 | 32 | def get(self, key): |
30 | 33 | result = self.pending_result_cache.get(key) |
38 | 41 | self.pending_result_cache[key] = result |
39 | 42 | |
40 | 43 | def remove(r): |
41 | self.pending_result_cache.pop(key, None) | |
44 | if self.timeout_sec: | |
45 | self.clock.call_later( | |
46 | self.timeout_sec, | |
47 | self.pending_result_cache.pop, key, None, | |
48 | ) | |
49 | else: | |
50 | self.pending_result_cache.pop(key, None) | |
42 | 51 | return r |
43 | 52 | |
44 | 53 | result.addBoth(remove) |
83 | 83 | |
84 | 84 | if context != self.start_context: |
85 | 85 | logger.warn( |
86 | "Context have unexpectedly changed from '%s' to '%s'. (%r)", | |
86 | "Context has unexpectedly changed from '%s' to '%s'. (%r)", | |
87 | 87 | context, self.start_context, self.name |
88 | 88 | ) |
89 | 89 | return |
24 | 24 | ALL_ALONE = "Empty Room" |
25 | 25 | |
26 | 26 | |
27 | def calculate_room_name(room_state, user_id, fallback_to_members=True): | |
27 | def calculate_room_name(room_state, user_id, fallback_to_members=True, | |
28 | fallback_to_single_member=True): | |
28 | 29 | """ |
29 | 30 | Works out a user-facing name for the given room as per Matrix |
30 | 31 | spec recommendations. |
81 | 82 | ): |
82 | 83 | if ("m.room.member", my_member_event.sender) in room_state: |
83 | 84 | inviter_member_event = room_state[("m.room.member", my_member_event.sender)] |
84 | return "Invite from %s" % (name_from_member_event(inviter_member_event),) | |
85 | if fallback_to_single_member: | |
86 | return "Invite from %s" % (name_from_member_event(inviter_member_event),) | |
87 | else: | |
88 | return None | |
85 | 89 | else: |
86 | 90 | return "Room Invite" |
87 | 91 | |
128 | 132 | return name_from_member_event(all_members[0]) |
129 | 133 | else: |
130 | 134 | return ALL_ALONE |
135 | elif len(other_members) == 1 and not fallback_to_single_member: | |
136 | return None | |
131 | 137 | else: |
132 | 138 | return descriptor_from_member_events(other_members) |
133 | 139 |
127 | 127 | ) |
128 | 128 | |
129 | 129 | valid_err_code = False |
130 | if exc_type is CodeMessageException: | |
130 | if exc_type is not None and issubclass(exc_type, CodeMessageException): | |
131 | 131 | valid_err_code = 0 <= exc_val.code < 500 |
132 | 132 | |
133 | 133 | if exc_type is None or valid_err_code: |
20 | 20 | logger = logging.getLogger(__name__) |
21 | 21 | |
22 | 22 | |
23 | def get_version_string(name, module): | |
23 | def get_version_string(module): | |
24 | 24 | try: |
25 | 25 | null = open(os.devnull, 'w') |
26 | 26 | cwd = os.path.dirname(os.path.abspath(module.__file__)) |
73 | 73 | ) |
74 | 74 | |
75 | 75 | return ( |
76 | "%s/%s (%s)" % ( | |
77 | name, module.__version__, git_version, | |
76 | "%s (%s)" % ( | |
77 | module.__version__, git_version, | |
78 | 78 | ) |
79 | 79 | ).encode("ascii") |
80 | 80 | except Exception as e: |
81 | 81 | logger.info("Failed to check for git repository: %s", e) |
82 | 82 | |
83 | return ("%s/%s" % (name, module.__version__,)).encode("ascii") | |
83 | return module.__version__.encode("ascii") |
44 | 44 | user_info = { |
45 | 45 | "name": self.test_user, |
46 | 46 | "token_id": "ditto", |
47 | "device_id": "device", | |
47 | 48 | } |
48 | 49 | self.store.get_user_by_access_token = Mock(return_value=user_info) |
49 | 50 | |
142 | 143 | # TODO(danielwh): Remove this mock when we remove the |
143 | 144 | # get_user_by_access_token fallback. |
144 | 145 | self.store.get_user_by_access_token = Mock( |
145 | return_value={"name": "@baldrick:matrix.org"} | |
146 | return_value={ | |
147 | "name": "@baldrick:matrix.org", | |
148 | "device_id": "device", | |
149 | } | |
146 | 150 | ) |
147 | 151 | |
148 | 152 | user_id = "@baldrick:matrix.org" |
156 | 160 | user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) |
157 | 161 | user = user_info["user"] |
158 | 162 | self.assertEqual(UserID.from_string(user_id), user) |
163 | ||
164 | # TODO: device_id should come from the macaroon, but currently comes | |
165 | # from the db. | |
166 | self.assertEqual(user_info["device_id"], "device") | |
159 | 167 | |
160 | 168 | @defer.inlineCallbacks |
161 | 169 | def test_get_guest_user_from_macaroon(self): |
280 | 288 | macaroon.add_first_party_caveat("gen = 1") |
281 | 289 | macaroon.add_first_party_caveat("type = access") |
282 | 290 | macaroon.add_first_party_caveat("user_id = %s" % (user,)) |
283 | macaroon.add_first_party_caveat("time < 1") # ms | |
291 | macaroon.add_first_party_caveat("time < -2000") # ms | |
284 | 292 | |
285 | 293 | self.hs.clock.now = 5000 # seconds |
286 | 294 | self.hs.config.expire_access_token = True |
292 | 300 | yield self.auth.get_user_from_macaroon(macaroon.serialize()) |
293 | 301 | self.assertEqual(401, cm.exception.code) |
294 | 302 | self.assertIn("Invalid macaroon", cm.exception.msg) |
303 | ||
304 | @defer.inlineCallbacks | |
305 | def test_get_user_from_macaroon_with_valid_duration(self): | |
306 | # TODO(danielwh): Remove this mock when we remove the | |
307 | # get_user_by_access_token fallback. | |
308 | self.store.get_user_by_access_token = Mock( | |
309 | return_value={"name": "@baldrick:matrix.org"} | |
310 | ) | |
311 | ||
312 | self.store.get_user_by_access_token = Mock( | |
313 | return_value={"name": "@baldrick:matrix.org"} | |
314 | ) | |
315 | ||
316 | user_id = "@baldrick:matrix.org" | |
317 | macaroon = pymacaroons.Macaroon( | |
318 | location=self.hs.config.server_name, | |
319 | identifier="key", | |
320 | key=self.hs.config.macaroon_secret_key) | |
321 | macaroon.add_first_party_caveat("gen = 1") | |
322 | macaroon.add_first_party_caveat("type = access") | |
323 | macaroon.add_first_party_caveat("user_id = %s" % (user_id,)) | |
324 | macaroon.add_first_party_caveat("time < 900000000") # ms | |
325 | ||
326 | self.hs.clock.now = 5000 # seconds | |
327 | self.hs.config.expire_access_token = True | |
328 | ||
329 | user_info = yield self.auth.get_user_from_macaroon(macaroon.serialize()) | |
330 | user = user_info["user"] | |
331 | self.assertEqual(UserID.from_string(user_id), user) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | import synapse.api.errors | |
18 | import synapse.handlers.device | |
19 | ||
20 | import synapse.storage | |
21 | from synapse import types | |
22 | from tests import unittest, utils | |
23 | ||
24 | user1 = "@boris:aaa" | |
25 | user2 = "@theresa:bbb" | |
26 | ||
27 | ||
28 | class DeviceTestCase(unittest.TestCase): | |
29 | def __init__(self, *args, **kwargs): | |
30 | super(DeviceTestCase, self).__init__(*args, **kwargs) | |
31 | self.store = None # type: synapse.storage.DataStore | |
32 | self.handler = None # type: synapse.handlers.device.DeviceHandler | |
33 | self.clock = None # type: utils.MockClock | |
34 | ||
35 | @defer.inlineCallbacks | |
36 | def setUp(self): | |
37 | hs = yield utils.setup_test_homeserver(handlers=None) | |
38 | self.handler = synapse.handlers.device.DeviceHandler(hs) | |
39 | self.store = hs.get_datastore() | |
40 | self.clock = hs.get_clock() | |
41 | ||
42 | @defer.inlineCallbacks | |
43 | def test_device_is_created_if_doesnt_exist(self): | |
44 | res = yield self.handler.check_device_registered( | |
45 | user_id="boris", | |
46 | device_id="fco", | |
47 | initial_device_display_name="display name" | |
48 | ) | |
49 | self.assertEqual(res, "fco") | |
50 | ||
51 | dev = yield self.handler.store.get_device("boris", "fco") | |
52 | self.assertEqual(dev["display_name"], "display name") | |
53 | ||
54 | @defer.inlineCallbacks | |
55 | def test_device_is_preserved_if_exists(self): | |
56 | res1 = yield self.handler.check_device_registered( | |
57 | user_id="boris", | |
58 | device_id="fco", | |
59 | initial_device_display_name="display name" | |
60 | ) | |
61 | self.assertEqual(res1, "fco") | |
62 | ||
63 | res2 = yield self.handler.check_device_registered( | |
64 | user_id="boris", | |
65 | device_id="fco", | |
66 | initial_device_display_name="new display name" | |
67 | ) | |
68 | self.assertEqual(res2, "fco") | |
69 | ||
70 | dev = yield self.handler.store.get_device("boris", "fco") | |
71 | self.assertEqual(dev["display_name"], "display name") | |
72 | ||
73 | @defer.inlineCallbacks | |
74 | def test_device_id_is_made_up_if_unspecified(self): | |
75 | device_id = yield self.handler.check_device_registered( | |
76 | user_id="theresa", | |
77 | device_id=None, | |
78 | initial_device_display_name="display" | |
79 | ) | |
80 | ||
81 | dev = yield self.handler.store.get_device("theresa", device_id) | |
82 | self.assertEqual(dev["display_name"], "display") | |
83 | ||
84 | @defer.inlineCallbacks | |
85 | def test_get_devices_by_user(self): | |
86 | yield self._record_users() | |
87 | ||
88 | res = yield self.handler.get_devices_by_user(user1) | |
89 | self.assertEqual(3, len(res)) | |
90 | device_map = { | |
91 | d["device_id"]: d for d in res | |
92 | } | |
93 | self.assertDictContainsSubset({ | |
94 | "user_id": user1, | |
95 | "device_id": "xyz", | |
96 | "display_name": "display 0", | |
97 | "last_seen_ip": None, | |
98 | "last_seen_ts": None, | |
99 | }, device_map["xyz"]) | |
100 | self.assertDictContainsSubset({ | |
101 | "user_id": user1, | |
102 | "device_id": "fco", | |
103 | "display_name": "display 1", | |
104 | "last_seen_ip": "ip1", | |
105 | "last_seen_ts": 1000000, | |
106 | }, device_map["fco"]) | |
107 | self.assertDictContainsSubset({ | |
108 | "user_id": user1, | |
109 | "device_id": "abc", | |
110 | "display_name": "display 2", | |
111 | "last_seen_ip": "ip3", | |
112 | "last_seen_ts": 3000000, | |
113 | }, device_map["abc"]) | |
114 | ||
115 | @defer.inlineCallbacks | |
116 | def test_get_device(self): | |
117 | yield self._record_users() | |
118 | ||
119 | res = yield self.handler.get_device(user1, "abc") | |
120 | self.assertDictContainsSubset({ | |
121 | "user_id": user1, | |
122 | "device_id": "abc", | |
123 | "display_name": "display 2", | |
124 | "last_seen_ip": "ip3", | |
125 | "last_seen_ts": 3000000, | |
126 | }, res) | |
127 | ||
128 | @defer.inlineCallbacks | |
129 | def test_delete_device(self): | |
130 | yield self._record_users() | |
131 | ||
132 | # delete the device | |
133 | yield self.handler.delete_device(user1, "abc") | |
134 | ||
135 | # check the device was deleted | |
136 | with self.assertRaises(synapse.api.errors.NotFoundError): | |
137 | yield self.handler.get_device(user1, "abc") | |
138 | ||
139 | # we'd like to check the access token was invalidated, but that's a | |
140 | # bit of a PITA. | |
141 | ||
142 | @defer.inlineCallbacks | |
143 | def test_update_device(self): | |
144 | yield self._record_users() | |
145 | ||
146 | update = {"display_name": "new display"} | |
147 | yield self.handler.update_device(user1, "abc", update) | |
148 | ||
149 | res = yield self.handler.get_device(user1, "abc") | |
150 | self.assertEqual(res["display_name"], "new display") | |
151 | ||
152 | @defer.inlineCallbacks | |
153 | def test_update_unknown_device(self): | |
154 | update = {"display_name": "new_display"} | |
155 | with self.assertRaises(synapse.api.errors.NotFoundError): | |
156 | yield self.handler.update_device("user_id", "unknown_device_id", | |
157 | update) | |
158 | ||
159 | @defer.inlineCallbacks | |
160 | def _record_users(self): | |
161 | # check this works for both devices which have a recorded client_ip, | |
162 | # and those which don't. | |
163 | yield self._record_user(user1, "xyz", "display 0") | |
164 | yield self._record_user(user1, "fco", "display 1", "token1", "ip1") | |
165 | yield self._record_user(user1, "abc", "display 2", "token2", "ip2") | |
166 | yield self._record_user(user1, "abc", "display 2", "token3", "ip3") | |
167 | ||
168 | yield self._record_user(user2, "def", "dispkay", "token4", "ip4") | |
169 | ||
170 | @defer.inlineCallbacks | |
171 | def _record_user(self, user_id, device_id, display_name, | |
172 | access_token=None, ip=None): | |
173 | device_id = yield self.handler.check_device_registered( | |
174 | user_id=user_id, | |
175 | device_id=device_id, | |
176 | initial_device_display_name=display_name | |
177 | ) | |
178 | ||
179 | if ip is not None: | |
180 | yield self.store.insert_client_ip( | |
181 | types.UserID.from_string(user_id), | |
182 | access_token, ip, "user_agent", device_id) | |
183 | self.clock.advance_time(1000) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | import mock | |
16 | from twisted.internet import defer | |
17 | ||
18 | import synapse.api.errors | |
19 | import synapse.handlers.e2e_keys | |
20 | ||
21 | import synapse.storage | |
22 | from tests import unittest, utils | |
23 | ||
24 | ||
25 | class E2eKeysHandlerTestCase(unittest.TestCase): | |
26 | def __init__(self, *args, **kwargs): | |
27 | super(E2eKeysHandlerTestCase, self).__init__(*args, **kwargs) | |
28 | self.hs = None # type: synapse.server.HomeServer | |
29 | self.handler = None # type: synapse.handlers.e2e_keys.E2eKeysHandler | |
30 | ||
31 | @defer.inlineCallbacks | |
32 | def setUp(self): | |
33 | self.hs = yield utils.setup_test_homeserver( | |
34 | handlers=None, | |
35 | replication_layer=mock.Mock(), | |
36 | ) | |
37 | self.handler = synapse.handlers.e2e_keys.E2eKeysHandler(self.hs) | |
38 | ||
39 | @defer.inlineCallbacks | |
40 | def test_query_local_devices_no_devices(self): | |
41 | """If the user has no devices, we expect an empty list. | |
42 | """ | |
43 | local_user = "@boris:" + self.hs.hostname | |
44 | res = yield self.handler.query_local_devices({local_user: None}) | |
45 | self.assertDictEqual(res, {local_user: {}}) |
18 | 18 | |
19 | 19 | from mock import Mock, NonCallableMock |
20 | 20 | |
21 | import synapse.types | |
21 | 22 | from synapse.api.errors import AuthError |
22 | 23 | from synapse.handlers.profile import ProfileHandler |
23 | 24 | from synapse.types import UserID |
24 | 25 | |
25 | from tests.utils import setup_test_homeserver, requester_for_user | |
26 | from tests.utils import setup_test_homeserver | |
26 | 27 | |
27 | 28 | |
28 | 29 | class ProfileHandlers(object): |
85 | 86 | def test_set_my_name(self): |
86 | 87 | yield self.handler.set_displayname( |
87 | 88 | self.frank, |
88 | requester_for_user(self.frank), | |
89 | synapse.types.create_requester(self.frank), | |
89 | 90 | "Frank Jr." |
90 | 91 | ) |
91 | 92 | |
98 | 99 | def test_set_my_name_noauth(self): |
99 | 100 | d = self.handler.set_displayname( |
100 | 101 | self.frank, |
101 | requester_for_user(self.bob), | |
102 | synapse.types.create_requester(self.bob), | |
102 | 103 | "Frank Jr." |
103 | 104 | ) |
104 | 105 | |
143 | 144 | @defer.inlineCallbacks |
144 | 145 | def test_set_my_avatar(self): |
145 | 146 | yield self.handler.set_avatar_url( |
146 | self.frank, requester_for_user(self.frank), "http://my.server/pic.gif" | |
147 | self.frank, synapse.types.create_requester(self.frank), | |
148 | "http://my.server/pic.gif" | |
147 | 149 | ) |
148 | 150 | |
149 | 151 | self.assertEquals( |
41 | 41 | http_client=None, |
42 | 42 | expire_access_token=True) |
43 | 43 | self.auth_handler = Mock( |
44 | generate_short_term_login_token=Mock(return_value='secret')) | |
44 | generate_access_token=Mock(return_value='secret')) | |
45 | 45 | self.hs.handlers = RegistrationHandlers(self.hs) |
46 | 46 | self.handler = self.hs.get_handlers().registration_handler |
47 | 47 | self.hs.get_handlers().profile_handler = Mock() |
48 | 48 | self.mock_handler = Mock(spec=[ |
49 | "generate_short_term_login_token", | |
49 | "generate_access_token", | |
50 | 50 | ]) |
51 | 51 | self.hs.get_auth_handler = Mock(return_value=self.auth_handler) |
52 | 52 |
56 | 56 | |
57 | 57 | def tearDown(self): |
58 | 58 | [unpatch() for unpatch in self.unpatches] |
59 | ||
60 | @defer.inlineCallbacks | |
61 | def test_room_name_and_aliases(self): | |
62 | create = yield self.persist(type="m.room.create", key="", creator=USER_ID) | |
63 | yield self.persist(type="m.room.member", key=USER_ID, membership="join") | |
64 | yield self.persist(type="m.room.name", key="", name="name1") | |
65 | yield self.persist( | |
66 | type="m.room.aliases", key="blue", aliases=["#1:blue"] | |
67 | ) | |
68 | yield self.replicate() | |
69 | yield self.check( | |
70 | "get_room_name_and_aliases", (ROOM_ID,), ("name1", ["#1:blue"]) | |
71 | ) | |
72 | ||
73 | # Set the room name. | |
74 | yield self.persist(type="m.room.name", key="", name="name2") | |
75 | yield self.replicate() | |
76 | yield self.check( | |
77 | "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#1:blue"]) | |
78 | ) | |
79 | ||
80 | # Set the room aliases. | |
81 | yield self.persist( | |
82 | type="m.room.aliases", key="blue", aliases=["#2:blue"] | |
83 | ) | |
84 | yield self.replicate() | |
85 | yield self.check( | |
86 | "get_room_name_and_aliases", (ROOM_ID,), ("name2", ["#2:blue"]) | |
87 | ) | |
88 | ||
89 | # Leave and join the room clobbering the state. | |
90 | yield self.persist(type="m.room.member", key=USER_ID, membership="leave") | |
91 | yield self.persist( | |
92 | type="m.room.member", key=USER_ID, membership="join", | |
93 | reset_state=[create] | |
94 | ) | |
95 | yield self.replicate() | |
96 | ||
97 | yield self.check( | |
98 | "get_room_name_and_aliases", (ROOM_ID,), (None, []) | |
99 | ) | |
100 | 59 | |
101 | 60 | @defer.inlineCallbacks |
102 | 61 | def test_room_members(self): |
12 | 12 | # See the License for the specific language governing permissions and |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | import contextlib | |
16 | import json | |
17 | ||
18 | from mock import Mock, NonCallableMock | |
19 | from twisted.internet import defer | |
20 | ||
21 | import synapse.types | |
15 | 22 | from synapse.replication.resource import ReplicationResource |
16 | from synapse.types import Requester, UserID | |
17 | ||
18 | from twisted.internet import defer | |
23 | from synapse.types import UserID | |
19 | 24 | from tests import unittest |
20 | from tests.utils import setup_test_homeserver, requester_for_user | |
21 | from mock import Mock, NonCallableMock | |
22 | import json | |
23 | import contextlib | |
25 | from tests.utils import setup_test_homeserver | |
24 | 26 | |
25 | 27 | |
26 | 28 | class ReplicationResourceCase(unittest.TestCase): |
60 | 62 | def test_events_and_state(self): |
61 | 63 | get = self.get(events="-1", state="-1", timeout="0") |
62 | 64 | yield self.hs.get_handlers().room_creation_handler.create_room( |
63 | Requester(self.user, "", False), {} | |
65 | synapse.types.create_requester(self.user), {} | |
64 | 66 | ) |
65 | 67 | code, body = yield get |
66 | 68 | self.assertEquals(code, 200) |
143 | 145 | def send_text_message(self, room_id, message): |
144 | 146 | handler = self.hs.get_handlers().message_handler |
145 | 147 | event = yield handler.create_and_send_nonmember_event( |
146 | requester_for_user(self.user), | |
148 | synapse.types.create_requester(self.user), | |
147 | 149 | { |
148 | 150 | "type": "m.room.message", |
149 | 151 | "content": {"body": "message", "msgtype": "m.text"}, |
156 | 158 | @defer.inlineCallbacks |
157 | 159 | def create_room(self): |
158 | 160 | result = yield self.hs.get_handlers().room_creation_handler.create_room( |
159 | Requester(self.user, "", False), {} | |
161 | synapse.types.create_requester(self.user), {} | |
160 | 162 | ) |
161 | 163 | defer.returnValue(result["room_id"]) |
162 | 164 |
13 | 13 | # limitations under the License. |
14 | 14 | |
15 | 15 | """Tests REST events for /profile paths.""" |
16 | from tests import unittest | |
16 | from mock import Mock | |
17 | 17 | from twisted.internet import defer |
18 | 18 | |
19 | from mock import Mock | |
20 | ||
19 | import synapse.types | |
20 | from synapse.api.errors import SynapseError, AuthError | |
21 | from synapse.rest.client.v1 import profile | |
22 | from tests import unittest | |
21 | 23 | from ....utils import MockHttpResource, setup_test_homeserver |
22 | ||
23 | from synapse.api.errors import SynapseError, AuthError | |
24 | from synapse.types import Requester, UserID | |
25 | ||
26 | from synapse.rest.client.v1 import profile | |
27 | 24 | |
28 | 25 | myid = "@1234ABCD:test" |
29 | 26 | PATH_PREFIX = "/_matrix/client/api/v1" |
51 | 48 | ) |
52 | 49 | |
53 | 50 | def _get_user_by_req(request=None, allow_guest=False): |
54 | return Requester(UserID.from_string(myid), "", False) | |
51 | return synapse.types.create_requester(myid) | |
55 | 52 | |
56 | 53 | hs.get_v1auth().get_user_by_req = _get_user_by_req |
57 | 54 |
29 | 29 | self.registration_handler = Mock() |
30 | 30 | self.identity_handler = Mock() |
31 | 31 | self.login_handler = Mock() |
32 | self.device_handler = Mock() | |
32 | 33 | |
33 | 34 | # do the dance to hook it up to the hs global |
34 | 35 | self.handlers = Mock( |
41 | 42 | self.hs.get_auth = Mock(return_value=self.auth) |
42 | 43 | self.hs.get_handlers = Mock(return_value=self.handlers) |
43 | 44 | self.hs.get_auth_handler = Mock(return_value=self.auth_handler) |
45 | self.hs.get_device_handler = Mock(return_value=self.device_handler) | |
44 | 46 | self.hs.config.enable_registration = True |
45 | 47 | |
46 | 48 | # init the thing we're testing |
60 | 62 | "id": "1234" |
61 | 63 | } |
62 | 64 | self.registration_handler.appservice_register = Mock( |
63 | return_value=(user_id, token) | |
65 | return_value=user_id | |
64 | 66 | ) |
67 | self.auth_handler.get_login_tuple_for_user_id = Mock( | |
68 | return_value=(token, "kermits_refresh_token") | |
69 | ) | |
70 | ||
65 | 71 | (code, result) = yield self.servlet.on_POST(self.request) |
66 | 72 | self.assertEquals(code, 200) |
67 | 73 | det_data = { |
68 | 74 | "user_id": user_id, |
69 | 75 | "access_token": token, |
76 | "refresh_token": "kermits_refresh_token", | |
70 | 77 | "home_server": self.hs.hostname |
71 | 78 | } |
72 | 79 | self.assertDictContainsSubset(det_data, result) |
104 | 111 | def test_POST_user_valid(self): |
105 | 112 | user_id = "@kermit:muppet" |
106 | 113 | token = "kermits_access_token" |
114 | device_id = "frogfone" | |
107 | 115 | self.request_data = json.dumps({ |
108 | 116 | "username": "kermit", |
109 | "password": "monkey" | |
117 | "password": "monkey", | |
118 | "device_id": device_id, | |
110 | 119 | }) |
111 | 120 | self.registration_handler.check_username = Mock(return_value=True) |
112 | 121 | self.auth_result = (True, None, { |
113 | 122 | "username": "kermit", |
114 | 123 | "password": "monkey" |
115 | 124 | }, None) |
116 | self.registration_handler.register = Mock(return_value=(user_id, token)) | |
125 | self.registration_handler.register = Mock(return_value=(user_id, None)) | |
126 | self.auth_handler.get_login_tuple_for_user_id = Mock( | |
127 | return_value=(token, "kermits_refresh_token") | |
128 | ) | |
129 | self.device_handler.check_device_registered = \ | |
130 | Mock(return_value=device_id) | |
117 | 131 | |
118 | 132 | (code, result) = yield self.servlet.on_POST(self.request) |
119 | 133 | self.assertEquals(code, 200) |
120 | 134 | det_data = { |
121 | 135 | "user_id": user_id, |
122 | 136 | "access_token": token, |
123 | "home_server": self.hs.hostname | |
137 | "refresh_token": "kermits_refresh_token", | |
138 | "home_server": self.hs.hostname, | |
139 | "device_id": device_id, | |
124 | 140 | } |
125 | 141 | self.assertDictContainsSubset(det_data, result) |
126 | 142 | self.assertIn("refresh_token", result) |
143 | self.auth_handler.get_login_tuple_for_user_id( | |
144 | user_id, device_id=device_id, initial_device_display_name=None) | |
127 | 145 | |
128 | 146 | def test_POST_disabled_registration(self): |
129 | 147 | self.hs.config.enable_registration = False |
29 | 29 | def create_room(self, room): |
30 | 30 | builder = self.event_builder_factory.new({ |
31 | 31 | "type": EventTypes.Create, |
32 | "sender": "", | |
32 | 33 | "room_id": room.to_string(), |
33 | 34 | "content": {}, |
34 | 35 | }) |
9 | 9 | |
10 | 10 | @defer.inlineCallbacks |
11 | 11 | def setUp(self): |
12 | hs = yield setup_test_homeserver() | |
12 | hs = yield setup_test_homeserver() # type: synapse.server.HomeServer | |
13 | 13 | self.store = hs.get_datastore() |
14 | 14 | self.clock = hs.get_clock() |
15 | 15 | |
19 | 19 | "test_update", self.update_handler |
20 | 20 | ) |
21 | 21 | |
22 | # run the real background updates, to get them out the way | |
23 | # (perhaps we should run them as part of the test HS setup, since we | |
24 | # run all of the other schema setup stuff there?) | |
25 | while True: | |
26 | res = yield self.store.do_next_background_update(1000) | |
27 | if res is None: | |
28 | break | |
29 | ||
22 | 30 | @defer.inlineCallbacks |
23 | 31 | def test_do_background_update(self): |
24 | 32 | desired_count = 1000 |
25 | 33 | duration_ms = 42 |
26 | 34 | |
35 | # first step: make a bit of progress | |
27 | 36 | @defer.inlineCallbacks |
28 | 37 | def update(progress, count): |
29 | 38 | self.clock.advance_time_msec(count * duration_ms) |
41 | 50 | yield self.store.start_background_update("test_update", {"my_key": 1}) |
42 | 51 | |
43 | 52 | self.update_handler.reset_mock() |
44 | result = yield self.store.do_background_update( | |
53 | result = yield self.store.do_next_background_update( | |
45 | 54 | duration_ms * desired_count |
46 | 55 | ) |
47 | 56 | self.assertIsNotNone(result) |
49 | 58 | {"my_key": 1}, self.store.DEFAULT_BACKGROUND_BATCH_SIZE |
50 | 59 | ) |
51 | 60 | |
61 | # second step: complete the update | |
52 | 62 | @defer.inlineCallbacks |
53 | 63 | def update(progress, count): |
54 | 64 | yield self.store._end_background_update("test_update") |
55 | 65 | defer.returnValue(count) |
56 | 66 | |
57 | 67 | self.update_handler.side_effect = update |
58 | ||
59 | 68 | self.update_handler.reset_mock() |
60 | result = yield self.store.do_background_update( | |
69 | result = yield self.store.do_next_background_update( | |
61 | 70 | duration_ms * desired_count |
62 | 71 | ) |
63 | 72 | self.assertIsNotNone(result) |
65 | 74 | {"my_key": 2}, desired_count |
66 | 75 | ) |
67 | 76 | |
77 | # third step: we don't expect to be called any more | |
68 | 78 | self.update_handler.reset_mock() |
69 | result = yield self.store.do_background_update( | |
79 | result = yield self.store.do_next_background_update( | |
70 | 80 | duration_ms * desired_count |
71 | 81 | ) |
72 | 82 | self.assertIsNone(result) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | import synapse.server | |
18 | import synapse.storage | |
19 | import synapse.types | |
20 | import tests.unittest | |
21 | import tests.utils | |
22 | ||
23 | ||
24 | class ClientIpStoreTestCase(tests.unittest.TestCase): | |
25 | def __init__(self, *args, **kwargs): | |
26 | super(ClientIpStoreTestCase, self).__init__(*args, **kwargs) | |
27 | self.store = None # type: synapse.storage.DataStore | |
28 | self.clock = None # type: tests.utils.MockClock | |
29 | ||
30 | @defer.inlineCallbacks | |
31 | def setUp(self): | |
32 | hs = yield tests.utils.setup_test_homeserver() | |
33 | self.store = hs.get_datastore() | |
34 | self.clock = hs.get_clock() | |
35 | ||
36 | @defer.inlineCallbacks | |
37 | def test_insert_new_client_ip(self): | |
38 | self.clock.now = 12345678 | |
39 | user_id = "@user:id" | |
40 | yield self.store.insert_client_ip( | |
41 | synapse.types.UserID.from_string(user_id), | |
42 | "access_token", "ip", "user_agent", "device_id", | |
43 | ) | |
44 | ||
45 | # deliberately use an iterable here to make sure that the lookup | |
46 | # method doesn't iterate it twice | |
47 | device_list = iter(((user_id, "device_id"),)) | |
48 | result = yield self.store.get_last_client_ip_by_device(device_list) | |
49 | ||
50 | r = result[(user_id, "device_id")] | |
51 | self.assertDictContainsSubset( | |
52 | { | |
53 | "user_id": user_id, | |
54 | "device_id": "device_id", | |
55 | "access_token": "access_token", | |
56 | "ip": "ip", | |
57 | "user_agent": "user_agent", | |
58 | "last_seen": 12345678000, | |
59 | }, | |
60 | r | |
61 | ) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | import synapse.api.errors | |
18 | import tests.unittest | |
19 | import tests.utils | |
20 | ||
21 | ||
22 | class DeviceStoreTestCase(tests.unittest.TestCase): | |
23 | def __init__(self, *args, **kwargs): | |
24 | super(DeviceStoreTestCase, self).__init__(*args, **kwargs) | |
25 | self.store = None # type: synapse.storage.DataStore | |
26 | ||
27 | @defer.inlineCallbacks | |
28 | def setUp(self): | |
29 | hs = yield tests.utils.setup_test_homeserver() | |
30 | ||
31 | self.store = hs.get_datastore() | |
32 | ||
33 | @defer.inlineCallbacks | |
34 | def test_store_new_device(self): | |
35 | yield self.store.store_device( | |
36 | "user_id", "device_id", "display_name" | |
37 | ) | |
38 | ||
39 | res = yield self.store.get_device("user_id", "device_id") | |
40 | self.assertDictContainsSubset({ | |
41 | "user_id": "user_id", | |
42 | "device_id": "device_id", | |
43 | "display_name": "display_name", | |
44 | }, res) | |
45 | ||
46 | @defer.inlineCallbacks | |
47 | def test_get_devices_by_user(self): | |
48 | yield self.store.store_device( | |
49 | "user_id", "device1", "display_name 1" | |
50 | ) | |
51 | yield self.store.store_device( | |
52 | "user_id", "device2", "display_name 2" | |
53 | ) | |
54 | yield self.store.store_device( | |
55 | "user_id2", "device3", "display_name 3" | |
56 | ) | |
57 | ||
58 | res = yield self.store.get_devices_by_user("user_id") | |
59 | self.assertEqual(2, len(res.keys())) | |
60 | self.assertDictContainsSubset({ | |
61 | "user_id": "user_id", | |
62 | "device_id": "device1", | |
63 | "display_name": "display_name 1", | |
64 | }, res["device1"]) | |
65 | self.assertDictContainsSubset({ | |
66 | "user_id": "user_id", | |
67 | "device_id": "device2", | |
68 | "display_name": "display_name 2", | |
69 | }, res["device2"]) | |
70 | ||
71 | @defer.inlineCallbacks | |
72 | def test_update_device(self): | |
73 | yield self.store.store_device( | |
74 | "user_id", "device_id", "display_name 1" | |
75 | ) | |
76 | ||
77 | res = yield self.store.get_device("user_id", "device_id") | |
78 | self.assertEqual("display_name 1", res["display_name"]) | |
79 | ||
80 | # do a no-op first | |
81 | yield self.store.update_device( | |
82 | "user_id", "device_id", | |
83 | ) | |
84 | res = yield self.store.get_device("user_id", "device_id") | |
85 | self.assertEqual("display_name 1", res["display_name"]) | |
86 | ||
87 | # do the update | |
88 | yield self.store.update_device( | |
89 | "user_id", "device_id", | |
90 | new_display_name="display_name 2", | |
91 | ) | |
92 | ||
93 | # check it worked | |
94 | res = yield self.store.get_device("user_id", "device_id") | |
95 | self.assertEqual("display_name 2", res["display_name"]) | |
96 | ||
97 | @defer.inlineCallbacks | |
98 | def test_update_unknown_device(self): | |
99 | with self.assertRaises(synapse.api.errors.StoreError) as cm: | |
100 | yield self.store.update_device( | |
101 | "user_id", "unknown_device_id", | |
102 | new_display_name="display_name 2", | |
103 | ) | |
104 | self.assertEqual(404, cm.exception.code) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | import tests.unittest | |
18 | import tests.utils | |
19 | ||
20 | ||
21 | class EndToEndKeyStoreTestCase(tests.unittest.TestCase): | |
22 | def __init__(self, *args, **kwargs): | |
23 | super(EndToEndKeyStoreTestCase, self).__init__(*args, **kwargs) | |
24 | self.store = None # type: synapse.storage.DataStore | |
25 | ||
26 | @defer.inlineCallbacks | |
27 | def setUp(self): | |
28 | hs = yield tests.utils.setup_test_homeserver() | |
29 | ||
30 | self.store = hs.get_datastore() | |
31 | ||
32 | @defer.inlineCallbacks | |
33 | def test_key_without_device_name(self): | |
34 | now = 1470174257070 | |
35 | json = '{ "key": "value" }' | |
36 | ||
37 | yield self.store.set_e2e_device_keys( | |
38 | "user", "device", now, json) | |
39 | ||
40 | res = yield self.store.get_e2e_device_keys((("user", "device"),)) | |
41 | self.assertIn("user", res) | |
42 | self.assertIn("device", res["user"]) | |
43 | dev = res["user"]["device"] | |
44 | self.assertDictContainsSubset({ | |
45 | "key_json": json, | |
46 | "device_display_name": None, | |
47 | }, dev) | |
48 | ||
49 | @defer.inlineCallbacks | |
50 | def test_get_key_with_device_name(self): | |
51 | now = 1470174257070 | |
52 | json = '{ "key": "value" }' | |
53 | ||
54 | yield self.store.set_e2e_device_keys( | |
55 | "user", "device", now, json) | |
56 | yield self.store.store_device( | |
57 | "user", "device", "display_name" | |
58 | ) | |
59 | ||
60 | res = yield self.store.get_e2e_device_keys((("user", "device"),)) | |
61 | self.assertIn("user", res) | |
62 | self.assertIn("device", res["user"]) | |
63 | dev = res["user"]["device"] | |
64 | self.assertDictContainsSubset({ | |
65 | "key_json": json, | |
66 | "device_display_name": "display_name", | |
67 | }, dev) | |
68 | ||
69 | @defer.inlineCallbacks | |
70 | def test_multiple_devices(self): | |
71 | now = 1470174257070 | |
72 | ||
73 | yield self.store.set_e2e_device_keys( | |
74 | "user1", "device1", now, 'json11') | |
75 | yield self.store.set_e2e_device_keys( | |
76 | "user1", "device2", now, 'json12') | |
77 | yield self.store.set_e2e_device_keys( | |
78 | "user2", "device1", now, 'json21') | |
79 | yield self.store.set_e2e_device_keys( | |
80 | "user2", "device2", now, 'json22') | |
81 | ||
82 | res = yield self.store.get_e2e_device_keys((("user1", "device1"), | |
83 | ("user2", "device2"))) | |
84 | self.assertIn("user1", res) | |
85 | self.assertIn("device1", res["user1"]) | |
86 | self.assertNotIn("device2", res["user1"]) | |
87 | self.assertIn("user2", res) | |
88 | self.assertNotIn("device1", res["user2"]) | |
89 | self.assertIn("device2", res["user2"]) |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from twisted.internet import defer | |
16 | ||
17 | import tests.unittest | |
18 | import tests.utils | |
19 | ||
20 | USER_ID = "@user:example.com" | |
21 | ||
22 | ||
23 | class EventPushActionsStoreTestCase(tests.unittest.TestCase): | |
24 | ||
25 | @defer.inlineCallbacks | |
26 | def setUp(self): | |
27 | hs = yield tests.utils.setup_test_homeserver() | |
28 | self.store = hs.get_datastore() | |
29 | ||
30 | @defer.inlineCallbacks | |
31 | def test_get_unread_push_actions_for_user_in_range_for_http(self): | |
32 | yield self.store.get_unread_push_actions_for_user_in_range_for_http( | |
33 | USER_ID, 0, 1000, 20 | |
34 | ) | |
35 | ||
36 | @defer.inlineCallbacks | |
37 | def test_get_unread_push_actions_for_user_in_range_for_email(self): | |
38 | yield self.store.get_unread_push_actions_for_user_in_range_for_email( | |
39 | USER_ID, 0, 1000, 20 | |
40 | ) |
36 | 36 | |
37 | 37 | @defer.inlineCallbacks |
38 | 38 | def test_count_daily_messages(self): |
39 | self.db_pool.runQuery("DELETE FROM stats_reporting") | |
39 | yield self.db_pool.runQuery("DELETE FROM stats_reporting") | |
40 | 40 | |
41 | 41 | self.hs.clock.now = 100 |
42 | 42 | |
59 | 59 | # it isn't old enough. |
60 | 60 | count = yield self.store.count_daily_messages() |
61 | 61 | self.assertIsNone(count) |
62 | self._assert_stats_reporting(1, self.hs.clock.now) | |
62 | yield self._assert_stats_reporting(1, self.hs.clock.now) | |
63 | 63 | |
64 | 64 | # Already reported yesterday, two new events from today. |
65 | 65 | yield self.event_injector.inject_message(room, user, "Yeah they are!") |
67 | 67 | self.hs.clock.now += 60 * 60 * 24 |
68 | 68 | count = yield self.store.count_daily_messages() |
69 | 69 | self.assertEqual(2, count) # 2 since yesterday |
70 | self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever | |
70 | yield self._assert_stats_reporting(3, self.hs.clock.now) # 3 ever | |
71 | 71 | |
72 | 72 | # Last reported too recently. |
73 | 73 | yield self.event_injector.inject_message(room, user, "Who could disagree?") |
74 | 74 | self.hs.clock.now += 60 * 60 * 22 |
75 | 75 | count = yield self.store.count_daily_messages() |
76 | 76 | self.assertIsNone(count) |
77 | self._assert_stats_reporting(4, self.hs.clock.now) | |
77 | yield self._assert_stats_reporting(4, self.hs.clock.now) | |
78 | 78 | |
79 | 79 | # Last reported too long ago |
80 | 80 | yield self.event_injector.inject_message(room, user, "No one.") |
81 | 81 | self.hs.clock.now += 60 * 60 * 26 |
82 | 82 | count = yield self.store.count_daily_messages() |
83 | 83 | self.assertIsNone(count) |
84 | self._assert_stats_reporting(5, self.hs.clock.now) | |
84 | yield self._assert_stats_reporting(5, self.hs.clock.now) | |
85 | 85 | |
86 | 86 | # And now let's actually report something |
87 | 87 | yield self.event_injector.inject_message(room, user, "Indeed.") |
91 | 91 | self.hs.clock.now += (60 * 60 * 24) + 50 |
92 | 92 | count = yield self.store.count_daily_messages() |
93 | 93 | self.assertEqual(3, count) |
94 | self._assert_stats_reporting(8, self.hs.clock.now) | |
94 | yield self._assert_stats_reporting(8, self.hs.clock.now) | |
95 | 95 | |
96 | 96 | @defer.inlineCallbacks |
97 | 97 | def _get_last_stream_token(self): |
37 | 37 | "BcDeFgHiJkLmNoPqRsTuVwXyZa" |
38 | 38 | ] |
39 | 39 | self.pwhash = "{xx1}123456789" |
40 | self.device_id = "akgjhdjklgshg" | |
40 | 41 | |
41 | 42 | @defer.inlineCallbacks |
42 | 43 | def test_register(self): |
63 | 64 | @defer.inlineCallbacks |
64 | 65 | def test_add_tokens(self): |
65 | 66 | yield self.store.register(self.user_id, self.tokens[0], self.pwhash) |
66 | yield self.store.add_access_token_to_user(self.user_id, self.tokens[1]) | |
67 | yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], | |
68 | self.device_id) | |
67 | 69 | |
68 | 70 | result = yield self.store.get_user_by_access_token(self.tokens[1]) |
69 | 71 | |
70 | 72 | self.assertDictContainsSubset( |
71 | 73 | { |
72 | 74 | "name": self.user_id, |
75 | "device_id": self.device_id, | |
73 | 76 | }, |
74 | 77 | result |
75 | 78 | ) |
79 | 82 | @defer.inlineCallbacks |
80 | 83 | def test_exchange_refresh_token_valid(self): |
81 | 84 | uid = stringutils.random_string(32) |
85 | device_id = stringutils.random_string(16) | |
82 | 86 | generator = TokenGenerator() |
83 | 87 | last_token = generator.generate(uid) |
84 | 88 | |
85 | 89 | self.db_pool.runQuery( |
86 | "INSERT INTO refresh_tokens(user_id, token) VALUES(?,?)", | |
87 | (uid, last_token,)) | |
90 | "INSERT INTO refresh_tokens(user_id, token, device_id) " | |
91 | "VALUES(?,?,?)", | |
92 | (uid, last_token, device_id)) | |
88 | 93 | |
89 | (found_user_id, refresh_token) = yield self.store.exchange_refresh_token( | |
90 | last_token, generator.generate) | |
94 | (found_user_id, refresh_token, device_id) = \ | |
95 | yield self.store.exchange_refresh_token(last_token, | |
96 | generator.generate) | |
91 | 97 | self.assertEqual(uid, found_user_id) |
92 | 98 | |
93 | 99 | rows = yield self.db_pool.runQuery( |
94 | "SELECT token FROM refresh_tokens WHERE user_id = ?", (uid, )) | |
95 | self.assertEqual([(refresh_token,)], rows) | |
100 | "SELECT token, device_id FROM refresh_tokens WHERE user_id = ?", | |
101 | (uid, )) | |
102 | self.assertEqual([(refresh_token, device_id)], rows) | |
96 | 103 | # We issued token 1, then exchanged it for token 2 |
97 | 104 | expected_refresh_token = u"%s-%d" % (uid, 2,) |
98 | 105 | self.assertEqual(expected_refresh_token, refresh_token) |
120 | 127 | with self.assertRaises(StoreError): |
121 | 128 | yield self.store.exchange_refresh_token(last_token, generator.generate) |
122 | 129 | |
130 | @defer.inlineCallbacks | |
131 | def test_user_delete_access_tokens(self): | |
132 | # add some tokens | |
133 | generator = TokenGenerator() | |
134 | refresh_token = generator.generate(self.user_id) | |
135 | yield self.store.register(self.user_id, self.tokens[0], self.pwhash) | |
136 | yield self.store.add_access_token_to_user(self.user_id, self.tokens[1], | |
137 | self.device_id) | |
138 | yield self.store.add_refresh_token_to_user(self.user_id, refresh_token, | |
139 | self.device_id) | |
140 | ||
141 | # now delete some | |
142 | yield self.store.user_delete_access_tokens( | |
143 | self.user_id, device_id=self.device_id, delete_refresh_tokens=True) | |
144 | ||
145 | # check they were deleted | |
146 | user = yield self.store.get_user_by_access_token(self.tokens[1]) | |
147 | self.assertIsNone(user, "access token was not deleted by device_id") | |
148 | with self.assertRaises(StoreError): | |
149 | yield self.store.exchange_refresh_token(refresh_token, | |
150 | generator.generate) | |
151 | ||
152 | # check the one not associated with the device was not deleted | |
153 | user = yield self.store.get_user_by_access_token(self.tokens[0]) | |
154 | self.assertEqual(self.user_id, user["name"]) | |
155 | ||
156 | # now delete the rest | |
157 | yield self.store.user_delete_access_tokens( | |
158 | self.user_id, delete_refresh_tokens=True) | |
159 | ||
160 | user = yield self.store.get_user_by_access_token(self.tokens[0]) | |
161 | self.assertIsNone(user, | |
162 | "access token was not deleted without device_id") | |
163 | ||
123 | 164 | |
124 | 165 | class TokenGenerator: |
125 | 166 | def __init__(self): |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2014-2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | from . import unittest | |
16 | ||
17 | from synapse.rest.media.v1.preview_url_resource import summarize_paragraphs | |
18 | ||
19 | ||
20 | class PreviewTestCase(unittest.TestCase): | |
21 | ||
22 | def test_long_summarize(self): | |
23 | example_paras = [ | |
24 | """Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami: | |
25 | Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in | |
26 | Troms county, Norway. The administrative centre of the municipality is | |
27 | the city of Tromsø. Outside of Norway, Tromso and Tromsö are | |
28 | alternative spellings of the city.Tromsø is considered the northernmost | |
29 | city in the world with a population above 50,000. The most populous town | |
30 | north of it is Alta, Norway, with a population of 14,272 (2013).""", | |
31 | ||
32 | """Tromsø lies in Northern Norway. The municipality has a population of | |
33 | (2015) 72,066, but with an annual influx of students it has over 75,000 | |
34 | most of the year. It is the largest urban area in Northern Norway and the | |
35 | third largest north of the Arctic Circle (following Murmansk and Norilsk). | |
36 | Most of Tromsø, including the city centre, is located on the island of | |
37 | Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012, | |
38 | Tromsøya had a population of 36,088. Substantial parts of the urban area | |
39 | are also situated on the mainland to the east, and on parts of Kvaløya—a | |
40 | large island to the west. Tromsøya is connected to the mainland by the Tromsø | |
41 | Bridge and the Tromsøysund Tunnel, and to the island of Kvaløya by the | |
42 | Sandnessund Bridge. Tromsø Airport connects the city to many destinations | |
43 | in Europe. The city is warmer than most other places located on the same | |
44 | latitude, due to the warming effect of the Gulf Stream.""", | |
45 | ||
46 | """The city centre of Tromsø contains the highest number of old wooden | |
47 | houses in Northern Norway, the oldest house dating from 1789. The Arctic | |
48 | Cathedral, a modern church from 1965, is probably the most famous landmark | |
49 | in Tromsø. The city is a cultural centre for its region, with several | |
50 | festivals taking place in the summer. Some of Norway's best-known | |
51 | musicians, Torbjørn Brundtland and Svein Berge of the electronica duo | |
52 | Röyksopp and Lene Marlin grew up and started their careers in Tromsø. | |
53 | Noted electronic musician Geir Jenssen also hails from Tromsø.""", | |
54 | ] | |
55 | ||
56 | desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) | |
57 | ||
58 | self.assertEquals( | |
59 | desc, | |
60 | "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" | |
61 | " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" | |
62 | " Troms county, Norway. The administrative centre of the municipality is" | |
63 | " the city of Tromsø. Outside of Norway, Tromso and Tromsö are" | |
64 | " alternative spellings of the city.Tromsø is considered the northernmost" | |
65 | " city in the world with a population above 50,000. The most populous town" | |
66 | " north of it is Alta, Norway, with a population of 14,272 (2013)." | |
67 | ) | |
68 | ||
69 | desc = summarize_paragraphs(example_paras[1:], min_size=200, max_size=500) | |
70 | ||
71 | self.assertEquals( | |
72 | desc, | |
73 | "Tromsø lies in Northern Norway. The municipality has a population of" | |
74 | " (2015) 72,066, but with an annual influx of students it has over 75,000" | |
75 | " most of the year. It is the largest urban area in Northern Norway and the" | |
76 | " third largest north of the Arctic Circle (following Murmansk and Norilsk)." | |
77 | " Most of Tromsø, including the city centre, is located on the island of" | |
78 | " Tromsøya, 350 kilometres (217 mi) north of the Arctic Circle. In 2012," | |
79 | " Tromsøya had a population of 36,088. Substantial parts of the…" | |
80 | ) | |
81 | ||
82 | def test_short_summarize(self): | |
83 | example_paras = [ | |
84 | "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" | |
85 | " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" | |
86 | " Troms county, Norway.", | |
87 | ||
88 | "Tromsø lies in Northern Norway. The municipality has a population of" | |
89 | " (2015) 72,066, but with an annual influx of students it has over 75,000" | |
90 | " most of the year.", | |
91 | ||
92 | "The city centre of Tromsø contains the highest number of old wooden" | |
93 | " houses in Northern Norway, the oldest house dating from 1789. The Arctic" | |
94 | " Cathedral, a modern church from 1965, is probably the most famous landmark" | |
95 | " in Tromsø.", | |
96 | ] | |
97 | ||
98 | desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) | |
99 | ||
100 | self.assertEquals( | |
101 | desc, | |
102 | "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" | |
103 | " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" | |
104 | " Troms county, Norway.\n" | |
105 | "\n" | |
106 | "Tromsø lies in Northern Norway. The municipality has a population of" | |
107 | " (2015) 72,066, but with an annual influx of students it has over 75,000" | |
108 | " most of the year." | |
109 | ) | |
110 | ||
111 | def test_small_then_large_summarize(self): | |
112 | example_paras = [ | |
113 | "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" | |
114 | " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" | |
115 | " Troms county, Norway.", | |
116 | ||
117 | "Tromsø lies in Northern Norway. The municipality has a population of" | |
118 | " (2015) 72,066, but with an annual influx of students it has over 75,000" | |
119 | " most of the year." | |
120 | " The city centre of Tromsø contains the highest number of old wooden" | |
121 | " houses in Northern Norway, the oldest house dating from 1789. The Arctic" | |
122 | " Cathedral, a modern church from 1965, is probably the most famous landmark" | |
123 | " in Tromsø.", | |
124 | ] | |
125 | ||
126 | desc = summarize_paragraphs(example_paras, min_size=200, max_size=500) | |
127 | self.assertEquals( | |
128 | desc, | |
129 | "Tromsø (Norwegian pronunciation: [ˈtrʊmsœ] ( listen); Northern Sami:" | |
130 | " Romsa; Finnish: Tromssa[2] Kven: Tromssa) is a city and municipality in" | |
131 | " Troms county, Norway.\n" | |
132 | "\n" | |
133 | "Tromsø lies in Northern Norway. The municipality has a population of" | |
134 | " (2015) 72,066, but with an annual influx of students it has over 75,000" | |
135 | " most of the year. The city centre of Tromsø contains the highest number" | |
136 | " of old wooden houses in Northern Norway, the oldest house dating from" | |
137 | " 1789. The Arctic Cathedral, a modern church…" | |
138 | ) |
16 | 16 | |
17 | 17 | import logging |
18 | 18 | |
19 | ||
20 | 19 | # logging doesn't have a "don't log anything at all EVARRRR setting, |
21 | 20 | # but since the highest value is 50, 1000000 should do ;) |
22 | 21 | NEVER = 1000000 |
23 | 22 | |
24 | logging.getLogger().addHandler(logging.StreamHandler()) | |
23 | handler = logging.StreamHandler() | |
24 | handler.setFormatter(logging.Formatter( | |
25 | "%(levelname)s:%(name)s:%(message)s [%(pathname)s:%(lineno)d]" | |
26 | )) | |
27 | logging.getLogger().addHandler(handler) | |
25 | 28 | logging.getLogger().setLevel(NEVER) |
29 | logging.getLogger("synapse.storage.SQL").setLevel(NEVER) | |
30 | logging.getLogger("synapse.storage.txn").setLevel(NEVER) | |
26 | 31 | |
27 | 32 | |
28 | 33 | def around(target): |
69 | 74 | return ret |
70 | 75 | |
71 | 76 | logging.getLogger().setLevel(level) |
72 | # Don't set SQL logging | |
73 | logging.getLogger("synapse.storage").setLevel(old_level) | |
74 | 77 | return orig() |
75 | 78 | |
76 | 79 | def assertObjectHasAttributes(self, attrs, obj): |
0 | # -*- coding: utf-8 -*- | |
1 | # Copyright 2016 OpenMarket Ltd | |
2 | # | |
3 | # Licensed under the Apache License, Version 2.0 (the "License"); | |
4 | # you may not use this file except in compliance with the License. | |
5 | # You may obtain a copy of the License at | |
6 | # | |
7 | # http://www.apache.org/licenses/LICENSE-2.0 | |
8 | # | |
9 | # Unless required by applicable law or agreed to in writing, software | |
10 | # distributed under the License is distributed on an "AS IS" BASIS, | |
11 | # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
12 | # See the License for the specific language governing permissions and | |
13 | # limitations under the License. | |
14 | ||
15 | ||
16 | from tests import unittest | |
17 | ||
18 | from synapse.util.async import ReadWriteLock | |
19 | ||
20 | ||
21 | class ReadWriteLockTestCase(unittest.TestCase): | |
22 | ||
23 | def _assert_called_before_not_after(self, lst, first_false): | |
24 | for i, d in enumerate(lst[:first_false]): | |
25 | self.assertTrue(d.called, msg="%d was unexpectedly false" % i) | |
26 | ||
27 | for i, d in enumerate(lst[first_false:]): | |
28 | self.assertFalse( | |
29 | d.called, msg="%d was unexpectedly true" % (i + first_false) | |
30 | ) | |
31 | ||
32 | def test_rwlock(self): | |
33 | rwlock = ReadWriteLock() | |
34 | ||
35 | key = object() | |
36 | ||
37 | ds = [ | |
38 | rwlock.read(key), # 0 | |
39 | rwlock.read(key), # 1 | |
40 | rwlock.write(key), # 2 | |
41 | rwlock.write(key), # 3 | |
42 | rwlock.read(key), # 4 | |
43 | rwlock.read(key), # 5 | |
44 | rwlock.write(key), # 6 | |
45 | ] | |
46 | ||
47 | self._assert_called_before_not_after(ds, 2) | |
48 | ||
49 | with ds[0].result: | |
50 | self._assert_called_before_not_after(ds, 2) | |
51 | self._assert_called_before_not_after(ds, 2) | |
52 | ||
53 | with ds[1].result: | |
54 | self._assert_called_before_not_after(ds, 2) | |
55 | self._assert_called_before_not_after(ds, 3) | |
56 | ||
57 | with ds[2].result: | |
58 | self._assert_called_before_not_after(ds, 3) | |
59 | self._assert_called_before_not_after(ds, 4) | |
60 | ||
61 | with ds[3].result: | |
62 | self._assert_called_before_not_after(ds, 4) | |
63 | self._assert_called_before_not_after(ds, 6) | |
64 | ||
65 | with ds[5].result: | |
66 | self._assert_called_before_not_after(ds, 6) | |
67 | self._assert_called_before_not_after(ds, 6) | |
68 | ||
69 | with ds[4].result: | |
70 | self._assert_called_before_not_after(ds, 6) | |
71 | self._assert_called_before_not_after(ds, 7) | |
72 | ||
73 | with ds[6].result: | |
74 | pass | |
75 | ||
76 | d = rwlock.write(key) | |
77 | self.assertTrue(d.called) | |
78 | with d.result: | |
79 | pass | |
80 | ||
81 | d = rwlock.read(key) | |
82 | self.assertTrue(d.called) | |
83 | with d.result: | |
84 | pass |
19 | 19 | from synapse.storage.engines import create_engine |
20 | 20 | from synapse.server import HomeServer |
21 | 21 | from synapse.federation.transport import server |
22 | from synapse.types import Requester | |
23 | 22 | from synapse.util.ratelimitutils import FederationRateLimiter |
24 | 23 | |
25 | 24 | from synapse.util.logcontext import LoggingContext |
55 | 54 | |
56 | 55 | config.use_frozen_dicts = True |
57 | 56 | config.database_config = {"name": "sqlite3"} |
57 | config.ldap_enabled = False | |
58 | 58 | |
59 | 59 | if "clock" not in kargs: |
60 | 60 | kargs["clock"] = MockClock() |
510 | 510 | "call(%s)" % _format_call(c[0], c[1]) for c in calls |
511 | 511 | ]) |
512 | 512 | ) |
513 | ||
514 | ||
515 | def requester_for_user(user): | |
516 | return Requester(user, None, False) |